Package org.apache.hadoop.ipc

Source Code of org.apache.hadoop.ipc.TestIPC$IOEOnWriteWritable

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.ipc;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;

import java.io.ByteArrayOutputStream;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

import javax.net.SocketFactory;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.retry.RetryPolicies;
import org.apache.hadoop.io.retry.RetryProxy;
import org.apache.hadoop.ipc.Client.ConnectionId;
import org.apache.hadoop.ipc.RPC.RpcKind;
import org.apache.hadoop.ipc.Server.Connection;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.net.ConnectTimeoutException;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.util.StringUtils;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import com.google.common.primitives.Bytes;
import com.google.common.primitives.Ints;

/** Unit tests for IPC. */
public class TestIPC {
  public static final Log LOG =
    LogFactory.getLog(TestIPC.class);
 
  final private static Configuration conf = new Configuration();
  final static private int PING_INTERVAL = 1000;
  final static private int MIN_SLEEP_TIME = 1000;
  /**
   * Flag used to turn off the fault injection behavior
   * of the various writables.
   **/
  static boolean WRITABLE_FAULTS_ENABLED = true;
  static int WRITABLE_FAULTS_SLEEP = 0;
 
  static {
    Client.setPingInterval(conf, PING_INTERVAL);
  }

  private static final Random RANDOM = new Random();

  private static final String ADDRESS = "0.0.0.0";

  /** Directory where we can count open file descriptors on Linux */
  private static final File FD_DIR = new File("/proc/self/fd");

  private static class TestServer extends Server {
    // Tests can set callListener to run a piece of code each time the server
    // receives a call.  This code executes on the server thread, so it has
    // visibility of that thread's thread-local storage.
    private Runnable callListener;
    private boolean sleep;
    private Class<? extends Writable> responseClass;

    public TestServer(int handlerCount, boolean sleep) throws IOException {
      this(handlerCount, sleep, LongWritable.class, null);
    }
   
    public TestServer(int handlerCount, boolean sleep,
        Class<? extends Writable> paramClass,
        Class<? extends Writable> responseClass)
      throws IOException {
      super(ADDRESS, 0, paramClass, handlerCount, conf);
      this.sleep = sleep;
      this.responseClass = responseClass;
    }

    @Override
    public Writable call(RPC.RpcKind rpcKind, String protocol, Writable param,
        long receiveTime) throws IOException {
      if (sleep) {
        // sleep a bit
        try {
          Thread.sleep(RANDOM.nextInt(PING_INTERVAL) + MIN_SLEEP_TIME);
        } catch (InterruptedException e) {}
      }
      if (callListener != null) {
        callListener.run();
      }
      if (responseClass != null) {
        try {
          return responseClass.newInstance();
        } catch (Exception e) {
          throw new RuntimeException(e);
       
      } else {
        return param;                               // echo param as result
      }
    }
  }

  private static class SerialCaller extends Thread {
    private Client client;
    private InetSocketAddress server;
    private int count;
    private boolean failed;

    public SerialCaller(Client client, InetSocketAddress server, int count) {
      this.client = client;
      this.server = server;
      this.count = count;
    }

    @Override
    public void run() {
      for (int i = 0; i < count; i++) {
        try {
          LongWritable param = new LongWritable(RANDOM.nextLong());
          LongWritable value =
            (LongWritable)client.call(param, server, null, null, 0, conf);
          if (!param.equals(value)) {
            LOG.fatal("Call failed!");
            failed = true;
            break;
          }
        } catch (Exception e) {
          LOG.fatal("Caught: " + StringUtils.stringifyException(e));
          failed = true;
        }
      }
    }
  }

  /**
   * A RpcInvocationHandler instance for test. Its invoke function uses the same
   * {@link Client} instance, and will fail the first totalRetry times (by
   * throwing an IOException).
   */
  private static class TestInvocationHandler implements RpcInvocationHandler {
    private static int retry = 0;
    private final Client client;
    private final Server server;
    private final int total;
   
    TestInvocationHandler(Client client, Server server, int total) {
      this.client = client;
      this.server = server;
      this.total = total;
    }
   
    @Override
    public Object invoke(Object proxy, Method method, Object[] args)
        throws Throwable {
      LongWritable param = new LongWritable(RANDOM.nextLong());
      LongWritable value = (LongWritable) client.call(param,
          NetUtils.getConnectAddress(server), null, null, 0, conf);
      if (retry++ < total) {
        throw new IOException("Fake IOException");
      } else {
        return value;
      }
    }

    @Override
    public void close() throws IOException {}
   
    @Override
    public ConnectionId getConnectionId() {
      return null;
    }
  }
 
  @Test(timeout=60000)
  public void testSerial() throws IOException, InterruptedException {
    internalTestSerial(3, false, 2, 5, 100);
    internalTestSerial(3, true, 2, 5, 10);
  }

