Package org.apache.mahout.clustering.streaming.tools

Source Code of org.apache.mahout.clustering.streaming.tools.ClusterQualitySummarizer

package org.apache.mahout.clustering.streaming.tools;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.List;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.stats.OnlineSummarizer;

public class ClusterQualitySummarizer {
  private String outputFile;


  private PrintWriter fileOut;

  private String trainFile;
  private String testFile;
  private String centroidFile;
  private String centroidCompareFile;
  private boolean mahoutKMeansFormat;
  private boolean mahoutKMeansFormatCompare;

  private DistanceMeasure distanceMeasure = new SquaredEuclideanDistanceMeasure();

  public void printSummaries(List<OnlineSummarizer> summarizers, String type) {
    printSummaries(summarizers, type, fileOut);
  }

  public static void printSummaries(List<OnlineSummarizer> summarizers, String type, PrintWriter fileOut) {
    double maxDistance = 0;
    for (int i = 0; i < summarizers.size(); ++i) {
      OnlineSummarizer summarizer = summarizers.get(i);
      if (summarizer.getCount() == 0) {
        System.out.printf("Cluster %d is empty\n", i);
        continue;
      }
      maxDistance = Math.max(maxDistance, summarizer.getMax());
      System.out.printf("Average distance in cluster %d [%d]: %f\n", i, summarizer.getCount(), summarizer.getMean());
      // If there is just one point in the cluster, quartiles cannot be estimated. We'll just assume all the quartiles
      // equal the only value.
      boolean moreThanOne = summarizer.getCount() > 1;
      if (fileOut != null) {
        fileOut.printf("%d,%f,%f,%f,%f,%f,%f,%f,%d,%s\n", i, summarizer.getMean(),
            summarizer.getSD(),
            summarizer.getQuartile(0),
            moreThanOne ? summarizer.getQuartile(1) : summarizer.getQuartile(0),
            moreThanOne ? summarizer.getQuartile(2) : summarizer.getQuartile(0),
            moreThanOne ? summarizer.getQuartile(3) : summarizer.getQuartile(0),
            summarizer.getQuartile(4), summarizer.getCount(), type);
      }
    }
    System.out.printf("Num clusters: %d; maxDistance: %f\n", summarizers.size(), maxDistance);
  }

  public void run(String[] args) {
    if (!parseArgs(args)) {
      return;
    }

    Configuration conf = new Configuration();
    try {
      Configuration.dumpConfiguration(conf, new OutputStreamWriter(System.out));

      fileOut = new PrintWriter(new FileOutputStream(outputFile));
      fileOut.printf("cluster,distance.mean,distance.sd,distance.q0,distance.q1,distance.q2,distance.q3,"
          + "distance.q4,count,is.train\n");

      // Reading in the centroids (both pairs, if they exist).
      List<Centroid> centroids;
      List<Centroid> centroidsCompare = null;
      if (mahoutKMeansFormat) {
        SequenceFileDirValueIterable<ClusterWritable> clusterIterable =
            new SequenceFileDirValueIterable<ClusterWritable>(new Path(centroidFile), PathType.GLOB, conf);
        centroids = Lists.newArrayList(IOUtils.getCentroidsFromClusterWritableIterable(clusterIterable));
      } else {
        SequenceFileDirValueIterable<CentroidWritable> centroidIterable =
            new SequenceFileDirValueIterable<CentroidWritable>(new Path(centroidFile), PathType.GLOB, conf);
        centroids = Lists.newArrayList(IOUtils.getCentroidsFromCentroidWritableIterable(centroidIterable));
      }

      if (centroidCompareFile != null) {
        if (mahoutKMeansFormatCompare) {
          SequenceFileDirValueIterable<ClusterWritable> clusterCompareIterable =
              new SequenceFileDirValueIterable<ClusterWritable>(new Path(centroidCompareFile), PathType.GLOB, conf);
          centroidsCompare = Lists.newArrayList(
              IOUtils.getCentroidsFromClusterWritableIterable(clusterCompareIterable));
        } else {
          SequenceFileDirValueIterable<CentroidWritable> centroidCompareIterable =
              new SequenceFileDirValueIterable<CentroidWritable>(new Path(centroidCompareFile), PathType.GLOB, conf);
          centroidsCompare = Lists.newArrayList(
              IOUtils.getCentroidsFromCentroidWritableIterable(centroidCompareIterable));
        }
      }

      // Reading in the "training" set.
      SequenceFileDirValueIterable<VectorWritable> trainIterable =
          new SequenceFileDirValueIterable<VectorWritable>(new Path(trainFile), PathType.GLOB, conf);
      Iterable<Vector> trainDatapoints = IOUtils.getVectorsFromVectorWritableIterable(trainIterable);
      Iterable<Vector> datapoints = trainDatapoints;

      printSummaries(ClusteringUtils.summarizeClusterDistances(trainDatapoints, centroids,
          new SquaredEuclideanDistanceMeasure()), "train");

      // Also adding in the "test" set.
      if (testFile != null) {
        SequenceFileDirValueIterable<VectorWritable> testIterable =
            new SequenceFileDirValueIterable<VectorWritable>(new Path(testFile), PathType.GLOB, conf);
        Iterable<Vector> testDatapoints = IOUtils.getVectorsFromVectorWritableIterable(testIterable);

        printSummaries(ClusteringUtils.summarizeClusterDistances(testDatapoints, centroids,
            new SquaredEuclideanDistanceMeasure()), "test");

        datapoints = Iterables.concat(trainDatapoints, testDatapoints);
      }

      // At this point, all train/test CSVs have been written. We now compute quality metrics.
      List<OnlineSummarizer> summaries =
          ClusteringUtils.summarizeClusterDistances(datapoints, centroids, distanceMeasure);
      List<OnlineSummarizer> compareSummaries = null;
      if (centroidsCompare != null) {
        compareSummaries = ClusteringUtils.summarizeClusterDistances(datapoints, centroidsCompare, distanceMeasure);
      }
      System.out.printf("[Dunn Index] First: %f", ClusteringUtils.dunnIndex(centroids, distanceMeasure, summaries));
      if (compareSummaries != null) {
        System.out.printf(" Second: %f\n",
            ClusteringUtils.dunnIndex(centroidsCompare, distanceMeasure, compareSummaries));
      } else {
        System.out.printf("\n");
      }
      System.out.printf("[Davies-Bouldin Index] First: %f",
          ClusteringUtils.daviesBouldinIndex(centroids, distanceMeasure, summaries));
      if (compareSummaries != null) {
        System.out.printf(" Second: %f\n",
          ClusteringUtils.daviesBouldinIndex(centroidsCompare, distanceMeasure, compareSummaries));
      } else {
        System.out.printf("\n");
      }

      if (outputFile != null) {
        fileOut.close();
      }
    } catch (IOException e) {
      System.out.println(e.getMessage());
    }
  }

