Package com.facebook.swift.service

Source Code of com.facebook.swift.service.ThriftMethodHandler$AsyncMethodCallFuture

/*
* Copyright (C) 2012 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License. You may obtain
* a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.facebook.swift.service;

import com.facebook.nifty.client.RequestChannel;
import com.facebook.nifty.core.RequestContext;
import com.facebook.nifty.core.RequestContexts;
import com.facebook.nifty.core.TChannelBufferInputTransport;
import com.facebook.nifty.core.TChannelBufferOutputTransport;
import com.facebook.swift.codec.ThriftCodec;
import com.facebook.swift.codec.ThriftCodecManager;
import com.facebook.swift.codec.internal.TProtocolReader;
import com.facebook.swift.codec.internal.TProtocolWriter;
import com.facebook.swift.codec.metadata.ThriftFieldMetadata;
import com.facebook.swift.codec.metadata.ThriftParameterInjection;
import com.facebook.swift.codec.metadata.ThriftType;
import com.facebook.swift.service.metadata.ThriftMethodMetadata;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.jboss.netty.buffer.ChannelBuffer;
import org.weakref.jmx.Managed;

import javax.annotation.Nullable;
import javax.annotation.concurrent.ThreadSafe;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;

import static org.apache.thrift.TApplicationException.BAD_SEQUENCE_ID;
import static org.apache.thrift.TApplicationException.INVALID_MESSAGE_TYPE;
import static org.apache.thrift.TApplicationException.WRONG_METHOD_NAME;
import static org.apache.thrift.protocol.TMessageType.CALL;
import static org.apache.thrift.protocol.TMessageType.EXCEPTION;
import static org.apache.thrift.protocol.TMessageType.ONEWAY;
import static org.apache.thrift.protocol.TMessageType.REPLY;

@ThreadSafe
public class ThriftMethodHandler
{
    private final String name;
    private final String qualifiedName;
    private final List<ParameterHandler> parameterCodecs;
    private final ThriftCodec<Object> successCodec;
    private final Map<Short, ThriftCodec<Object>> exceptionCodecs;
    private final boolean oneway;
    private static final Executor SAME_THREAD_EXECUTOR = MoreExecutors.sameThreadExecutor();

    private final boolean invokeAsynchronously;

    public ThriftMethodHandler(ThriftMethodMetadata methodMetadata, ThriftCodecManager codecManager)
    {
        name = methodMetadata.getName();
        qualifiedName = methodMetadata.getQualifiedName();
        invokeAsynchronously = methodMetadata.isAsync();

        oneway = methodMetadata.getOneway();

        // get the thrift codecs for the parameters
        ParameterHandler[] parameters = new ParameterHandler[methodMetadata.getParameters().size()];
        for (ThriftFieldMetadata fieldMetadata : methodMetadata.getParameters()) {
            ThriftParameterInjection parameter = (ThriftParameterInjection) fieldMetadata.getInjections().get(0);

            ParameterHandler handler = new ParameterHandler(
                    fieldMetadata.getId(),
                    fieldMetadata.getName(),
                    (ThriftCodec<Object>) codecManager.getCodec(fieldMetadata.getThriftType()));

            parameters[parameter.getParameterIndex()] = handler;
        }
        parameterCodecs = ImmutableList.copyOf(parameters);

        // get the thrift codecs for the exceptions
        ImmutableMap.Builder<Short, ThriftCodec<Object>> exceptions = ImmutableMap.builder();
        for (Map.Entry<Short, ThriftType> entry : methodMetadata.getExceptions().entrySet()) {
            exceptions.put(entry.getKey(), (ThriftCodec<Object>) codecManager.getCodec(entry.getValue()));
        }
        exceptionCodecs = exceptions.build();

        // get the thrift codec for the return value
        successCodec = (ThriftCodec<Object>) codecManager.getCodec(methodMetadata.getReturnType());
    }

    @Managed
    public String getName()
    {
        return name;
    }

    public String getQualifiedName()
    {
        return qualifiedName;
    }

    public Object invoke(
            final RequestChannel channel,
            final TChannelBufferInputTransport inputTransport,
            final TChannelBufferOutputTransport outputTransport,
            final TProtocol inputProtocol,
            final TProtocol outputProtocol,
            final int sequenceId,
            final ClientContextChain contextChain,
            final Object... args)
            throws Exception
    {
        if (invokeAsynchronously)
        {
            // This method declares a Future return value: run it asynchronously
            return asynchronousInvoke(channel, inputTransport, outputTransport, inputProtocol, outputProtocol, sequenceId, contextChain, args);
        }
        else
        {
            try {
                // This method declares an immediate return value: run it synchronously
                return synchronousInvoke(channel, inputTransport, outputTransport, inputProtocol, outputProtocol, sequenceId, contextChain, args);
            }
            finally {
                contextChain.done();
            }
        }
    }

    private Object synchronousInvoke(
            RequestChannel channel,
            TChannelBufferInputTransport inputTransport,
            TChannelBufferOutputTransport outputTransport,
            TProtocol inputProtocol,
            TProtocol outputProtocol,
            int sequenceId,
            ClientContextChain contextChain,
            Object[] args)
            throws Exception
    {
        Object results = null;

        // write request
        contextChain.preWrite(args);
        outputTransport.resetOutputBuffer();
        writeArguments(outputProtocol, sequenceId, args);
        // Don't need to copy the output buffer for sync case
        ChannelBuffer requestBuffer = outputTransport.getOutputBuffer();
        contextChain.postWrite(args);

        if (!this.oneway) {
            ChannelBuffer responseBuffer;

            try {
                responseBuffer = SyncClientHelpers.sendSynchronousTwoWayMessage(channel, requestBuffer);
            } catch (Exception e) {
                contextChain.preReadException(e);
                throw e;
            }

            // read results
            contextChain.preRead();
            try {
                inputTransport.setInputBuffer(responseBuffer);
                waitForResponse(inputProtocol, sequenceId);
                results = readResponse(inputProtocol);
                contextChain.postRead(results);
            } catch (Exception e) {
                contextChain.postReadException(e);
                throw e;
            }
        } else {
            try {
                SyncClientHelpers.sendSynchronousOneWayMessage(channel, requestBuffer);
            } catch (Exception e) {
                throw e;
            }
        }

        return results;
    }

    public ListenableFuture<Object> asynchronousInvoke(
            final RequestChannel channel,
            final TChannelBufferInputTransport inputTransport,
            final TChannelBufferOutputTransport outputTransport,
            final TProtocol inputProtocol,
            final TProtocol outputProtocol,
            final int sequenceId,
            final ClientContextChain contextChain,
            final Object[] args)
        throws Exception
    {
        final AsyncMethodCallFuture<Object> future = AsyncMethodCallFuture.create(contextChain);
        final RequestContext requestContext = RequestContexts.getCurrentContext();

        contextChain.preWrite(args);
        outputTransport.resetOutputBuffer();
        writeArguments(outputProtocol, sequenceId, args);
        ChannelBuffer requestBuffer = outputTransport.getOutputBuffer().copy();
        contextChain.postWrite(args);

        // send message and setup listener to handle the response
        channel.sendAsynchronousRequest(requestBuffer, false, new RequestChannel.Listener() {
            @Override
            public void onRequestSent() {
                if (oneway) {
                    try {
                        future.set(null);
                    }
                    catch (Exception e) {
                        future.setException(e);
                    }
                }
            }

            @Override
            public void onResponseReceived(ChannelBuffer message) {
                RequestContext oldRequestContext = RequestContexts.getCurrentContext();
                RequestContexts.setCurrentContext(requestContext);
                try {
                    contextChain.preRead();
                    inputTransport.setInputBuffer(message);
                    waitForResponse(inputProtocol, sequenceId);
                    Object results = readResponse(inputProtocol);
                    contextChain.postRead(results);
                    future.set(results);
                }
                catch (Exception e) {
                    contextChain.postReadException(e);
                    future.setException(e);
                }
                finally {
                    RequestContexts.setCurrentContext(oldRequestContext);
                }
            }

            @Override
            public void onChannelError(TException e) {
                RequestContext oldRequestContext = RequestContexts.getCurrentContext();
                RequestContexts.setCurrentContext(requestContext);
                try {
                    contextChain.preReadException(e);
                    future.setException(e);
                } finally {
                    RequestContexts.setCurrentContext(oldRequestContext);
                }
            }
        });

        return future;
    }

    private Object readResponse(TProtocol in)
            throws Exception
    {
        TProtocolReader reader = new TProtocolReader(in);
        reader.readStructBegin();
        Object results = null;
        Exception exception = null;
        while (reader.nextField()) {
            if (reader.getFieldId() == 0) {
                results = reader.readField(successCodec);
            }
            else {
                ThriftCodec<Object> exceptionCodec = exceptionCodecs.get(reader.getFieldId());
                if (exceptionCodec != null) {
                    exception = (Exception) reader.readField(exceptionCodec);
                }
                else {
                    reader.skipFieldData();
                }
            }
        }
        reader.readStructEnd();
        in.readMessageEnd();

        if (exception != null) {
            throw exception;
        }

        if (successCodec.getType() == ThriftType.VOID) {
            // TODO: check for non-null return from a void function?
            return null;
        }

        if (results == null) {
            throw new TApplicationException(TApplicationException.MISSING_RESULT, name + " failed: unknown result");
        }
        return results;
    }

    private void writeArguments(TProtocol out, int sequenceId, Object[] args)
            throws Exception
    {
        // Note that though setting message type to ONEWAY can be helpful when looking at packet
        // captures, some clients always send CALL and so servers are forced to rely on the "oneway"
        // attribute on thrift method in the interface definition, rather than checking the message
        // type.
        out.writeMessageBegin(new TMessage(name, oneway ? ONEWAY : CALL, sequenceId));

        // write the parameters
        TProtocolWriter writer = new TProtocolWriter(out);
        writer.writeStructBegin(name + "_args");
        for (int i = 0; i < args.length; i++) {
            Object value = args[i];
            ParameterHandler parameter = parameterCodecs.get(i);
            writer.writeField(parameter.getName(), parameter.getId(), parameter.getCodec(), value);
        }
        writer.writeStructEnd();

        out.writeMessageEnd();
        out.getTransport().flush();
    }

    private void waitForResponse(TProtocol in, int sequenceId)
            throws TException
    {
        TMessage message = in.readMessageBegin();
        if (message.type == EXCEPTION) {
            TApplicationException exception = TApplicationException.read(in);
            in.readMessageEnd();
            throw exception;
        }
        if (message.type != REPLY) {
            throw new TApplicationException(INVALID_MESSAGE_TYPE,
                                            "Received invalid message type " + message.type + " from server");
        }
        if (!message.name.equals(this.name)) {
            throw new TApplicationException(WRONG_METHOD_NAME,
                                            "Wrong method name in reply: expected " + this.name + " but received " + message.name);
        }
        if (message.seqid != sequenceId) {
            throw new TApplicationException(BAD_SEQUENCE_ID, name + " failed: out of sequence response");
        }
    }

    private static final class ParameterHandler
    {
        private final short id;
        private final String name;
        private final ThriftCodec<Object> codec;

        private ParameterHandler(short id, String name, ThriftCodec<Object> codec)
        {
            this.id = id;
            this.name = name;
            this.codec = codec;
        }

        public short getId()
        {
            return id;
        }

        public String getName()
        {
            return name;
        }

        public ThriftCodec<Object> getCodec()
        {
            return codec;
        }
    }

    private static final class AsyncMethodCallFuture<T> extends AbstractFuture<T>
    {
        private final ClientContextChain contextChain;

        public static <T> AsyncMethodCallFuture<T> create(ClientContextChain contextChain)
        {
            return new AsyncMethodCallFuture<>(contextChain);
        }

        private AsyncMethodCallFuture(ClientContextChain contextChain) {
            this.contextChain = contextChain;
        }

        @Override
        public boolean set(@Nullable T value)
        {
            contextChain.done();
            return super.set(value);
        }

        @Override
        public boolean setException(Throwable throwable)
        {
            contextChain.done();
            return super.setException(throwable);
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning)
        {
            // Async call futures represent requests running on some other service,
            // there is no way to cancel the request once it has been sent.
            return false;
        }
    }
}
TOP

Related Classes of com.facebook.swift.service.ThriftMethodHandler$AsyncMethodCallFuture

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.