/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.server;
import java.util.Set;
import com.google.common.base.Throwables;
import com.google.common.collect.Sets;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.Encodable;
import org.apache.spark.network.protocol.RequestMessage;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.util.NettyUtils;
/**
* A handler that processes requests from clients and writes chunk data back. Each handler is
* attached to a single Netty channel, and keeps track of which streams have been fetched via this
* channel, in order to clean them up if the channel is terminated (see #channelUnregistered).
*
* The messages should have been processed by the pipeline setup by {@link TransportServer}.
*/
public class TransportRequestHandler extends MessageHandler<RequestMessage> {
private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class);
/** The Netty channel that this handler is associated with. */
private final Channel channel;
/** Client on the same channel allowing us to talk back to the requester. */
private final TransportClient reverseClient;
/** Handles all RPC messages. */
private final RpcHandler rpcHandler;
/** Returns each chunk part of a stream. */
private final StreamManager streamManager;
/** List of all stream ids that have been read on this handler, used for cleanup. */
private final Set<Long> streamIds;
public TransportRequestHandler(
Channel channel,
TransportClient reverseClient,
RpcHandler rpcHandler) {
this.channel = channel;
this.reverseClient = reverseClient;
this.rpcHandler = rpcHandler;
this.streamManager = rpcHandler.getStreamManager();
this.streamIds = Sets.newHashSet();
}
@Override
public void exceptionCaught(Throwable cause) {
}
@Override
public void channelUnregistered() {
// Inform the StreamManager that these streams will no longer be read from.
for (long streamId : streamIds) {
streamManager.connectionTerminated(streamId);
}
rpcHandler.connectionTerminated(reverseClient);
}
@Override
public void handle(RequestMessage request) {
if (request instanceof ChunkFetchRequest) {
processFetchRequest((ChunkFetchRequest) request);
} else if (request instanceof RpcRequest) {
processRpcRequest((RpcRequest) request);
} else {
throw new IllegalArgumentException("Unknown request type: " + request);
}
}
private void processFetchRequest(final ChunkFetchRequest req) {
final String client = NettyUtils.getRemoteAddress(channel);
streamIds.add(req.streamChunkId.streamId);
logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId);
ManagedBuffer buf;
try {
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format(
"Error opening block %s for request from %s", req.streamChunkId, client), e);
respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
return;
}
respond(new ChunkFetchSuccess(req.streamChunkId, buf));
}
private void processRpcRequest(final RpcRequest req) {
try {
rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() {
@Override
public void onSuccess(byte[] response) {
respond(new RpcResponse(req.requestId, response));
}
@Override
public void onFailure(Throwable e) {
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
}
});
} catch (Exception e) {
logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e);
respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e)));
}
}
/**
* Responds to a single message with some Encodable object. If a failure occurs while sending,
* it will be logged and the channel closed.
*/
private void respond(final Encodable result) {
final String remoteAddress = channel.remoteAddress().toString();
channel.writeAndFlush(result).addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
logger.trace(String.format("Sent result %s to client %s", result, remoteAddress));
} else {
logger.error(String.format("Error sending result %s to %s; closing connection",
result, remoteAddress), future.cause());
channel.close();
}
}
}
);
}
}