Package org.apache.spark.network.shuffle

Source Code of org.apache.spark.network.shuffle.OneForOneBlockFetcherSuite

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.shuffle;

import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import com.google.common.collect.Maps;
import io.netty.buffer.Unpooled;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import static org.junit.Assert.*;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NettyManagedBuffer;
import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;

public class OneForOneBlockFetcherSuite {
  @Test
  public void testFetchOne() {
    LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
    blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));

    BlockFetchingListener listener = fetchBlocks(blocks);

    verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
  }

  @Test
  public void testFetchThree() {
    LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
    blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
    blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
    blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));

    BlockFetchingListener listener = fetchBlocks(blocks);

    for (int i = 0; i < 3; i ++) {
      verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i));
    }
  }

  @Test
  public void testFailure() {
    LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
    blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
    blocks.put("b1", null);
    blocks.put("b2", null);

    BlockFetchingListener listener = fetchBlocks(blocks);

    // Each failure will cause a failure to be invoked in all remaining block fetches.
    verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
    verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any());
    verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any());
  }

  @Test
  public void testFailureAndSuccess() {
    LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap();
    blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
    blocks.put("b1", null);
    blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21])));

    BlockFetchingListener listener = fetchBlocks(blocks);

    // We may call both success and failure for the same block.
    verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0"));
    verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any());
    verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2"));
    verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any());
  }

  @Test
  public void testEmptyBlockFetch() {
    try {
      fetchBlocks(Maps.<String, ManagedBuffer>newLinkedHashMap());
      fail();
    } catch (IllegalArgumentException e) {
      assertEquals("Zero-sized blockIds array", e.getMessage());
    }
  }

  /**
   * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which
   * simply returns the given (BlockId, Block) pairs.
   * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order
   * that they were inserted in.
   *
   * If a block's buffer is "null", an exception will be thrown instead.
   */
  private BlockFetchingListener fetchBlocks(final LinkedHashMap<String, ManagedBuffer> blocks) {
    TransportClient client = mock(TransportClient.class);
    BlockFetchingListener listener = mock(BlockFetchingListener.class);
    final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
    OneForOneBlockFetcher fetcher =
      new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener);

    // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123
    doAnswer(new Answer<Void>() {
      @Override
      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
        BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray(
          (byte[]) invocationOnMock.getArguments()[0]);
        RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1];
        callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray());
        assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message);
        return null;
      }
    }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());

    // Respond to each chunk request with a single buffer from our blocks array.
    final AtomicInteger expectedChunkIndex = new AtomicInteger(0);
    final Iterator<ManagedBuffer> blockIterator = blocks.values().iterator();
    doAnswer(new Answer<Void>() {
      @Override
      public Void answer(InvocationOnMock invocation) throws Throwable {
        try {
          long streamId = (Long) invocation.getArguments()[0];
          int myChunkIndex = (Integer) invocation.getArguments()[1];
          assertEquals(123, streamId);
          assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex);

          ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2];
          ManagedBuffer result = blockIterator.next();
          if (result != null) {
            callback.onSuccess(myChunkIndex, result);
          } else {
            callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex));
          }
        } catch (Exception e) {
          e.printStackTrace();
          fail("Unexpected failure");
        }
        return null;
      }
    }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any());

    fetcher.start();
    return listener;
  }
}
TOP

Related Classes of org.apache.spark.network.shuffle.OneForOneBlockFetcherSuite

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.