Package org.apache.mahout.utils.eval

Source Code of org.apache.mahout.utils.eval.InMemoryFactorizationEvaluator

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.utils.eval;

import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.model.GenericPreference;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.IOUtils;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.nio.charset.Charset;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
* <p>Measures the root-mean-squared error of a ratring matrix factorization against a test set.</p>
*
* <p>the factorization matrices are read into memory, which makes this job pretty fast, if you get OutOfMemoryErrors,
* use {@link ParallelFactorizationEvaluator} instead</p>
*
  * <p>Command line arguments specific to this class are:</p>
*
* <ol>
* <li>--output (path): path where output should go</li>
* <li>--pairs (path): path containing the test ratings, each line must be userID,itemID,rating</li>
* <li>--userFeatures (path): path to the user feature matrix</li>
* <li>--itemFeatures (path): path to the item feature matrix</li>
* </ol>
*/
public class InMemoryFactorizationEvaluator extends AbstractJob {

  public static void main(String[] args) throws Exception {
    ToolRunner.run(new InMemoryFactorizationEvaluator(), args);
  }

  @Override
  public int run(String[] args) throws Exception {

    addOption("pairs", "p", "path containing the test ratings, each line must be userID,itemID,rating", true);
    addOption("userFeatures", "u", "path to the user feature matrix", true);
    addOption("itemFeatures", "i", "path to the item feature matrix", true);
    addOutputOption();

    Map<String,String> parsedArgs = parseArguments(args);
    if (parsedArgs == null) {
      return -1;
    }

    Path pairs = new Path(parsedArgs.get("--pairs"));
    Path userFeatures = new Path(parsedArgs.get("--userFeatures"));
    Path itemFeatures = new Path(parsedArgs.get("--itemFeatures"));

    Matrix u = readMatrix(userFeatures);
    Matrix m = readMatrix(itemFeatures);

    FullRunningAverage rmseAvg = new FullRunningAverage();
    FullRunningAverage maeAvg = new FullRunningAverage();
    int pairsUsed = 1;
    Writer writer = new OutputStreamWriter(System.out);
    try {
      for (Preference pref : readProbePreferences(pairs)) {
        int userID = (int) pref.getUserID();
        int itemID = (int) pref.getItemID();

        double rating = pref.getValue();
        double estimate = u.getRow(userID).dot(m.getRow(itemID));
        double err = rating - estimate;
        rmseAvg.addDatum(err * err);
        maeAvg.addDatum(Math.abs(err));
        writer.write("Probe [" + pairsUsed + "], rating of user [" + userID + "] towards item [" + itemID + "], " +
            "[" + rating + "] estimated [" + estimate + "]\n");
        pairsUsed++;
      }
      double rmse = Math.sqrt(rmseAvg.getAverage());
      double mae = maeAvg.getAverage();
      writer.write("RMSE: " + rmse + ", MAE: " + mae + "\n");
    } finally {
      IOUtils.quietClose(writer);
    }
    return 0;
  }

  private Matrix readMatrix(Path dir) throws IOException {

    Matrix matrix = new SparseMatrix(new int[] { Integer.MAX_VALUE, Integer.MAX_VALUE });

    FileSystem fs = dir.getFileSystem(getConf());
    for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) {
      Path path = seqFile.getPath();
      SequenceFile.Reader reader = null;
      try {
        reader = new SequenceFile.Reader(fs, path, getConf());
        IntWritable key = new IntWritable();
        VectorWritable value = new VectorWritable();
        while (reader.next(key, value)) {
          int row = key.get();
          Iterator<Vector.Element> elementsIterator = value.get().iterateNonZero();
          while (elementsIterator.hasNext()) {
            Vector.Element element = elementsIterator.next();
            matrix.set(row, element.index(), element.get());
          }
        }
      } finally {
        IOUtils.quietClose(reader);
      }
    }
    return matrix;
  }

  private List<Preference> readProbePreferences(Path dir) throws IOException {

    List<Preference> preferences = new LinkedList<Preference>();
    FileSystem fs = dir.getFileSystem(getConf());
    for (FileStatus seqFile : fs.globStatus(new Path(dir, "part-*"))) {
      Path path = seqFile.getPath();
      InputStream in = null;
      try  {
        in = fs.open(path);
        BufferedReader reader = new BufferedReader(new InputStreamReader(in, Charset.forName("UTF-8")));
        String line;
        while ((line = reader.readLine()) != null) {
          String[] tokens = TasteHadoopUtils.splitPrefTokens(line);
          long userID = Long.parseLong(tokens[0]);
          long itemID = Long.parseLong(tokens[1]);
          float value = Float.parseFloat(tokens[2]);
          preferences.add(new GenericPreference(userID, itemID, value));
        }
      } finally {
        IOUtils.quietClose(in);
      }
    }
    return preferences;
  }
}
TOP

Related Classes of org.apache.mahout.utils.eval.InMemoryFactorizationEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.