Package org.apache.mahout.math.neighborhood

Source Code of org.apache.mahout.math.neighborhood.SearchSanityTest

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.math.neighborhood;

import java.util.Arrays;
import java.util.List;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.jet.math.Constants;
import org.apache.mahout.math.random.MultiNormal;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
public class SearchSanityTest extends MahoutTestCase {
  private static final int NUM_DATA_POINTS = 1 << 13;
  private static final int NUM_DIMENSIONS = 20;
  private static final int NUM_PROJECTIONS = 3;
  private static final int SEARCH_SIZE = 30;

  private UpdatableSearcher searcher;
  private Matrix dataPoints;

  protected static Matrix multiNormalRandomData(int numDataPoints, int numDimensions) {
    Matrix data = new DenseMatrix(numDataPoints, numDimensions);
    MultiNormal gen = new MultiNormal(20);
    for (MatrixSlice slice : data) {
      slice.vector().assign(gen.sample());
    }
    return data;
  }

  @Parameterized.Parameters
  public static List<Object[]> generateData() {
    Matrix dataPoints = multiNormalRandomData(NUM_DATA_POINTS, NUM_DIMENSIONS);
    return Arrays.asList(new Object[][]{
        {new ProjectionSearch(new EuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), dataPoints},
        {new FastProjectionSearch(new EuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE),
            dataPoints},
        {new LocalitySensitiveHashSearch(new EuclideanDistanceMeasure(), SEARCH_SIZE), dataPoints},
    });
  }

  public SearchSanityTest(UpdatableSearcher searcher, Matrix dataPoints) {
    this.searcher = searcher;
    this.dataPoints = dataPoints;
  }

  @Test
  public void testExactMatch() {
    searcher.clear();
    Iterable<MatrixSlice> data = dataPoints;

    final Iterable<MatrixSlice> batch1 = Iterables.limit(data, 300);
    List<MatrixSlice> queries = Lists.newArrayList(Iterables.limit(batch1, 100));

    // adding the data in multiple batches triggers special code in some searchers
    searcher.addAllMatrixSlices(batch1);
    assertEquals(300, searcher.size());

    Vector q = Iterables.get(data, 0).vector();
    List<WeightedThing<Vector>> r = searcher.search(q, 2);
    assertEquals(0, r.get(0).getValue().minus(q).norm(1), 1.0e-8);

    final Iterable<MatrixSlice> batch2 = Iterables.limit(Iterables.skip(data, 300), 10);
    searcher.addAllMatrixSlices(batch2);
    assertEquals(310, searcher.size());

    q = Iterables.get(data, 302).vector();
    r = searcher.search(q, 2);
    assertEquals(0, r.get(0).getValue().minus(q).norm(1), 1.0e-8);

    searcher.addAllMatrixSlices(Iterables.skip(data, 310));
    assertEquals(dataPoints.numRows(), searcher.size());

    for (MatrixSlice query : queries) {
      r = searcher.search(query.vector(), 2);
      assertEquals("Distance has to be about zero", 0, r.get(0).getWeight(), 1.0e-6);
      assertEquals("Answer must be substantially the same as query", 0,
          r.get(0).getValue().minus(query.vector()).norm(1), 1.0e-8);
      assertTrue("Wrong answer must have non-zero distance",
          r.get(1).getWeight() > r.get(0).getWeight());
    }
  }

  @Test
  public void testNearMatch() {
    searcher.clear();
    List<MatrixSlice> queries = Lists.newArrayList(Iterables.limit(dataPoints, 100));
    searcher.addAllMatrixSlicesAsWeightedVectors(dataPoints);

    MultiNormal noise = new MultiNormal(0.01, new DenseVector(20));
    for (MatrixSlice slice : queries) {
      Vector query = slice.vector();
      final Vector epsilon = noise.sample();
      List<WeightedThing<Vector>> r = searcher.search(query, 2);
      query = query.plus(epsilon);
      assertEquals("Distance has to be small", epsilon.norm(2), r.get(0).getWeight(), 1.0e-1);
      assertEquals("Answer must be substantially the same as query", epsilon.norm(2),
          r.get(0).getValue().minus(query).norm(2), 1.0e-1);
      assertTrue("Wrong answer must be further away", r.get(1).getWeight() > r.get(0).getWeight());
    }
  }

