Package org.apache.mahout.clustering.streaming.cluster

Source Code of org.apache.mahout.clustering.streaming.cluster.StreamingKMeansTest

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.clustering.streaming.cluster;


import java.util.Arrays;
import java.util.List;

import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.neighborhood.BruteSearch;
import org.apache.mahout.math.neighborhood.FastProjectionSearch;
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.Searcher;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.runners.Parameterized.Parameters;


@RunWith(Parameterized.class)
public class StreamingKMeansTest {
  private static final int NUM_DATA_POINTS = 1 << 16;
  private static final int NUM_DIMENSIONS = 6;
  private static final int NUM_PROJECTIONS = 2;
  private static final int SEARCH_SIZE = 10;

  private static Pair<List<Centroid>, List<Centroid>> syntheticData ;

  @Before
  public void setUp() {
    RandomUtils.useTestSeed();
    syntheticData =
      DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS);
  }

  private UpdatableSearcher searcher;
  private boolean allAtOnce;

  public StreamingKMeansTest(UpdatableSearcher searcher, boolean allAtOnce) {
    this.searcher = searcher;
    this.allAtOnce = allAtOnce;
  }

  @Parameters
  public static List<Object[]> generateData() {
    return Arrays.asList(new Object[][] {
        {new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), true},
        {new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE),
            true},
        {new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), false},
        {new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE),
            false},
    });
  }

  @Test
  public void testAverageDistanceCutoff() {
    double avgDistanceCutoff = 0;
    double avgNumClusters = 0;
    int numTests = 1;
    System.out.printf("Distance cutoff for %s\n", searcher.getClass().getName());
    for (int i = 0; i < numTests; ++i) {
      searcher.clear();
      int numStreamingClusters = (int)Math.log(syntheticData.getFirst().size()) * (1 <<
          NUM_DIMENSIONS);
      double distanceCutoff = 1.0e-6;
      double estimatedCutoff = ClusteringUtils.estimateDistanceCutoff(syntheticData.getFirst(),
          searcher.getDistanceMeasure(), 100);
      System.out.printf("[%d] Generated synthetic data [magic] %f [estimate] %f\n", i, distanceCutoff, estimatedCutoff);
      StreamingKMeans clusterer =
          new StreamingKMeans(searcher, numStreamingClusters, estimatedCutoff);
      clusterer.cluster(syntheticData.getFirst());
      avgDistanceCutoff += clusterer.getDistanceCutoff();
      avgNumClusters += clusterer.getNumClusters();
      System.out.printf("[%d] %f\n", i, clusterer.getDistanceCutoff());
    }
    avgDistanceCutoff /= numTests;
    avgNumClusters /= numTests;
    System.out.printf("Final: distanceCutoff: %f estNumClusters: %f\n", avgDistanceCutoff, avgNumClusters);
  }

  @Test
  public void testClustering() {
    searcher.clear();
    int numStreamingClusters = (int)Math.log(syntheticData.getFirst().size()) * (1 << NUM_DIMENSIONS);
    System.out.printf("k log n = %d\n", numStreamingClusters);
    double estimatedCutoff = ClusteringUtils.estimateDistanceCutoff(syntheticData.getFirst(),
        searcher.getDistanceMeasure(), 100);
    StreamingKMeans clusterer =
        new StreamingKMeans(searcher, numStreamingClusters, estimatedCutoff);

    long startTime = System.currentTimeMillis();
    if (allAtOnce) {
      clusterer.cluster(syntheticData.getFirst());
    } else {
      for (Centroid datapoint : syntheticData.getFirst()) {
        clusterer.cluster(datapoint);
      }
    }
    long endTime = System.currentTimeMillis();

    System.out.printf("%s %s\n", searcher.getClass().getName(), searcher.getDistanceMeasure()
        .getClass().getName());
    System.out.printf("Total number of clusters %d\n", clusterer.getNumClusters());

    System.out.printf("Weights: %f %f\n", ClusteringUtils.totalWeight(syntheticData.getFirst()),
        ClusteringUtils.totalWeight(clusterer));
    assertEquals("Total weight not preserved", ClusteringUtils.totalWeight(syntheticData.getFirst()),
        ClusteringUtils.totalWeight(clusterer), 1.0e-9);

    // and verify that each corner of the cube has a centroid very nearby
    double maxWeight = 0;
    for (Vector mean : syntheticData.getSecond()) {
      WeightedThing<Vector> v = searcher.search(mean, 1).get(0);
      maxWeight = Math.max(v.getWeight(), maxWeight);
    }
    assertTrue("Maximum weight too large " + maxWeight, maxWeight < 0.05);
    double clusterTime = (endTime - startTime) / 1000.0;
    System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n",
        searcher.getClass().getName(), clusterTime,
        clusterTime / syntheticData.getFirst().size() * 1.0e6);

    // verify that the total weight of the centroids near each corner is correct
    double[] cornerWeights = new double[1 << NUM_DIMENSIONS];
    Searcher trueFinder = new BruteSearch(new EuclideanDistanceMeasure());
    for (Vector trueCluster : syntheticData.getSecond()) {
      trueFinder.add(trueCluster);
    }
    for (Centroid centroid : clusterer) {
      WeightedThing<Vector> closest = trueFinder.search(centroid, 1).get(0);
      cornerWeights[((Centroid)closest.getValue()).getIndex()] += centroid.getWeight();
    }
    int expectedNumPoints = NUM_DATA_POINTS / (1 << NUM_DIMENSIONS);
    for (double v : cornerWeights) {
      System.out.printf("%f ", v);
    }
    System.out.println();
    for (double v : cornerWeights) {
      assertEquals(expectedNumPoints, v, 0);
    }
  }
}
TOP

Related Classes of org.apache.mahout.clustering.streaming.cluster.StreamingKMeansTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.