Package org.apache.mahout.cf.taste.impl.eval

Source Code of org.apache.mahout.cf.taste.impl.eval.GenericRecommenderIRStatsEvaluator

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.cf.taste.impl.eval;

import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.eval.IRStatistics;
import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
import org.apache.mahout.cf.taste.eval.RecommenderIRStatsEvaluator;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.RandomUtils;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.FastSet;
import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
import org.apache.mahout.cf.taste.impl.model.GenericUser;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Item;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.User;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.cf.taste.recommender.Recommender;
import org.apache.mahout.cf.taste.recommender.Rescorer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

/**
* <p>For each {@link User}, these implementation determine the top <code>n</code> preferences,
* then evaluate the IR statistics based on a {@link DataModel} that does not have these values.
* This number <code>n</code> is the "at" value, as in "precision at 5". For example, this would mean precision
* evaluated by removing the top 5 preferences for a {@link User} and then finding the percentage of those 5
* {@link Item}s included in the top 5 recommendations for that user.</p>
*/
public final class GenericRecommenderIRStatsEvaluator implements RecommenderIRStatsEvaluator {

  private static final Logger log = LoggerFactory.getLogger(GenericRecommenderIRStatsEvaluator.class);

  /**
   * Pass as "relevanceThreshold" argument to
   * {@link #evaluate(RecommenderBuilder, DataModel, Rescorer, int, double, double)} to have it attempt
   * to compute a reasonable threshold. Note that this will impact performance.
   */
  public static final double CHOOSE_THRESHOLD = Double.NaN;

  private final Random random;

  public GenericRecommenderIRStatsEvaluator() {
    random = RandomUtils.getRandom();
  }

  @Override
  public IRStatistics evaluate(RecommenderBuilder recommenderBuilder,
                               DataModel dataModel,
                               Rescorer<Item> rescorer,
                               int at,
                               double relevanceThreshold,
                               double evaluationPercentage) throws TasteException {

    if (recommenderBuilder == null) {
      throw new IllegalArgumentException("recommenderBuilder is null");
    }
    if (dataModel == null) {
      throw new IllegalArgumentException("dataModel is null");
    }
    if (at < 1) {
      throw new IllegalArgumentException("at must be at least 1");
    }
    if (Double.isNaN(evaluationPercentage) || evaluationPercentage <= 0.0 || evaluationPercentage > 1.0) {
      throw new IllegalArgumentException("Invalid evaluationPercentage: " + evaluationPercentage);
    }
    if (Double.isNaN(relevanceThreshold)) {
      throw new IllegalArgumentException("Invalid relevanceThreshold: " + evaluationPercentage);
    }

    int numItems = dataModel.getNumItems();
    RunningAverage precision = new FullRunningAverage();
    RunningAverage recall = new FullRunningAverage();
    RunningAverage fallOut = new FullRunningAverage();
    for (User user : dataModel.getUsers()) {
      if (random.nextDouble() < evaluationPercentage) {
        Object id = user.getID();
        Collection<Item> relevantItems = new FastSet<Item>(at);
        Preference[] prefs = user.getPreferencesAsArray();
        double theRelevanceThreshold = Double.isNaN(relevanceThreshold) ? computeThreshold(prefs) : relevanceThreshold;
        for (Preference pref : prefs) {
          if (pref.getValue() >= theRelevanceThreshold) {
            relevantItems.add(pref.getItem());
          }
        }
        int numRelevantItems = relevantItems.size();
        if (numRelevantItems > 0) {
          List<User> trainingUsers = new ArrayList<User>(dataModel.getNumUsers());
          for (User user2 : dataModel.getUsers()) {
            processOtherUser(id, relevantItems, trainingUsers, user2);
          }
          DataModel trainingModel = new GenericDataModel(trainingUsers);
          Recommender recommender = recommenderBuilder.buildRecommender(trainingModel);

          try {
            trainingModel.getUser(id);
          } catch (NoSuchUserException nsee) {
            continue; // Oops we excluded all prefs for the user -- just move on
          }

          int intersectionSize = 0;
          List<RecommendedItem> recommendedItems = recommender.recommend(id, at, rescorer);
          for (RecommendedItem recommendedItem : recommendedItems) {
            if (relevantItems.contains(recommendedItem.getItem())) {
              intersectionSize++;
            }
          }
          int numRecommendedItems = recommendedItems.size();
          if (numRecommendedItems > 0) {
            precision.addDatum((double) intersectionSize / (double) numRecommendedItems);
          }
          recall.addDatum((double) intersectionSize / (double) numRelevantItems);
          if (numRelevantItems < prefs.length) {
            fallOut.addDatum((double) (numRecommendedItems - intersectionSize) /
                             (double) (numItems - numRelevantItems));
          }

          log.info("Precision/recall/fall-out: {} / {} / {}", new Object[] {
              precision.getAverage(), recall.getAverage(), fallOut.getAverage()
          });
        }
      }
    }

    return new IRStatisticsImpl(precision.getAverage(), recall.getAverage(), fallOut.getAverage());
  }

  private void processOtherUser(Object id,
                                Collection<Item> relevantItems,
                                Collection<User> trainingUsers,
                                User user2) {
    if (id.equals(user2.getID())) {
      List<Preference> trainingPrefs = new ArrayList<Preference>();
      Preference[] prefs2 = user2.getPreferencesAsArray();
      for (Preference pref : prefs2) {
        if (!relevantItems.contains(pref.getItem())) {
          trainingPrefs.add(pref);
        }
      }
      if (!trainingPrefs.isEmpty()) {
        User trainingUser = new GenericUser<String>(id.toString(), trainingPrefs);
        trainingUsers.add(trainingUser);
      }
    } else {
      trainingUsers.add(user2);
    }
  }

  private static double computeThreshold(Preference[] prefs) {
    if (prefs.length < 2) {
      // Not enough data points -- return a threshold that allows everything
      return Double.NEGATIVE_INFINITY;
    }
    RunningAverageAndStdDev stdDev = new FullRunningAverageAndStdDev();
    for (Preference pref : prefs) {
      stdDev.addDatum(pref.getValue());
    }
    return stdDev.getAverage() + stdDev.getStandardDeviation();
  }

}
TOP

Related Classes of org.apache.mahout.cf.taste.impl.eval.GenericRecommenderIRStatsEvaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.