Package org.apache.mahout.classifier.sgd

Source Code of org.apache.mahout.classifier.sgd.OnlineLogisticRegressionTest

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.sgd;

import com.google.common.base.CharMatcher;
import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.CharStreams;
import com.google.common.io.Resources;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.junit.Test;

import java.io.IOException;
import java.io.InputStreamReader;
import java.util.List;
import java.util.Map;
import java.util.Random;

public final class OnlineLogisticRegressionTest extends MahoutTestCase {

  private Matrix input;

  /**
   * The CrossFoldLearner is probably the best learner to use for new applications.
    * @throws IOException If test resources aren't readable.
   */
  @Test
  public void crossValidation() throws IOException {
    Vector target = readStandardData();

    CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
            .lambda(1 * 1.0e-3)
            .learningRate(50);


    train(input, target, lr);

    System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
    test(input, target, lr);

  }

  @Test
  public void crossValidatedAuc() throws IOException {
    RandomUtils.useTestSeed();
    Random gen = RandomUtils.getRandom();

    Matrix data = readCsv("cancer.csv");
    CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
            .stepOffset(10)
            .decayExponent(0.7)
            .lambda(1 * 1.0e-3)
            .learningRate(5);
    int k = 0;
    int[] ordering = permute(gen, data.numRows());
    for (int epoch = 0; epoch < 100; epoch++) {
      for (int row : ordering) {
        lr.train(row, (int) data.get(row, 9), data.viewRow(row));
        System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
      }
      assertEquals(1, lr.auc(), 0.2);
    }
    assertEquals(1, lr.auc(), 0.1);
  }

  /**
   * Verifies that a classifier with known coefficients does the right thing.
   */
  @Test
  public void testClassify() {
    OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
    // set up some internal coefficients as if we had learned them
    lr.setBeta(0, 0, -1);
    lr.setBeta(1, 0, -2);

    // zero vector gives no information.  All classes are equal.
    Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
    assertEquals(1 / 3.0, v.get(0), 1.0e-8);
    assertEquals(1 / 3.0, v.get(1), 1.0e-8);

    v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(1 / 3.0, v.get(0), 1.0e-8);
    assertEquals(1 / 3.0, v.get(1), 1.0e-8);
    assertEquals(1 / 3.0, v.get(2), 1.0e-8);

    // weights for second vector component are still zero so all classifications are equally likely
    v = lr.classify(new DenseVector(new double[]{0, 1}));
    assertEquals(1 / 3.0, v.get(0), 1.0e-3);
    assertEquals(1 / 3.0, v.get(1), 1.0e-3);

    v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(1 / 3.0, v.get(0), 1.0e-3);
    assertEquals(1 / 3.0, v.get(1), 1.0e-3);
    assertEquals(1 / 3.0, v.get(2), 1.0e-3);

    // but the weights on the first component are non-zero
    v = lr.classify(new DenseVector(new double[]{1, 0}));
    assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
    assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);

    v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
    assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
    assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);

    lr.setBeta(0, 1, 1);

    v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
    assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
    assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);

    lr.setBeta(1, 1, 3);

    v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
    assertEquals(1.0, v.zSum(), 1.0e-8);
    assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
    assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
    assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
  }

  @Test
  public void testTrain() throws Exception {
    Vector target = readStandardData();


    // lambda here needs to be relatively small to avoid swamping the actual signal, but can be
    // larger than usual because the data are dense.  The learning rate doesn't matter too much
    // for this example, but should generally be < 1
    // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
    //   --target y --categories 2 --predictors  V2 V3 V4 V5 V6 V7 --types n
    OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
            .lambda(1 * 1.0e-3)
            .learningRate(50);

    train(input, target, lr);
    test(input, target, lr);
  }

  private Vector readStandardData() throws IOException {
    // 60 test samples.  First column is constant.  Second and third are normally distributed from
    // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59).  The first 30 rows have a
    // target variable of 0, the last 30 a target of 1.  The remaining columns are are random noise.
    input = readCsv("sgd.csv");

    // regenerate the target variable
    Vector target = new DenseVector(60);
    target.assign(0);
    target.viewPart(30, 30).assign(1);
    return target;
  }

  private static void train(Matrix input, Vector target, OnlineLearner lr) {
    RandomUtils.useTestSeed();
    Random gen = RandomUtils.getRandom();

    // train on samples in random order (but only one pass)
    for (int row : permute(gen, 60)) {
      lr.train((int) target.get(row), input.getRow(row));
    }
    lr.close();
  }

  private static void test(Matrix input, Vector target, AbstractVectorClassifier lr) {
    // now test the accuracy
    Matrix tmp = lr.classify(input);
    // mean(abs(tmp - target))
    double meanAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;

    // max(abs(tmp - target)
    double maxAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);

    System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
    assertEquals(0, meanAbsoluteError , 0.05);
    assertEquals(0, maxAbsoluteError, 0.3);

    // convenience methods should give the same results
    Vector v = lr.classifyScalar(input);
    assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5);
    v = lr.classifyFull(input).getColumn(1);
    assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4);
  }

  /**
   * Permute the integers from 0 ... max-1
   *
   * @param gen The random number generator to use.
   * @param max The number of integers to permute
   * @return An array of jumbled integer values
   */
  private static int[] permute(Random gen, int max) {
    int[] permutation = new int[max];
    permutation[0] = 0;
    for (int i = 1; i < max; i++) {
      int n = gen.nextInt(i + 1);
      if (n == i) {
        permutation[i] = i;
      } else {
        permutation[i] = permutation[n];
        permutation[n] = i;
      }
    }
    return permutation;
  }


  /**
   * Reads a file containing CSV data.  This isn't implemented quite the way you might like for a
   * real program, but does the job for reading test data.  Most notably, it will only read numbers,
   * not quoted strings.
   *
   * @param resourceName Where to get the data.
   * @return A matrix of the results.
   * @throws IOException If there is an error reading the data
   */
  private static Matrix readCsv(String resourceName) throws IOException {
    Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \""));

    Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8);
    List<String> data = CharStreams.readLines(isr);
    String first = data.get(0);
    data = data.subList(1, data.size());

    List<String> values = Lists.newArrayList(onCommas.split(first));
    Matrix r = new DenseMatrix(data.size(), values.size());

    int column = 0;
    Map<String, Integer> labels = Maps.newHashMap();
    for (String value : values) {
      labels.put(value, column);
      column++;
    }
    r.setColumnLabelBindings(labels);

    int row = 0;
    for (String line : data) {
      column = 0;
      values = Lists.newArrayList(onCommas.split(line));
      for (String value : values) {
        r.set(row, column, Double.parseDouble(value));
        column++;
      }
      row++;
    }

    return r;
  }
}
TOP

Related Classes of org.apache.mahout.classifier.sgd.OnlineLogisticRegressionTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.