  public void internalTestSerial(int handlerCount, boolean handlerSleep,
                         int clientCount, int callerCount, int callCount)
    throws IOException, InterruptedException {
    Server server = new TestServer(handlerCount, handlerSleep);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();

    Client[] clients = new Client[clientCount];
    for (int i = 0; i < clientCount; i++) {
      clients[i] = new Client(LongWritable.class, conf);
    }
   
    SerialCaller[] callers = new SerialCaller[callerCount];
    for (int i = 0; i < callerCount; i++) {
      callers[i] = new SerialCaller(clients[i%clientCount], addr, callCount);
      callers[i].start();
    }
    for (int i = 0; i < callerCount; i++) {
      callers[i].join();
      assertFalse(callers[i].failed);
    }
    for (int i = 0; i < clientCount; i++) {
      clients[i].stop();
    }
    server.stop();
  }
 
  @Test(timeout=60000)
  public void testStandAloneClient() throws IOException {
    Client client = new Client(LongWritable.class, conf);
    InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
    try {
      client.call(new LongWritable(RANDOM.nextLong()),
              address, null, null, 0, conf);
      fail("Expected an exception to have been thrown");
    } catch (IOException e) {
      String message = e.getMessage();
      String addressText = address.getHostName() + ":" + address.getPort();
      assertTrue("Did not find "+addressText+" in "+message,
              message.contains(addressText));
      Throwable cause=e.getCause();
      assertNotNull("No nested exception in "+e,cause);
      String causeText=cause.getMessage();
      assertTrue("Did not find " + causeText + " in " + message,
              message.contains(causeText));
    }
  }
 
  static void maybeThrowIOE() throws IOException {
    if (WRITABLE_FAULTS_ENABLED) {
      maybeSleep();
      throw new IOException("Injected fault");
    }
  }

  static void maybeThrowRTE() {
    if (WRITABLE_FAULTS_ENABLED) {
      maybeSleep();
      throw new RuntimeException("Injected fault");
    }
  }

  private static void maybeSleep() {
    if (WRITABLE_FAULTS_SLEEP > 0) {
      try {
        Thread.sleep(WRITABLE_FAULTS_SLEEP);
      } catch (InterruptedException ie) {
      }
    }
  }

  @SuppressWarnings("unused")
  private static class IOEOnReadWritable extends LongWritable {
    public IOEOnReadWritable() {}

    @Override
    public void readFields(DataInput in) throws IOException {
      super.readFields(in);
      maybeThrowIOE();
    }
  }
 
  @SuppressWarnings("unused")
  private static class RTEOnReadWritable extends LongWritable {
    public RTEOnReadWritable() {}
   
    @Override
    public void readFields(DataInput in) throws IOException {
      super.readFields(in);
      maybeThrowRTE();
    }
  }
 
  @SuppressWarnings("unused")
  private static class IOEOnWriteWritable extends LongWritable {
    public IOEOnWriteWritable() {}

    @Override
    public void write(DataOutput out) throws IOException {
      super.write(out);
      maybeThrowIOE();
    }
  }

  @SuppressWarnings("unused")
  private static class RTEOnWriteWritable extends LongWritable {
    public RTEOnWriteWritable() {}

    @Override
    public void write(DataOutput out) throws IOException {
      super.write(out);
      maybeThrowRTE();
    }
  }
 
  /**
   * Generic test case for exceptions thrown at some point in the IPC
   * process.
   *
   * @param clientParamClass - client writes this writable for parameter
   * @param serverParamClass - server reads this writable for parameter
   * @param serverResponseClass - server writes this writable for response
   * @param clientResponseClass - client reads this writable for response
   */
  private void doErrorTest(
      Class<? extends LongWritable> clientParamClass,
      Class<? extends LongWritable> serverParamClass,
      Class<? extends LongWritable> serverResponseClass,
      Class<? extends LongWritable> clientResponseClass)
      throws IOException, InstantiationException, IllegalAccessException {
   
    // start server
    Server server = new TestServer(1, false,
        serverParamClass, serverResponseClass);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();

    // start client
    WRITABLE_FAULTS_ENABLED = true;
    Client client = new Client(clientResponseClass, conf);
    try {
      LongWritable param = clientParamClass.newInstance();

      try {
        client.call(param, addr, null, null, 0, conf);
        fail("Expected an exception to have been thrown");
      } catch (Throwable t) {
        assertExceptionContains(t, "Injected fault");
      }
     
      // Doing a second call with faults disabled should return fine --
      // ie the internal state of the client or server should not be broken
      // by the failed call
      WRITABLE_FAULTS_ENABLED = false;
      client.call(param, addr, null, null, 0, conf);
     
    } finally {
      server.stop();
    }
  }

