/** Story: test that the mapper will map input points to the nearest cluster */
@Test
public void testKMeansMapper() throws Exception {
KMeansMapper mapper = new KMeansMapper();
EuclideanDistanceMeasure measure = new EuclideanDistanceMeasure();
Configuration conf = new Configuration();
conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measure.getClass().getName());
conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, "0.001");
conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, "");
List<VectorWritable> points = getPointsWritable(REFERENCE);
for (int k = 0; k < points.size(); k++) {
// pick k initial cluster centers at random
DummyRecordWriter<Text, ClusterObservations> mapWriter = new DummyRecordWriter<Text, ClusterObservations>();
Mapper<WritableComparable<?>, VectorWritable, Text, ClusterObservations>.Context mapContext = DummyRecordWriter
.build(mapper, conf, mapWriter);
List<Cluster> clusters = new ArrayList<Cluster>();
for (int i = 0; i < k + 1; i++) {
Cluster cluster = new Cluster(points.get(i).get(), i, measure);
// add the center so the centroid will be correct upon output
cluster.observe(cluster.getCenter(), 1);
clusters.add(cluster);
}
mapper.setup(clusters, measure);
// map the data
for (VectorWritable point : points) {
mapper.map(new Text(), point, mapContext);
}
assertEquals("Number of map results", k + 1, mapWriter.getData().size());
Map<String, Cluster> clusterMap = loadClusterMap(clusters);
for (Text key : mapWriter.getKeys()) {
AbstractCluster cluster = clusterMap.get(key.toString());
List<ClusterObservations> values = mapWriter.getValue(key);
for (ClusterObservations value : values) {
double distance = measure.distance(cluster.getCenter(), value.getS1());
for (AbstractCluster c : clusters) {
assertTrue("distance error", distance <= measure.distance(value.getS1(), c.getCenter()));
}
}
}
}
}