Package org.apache.mahout.classifier.sgd

Source Code of org.apache.mahout.classifier.sgd.OnlineBaseTest

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.sgd;

import com.google.common.base.CharMatcher;
import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.CharStreams;
import com.google.common.io.Resources;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;

import java.io.IOException;
import java.io.InputStreamReader;
import java.util.List;
import java.util.Map;
import java.util.Random;

public abstract class OnlineBaseTest extends MahoutTestCase {

  private Matrix input;

  Matrix getInput() {
    return input;
  }

  Vector readStandardData() throws IOException {
    // 60 test samples.  First column is constant.  Second and third are normally distributed from
    // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59).  The first 30 rows have a
    // target variable of 0, the last 30 a target of 1.  The remaining columns are are random noise.
    input = readCsv("sgd.csv");

    // regenerate the target variable
    Vector target = new DenseVector(60);
    target.assign(0);
    target.viewPart(30, 30).assign(1);
    return target;
  }

  static void train(Matrix input, Vector target, OnlineLearner lr) {
    RandomUtils.useTestSeed();
    Random gen = RandomUtils.getRandom();

    // train on samples in random order (but only one pass)
    for (int row : permute(gen, 60)) {
      lr.train((int) target.get(row), input.viewRow(row));
    }
    lr.close();
  }

  static void test(Matrix input, Vector target, AbstractVectorClassifier lr,
                   double expected_mean_error, double expected_absolute_error) {
    // now test the accuracy
    Matrix tmp = lr.classify(input);
    // mean(abs(tmp - target))
    double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;

    // max(abs(tmp - target)
    double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);

    System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
    assertEquals(0, meanAbsoluteError , expected_mean_error);
    assertEquals(0, maxAbsoluteError, expected_absolute_error);

    // convenience methods should give the same results
    Vector v = lr.classifyScalar(input);
    assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5);
    v = lr.classifyFull(input).viewColumn(1);
    assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4);
  }

  /**
   * Permute the integers from 0 ... max-1
   *
   * @param gen The random number generator to use.
   * @param max The number of integers to permute
   * @return An array of jumbled integer values
   */
  static int[] permute(Random gen, int max) {
    int[] permutation = new int[max];
    permutation[0] = 0;
    for (int i = 1; i < max; i++) {
      int n = gen.nextInt(i + 1);
      if (n == i) {
        permutation[i] = i;
      } else {
        permutation[i] = permutation[n];
        permutation[n] = i;
      }
    }
    return permutation;
  }


  /**
   * Reads a file containing CSV data.  This isn't implemented quite the way you might like for a
   * real program, but does the job for reading test data.  Most notably, it will only read numbers,
   * not quoted strings.
   *
   * @param resourceName Where to get the data.
   * @return A matrix of the results.
   * @throws IOException If there is an error reading the data
   */
  static Matrix readCsv(String resourceName) throws IOException {
    Splitter onCommas = Splitter.on(',').trimResults(CharMatcher.anyOf(" \""));

    Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8);
    List<String> data = CharStreams.readLines(isr);
    String first = data.get(0);
    data = data.subList(1, data.size());

    List<String> values = Lists.newArrayList(onCommas.split(first));
    Matrix r = new DenseMatrix(data.size(), values.size());

    int column = 0;
    Map<String, Integer> labels = Maps.newHashMap();
    for (String value : values) {
      labels.put(value, column);
      column++;
    }
    r.setColumnLabelBindings(labels);

    int row = 0;
    for (String line : data) {
      column = 0;
      values = Lists.newArrayList(onCommas.split(line));
      for (String value : values) {
        r.set(row, column, Double.parseDouble(value));
        column++;
      }
      row++;
    }

    return r;
  }
}
TOP

Related Classes of org.apache.mahout.classifier.sgd.OnlineBaseTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.