  @Test(timeout=60000)
  public void testIOEOnClientWriteParam() throws Exception {
    doErrorTest(IOEOnWriteWritable.class,
        LongWritable.class,
        LongWritable.class,
        LongWritable.class);
  }
 
  @Test(timeout=60000)
  public void testRTEOnClientWriteParam() throws Exception {
    doErrorTest(RTEOnWriteWritable.class,
        LongWritable.class,
        LongWritable.class,
        LongWritable.class);
  }

  @Test(timeout=60000)
  public void testIOEOnServerReadParam() throws Exception {
    doErrorTest(LongWritable.class,
        IOEOnReadWritable.class,
        LongWritable.class,
        LongWritable.class);
  }
 
  @Test(timeout=60000)
  public void testRTEOnServerReadParam() throws Exception {
    doErrorTest(LongWritable.class,
        RTEOnReadWritable.class,
        LongWritable.class,
        LongWritable.class);
  }

 
  @Test(timeout=60000)
  public void testIOEOnServerWriteResponse() throws Exception {
    doErrorTest(LongWritable.class,
        LongWritable.class,
        IOEOnWriteWritable.class,
        LongWritable.class);
  }
 
  @Test(timeout=60000)
  public void testRTEOnServerWriteResponse() throws Exception {
    doErrorTest(LongWritable.class,
        LongWritable.class,
        RTEOnWriteWritable.class,
        LongWritable.class);
  }
 
  @Test(timeout=60000)
  public void testIOEOnClientReadResponse() throws Exception {
    doErrorTest(LongWritable.class,
        LongWritable.class,
        LongWritable.class,
        IOEOnReadWritable.class);
  }
 
  @Test(timeout=60000)
  public void testRTEOnClientReadResponse() throws Exception {
    doErrorTest(LongWritable.class,
        LongWritable.class,
        LongWritable.class,
        RTEOnReadWritable.class);
  }
 
  /**
   * Test case that fails a write, but only after taking enough time
   * that a ping should have been sent. This is a reproducer for a
   * deadlock seen in one iteration of HADOOP-6762.
   */
  @Test(timeout=60000)
  public void testIOEOnWriteAfterPingClient() throws Exception {
    // start server
    Client.setPingInterval(conf, 100);

    try {
      WRITABLE_FAULTS_SLEEP = 1000;
      doErrorTest(IOEOnWriteWritable.class,
          LongWritable.class,
          LongWritable.class,
          LongWritable.class);
    } finally {
      WRITABLE_FAULTS_SLEEP = 0;
    }
  }
 
  private static void assertExceptionContains(
      Throwable t, String substring) {
    String msg = StringUtils.stringifyException(t);
    assertTrue("Exception should contain substring '" + substring + "':\n" +
        msg, msg.contains(substring));
    LOG.info("Got expected exception", t);
  }
 
  /**
   * Test that, if the socket factory throws an IOE, it properly propagates
   * to the client.
   */
  @Test(timeout=60000)
  public void testSocketFactoryException() throws IOException {
    SocketFactory mockFactory = mock(SocketFactory.class);
    doThrow(new IOException("Injected fault")).when(mockFactory).createSocket();
    Client client = new Client(LongWritable.class, conf, mockFactory);
   
    InetSocketAddress address = new InetSocketAddress("127.0.0.1", 10);
    try {
      client.call(new LongWritable(RANDOM.nextLong()),
              address, null, null, 0, conf);
      fail("Expected an exception to have been thrown");
    } catch (IOException e) {
      assertTrue(e.getMessage().contains("Injected fault"));
    }
  }

  /**
   * Test that, if a RuntimeException is thrown after creating a socket
   * but before successfully connecting to the IPC server, that the
   * failure is handled properly. This is a regression test for
   * HADOOP-7428.
   */
  @Test(timeout=60000)
  public void testRTEDuringConnectionSetup() throws IOException {
    // Set up a socket factory which returns sockets which
    // throw an RTE when setSoTimeout is called.
    SocketFactory spyFactory = spy(NetUtils.getDefaultSocketFactory(conf));
    Mockito.doAnswer(new Answer<Socket>() {
      @Override
      public Socket answer(InvocationOnMock invocation) throws Throwable {
        Socket s = spy((Socket)invocation.callRealMethod());
        doThrow(new RuntimeException("Injected fault")).when(s)
          .setSoTimeout(anyInt());
        return s;
      }
    }).when(spyFactory).createSocket();
     
    Server server = new TestServer(1, true);
    server.start();
    try {
      // Call should fail due to injected exception.
      InetSocketAddress address = NetUtils.getConnectAddress(server);
      Client client = new Client(LongWritable.class, conf, spyFactory);
      try {
        client.call(new LongWritable(RANDOM.nextLong()),
                address, null, null, 0, conf);
        fail("Expected an exception to have been thrown");
      } catch (Exception e) {
        LOG.info("caught expected exception", e);
        assertTrue(StringUtils.stringifyException(e).contains(
            "Injected fault"));
      }
      // Resetting to the normal socket behavior should succeed
      // (i.e. it should not have cached a half-constructed connection)
 
      Mockito.reset(spyFactory);
      client.call(new LongWritable(RANDOM.nextLong()),
          address, null, null, 0, conf);
    } finally {
      server.stop();
    }
  }
 