  @Test
  public void testOrdering() {
    searcher.clear();
    Matrix queries = new DenseMatrix(100, 20);
    MultiNormal gen = new MultiNormal(20);
    for (int i = 0; i < 100; i++) {
      queries.viewRow(i).assign(gen.sample());
    }
    searcher.addAllMatrixSlices(dataPoints);

    for (MatrixSlice query : queries) {
      List<WeightedThing<Vector>> r = searcher.search(query.vector(), 200);
      double x = 0;
      for (WeightedThing<Vector> thing : r) {
        assertTrue("Scores must be monotonic increasing", thing.getWeight() >= x);
        x = thing.getWeight();
      }
    }
  }

  @Test
  public void testRemoval() {
    searcher.clear();
    searcher.addAllMatrixSlices(dataPoints);
    //noinspection ConstantConditions
    if (searcher instanceof UpdatableSearcher) {
      List<Vector> x = Lists.newArrayList(Iterables.limit(searcher, 2));
      int size0 = searcher.size();

      List<WeightedThing<Vector>> r0 = searcher.search(x.get(0), 2);

      searcher.remove(x.get(0), 1.0e-7);
      assertEquals(size0 - 1, searcher.size());

      List<WeightedThing<Vector>> r = searcher.search(x.get(0), 1);
      assertTrue("Vector should be gone", r.get(0).getWeight() > 0);
      assertEquals("Previous second neighbor should be first", 0,
          r.get(0).getValue().minus(r0.get(1).getValue()).norm (1), 1.0e-8);

      searcher.remove(x.get(1), 1.0e-7);
      assertEquals(size0 - 2, searcher.size());

      r = searcher.search(x.get(1), 1);
      assertTrue("Vector should be gone", r.get(0).getWeight() > 0);

      // Vectors don't show up in iterator.
      for (Vector v : searcher) {
        assertTrue(x.get(0).minus(v).norm(1) > 1.0e-6);
        assertTrue(x.get(1).minus(v).norm(1) > 1.0e-6);
      }
    } else {
      try {
        List<Vector> x = Lists.newArrayList(Iterables.limit(searcher, 2));
        searcher.remove(x.get(0), 1.0e-7);
        fail("Shouldn't be able to delete from " + searcher.getClass().getName());
      } catch (UnsupportedOperationException e) {
        // good enough that UOE is thrown
      }
    }
  }

  @Test
  public void testSearchFirst() {
    searcher.clear();
    searcher.addAll(dataPoints);
    for (Vector datapoint : dataPoints) {
      WeightedThing<Vector> first = searcher.searchFirst(datapoint, false);
      WeightedThing<Vector> second = searcher.searchFirst(datapoint, true);
      List<WeightedThing<Vector>> firstTwo = searcher.search(datapoint, 2);

      assertEquals("First isn't self", 0, first.getWeight(), 0);
      assertEquals("First isn't self", datapoint, first.getValue());
      assertEquals("First doesn't match", first, firstTwo.get(0));
      assertEquals("Second doesn't match", second, firstTwo.get(1));
    }
  }

  @Test
  public void testRemove() {
    searcher.clear();
    for (int i = 0; i < dataPoints.rowSize(); ++i) {
      Vector datapoint = dataPoints.viewRow(i);
      searcher.add(datapoint);
      // As long as points are not searched for right after being added, in FastProjectionSearch, points are not
      // merged with the main list right away, so if a search for a point occurs before it's merged the pendingAdditions
      // list also needs to be looked at.
      // This used to not be the case for searchFirst(), thereby causing removal failures.
      if (i % 2 == 0) {
        assertTrue("Failed to find self [search]",
            searcher.search(datapoint, 1).get(0).getWeight() < Constants.EPSILON);
        assertTrue("Failed to find self [searchFirst]",
            searcher.searchFirst(datapoint, false).getWeight() < Constants.EPSILON);
        assertTrue("Failed to remove self", searcher.remove(datapoint, Constants.EPSILON));
      }
    }
  }
}
TOP

Related Classes of org.apache.mahout.math.neighborhood.SearchSanityTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.