  private boolean parseArgs(String[] args) {
    DefaultOptionBuilder builder = new DefaultOptionBuilder();

    Option help = builder.withLongName("help").withDescription("print this list").create();

    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
    Option inputFileOption = builder.withLongName("input")
        .withShortName("i")
        .withRequired(true)
        .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
        .withDescription("where to get seq files with the vectors (training set)")
        .create();

    Option testInputFileOption = builder.withLongName("testInput")
        .withShortName("itest")
        .withArgument(argumentBuilder.withName("testInput").withMaximum(1).create())
        .withDescription("where to get seq files with the vectors (test set)")
        .create();

    Option centroidsFileOption = builder.withLongName("centroids")
        .withShortName("c")
        .withRequired(true)
        .withArgument(argumentBuilder.withName("centroids").withMaximum(1).create())
        .withDescription("where to get seq files with the centroids (from Mahout KMeans or StreamingKMeansDriver)")
        .create();

    Option centroidsCompareFileOption = builder.withLongName("centroidsCompare")
        .withShortName("cc")
        .withRequired(false)
        .withArgument(argumentBuilder.withName("centroidsCompare").withMaximum(1).create())
        .withDescription("where to get seq files with the second set of centroids (from Mahout KMeans or "
            + "StreamingKMeansDriver)")
        .create();

    Option outputFileOption = builder.withLongName("output")
        .withShortName("o")
        .withRequired(true)
        .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
        .withDescription("where to dump the CSV file with the results")
        .create();

    Option mahoutKMeansFormatOption = builder.withLongName("mahoutkmeansformat")
        .withShortName("mkm")
        .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs")
        .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create())
        .create();

    Option mahoutKMeansCompareFormatOption = builder.withLongName("mahoutkmeansformatCompare")
        .withShortName("mkmc")
        .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs")
        .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create())
        .create();

    Group normalArgs = new GroupBuilder()
        .withOption(help)
        .withOption(inputFileOption)
        .withOption(testInputFileOption)
        .withOption(outputFileOption)
        .withOption(centroidsFileOption)
        .withOption(centroidsCompareFileOption)
        .withOption(mahoutKMeansFormatOption)
        .withOption(mahoutKMeansCompareFormatOption)
        .create();

    Parser parser = new Parser();
    parser.setHelpOption(help);
    parser.setHelpTrigger("--help");
    parser.setGroup(normalArgs);
    parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 150));

    CommandLine cmdLine = parser.parseAndHelp(args);
    if (cmdLine == null) {
      return false;
    }

    trainFile = (String) cmdLine.getValue(inputFileOption);
    if (cmdLine.hasOption(testInputFileOption)) {
      testFile = (String) cmdLine.getValue(testInputFileOption);
    }
    centroidFile = (String) cmdLine.getValue(centroidsFileOption);
    if (cmdLine.hasOption(centroidsCompareFileOption)) {
      centroidCompareFile = (String) cmdLine.getValue(centroidsCompareFileOption);
    }
    outputFile = (String) cmdLine.getValue(outputFileOption);
    if (cmdLine.hasOption(mahoutKMeansFormatOption)) {
      mahoutKMeansFormat = true;
    }
    if (cmdLine.hasOption(mahoutKMeansCompareFormatOption)) {
      mahoutKMeansFormatCompare = true;
    }
    return true;
  }

  public static void main(String[] args) {
    new ClusterQualitySummarizer().run(args);
  }
}
TOP

Related Classes of org.apache.mahout.clustering.streaming.tools.ClusterQualitySummarizer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.