  @Test(timeout=60000)
  public void testIpcTimeout() throws IOException {
    // start server
    Server server = new TestServer(1, true);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();

    // start client
    Client client = new Client(LongWritable.class, conf);
    // set timeout to be less than MIN_SLEEP_TIME
    try {
      client.call(new LongWritable(RANDOM.nextLong()),
              addr, null, null, MIN_SLEEP_TIME/2, conf);
      fail("Expected an exception to have been thrown");
    } catch (SocketTimeoutException e) {
      LOG.info("Get a SocketTimeoutException ", e);
    }
    // set timeout to be bigger than 3*ping interval
    client.call(new LongWritable(RANDOM.nextLong()),
        addr, null, null, 3*PING_INTERVAL+MIN_SLEEP_TIME, conf);
  }

  @Test(timeout=60000)
  public void testIpcConnectTimeout() throws IOException {
    // start server
    Server server = new TestServer(1, true);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    //Intentionally do not start server to get a connection timeout

    // start client
    Client.setConnectTimeout(conf, 100);
    Client client = new Client(LongWritable.class, conf);
    // set the rpc timeout to twice the MIN_SLEEP_TIME
    try {
      client.call(new LongWritable(RANDOM.nextLong()),
              addr, null, null, MIN_SLEEP_TIME*2, conf);
      fail("Expected an exception to have been thrown");
    } catch (SocketTimeoutException e) {
      LOG.info("Get a SocketTimeoutException ", e);
    }
  }
 
  /**
   * Check service class byte in IPC header is correct on wire.
   */
  @Test(timeout=60000)
  public void testIpcWithServiceClass() throws IOException {
    // start server
    Server server = new TestServer(5, false);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();

    // start client
    Client.setConnectTimeout(conf, 10000);

    callAndVerify(server, addr, 0, true);
    // Service Class is low to -128 as byte on wire.
    // -128 shouldn't be casted on wire but -129 should.
    callAndVerify(server, addr, -128, true);
    callAndVerify(server, addr, -129, false);

    // Service Class is up to 127.
    // 127 shouldn't be casted on wire but 128 should.
    callAndVerify(server, addr, 127, true);
    callAndVerify(server, addr, 128, false);

    server.stop();
  }

  /**
   * Make a call from a client and verify if header info is changed in server side
   */
  private void callAndVerify(Server server, InetSocketAddress addr,
      int serviceClass, boolean noChanged) throws IOException{
    Client client = new Client(LongWritable.class, conf);

    client.call(new LongWritable(RANDOM.nextLong()),
        addr, null, null, MIN_SLEEP_TIME, serviceClass, conf);
    Connection connection = server.getConnections().get(0);
    int serviceClass2 = connection.getServiceClass();
    assertFalse(noChanged ^ serviceClass == serviceClass2);
    client.stop();
  }

  /**
   * Check that file descriptors aren't leaked by starting
   * and stopping IPC servers.
   */
  @Test(timeout=60000)
  public void testSocketLeak() throws IOException {
    Assume.assumeTrue(FD_DIR.exists()); // only run on Linux

    long startFds = countOpenFileDescriptors();
    for (int i = 0; i < 50; i++) {
      Server server = new TestServer(1, true);
      server.start();
      server.stop();
    }
    long endFds = countOpenFileDescriptors();
   
    assertTrue("Leaked " + (endFds - startFds) + " file descriptors",
        endFds - startFds < 20);
  }
 
  private long countOpenFileDescriptors() {
    return FD_DIR.list().length;
  }

  @Test(timeout=60000)
  public void testIpcFromHadoop_0_18_13() throws IOException {
    doIpcVersionTest(NetworkTraces.HADOOP_0_18_3_RPC_DUMP,
        NetworkTraces.RESPONSE_TO_HADOOP_0_18_3_RPC);
  }
 
  @Test(timeout=60000)
  public void testIpcFromHadoop0_20_3() throws IOException {
    doIpcVersionTest(NetworkTraces.HADOOP_0_20_3_RPC_DUMP,
        NetworkTraces.RESPONSE_TO_HADOOP_0_20_3_RPC);
  }
 
