Package org.apache.flink.api.common.typeutils

Source Code of org.apache.flink.api.common.typeutils.ComparatorTestBase$TestInputView

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.common.typeutils;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;

import static org.junit.Assert.*;

import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.MemorySegment;
import org.junit.Assert;
import org.junit.Test;

/**
* Abstract test base for comparators.
*
* @param <T>
*/
public abstract class ComparatorTestBase<T> {

  // Same as in the NormalizedKeySorter
  private static final int DEFAULT_MAX_NORMALIZED_KEY_LEN = 8;

  protected abstract TypeComparator<T> createComparator(boolean ascending);

  protected abstract TypeSerializer<T> createSerializer();

  /**
   * Returns the sorted data set.
   * <p>
   * Note: every element needs to be *strictly greater* than the previous element.
   *
   * @return sorted test data set
   */
  protected abstract T[] getSortedTestData();

  // -------------------------------- test duplication ------------------------------------------
 
  @Test
  public void testDuplicate() {
    try {
      TypeComparator<T> comparator = getComparator(true);
      TypeComparator<T> clone = comparator.duplicate();
     
      T[] data = getSortedData();
      comparator.setReference(data[0]);
      clone.setReference(data[1]);
     
      assertTrue("Comparator duplication does not work: Altering the reference in a duplicated comparator alters the original comparator's reference.",
          comparator.equalToReference(data[0]) && clone.equalToReference(data[1]));
    }
    catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
      Assert.fail(e.getMessage());
    }
  }
 
  // --------------------------------- equality tests -------------------------------------------
 
  @Test
  public void testEquality() {
    testEquals(true);
    testEquals(false);
  }

  protected void testEquals(boolean ascending) {
    try {
      // Just setup two identical output/inputViews and go over their data to see if compare works
      TestOutputView out1;
      TestOutputView out2;
      TestInputView in1;
      TestInputView in2;

      // Now use comparator and compare
      TypeComparator<T> comparator = getComparator(ascending);
      T[] data = getSortedData();
      for (T d : data) {

        out2 = new TestOutputView();
        writeSortedData(d, out2);
        in2 = out2.getInputView();

        out1 = new TestOutputView();
        writeSortedData(d, out1);
        in1 = out1.getInputView();

        assertTrue(comparator.compareSerialized(in1, in2) == 0);
      }
    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
      fail("Exception in test: " + e.getMessage());
    }
  }

  @Test
  public void testEqualityWithReference() {
    try {
      TypeSerializer<T> serializer = createSerializer();
      TypeComparator<T> comparator = getComparator(true);
      TypeComparator<T> comparator2 = getComparator(true);
      T[] data = getSortedData();
      for (T d : data) {
        comparator.setReference(d);
        // Make a copy to compare
        T copy = serializer.copy(d, serializer.createInstance());

        // And then test equalTo and compareToReference method of comparator
        assertTrue(comparator.equalToReference(d));
        comparator2.setReference(copy);
        assertTrue(comparator.compareToReference(comparator2) == 0);
      }
    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
      fail("Exception in test: " + e.getMessage());
    }
  }

  // --------------------------------- inequality tests ----------------------------------------
  @Test
  public void testInequality() {
    testGreatSmallAscDesc(true, true);
    testGreatSmallAscDesc(false, true);
    testGreatSmallAscDesc(true, false);
    testGreatSmallAscDesc(false, false);
  }

  protected void testGreatSmallAscDesc(boolean ascending, boolean greater) {
    try {
      //split data into low and high part
      T[] data = getSortedData();

      TypeComparator<T> comparator = getComparator(ascending);
      TestOutputView out1;
      TestOutputView out2;
      TestInputView in1;
      TestInputView in2;

      //compares every element in high with every element in low
      for (int x = 0; x < data.length - 1; x++) {
        for (int y = x + 1; y < data.length; y++) {
          out1 = new TestOutputView();
          writeSortedData(data[x], out1);
          in1 = out1.getInputView();

          out2 = new TestOutputView();
          writeSortedData(data[y], out2);
          in2 = out2.getInputView();

          if (greater && ascending) {
            assertTrue(comparator.compareSerialized(in1, in2) < 0);
          }
          if (greater && !ascending) {
            assertTrue(comparator.compareSerialized(in1, in2) > 0);
          }
          if (!greater && ascending) {
            assertTrue(comparator.compareSerialized(in2, in1) > 0);
          }
          if (!greater && !ascending) {
            assertTrue(comparator.compareSerialized(in2, in1) < 0);
          }
        }
      }
    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
      fail("Exception in test: " + e.getMessage());
    }
  }

  @Test
  public void testInequalityWithReference() {
    testGreatSmallAscDescWithReference(true, true);
    testGreatSmallAscDescWithReference(true, false);
    testGreatSmallAscDescWithReference(false, true);
    testGreatSmallAscDescWithReference(false, false);
  }

  protected void testGreatSmallAscDescWithReference(boolean ascending, boolean greater) {
    try {
      T[] data = getSortedData();

      TypeComparator<T> comparatorLow = getComparator(ascending);
      TypeComparator<T> comparatorHigh = getComparator(ascending);

      //compares every element in high with every element in low
      for (int x = 0; x < data.length - 1; x++) {
        for (int y = x + 1; y < data.length; y++) {
          comparatorLow.setReference(data[x]);
          comparatorHigh.setReference(data[y]);

          if (greater && ascending) {
            assertTrue(comparatorLow.compareToReference(comparatorHigh) > 0);
          }
          if (greater && !ascending) {
            assertTrue(comparatorLow.compareToReference(comparatorHigh) < 0);
          }
          if (!greater && ascending) {
            assertTrue(comparatorHigh.compareToReference(comparatorLow) < 0);
          }
          if (!greater && !ascending) {
            assertTrue(comparatorHigh.compareToReference(comparatorLow) > 0);
          }
        }
      }
    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
      fail("Exception in test: " + e.getMessage());
    }
  }

  // --------------------------------- Normalized key tests -------------------------------------
 
  // Help Function for setting up a memory segment and normalize the keys of the data array in it
  public MemorySegment setupNormalizedKeysMemSegment(T[] data, int normKeyLen, TypeComparator<T> comparator) {
    MemorySegment memSeg = new MemorySegment(new byte[2048]);

    // Setup normalized Keys in the memory segment
    int offset = 0;
    for (T e : data) {
      comparator.putNormalizedKey(e, memSeg, offset, normKeyLen);
      offset += normKeyLen;
    }
    return memSeg;
  }
 
  // Help Function which return a normalizedKeyLength, either as done in the NormalizedKeySorter or it's half
  private int getNormKeyLen(boolean halfLength, T[] data,
      TypeComparator<T> comparator) throws Exception {
    // Same as in the NormalizedKeySorter
    int keyLen = Math.min(comparator.getNormalizeKeyLen(), DEFAULT_MAX_NORMALIZED_KEY_LEN);
    if (keyLen < comparator.getNormalizeKeyLen()) {
      assertTrue(comparator.isNormalizedKeyPrefixOnly(keyLen));
    }

    if (halfLength) {
      keyLen = keyLen / 2;
      assertTrue(comparator.isNormalizedKeyPrefixOnly(keyLen));
    }
    return keyLen;
  }

  @Test
  public void testNormalizedKeysEqualsFullLength() {
    // Ascending or descending does not matter in this case
    TypeComparator<T> comparator = getComparator(true);
    if (!comparator.supportsNormalizedKey()) {
      return;
    }
    testNormalizedKeysEquals(false);
  }

  @Test
  public void testNormalizedKeysEqualsHalfLength() {
    TypeComparator<T> comparator = getComparator(true);
    if (!comparator.supportsNormalizedKey()) {
      return;
    }
    testNormalizedKeysEquals(true);
  }
 
  public void testNormalizedKeysEquals(boolean halfLength) {
    try {
      TypeComparator<T> comparator = getComparator(true);
      T[] data = getSortedData();
      int normKeyLen = getNormKeyLen(halfLength, data, comparator);

      MemorySegment memSeg1 = setupNormalizedKeysMemSegment(data, normKeyLen, comparator);
      MemorySegment memSeg2 = setupNormalizedKeysMemSegment(data, normKeyLen, comparator);

      for (int i = 0; i < data.length; i++) {
        assertTrue(MemorySegment.compare(memSeg1, memSeg2, i * normKeyLen, i * normKeyLen, normKeyLen) == 0);
      }
    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
      fail("Exception in test: " + e.getMessage());
    }
  }


  @Test
  public void testNormalizedKeysGreatSmallFullLength() {
    // ascending/descending in comparator doesn't matter for normalized keys
    TypeComparator<T> comparator = getComparator(true);
    if (!comparator.supportsNormalizedKey()) {
      return;
    }
    testNormalizedKeysGreatSmall(true, comparator, false);
    testNormalizedKeysGreatSmall(false, comparator, false);
  }

  @Test
  public void testNormalizedKeysGreatSmallAscDescHalfLength() {
    // ascending/descending in comparator doesn't matter for normalized keys
    TypeComparator<T> comparator = getComparator(true);
    if (!comparator.supportsNormalizedKey()) {
      return;
    }
    testNormalizedKeysGreatSmall(true, comparator, true);
    testNormalizedKeysGreatSmall(false, comparator, true);
  }

  protected void testNormalizedKeysGreatSmall(boolean greater, TypeComparator<T> comparator, boolean halfLength) {
    try {
      T[] data = getSortedData();
      // Get the normKeyLen on which we are testing
      int normKeyLen = getNormKeyLen(halfLength, data, comparator);

      // Write the data into different 2 memory segements
      MemorySegment memSegLow = setupNormalizedKeysMemSegment(data, normKeyLen, comparator);
      MemorySegment memSegHigh = setupNormalizedKeysMemSegment(data, normKeyLen, comparator);

      boolean fullyDetermines = !comparator.isNormalizedKeyPrefixOnly(normKeyLen);
     
      // Compare every element with every bigger element
      for (int l = 0; l < data.length - 1; l++) {
        for (int h = l + 1; h < data.length; h++) {
          int cmp;
          if (greater) {
            cmp = MemorySegment.compare(memSegLow, memSegHigh, l * normKeyLen, h * normKeyLen, normKeyLen);
            if (fullyDetermines) {
              assertTrue(cmp < 0);
            } else {
              assertTrue(cmp <= 0);
            }
          } else {
            cmp = MemorySegment.compare(memSegHigh, memSegLow, h * normKeyLen, l * normKeyLen, normKeyLen);
            if (fullyDetermines) {
              assertTrue(cmp > 0);
            } else {
              assertTrue(cmp >= 0);
            }
          }
        }
      }
    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
      fail("Exception in test: " + e.getMessage());
    }
  }

  @Test
  public void testNormalizedKeyReadWriter() {
    try {
      T[] data = getSortedData();
      T reuse = getSortedData()[0];

      TypeComparator<T> comp1 = getComparator(true);
      if(!comp1.supportsSerializationWithKeyNormalization()){
        return;
      }
      TypeComparator<T> comp2 = comp1.duplicate();
      comp2.setReference(reuse);

      TestOutputView out = new TestOutputView();
      TestInputView in;

      for (T value : data) {
        comp1.setReference(value);
        comp1.writeWithKeyNormalization(value, out);
        in = out.getInputView();
        comp1.readWithKeyDenormalization(reuse, in);
       
        assertTrue(comp1.compareToReference(comp2) == 0);
      }
    } catch (Exception e) {
      System.err.println(e.getMessage());
      e.printStackTrace();
      fail("Exception in test: " + e.getMessage());
    }
  }

  // --------------------------------------------------------------------------------------------

  protected void deepEquals(String message, T should, T is) {
    assertEquals(message, should, is);
  }

  // --------------------------------------------------------------------------------------------
  protected TypeComparator<T> getComparator(boolean ascending) {
    TypeComparator<T> comparator = createComparator(ascending);
    if (comparator == null) {
      throw new RuntimeException("Test case corrupt. Returns null as comparator.");
    }
    return comparator;
  }

  protected T[] getSortedData() {
    T[] data = getSortedTestData();
    if (data == null) {
      throw new RuntimeException("Test case corrupt. Returns null as test data.");
    }
    if (data.length < 2) {
      throw new RuntimeException("Test case does not provide enough sorted test data.");
    }
   
    return data;
  }

  protected TypeSerializer<T> getSerializer() {
    TypeSerializer<T> serializer = createSerializer();
    if (serializer == null) {
      throw new RuntimeException("Test case corrupt. Returns null as serializer.");
    }
    return serializer;
  }

  protected void writeSortedData(T value, TestOutputView out) throws IOException {
    TypeSerializer<T> serializer = getSerializer();

    // Write data into a outputView
    serializer.serialize(value, out);

    // This are the same tests like in the serializer
    // Just look if the data is really there after serialization, before testing comparator on it
    TestInputView in = out.getInputView();
    assertTrue("No data available during deserialization.", in.available() > 0);

    T deserialized = serializer.deserialize(serializer.createInstance(), in);
    deepEquals("Deserialized value is wrong.", value, deserialized);

  }

  // --------------------------------------------------------------------------------------------
  protected static final class TestOutputView extends DataOutputStream implements DataOutputView {

    public TestOutputView() {
      super(new ByteArrayOutputStream(4096));
    }

    public TestInputView getInputView() {
      ByteArrayOutputStream baos = (ByteArrayOutputStream) out;
      return new TestInputView(baos.toByteArray());
    }

    @Override
    public void skipBytesToWrite(int numBytes) throws IOException {
      for (int i = 0; i < numBytes; i++) {
        write(0);
      }
    }

    @Override
    public void write(DataInputView source, int numBytes) throws IOException {
      byte[] buffer = new byte[numBytes];
      source.readFully(buffer);
      write(buffer);
    }
  }

  protected static final class TestInputView extends DataInputStream implements DataInputView {

    public TestInputView(byte[] data) {
      super(new ByteArrayInputStream(data));
    }

    @Override
    public void skipBytesToRead(int numBytes) throws IOException {
      while (numBytes > 0) {
        int skipped = skipBytes(numBytes);
        numBytes -= skipped;
      }
    }
  }
}
TOP

Related Classes of org.apache.flink.api.common.typeutils.ComparatorTestBase$TestInputView

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.