Package org.apache.mahout.benchmark

Source Code of org.apache.mahout.benchmark.VectorBenchmarks

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.benchmark;

import java.io.IOException;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry;
import java.util.regex.Pattern;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.TimingStatistics;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
import org.apache.mahout.common.distance.TanimotoDistanceMeasure;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VectorBenchmarks {

  private static final Logger log = LoggerFactory.getLogger(VectorBenchmarks.class);
  private static final Pattern TAB_PATTERN = Pattern.compile("\t");

  private final Vector[][] vectors;
  private final Vector[] clusters;
  private final SparseMatrix clusterDistances;
  private final List<Vector> randomVectors = new ArrayList<Vector>();
  private final List<int[]> randomVectorIndices = new ArrayList<int[]>();
  private final List<double[]> randomVectorValues = new ArrayList<double[]>();
  private final int cardinality;
  private final int sparsity;
  private final int numVectors;
  private final int loop;
  private final int opsPerUnit;
  private final Map<String,Integer> implType = new HashMap<String,Integer>();
  private final Map<String,List<String[]>> statsMap = new HashMap<String,List<String[]>>();
  private final int numClusters;
 
  public VectorBenchmarks(int cardinality, int sparsity, int numVectors, int numClusters, int loop, int opsPerUnit) {
    Random r = RandomUtils.getRandom();
    this.cardinality = cardinality;
    this.sparsity = sparsity;
    this.numVectors = numVectors;
    this.numClusters = numClusters;
    this.loop = loop;
    this.opsPerUnit = opsPerUnit;
    for (int i = 0; i < numVectors; i++) {
      Vector v = new SequentialAccessSparseVector(cardinality, sparsity); // sparsity!
      BitSet featureSpace = new BitSet(cardinality);
      int[] indexes = new int[sparsity];
      double[] values = new double[sparsity];
      int j = 0;
      while (j < sparsity) {
        double value = r.nextGaussian();
        int index = r.nextInt(cardinality);
        if (!featureSpace.get(index)) {
          featureSpace.set(index);
          indexes[j] = index;
          values[j++] = value;
          v.set(index, value);
        }
      }
      randomVectorIndices.add(indexes);
      randomVectorValues.add(values);
      randomVectors.add(v);
    }
    vectors = new Vector[3][numVectors];
    clusters = new Vector[numClusters];
    clusterDistances = new SparseMatrix(numClusters, numClusters);
  }
 
  private void printStats(TimingStatistics stats, String benchmarkName, String implName, String content) {
    printStats(stats, benchmarkName, implName, content, 1);
  }
 
  private void printStats(TimingStatistics stats, String benchmarkName, String implName) {
    printStats(stats, benchmarkName, implName, "", 1);
  }
 
  private void printStats(TimingStatistics stats,
                          String benchmarkName,
                          String implName,
                          String content,
                          int multiplier) {
    float speed = multiplier * loop * numVectors * sparsity * 1000.0f * 12 / stats.getSumTime();
    float opsPerSec = loop * numVectors * 1000000000.0f / stats.getSumTime();
    log.info("{} {} \n{} {} \nSpeed: {} UnitsProcessed/sec {} MBytes/sec                                   ",
      new Object[] {benchmarkName, implName, content, stats.toString(), opsPerSec, speed});
    String info = stats.toString().replaceAll("\n", "\t") + "\tSpeed = " + opsPerSec + " /sec\tRate = "
                  + speed + " MB/s";
    if (!implType.containsKey(implName)) {
      implType.put(implName, implType.size());
    }
    int implId = implType.get(implName);
    if (!statsMap.containsKey(benchmarkName)) {
      statsMap.put(benchmarkName, new ArrayList<String[]>());
    }
    List<String[]> implStats = statsMap.get(benchmarkName);
    while (implStats.size() < implId + 1) {
      implStats.add(new String[] {});
    }
    implStats.set(implId, TAB_PATTERN.split(info));
  }
 
  public void createBenchmark() {
    TimingStatistics stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[0][i] = new DenseVector(randomVectors.get(i));
        call.end();
      }
    }
    printStats(stats, "Create (copy)", "DenseVector");
   
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[1][i] = new RandomAccessSparseVector(randomVectors.get(i));
        call.end();
      }
    }
    printStats(stats, "Create (copy)", "RandSparseVector");
   
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[2][i] = new SequentialAccessSparseVector(randomVectors.get(i));
        call.end();
      }
    }
    printStats(stats, "Create (copy)", "SeqSparseVector");
   
  }

  private void buildVectorIncrementally(TimingStatistics stats, int randomIndex, Vector v, boolean useSetQuick) {
    int[] indexes = randomVectorIndices.get(randomIndex);
    double[] values = randomVectorValues.get(randomIndex);
    List<Integer> randomOrder = new ArrayList<Integer>();
    for (int i = 0; i < indexes.length; i++) {
      randomOrder.add(i);
    }
    Collections.shuffle(randomOrder);
    int[] permutation = new int[randomOrder.size()];
    for (int i = 0; i < randomOrder.size(); i++) {
      permutation[i] = randomOrder.get(i);
    }

    TimingStatistics.Call call = stats.newCall();
    if (useSetQuick) {
      for (int i : permutation) {
        v.setQuick(indexes[i], values[i]);
      }
    } else {
      for (int i : permutation) {
        v.set(indexes[i], values[i]);
      }
    }
    call.end();
  }

  public void incrementalCreateBenchmark() {
    TimingStatistics stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        vectors[0][i] = new DenseVector(cardinality);
        buildVectorIncrementally(stats, i, vectors[0][i], false);
      }
    }
    printStats(stats, "Create (incrementally)", "DenseVector");

    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        vectors[1][i] = new RandomAccessSparseVector(cardinality);
        buildVectorIncrementally(stats, i, vectors[1][i], false);
      }
    }
    printStats(stats, "Create (incrementally)", "RandSparseVector");

