package co.tomlee.nifty;
import com.facebook.nifty.client.FramedClientConnector;
import com.facebook.nifty.client.TNiftyClientChannelTransport;
import com.facebook.nifty.duplex.TDuplexProtocolFactory;
import com.google.common.net.HostAndPort;
import org.apache.commons.pool2.BaseKeyedPooledObjectFactory;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.commons.pool2.impl.GenericKeyedObjectPool;
import org.apache.thrift.TServiceClient;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Closeable;
import java.lang.reflect.Constructor;
import java.net.InetSocketAddress;
import java.time.Duration;
/**
* Connection pool for Nifty client connections.
*/
public final class TNiftyClientChannelTransportPool implements Closeable {
private static final Logger log = LoggerFactory.getLogger(TNiftyClientChannelTransportPool.class);
private final GenericKeyedObjectPool<InetSocketAddress, TNiftyClientChannelTransport> pool;
private final TProtocolFactory protocolFactory;
private final Constructor<? extends TServiceClient> clientConstructor;
/**
* Initialize a new connection pool with the specified configuration.
*
* @param config the connection pool configuration
*/
@SuppressWarnings("unchecked")
public TNiftyClientChannelTransportPool(final TNiftyClientChannelTransportPoolConfig config) {
config.validate();
this.protocolFactory = config.protocolFactory;
try {
this.clientConstructor = config.clientClass.getConstructor(TProtocol.class);
}
catch (Exception e) {
throw new IllegalStateException(e);
}
final TDuplexProtocolFactory duplexProtocolFactory =
TDuplexProtocolFactory.fromSingleFactory(config.protocolFactory);
pool = new GenericKeyedObjectPool<>(new BaseKeyedPooledObjectFactory<InetSocketAddress, TNiftyClientChannelTransport>() {
@Override
public TNiftyClientChannelTransport create(InetSocketAddress key) throws Exception {
log.debug("Connecting to {}", key);
final FramedClientConnector connector = new FramedClientConnector(key, duplexProtocolFactory);
final TNiftyClientChannelTransport transport =
config.niftyClient.connectSync(
config.clientClass,
connector,
config.connectTimeout,
config.receiveTimeout,
config.readTimeout,
config.sendTimeout,
config.maxFrameSize);
log.debug("Connected to {}", key);
return transport;
}
@Override
public PooledObject<TNiftyClientChannelTransport> wrap(TNiftyClientChannelTransport value) {
return new DefaultPooledObject<>(value);
}
@Override
public boolean validateObject(InetSocketAddress key, PooledObject<TNiftyClientChannelTransport> p) {
if (config.checkTransport != null) {
log.debug("Validating connection to {}", key);
return config.checkTransport.validate(p.getObject());
}
return true;
}
@Override
public void destroyObject(InetSocketAddress key, PooledObject<TNiftyClientChannelTransport> p) throws Exception {
log.debug("Closing connection to " + key);
p.getObject().close();
}
}, config.poolConfig);
}
Constructor<? extends TServiceClient> clientConstructor() {
return clientConstructor;
}
/**
* Close the connection pool.
* <p>
* Note that this does not automatically close all borrowed connections.
*/
@Override
public void close() {
pool.close();
}
/**
* Get a connection to the remote Nifty endpoint identified by the given hostname/IP & port.
*
* @param host the hostname or IP of the remote host
* @param port the remote port
* @return the new connection wrapped in a Thrift transport
* @throws Exception
*/
public TTransport getTransport(final String host, final int port) throws Exception {
return getTransport(InetSocketAddress.createUnresolved(host, port));
}
/**
* Get a connection to the remote Nifty endpoint identified by the given hostname & port.
*
* @param hostAndPort the hostname/IP and port of the remote endpoint
* @return the new connection wrapped in a Thrift transport
* @throws Exception
*/
public TTransport getTransport(final HostAndPort hostAndPort) throws Exception {
return getTransport(InetSocketAddress.createUnresolved(hostAndPort.getHostText(), hostAndPort.getPort()));
}
/**
* Get a connection to the remote Nifty endpoint identified by the given hostname & port
* and borrow timeout.
*
* @param hostAndPort the remote hostname/IP and port of the remote endpoint
* @param timeout the borrow timeout
* @return the new connection wrapped in a Thrift transport
* @throws Exception
*/
public TTransport getTransport(final HostAndPort hostAndPort, final Duration timeout) throws Exception {
return getTransport(InetSocketAddress.createUnresolved(hostAndPort.getHostText(), hostAndPort.getPort()), timeout);
}
/**
* Get a connection to the remote Nifty endpoint identified by the given hostname & port
* and borrow timeout.
*
* @param socketAddress the remote hostname/IP and port of the remote endpoint
* @return the new connection wrapped in a Thrift transport
* @throws Exception
*/
public TTransport getTransport(final InetSocketAddress socketAddress) throws Exception {
final TNiftyClientChannelTransport transport = pool.borrowObject(socketAddress);
return new TTransportShell(socketAddress, transport);
}
/**
* Get a connection to the remote Nifty endpoint identified by the given socket address
* and borrow timeout.
*
* @param socketAddress the remote socket address
* @param timeout the borrow timeout
* @return the new connection wrapped in a Thrift transport
* @throws Exception
*/
public TTransport getTransport(final InetSocketAddress socketAddress, final Duration timeout) throws Exception {
final TNiftyClientChannelTransport transport = pool.borrowObject(socketAddress, timeout.toMillis());
return new TTransportShell(socketAddress, transport);
}
TProtocolFactory protocolFactory() {
return protocolFactory;
}
private final class TTransportShell extends TTransport {
private final InetSocketAddress socketAddress;
private final TNiftyClientChannelTransport transport;
public TTransportShell(final InetSocketAddress socketAddress, final TNiftyClientChannelTransport transport) {
this.socketAddress = socketAddress;
this.transport = transport;
}
@Override
public boolean isOpen() {
return transport.isOpen();
}
@Override
public void open() throws TTransportException {
transport.open();
}
@Override
public void close() {
pool.returnObject(socketAddress, transport);
}
@Override
public int read(byte[] bytes, int offset, int size) throws TTransportException {
log.debug("Waiting to receive up to {} bytes from {}", size, socketAddress);
final int bytesRead = transport.read(bytes, offset, size);
log.debug("Received {} bytes from {}", bytesRead, socketAddress);
return bytesRead;
}
@Override
public void write(byte[] bytes, int offset, int size) throws TTransportException {
log.debug("Writing {} bytes to {}", size, socketAddress);
transport.write(bytes, offset, size);
}
@Override
public void flush() throws TTransportException {
log.debug("Flushing write buffer to {}", socketAddress);
transport.flush();
}
@Override
public boolean peek() {
return transport.peek();
}
@Override
public int readAll(byte[] buf, int off, int len) throws TTransportException {
log.debug("Waiting to receive all {} bytes from {}", len, socketAddress);
return transport.readAll(buf, off, len);
}
@Override
public void write(byte[] buf) throws TTransportException {
log.debug("Writing {} bytes to {}", buf.length, socketAddress);
transport.write(buf);
}
@Override
public byte[] getBuffer() {
return transport.getBuffer();
}
@Override
public int getBufferPosition() {
return transport.getBufferPosition();
}
@Override
public int getBytesRemainingInBuffer() {
return transport.getBytesRemainingInBuffer();
}
@Override
public void consumeBuffer(int len) {
transport.consumeBuffer(len);
}
}
}