  @Test(timeout=60000)
  public void testIpcFromHadoop0_21_0() throws IOException {
    doIpcVersionTest(NetworkTraces.HADOOP_0_21_0_RPC_DUMP,
        NetworkTraces.RESPONSE_TO_HADOOP_0_21_0_RPC);
  }
 
  @Test(timeout=60000)
  public void testHttpGetResponse() throws IOException {
    doIpcVersionTest("GET / HTTP/1.0\r\n\r\n".getBytes(),
        Server.RECEIVED_HTTP_REQ_RESPONSE.getBytes());
  }
 
  @Test(timeout=60000)
  public void testConnectionRetriesOnSocketTimeoutExceptions() throws IOException {
    Configuration conf = new Configuration();
    // set max retries to 0
    conf.setInt(
      CommonConfigurationKeysPublic.IPC_CLIENT_CONNECT_MAX_RETRIES_ON_SOCKET_TIMEOUTS_KEY,
      0);
    assertRetriesOnSocketTimeouts(conf, 1);

    // set max retries to 3
    conf.setInt(
      CommonConfigurationKeysPublic.IPC_CLIENT_CONNECT_MAX_RETRIES_ON_SOCKET_TIMEOUTS_KEY,
      3);
    assertRetriesOnSocketTimeouts(conf, 4);
  }

  private static class CallInfo {
    int id = RpcConstants.INVALID_CALL_ID;
    int retry = RpcConstants.INVALID_RETRY_COUNT;
  }

  /**
   * Test if
   * (1) the rpc server uses the call id/retry provided by the rpc client, and
   * (2) the rpc client receives the same call id/retry from the rpc server.
   */
  @Test(timeout=60000)
  public void testCallIdAndRetry() throws IOException {
    final CallInfo info = new CallInfo();

    // Override client to store the call info and check response
    final Client client = new Client(LongWritable.class, conf) {
      @Override
      Call createCall(RpcKind rpcKind, Writable rpcRequest) {
        final Call call = super.createCall(rpcKind, rpcRequest);
        info.id = call.id;
        info.retry = call.retry;
        return call;
      }
     
      @Override
      void checkResponse(RpcResponseHeaderProto header) throws IOException {
        super.checkResponse(header);
        Assert.assertEquals(info.id, header.getCallId());
        Assert.assertEquals(info.retry, header.getRetryCount());
      }
    };

    // Attach a listener that tracks every call received by the server.
    final TestServer server = new TestServer(1, false);
    server.callListener = new Runnable() {
      @Override
      public void run() {
        Assert.assertEquals(info.id, Server.getCallId());
        Assert.assertEquals(info.retry, Server.getCallRetryCount());
      }
    };

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      final SerialCaller caller = new SerialCaller(client, addr, 10);
      caller.run();
      assertFalse(caller.failed);
    } finally {
      client.stop();
      server.stop();
    }
  }
 
  /** A dummy protocol */
  private interface DummyProtocol {
    public void dummyRun();
  }
 
  /**
   * Test the retry count while used in a retry proxy.
   */
  @Test(timeout=60000)
  public void testRetryProxy() throws IOException {
    final Client client = new Client(LongWritable.class, conf);
   
    final TestServer server = new TestServer(1, false);
    server.callListener = new Runnable() {
      private int retryCount = 0;
      @Override
      public void run() {
        Assert.assertEquals(retryCount++, Server.getCallRetryCount());
      }
    };

    // try more times, so it is easier to find race condition bug
    // 10000 times runs about 6s on a core i7 machine
    final int totalRetry = 10000;
    DummyProtocol proxy = (DummyProtocol) Proxy.newProxyInstance(
        DummyProtocol.class.getClassLoader(),
        new Class[] { DummyProtocol.class }, new TestInvocationHandler(client,
            server, totalRetry));
    DummyProtocol retryProxy = (DummyProtocol) RetryProxy.create(
        DummyProtocol.class, proxy, RetryPolicies.RETRY_FOREVER);
   
    try {
      server.start();
      retryProxy.dummyRun();
      Assert.assertEquals(TestInvocationHandler.retry, totalRetry + 1);
    } finally {
      Client.setCallIdAndRetryCount(0, 0);
      client.stop();
      server.stop();
    }
  }
 