//    stats = new TimingStatistics();
//    for (int l = 0; l < loop; l++) {
//      for (int i = 0; i < numVectors; i++) {
//        vectors[2][i] = new SequentialAccessSparseVector(cardinality);
//        buildVectorIncrementally(stats, i, vectors[2][i], false);
//      }
//    }
//    printStats(stats, "Create (incrementally)", "SeqSparseVector");
   
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numClusters; i++) {
        clusters[i] = new RandomAccessSparseVector(cardinality);
        buildVectorIncrementally(stats, i, clusters[i], false);
      }
    }
    printStats(stats, "Create (incrementally)", "Clusters");
  }
 
  public void cloneBenchmark() {
    TimingStatistics stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[0][i] = vectors[0][i].clone();
        call.end();
      }
    }
    printStats(stats, "Clone", "DenseVector");
   
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[1][i] = vectors[1][i].clone();
        call.end();
      }
    }
    printStats(stats, "Clone", "RandSparseVector");
   
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vectors[2][i] = vectors[2][i].clone();
        call.end();
      }
    }
    printStats(stats, "Clone", "SeqSparseVector");
   
  }
 
  public void serializeBenchmark() throws IOException {
    Configuration conf = new Configuration();
    FileSystem fs = FileSystem.get(conf);
    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf,
      new Path("/tmp/dense-vector"), IntWritable.class, VectorWritable.class);

    Writable one = new IntWritable(0);
    VectorWritable vec = new VectorWritable();
   
    TimingStatistics stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vec.set(vectors[0][i]);
        writer.append(one, vec);
        call.end();
      }
    }
    writer.close();
    printStats(stats, "Serialize", "DenseVector");
   
    writer = new SequenceFile.Writer(fs, conf,
      new Path("/tmp/randsparse-vector"), IntWritable.class, VectorWritable.class);
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vec.set(vectors[1][i]);
        writer.append(one, vec);
        call.end();
      }
    }
    writer.close();
    printStats(stats, "Serialize", "RandSparseVector");
   
    writer = new SequenceFile.Writer(fs, conf,
      new Path("/tmp/seqsparse-vector"), IntWritable.class, VectorWritable.class);
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        vec.set(vectors[2][i]);
        writer.append(one, vec);
        call.end();
      }
    }
    writer.close();
    printStats(stats, "Serialize", "SeqSparseVector");
   
  }
 
  public void deserializeBenchmark() throws IOException {
    doDeserializeBenchmark("DenseVector", "/tmp/dense-vector");
    doDeserializeBenchmark("RandSparseVector", "/tmp/randsparse-vector");
    doDeserializeBenchmark("SeqSparseVector", "/tmp/seqsparse-vector");
  }

  private void doDeserializeBenchmark(String name, String pathString) throws IOException {
    TimingStatistics stats = new TimingStatistics();
    TimingStatistics.Call call = stats.newCall();
    Iterator<?> iterator = new SequenceFileValueIterator<Writable>(new Path(pathString), true, new Configuration());
    while (iterator.hasNext()) {
      iterator.next();
      call.end();
      call = stats.newCall();
    }
    printStats(stats, "Deserialize", name);
  }
 
  public void dotBenchmark() {
    double result = 0;
    TimingStatistics stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[0][i].dot(vectors[0][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "DenseVector", "sum = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[1][i].dot(vectors[1][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "RandSparseVector", "sum = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[2][i].dot(vectors[2][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "SeqSparseVector", "sum = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[0][i].dot(vectors[1][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "Dense.fn(Rand)", "sum = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[0][i].dot(vectors[2][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "Dense.fn(Seq)", "sum = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[1][i].dot(vectors[0][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "Rand.fn(Dense)", "sum = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[1][i].dot(vectors[2][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "Rand.fn(Seq)", "sum = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[2][i].dot(vectors[0][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "Seq.fn(Dense)", "sum = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        result += vectors[2][i].dot(vectors[1][(i + 1) % numVectors]);
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, "DotProduct", "Seq.fn(Rand)", "sum = " + result + ' ');


  }


  public void closestCentroidBenchmark(DistanceMeasure measure) {

    for (int i = 0; i < numClusters; i++) {
      for (int j = 0; j < numClusters; j++) {
        double distance = Double.POSITIVE_INFINITY;
        if (i != j) {
          distance = measure.distance(clusters[i], clusters[j]);
        }
        clusterDistances.setQuick(i, j, distance);
      }
    }

    long distanceCalculations = 0;
    TimingStatistics stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      TimingStatistics.Call call = stats.newCall();
      for (int i = 0; i < numVectors; i++) {
        Vector vector = vectors[1][i];
        double minDistance = Double.MAX_VALUE;
        for (int k = 0; k < numClusters; k++) {
          double distance = measure.distance(vector, clusters[k]);
          distanceCalculations++;
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
      }
      call.end();
    }
    printStats(stats,
               measure.getClass().getName(),
               "Closest center without Elkan's trick",
               "distanceCalculations = " + distanceCalculations);


    distanceCalculations = 0;
    stats = new TimingStatistics();
    Random rand = RandomUtils.getRandom();
    //rand.setSeed(System.currentTimeMillis());
    for (int l = 0; l < loop; l++) {
      TimingStatistics.Call call = stats.newCall();
      for (int i = 0; i < numVectors; i++) {
        Vector vector = vectors[1][i];
        int closestCentroid = rand.nextInt(numClusters);
        double dist = measure.distance(vector, clusters[closestCentroid]);
        distanceCalculations++;
        for (int k = 0; k < numClusters; k++) {
          if (closestCentroid != k) {
            double centroidDist = clusterDistances.getQuick(k, closestCentroid);
            if (centroidDist < 2 * dist) {
              dist = measure.distance(vector, clusters[k]);
              closestCentroid = k;
              distanceCalculations++;
            }
          }
        }
      }
      call.end();
    }
    printStats(stats,
               measure.getClass().getName(),
               "Closest center with Elkan's trick",
               "distanceCalculations = " + distanceCalculations);
  }

  public void distanceMeasureBenchmark(DistanceMeasure measure) {
    double result = 0;
    TimingStatistics stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[0][i], vectors[0][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "DenseVector", "minDistance = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[1][i], vectors[1][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "RandSparseVector", "minDistance = " + result
                                                                                + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[2][i], vectors[2][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "SeqSparseVector", "minDistance = " + result
                                                                                    + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[0][i], vectors[1][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "Dense.fn(Rand)", "minDistance = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[0][i], vectors[2][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "Dense.fn(Seq)", "minDistance = " + result
                                                                                + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[1][i], vectors[0][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "Rand.fn(Dense)", "minDistance = " + result
                                                                                    + ' ');
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[1][i], vectors[2][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "Rand.fn(Seq)", "minDistance = " + result + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[2][i], vectors[0][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "Seq.fn(Dense)", "minDistance = " + result
                                                                                + ' ');
    result = 0;
    stats = new TimingStatistics();
    for (int l = 0; l < loop; l++) {
      for (int i = 0; i < numVectors; i++) {
        TimingStatistics.Call call = stats.newCall();
        double minDistance = Double.MAX_VALUE;
        for (int u = 0; u < opsPerUnit; u++) {
          double distance = measure.distance(vectors[2][i], vectors[1][u]);
          if (distance < minDistance) {
            minDistance = distance;
          }
        }
        result += minDistance;
        call.end();
      }
    }
    // print result to prevent hotspot from eliminating deadcode
    printStats(stats, measure.getClass().getName(), "Seq.fn(Rand)", "minDistance = " + result
                                                                                    + ' ');
   
  }
 
  public static void main(String[] args) throws IOException {
   
    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
    ArgumentBuilder abuilder = new ArgumentBuilder();
    GroupBuilder gbuilder = new GroupBuilder();
   
    Option vectorSizeOpt = obuilder.withLongName("vectorSize").withRequired(false).withArgument(
      abuilder.withName("vs").withMinimum(1).withMaximum(1).create()).withDescription(
      "Cardinality of the vector. Default 1000").withShortName("vs").create();
   
    Option vectorSparsityOpt = obuilder.withLongName("sparsity").withRequired(false).withArgument(
      abuilder.withName("sp").withMinimum(1).withMaximum(1).create()).withDescription(
      "Sparsity of the vector. Default 1000").withShortName("sp").create();
    Option numVectorsOpt = obuilder.withLongName("numVectors").withRequired(false).withArgument(
      abuilder.withName("nv").withMinimum(1).withMaximum(1).create()).withDescription(
      "Number of Vectors to create. Default: 100").withShortName("nv").create();
    Option numClustersOpt = obuilder.withLongName("numClusters").withRequired(false).withArgument(
          abuilder.withName("vs").withMinimum(1).withMaximum(1).create()).withDescription(
          "Number of Vectors to create. Default: 10").withShortName("vs").create();
    Option loopOpt = obuilder.withLongName("loop").withRequired(false).withArgument(
      abuilder.withName("loop").withMinimum(1).withMaximum(1).create()).withDescription(
      "Number of times to loop. Default: 200").withShortName("l").create();
    Option numOpsOpt = obuilder.withLongName("numOps").withRequired(false).withArgument(
      abuilder.withName("numOps").withMinimum(1).withMaximum(1).create()).withDescription(
      "Number of operations to do per timer. "
          + "E.g In distance measure, the distance is calculated numOps times"
          + " and the total time is measured. Default: 10").withShortName("no").create();
   
    Option helpOpt = DefaultOptionCreator.helpOption();
   
    Group group = gbuilder.withName("Options").withOption(vectorSizeOpt).withOption(vectorSparsityOpt)
        .withOption(numVectorsOpt).withOption(loopOpt).withOption(numOpsOpt).withOption(helpOpt).create();
   
    try {
      Parser parser = new Parser();
      parser.setGroup(group);
      CommandLine cmdLine = parser.parse(args);
     
      if (cmdLine.hasOption(helpOpt)) {
        CommandLineUtil.printHelp(group);
        return;
      }
     
      int cardinality = 1000;
      if (cmdLine.hasOption(vectorSizeOpt)) {
        cardinality = Integer.parseInt((String) cmdLine.getValue(vectorSizeOpt));
       
      }   
     
      int numClusters=25;
      if (cmdLine.hasOption(numClustersOpt)) {
        numClusters = Integer.parseInt((String) cmdLine.getValue(numClustersOpt));
      }

      int sparsity = 1000;
      if (cmdLine.hasOption(vectorSparsityOpt)) {
        sparsity = Integer.parseInt((String) cmdLine.getValue(vectorSparsityOpt));
      }

      int numVectors = 100;
      if (cmdLine.hasOption(numVectorsOpt)) {
        numVectors = Integer.parseInt((String) cmdLine.getValue(numVectorsOpt));
       
      }
      int loop = 200;
      if (cmdLine.hasOption(loopOpt)) {
        loop = Integer.parseInt((String) cmdLine.getValue(loopOpt));
       
      }
      int numOps = 10;
      if (cmdLine.hasOption(numOpsOpt)) {
        numOps = Integer.parseInt((String) cmdLine.getValue(numOpsOpt));
       
      }
      VectorBenchmarks mark = new VectorBenchmarks(cardinality, sparsity, numVectors, numClusters, loop, numOps);
      mark.createBenchmark();
      mark.incrementalCreateBenchmark();
      mark.cloneBenchmark();
      mark.dotBenchmark();
      mark.serializeBenchmark();
      mark.deserializeBenchmark();
      mark.distanceMeasureBenchmark(new CosineDistanceMeasure());
      mark.distanceMeasureBenchmark(new SquaredEuclideanDistanceMeasure());
      mark.distanceMeasureBenchmark(new EuclideanDistanceMeasure());
      mark.distanceMeasureBenchmark(new ManhattanDistanceMeasure());
      mark.distanceMeasureBenchmark(new TanimotoDistanceMeasure());
     
      mark.closestCentroidBenchmark(new CosineDistanceMeasure());
      mark.closestCentroidBenchmark(new SquaredEuclideanDistanceMeasure());
      mark.closestCentroidBenchmark(new EuclideanDistanceMeasure());
      mark.closestCentroidBenchmark(new ManhattanDistanceMeasure());
      mark.closestCentroidBenchmark(new TanimotoDistanceMeasure());
     
      log.info("\n{}", mark);
    } catch (OptionException e) {
      CommandLineUtil.printHelp(group);
    }
   
  }
 
  @Override
  public String toString() {
    int pad = 24;
    StringBuilder sb = new StringBuilder(1000);
    sb.append(StringUtils.rightPad("BenchMarks", pad));
    for (int i = 0; i < implType.size(); i++) {
      for (Entry<String,Integer> e : implType.entrySet()) {
        if (e.getValue() == i) {
          sb.append(StringUtils.rightPad(e.getKey(), pad).substring(0, pad));
          break;
        }
      }
    }
    sb.append('\n');
    List<String> keys = new ArrayList<String>(statsMap.keySet());
    Collections.sort(keys);
    for (String benchmarkName : keys) {
      List<String[]> implTokenizedStats = statsMap.get(benchmarkName);
      int maxStats = 0;
      for (String[] stat : implTokenizedStats) {
        maxStats = Math.max(maxStats, stat.length);
      }
     
      for (int i = 0; i < maxStats; i++) {
        boolean printedName = false;
        for (String[] stats : implTokenizedStats) {
          if (i == 0 && !printedName) {
            sb.append(StringUtils.rightPad(benchmarkName, pad));
            printedName = true;
          } else if (!printedName) {
            printedName = true;
            sb.append(StringUtils.rightPad("", pad));
          }
          if (stats.length > i) {
            sb.append(StringUtils.rightPad(stats[i], pad));
          } else {
            sb.append(StringUtils.rightPad("", pad));
          }

        }
        sb.append('\n');
      }
      sb.append('\n');
    }
    return sb.toString();
  }
 
}
TOP

Related Classes of org.apache.mahout.benchmark.VectorBenchmarks

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.