package org.apache.s4.comm.tcp;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.util.ArrayDeque;
import java.util.Hashtable;
import java.util.Queue;
import java.util.concurrent.Executors;
import org.apache.s4.base.Emitter;
import org.apache.s4.comm.topology.ClusterNode;
import org.apache.s4.comm.topology.Topology;
import org.apache.s4.comm.topology.TopologyChangeListener;
import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFactory;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.SimpleChannelHandler;
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory;
import org.jboss.netty.handler.codec.frame.LengthFieldPrepender;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.collect.HashBiMap;
import com.google.inject.Inject;
public class TCPEmitter implements Emitter, ChannelFutureListener, TopologyChangeListener {
private static final Logger logger = LoggerFactory.getLogger(TCPEmitter.class);
private static final int BUFFER_SIZE = 10;
private static final int NUM_RETRIES = 10;
private Topology topology;
private final ClientBootstrap bootstrap;
static class MessageQueuesPerPartition {
private Hashtable<Integer, Queue<byte[]>> queues = new Hashtable<Integer, Queue<byte[]>>();
private boolean bounded;
MessageQueuesPerPartition(boolean bounded) {
this.bounded = bounded;
}
private boolean add(int partitionId, byte[] message) {
Queue<byte[]> messages = queues.get(partitionId);
if (messages == null) {
messages = new ArrayDeque<byte[]>();
queues.put(partitionId, messages);
}
if (bounded && messages.size() >= BUFFER_SIZE) {
// Too many messages already queued
return false;
}
messages.offer(message);
return true;
}
private byte[] peek(int partitionId) {
Queue<byte[]> messages = queues.get(partitionId);
try {
return messages.peek();
} catch (NullPointerException npe) {
return null;
}
}
private void remove(int partitionId) {
Queue<byte[]> messages = queues.get(partitionId);
if (messages.isEmpty()) {
logger.error("Trying to remove messages from an empty queue for partition" + partitionId);
return;
}
if (messages != null)
messages.remove();
}
}
private HashBiMap<Integer, Channel> partitionChannelMap;
private HashBiMap<Integer, ClusterNode> partitionNodeMap;
private MessageQueuesPerPartition queuedMessages = new MessageQueuesPerPartition(true);
@Inject
public TCPEmitter(Topology topology) throws InterruptedException {
this.topology = topology;
topology.addListener(this);
int clusterSize = this.topology.getTopology().getNodes().size();
partitionChannelMap = HashBiMap.create(clusterSize);
partitionNodeMap = HashBiMap.create(clusterSize);
ChannelFactory factory = new NioClientSocketChannelFactory(Executors.newCachedThreadPool(),
Executors.newCachedThreadPool());
bootstrap = new ClientBootstrap(factory);
bootstrap.setPipelineFactory(new ChannelPipelineFactory() {
@Override
public ChannelPipeline getPipeline() {
ChannelPipeline p = Channels.pipeline();
p.addLast("1", new LengthFieldPrepender(4));
p.addLast("2", new TestHandler());
return p;
}
});
bootstrap.setOption("tcpNoDelay", true);
bootstrap.setOption("keepAlive", true);
}
private boolean connectTo(Integer partitionId) {
ClusterNode clusterNode = partitionNodeMap.get(partitionId);
if (clusterNode == null) {
clusterNode = topology.getTopology().getNodes().get(partitionId);
partitionNodeMap.forcePut(partitionId, clusterNode);
}
if (clusterNode == null) {
logger.error("No ClusterNode exists for partitionId " + partitionId);
return false;
}
for (int retries = 0; retries < NUM_RETRIES; retries++) {
ChannelFuture f = this.bootstrap.connect(new InetSocketAddress(clusterNode.getMachineName(), clusterNode
.getPort()));
f.awaitUninterruptibly();
if (f.isSuccess()) {
partitionChannelMap.forcePut(partitionId, f.getChannel());
return true;
}
try {
Thread.sleep(10);
} catch (InterruptedException ie) {
logger.error(String.format("Interrupted while connecting to %s:%d", clusterNode.getMachineName(),
clusterNode.getPort()));
}
}
return false;
}
private void writeMessageToChannel(Channel channel, int partitionId, byte[] message) {
ChannelBuffer buffer = ChannelBuffers.buffer(message.length);
buffer.writeBytes(message);
ChannelFuture f = channel.write(buffer);
f.addListener(this);
}
private final Object sendLock = new Object();
@Override
public boolean send(int partitionId, byte[] message) {
Channel channel = partitionChannelMap.get(partitionId);
if (channel == null) {
if (connectTo(partitionId)) {
channel = partitionChannelMap.get(partitionId);
} else {
// could not connect, queue to the partitionBuffer
return queuedMessages.add(partitionId, message);
}
}
/*
* Try limiting the size of the send queue inside Netty
*/
if (!channel.isWritable()) {
synchronized (sendLock) {
// check again now that we have the lock
while (!channel.isWritable()) {
try {
sendLock.wait();
} catch (InterruptedException ie) {
return false;
}
}
}
}
/*
* Channel is available. Write messages in the following order: (1) Messages already on wire, (2) Previously
* buffered messages, and (3) the Current Message
*
* Once the channel returns success delete from the messagesOnTheWire
*/
byte[] messageBeingSent = null;
// while ((messageBeingSent = messagesOnTheWire.peek(partitionId)) != null) {
// writeMessageToChannel(channel, partitionId, messageBeingSent, false);
// }
while ((messageBeingSent = queuedMessages.peek(partitionId)) != null) {
writeMessageToChannel(channel, partitionId, messageBeingSent);
queuedMessages.remove(partitionId);
}
writeMessageToChannel(channel, partitionId, message);
return true;
}
@Override
public void operationComplete(ChannelFuture f) {
int partitionId = partitionChannelMap.inverse().get(f.getChannel());
if (f.isSuccess()) {
// messagesOnTheWire.remove(partitionId);
}
if (f.isCancelled()) {
logger.error("Send I/O was cancelled!! " + f.getChannel().getRemoteAddress());
} else if (!f.isSuccess()) {
logger.error("Exception on I/O operation", f.getCause());
logger.error(String.format("I/O on partition %d failed!", partitionId));
partitionChannelMap.remove(partitionId);
}
}
@Override
public void onChange() {
/*
* Close the channels that correspond to changed partitions and update partitionNodeMap
*/
for (ClusterNode clusterNode : topology.getTopology().getNodes()) {
Integer partition = clusterNode.getPartition();
ClusterNode oldNode = partitionNodeMap.get(partition);
if (oldNode != null && !oldNode.equals(clusterNode)) {
partitionChannelMap.remove(partition).close();
}
partitionNodeMap.forcePut(partition, clusterNode);
}
}
@Override
public int getPartitionCount() {
// Number of nodes is not same as number of partitions
return topology.getTopology().getPartitionCount();
}
class TestHandler extends SimpleChannelHandler {
@Override
public void channelInterestChanged(ChannelHandlerContext ctx, ChannelStateEvent e) {
// logger.info(String.format("%08x %08x %08x", e.getValue(),
// e.getChannel().getInterestOps(), Channel.OP_WRITE));
synchronized (sendLock) {
if (e.getChannel().isWritable()) {
sendLock.notify();
}
}
ctx.sendUpstream(e);
}
@Override
public void exceptionCaught(ChannelHandlerContext context, ExceptionEvent event) {
Integer partitionId = partitionChannelMap.inverse().get(context.getChannel());
if (partitionId == null) {
logger.error("Error on mystery channel!!");
}
logger.error("Error on channel to partition " + partitionId);
try {
throw event.getCause();
} catch (ConnectException ce) {
logger.error(ce.getMessage(), ce);
} catch (Throwable err) {
logger.error("Error", err);
if (context.getChannel().isOpen()) {
logger.error("Closing channel due to exception");
context.getChannel().close();
}
}
}
}
}