  /**
   * Test if the rpc server gets the default retry count (0) from client.
   */
  @Test(timeout=60000)
  public void testInitialCallRetryCount() throws IOException {
    // Override client to store the call id
    final Client client = new Client(LongWritable.class, conf);

    // Attach a listener that tracks every call ID received by the server.
    final TestServer server = new TestServer(1, false);
    server.callListener = new Runnable() {
      @Override
      public void run() {
        // we have not set the retry count for the client, thus on the server
        // side we should see retry count as 0
        Assert.assertEquals(0, Server.getCallRetryCount());
      }
    };

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      final SerialCaller caller = new SerialCaller(client, addr, 10);
      caller.run();
      assertFalse(caller.failed);
    } finally {
      client.stop();
      server.stop();
    }
  }
 
  /**
   * Test if the rpc server gets the retry count from client.
   */
  @Test(timeout=60000)
  public void testCallRetryCount() throws IOException {
    final int retryCount = 255;
    // Override client to store the call id
    final Client client = new Client(LongWritable.class, conf);
    Client.setCallIdAndRetryCount(Client.nextCallId(), 255);

    // Attach a listener that tracks every call ID received by the server.
    final TestServer server = new TestServer(1, false);
    server.callListener = new Runnable() {
      @Override
      public void run() {
        // we have not set the retry count for the client, thus on the server
        // side we should see retry count as 0
        Assert.assertEquals(retryCount, Server.getCallRetryCount());
      }
    };

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      final SerialCaller caller = new SerialCaller(client, addr, 10);
      caller.run();
      assertFalse(caller.failed);
    } finally {
      client.stop();
      server.stop();
    }
  }

  /**
   * Tests that client generates a unique sequential call ID for each RPC call,
   * even if multiple threads are using the same client.
* @throws InterruptedException
   */
  @Test(timeout=60000)
  public void testUniqueSequentialCallIds()
      throws IOException, InterruptedException {
    int serverThreads = 10, callerCount = 100, perCallerCallCount = 100;
    TestServer server = new TestServer(serverThreads, false);

    // Attach a listener that tracks every call ID received by the server.  This
    // list must be synchronized, because multiple server threads will add to it.
    final List<Integer> callIds = Collections.synchronizedList(
      new ArrayList<Integer>());
    server.callListener = new Runnable() {
      @Override
      public void run() {
        callIds.add(Server.getCallId());
      }
    };

    Client client = new Client(LongWritable.class, conf);

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      SerialCaller[] callers = new SerialCaller[callerCount];
      for (int i = 0; i < callerCount; ++i) {
        callers[i] = new SerialCaller(client, addr, perCallerCallCount);
        callers[i].start();
      }
      for (int i = 0; i < callerCount; ++i) {
        callers[i].join();
        assertFalse(callers[i].failed);
      }
    } finally {
      client.stop();
      server.stop();
    }

    int expectedCallCount = callerCount * perCallerCallCount;
    assertEquals(expectedCallCount, callIds.size());

    // It is not guaranteed that the server executes requests in sequential order
    // of client call ID, so we must sort the call IDs before checking that it
    // contains every expected value.
    Collections.sort(callIds);
    final int startID = callIds.get(0).intValue();
    for (int i = 0; i < expectedCallCount; ++i) {
      assertEquals(startID + i, callIds.get(i).intValue());
    }
  }

  private void assertRetriesOnSocketTimeouts(Configuration conf,
      int maxTimeoutRetries) throws IOException {
    SocketFactory mockFactory = Mockito.mock(SocketFactory.class);
    doThrow(new ConnectTimeoutException("fake")).when(mockFactory).createSocket();
    Client client = new Client(IntWritable.class, conf, mockFactory);
    InetSocketAddress address = new InetSocketAddress("127.0.0.1", 9090);
    try {
      client.call(new IntWritable(RANDOM.nextInt()), address, null, null, 0,
          conf);
      fail("Not throwing the SocketTimeoutException");
    } catch (SocketTimeoutException e) {
      Mockito.verify(mockFactory, Mockito.times(maxTimeoutRetries))
          .createSocket();
    }
  }
 
  private void doIpcVersionTest(
      byte[] requestData,
      byte[] expectedResponse) throws IOException {
    Server server = new TestServer(1, true);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();
    Socket socket = new Socket();

    try {
      NetUtils.connect(socket, addr, 5000);
     
      OutputStream out = socket.getOutputStream();
      InputStream in = socket.getInputStream();
      out.write(requestData, 0, requestData.length);
      out.flush();
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      IOUtils.copyBytes(in, baos, 256);
     
      byte[] responseData = baos.toByteArray();
     
      assertEquals(
          StringUtils.byteToHexString(expectedResponse),
          StringUtils.byteToHexString(responseData));
    } finally {
      IOUtils.closeSocket(socket);
      server.stop();
    }
  }
 
  /**
   * Convert a string of lines that look like:
   *   "68 72 70 63 02 00 00 00  82 00 1d 6f 72 67 2e 61 hrpc.... ...org.a"
   * .. into an array of bytes.
   */
  private static byte[] hexDumpToBytes(String hexdump) {
    final int LAST_HEX_COL = 3 * 16;
   
    StringBuilder hexString = new StringBuilder();
   
    for (String line : hexdump.toUpperCase().split("\n")) {
      hexString.append(line.substring(0, LAST_HEX_COL).replace(" ", ""));
    }
    return StringUtils.hexStringToByte(hexString.toString());
  }
 
  /**
   * Wireshark traces collected from various client versions. These enable
   * us to test that old versions of the IPC stack will receive the correct
   * responses so that they will throw a meaningful error message back
   * to the user.
   */
  private static abstract class NetworkTraces {
    /**
     * Wireshark dump of an RPC request from Hadoop 0.18.3
     */
    final static byte[] HADOOP_0_18_3_RPC_DUMP =
      hexDumpToBytes(
      "68 72 70 63 02 00 00 00  82 00 1d 6f 72 67 2e 61 hrpc.... ...org.a\n" +
      "70 61 63 68 65 2e 68 61  64 6f 6f 70 2e 69 6f 2e pache.ha doop.io.\n" +
      "57 72 69 74 61 62 6c 65  00 30 6f 72 67 2e 61 70 Writable .0org.ap\n" +
      "61 63 68 65 2e 68 61 64  6f 6f 70 2e 69 6f 2e 4f ache.had oop.io.O\n" +
      "62 6a 65 63 74 57 72 69  74 61 62 6c 65 24 4e 75 bjectWri table$Nu\n" +
      "6c 6c 49 6e 73 74 61 6e  63 65 00 2f 6f 72 67 2e llInstan ce./org.\n" +
      "61 70 61 63 68 65 2e 68  61 64 6f 6f 70 2e 73 65 apache.h adoop.se\n" +
      "63 75 72 69 74 79 2e 55  73 65 72 47 72 6f 75 70 curity.U serGroup\n" +
      "49 6e 66 6f 72 6d 61 74  69 6f 6e 00 00 00 6c 00 Informat ion...l.\n" +
      "00 00 00 00 12 67 65 74  50 72 6f 74 6f 63 6f 6c .....get Protocol\n" +
      "56 65 72 73 69 6f 6e 00  00 00 02 00 10 6a 61 76 Version. .....jav\n" +
      "61 2e 6c 61 6e 67 2e 53  74 72 69 6e 67 00 2e 6f a.lang.S tring..o\n" +
      "72 67 2e 61 70 61 63 68  65 2e 68 61 64 6f 6f 70 rg.apach e.hadoop\n" +
      "2e 6d 61 70 72 65 64 2e  4a 6f 62 53 75 62 6d 69 .mapred. JobSubmi\n" +
      "73 73 69 6f 6e 50 72 6f  74 6f 63 6f 6c 00 04 6c ssionPro tocol..l\n" +
      "6f 6e 67 00 00 00 00 00  00 00 0a                ong..... ...     \n");

    final static String HADOOP0_18_ERROR_MSG =
      "Server IPC version " + RpcConstants.CURRENT_VERSION +
      " cannot communicate with client version 2";
   
    /**
     * Wireshark dump of the correct response that triggers an error message
     * on an 0.18.3 client.
     */
    final static byte[] RESPONSE_TO_HADOOP_0_18_3_RPC =
      Bytes.concat(hexDumpToBytes(
      "00 00 00 00 01 00 00 00  29 6f 72 67 2e 61 70 61 ........ )org.apa\n" +
      "63 68 65 2e 68 61 64 6f  6f 70 2e 69 70 63 2e 52 che.hado op.ipc.R\n" +
      "50 43 24 56 65 72 73 69  6f 6e 4d 69 73 6d 61 74 PC$Versi onMismat\n" +
      "63 68                                            ch               \n"),
       Ints.toByteArray(HADOOP0_18_ERROR_MSG.length()),
       HADOOP0_18_ERROR_MSG.getBytes());
   
    /**
     * Wireshark dump of an RPC request from Hadoop 0.20.3
     */
    final static byte[] HADOOP_0_20_3_RPC_DUMP =
      hexDumpToBytes(
      "68 72 70 63 03 00 00 00  79 27 6f 72 67 2e 61 70 hrpc.... y'org.ap\n" +
      "61 63 68 65 2e 68 61 64  6f 6f 70 2e 69 70 63 2e ache.had oop.ipc.\n" +
      "56 65 72 73 69 6f 6e 65  64 50 72 6f 74 6f 63 6f Versione dProtoco\n" +
      "6c 01 0a 53 54 52 49 4e  47 5f 55 47 49 04 74 6f l..STRIN G_UGI.to\n" +
      "64 64 09 04 74 6f 64 64  03 61 64 6d 07 64 69 61 dd..todd .adm.dia\n" +
      "6c 6f 75 74 05 63 64 72  6f 6d 07 70 6c 75 67 64 lout.cdr om.plugd\n" +
      "65 76 07 6c 70 61 64 6d  69 6e 05 61 64 6d 69 6e ev.lpadm in.admin\n" +
      "0a 73 61 6d 62 61 73 68  61 72 65 06 6d 72 74 65 .sambash are.mrte\n" +
      "73 74 00 00 00 6c 00 00  00 00 00 12 67 65 74 50 st...l.. ....getP\n" +
      "72 6f 74 6f 63 6f 6c 56  65 72 73 69 6f 6e 00 00 rotocolV ersion..\n" +
      "00 02 00 10 6a 61 76 61  2e 6c 61 6e 67 2e 53 74 ....java .lang.St\n" +
      "72 69 6e 67 00 2e 6f 72  67 2e 61 70 61 63 68 65 ring..or g.apache\n" +
      "2e 68 61 64 6f 6f 70 2e  6d 61 70 72 65 64 2e 4a .hadoop. mapred.J\n" +
      "6f 62 53 75 62 6d 69 73  73 69 6f 6e 50 72 6f 74 obSubmis sionProt\n" +
      "6f 63 6f 6c 00 04 6c 6f  6e 67 00 00 00 00 00 00 ocol..lo ng......\n" +
      "00 14                                            ..               \n");

    final static String HADOOP0_20_ERROR_MSG =
      "Server IPC version " + RpcConstants.CURRENT_VERSION +
      " cannot communicate with client version 3";
   

    final static byte[] RESPONSE_TO_HADOOP_0_20_3_RPC =
      Bytes.concat(hexDumpToBytes(
      "ff ff ff ff ff ff ff ff  00 00 00 29 6f 72 67 2e ........ ...)org.\n" +
      "61 70 61 63 68 65 2e 68  61 64 6f 6f 70 2e 69 70 apache.h adoop.ip\n" +
      "63 2e 52 50 43 24 56 65  72 73 69 6f 6e 4d 69 73 c.RPC$Ve rsionMis\n" +
      "6d 61 74 63 68                                   match            \n"),
      Ints.toByteArray(HADOOP0_20_ERROR_MSG.length()),
      HADOOP0_20_ERROR_MSG.getBytes());
   
   
    final static String HADOOP0_21_ERROR_MSG =
      "Server IPC version " + RpcConstants.CURRENT_VERSION +
      " cannot communicate with client version 4";

    final static byte[] HADOOP_0_21_0_RPC_DUMP =
      hexDumpToBytes(
      "68 72 70 63 04 50                                hrpc.P" +
      // in 0.21 it comes in two separate TCP packets
      "00 00 00 3c 33 6f 72 67  2e 61 70 61 63 68 65 2e ...<3org .apache.\n" +
      "68 61 64 6f 6f 70 2e 6d  61 70 72 65 64 75 63 65 hadoop.m apreduce\n" +
      "2e 70 72 6f 74 6f 63 6f  6c 2e 43 6c 69 65 6e 74 .protoco l.Client\n" +
      "50 72 6f 74 6f 63 6f 6c  01 00 04 74 6f 64 64 00 Protocol ...todd.\n" +
      "00 00 00 71 00 00 00 00  00 12 67 65 74 50 72 6f ...q.... ..getPro\n" +
      "74 6f 63 6f 6c 56 65 72  73 69 6f 6e 00 00 00 02 tocolVer sion....\n" +
      "00 10 6a 61 76 61 2e 6c  61 6e 67 2e 53 74 72 69 ..java.l ang.Stri\n" +
      "6e 67 00 33 6f 72 67 2e  61 70 61 63 68 65 2e 68 ng.3org. apache.h\n" +
      "61 64 6f 6f 70 2e 6d 61  70 72 65 64 75 63 65 2e adoop.ma preduce.\n" +
      "70 72 6f 74 6f 63 6f 6c  2e 43 6c 69 65 6e 74 50 protocol .ClientP\n" +
      "72 6f 74 6f 63 6f 6c 00  04 6c 6f 6e 67 00 00 00 rotocol. .long...\n" +
      "00 00 00 00 21                                   ....!            \n");
   
    final static byte[] RESPONSE_TO_HADOOP_0_21_0_RPC =
      Bytes.concat(hexDumpToBytes(
      "ff ff ff ff ff ff ff ff  00 00 00 29 6f 72 67 2e ........ ...)org.\n" +
      "61 70 61 63 68 65 2e 68  61 64 6f 6f 70 2e 69 70 apache.h adoop.ip\n" +
      "63 2e 52 50 43 24 56 65  72 73 69 6f 6e 4d 69 73 c.RPC$Ve rsionMis\n" +
      "6d 61 74 63 68                                   match            \n"),
      Ints.toByteArray(HADOOP0_21_ERROR_MSG.length()),
      HADOOP0_21_ERROR_MSG.getBytes());
  }
}
TOP

Related Classes of org.apache.hadoop.ipc.TestIPC$IOEOnWriteWritable

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.