/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.cf.taste.impl.recommender.svd;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
/** Matrix factorization with user and item biases for rating prediction, trained with plain vanilla SGD */
public class RatingSGDFactorizer extends AbstractFactorizer {
protected static final int FEATURE_OFFSET = 3;
/** Multiplicative decay factor for learning_rate */
protected final double learningRateDecay;
/** Learning rate (step size) */
protected final double learningRate;
/** Parameter used to prevent overfitting. */
protected final double preventOverfitting;
/** Number of features used to compute this factorization */
protected final int numFeatures;
/** Number of iterations */
private final int numIterations;
/** Standard deviation for random initialization of features */
protected final double randomNoise;
/** User features */
protected double[][] userVectors;
/** Item features */
protected double[][] itemVectors;
protected final DataModel dataModel;
private long[] cachedUserIDs;
private long[] cachedItemIDs;
protected double biasLearningRate = 0.5;
protected double biasReg = 0.1;
/** place in user vector where the bias is stored */
protected static final int USER_BIAS_INDEX = 1;
/** place in item vector where the bias is stored */
protected static final int ITEM_BIAS_INDEX = 2;
public RatingSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException {
this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
}
public RatingSGDFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting,
double randomNoise, int numIterations, double learningRateDecay) throws TasteException {
super(dataModel);
this.dataModel = dataModel;
this.numFeatures = numFeatures + FEATURE_OFFSET;
this.numIterations = numIterations;
this.learningRate = learningRate;
this.learningRateDecay = learningRateDecay;
this.preventOverfitting = preventOverfitting;
this.randomNoise = randomNoise;
}
protected void prepareTraining() throws TasteException {
RandomWrapper random = RandomUtils.getRandom();
userVectors = new double[dataModel.getNumUsers()][numFeatures];
itemVectors = new double[dataModel.getNumItems()][numFeatures];
double globalAverage = getAveragePreference();
for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
userVectors[userIndex][0] = globalAverage;
userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias
userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias
for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
userVectors[userIndex][feature] = random.nextGaussian() * randomNoise;
}
}
for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) {
itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average
itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias
itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias
for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
itemVectors[itemIndex][feature] = random.nextGaussian() * randomNoise;
}
}
cachePreferences();
shufflePreferences();
}
private int countPreferences() throws TasteException {
int numPreferences = 0;
LongPrimitiveIterator userIDs = dataModel.getUserIDs();
while (userIDs.hasNext()) {
PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong());
numPreferences += preferencesFromUser.length();
}
return numPreferences;
}
private void cachePreferences() throws TasteException {
int numPreferences = countPreferences();
cachedUserIDs = new long[numPreferences];
cachedItemIDs = new long[numPreferences];
LongPrimitiveIterator userIDs = dataModel.getUserIDs();
int index = 0;
while (userIDs.hasNext()) {
long userID = userIDs.nextLong();
PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID);
for (Preference preference : preferencesFromUser) {
cachedUserIDs[index] = userID;
cachedItemIDs[index] = preference.getItemID();
index++;
}
}
}
protected void shufflePreferences() {
RandomWrapper random = RandomUtils.getRandom();
/* Durstenfeld shuffle */
for (int currentPos = cachedUserIDs.length - 1; currentPos > 0; currentPos--) {
int swapPos = random.nextInt(currentPos + 1);
swapCachedPreferences(currentPos, swapPos);
}
}
private void swapCachedPreferences(int posA, int posB) {
long tmpUserIndex = cachedUserIDs[posA];
long tmpItemIndex = cachedItemIDs[posA];
cachedUserIDs[posA] = cachedUserIDs[posB];
cachedItemIDs[posA] = cachedItemIDs[posB];
cachedUserIDs[posB] = tmpUserIndex;
cachedItemIDs[posB] = tmpItemIndex;
}
@Override
public Factorization factorize() throws TasteException {
prepareTraining();
double currentLearningRate = learningRate;
for (int it = 0; it < numIterations; it++) {
for (int index = 0; index < cachedUserIDs.length; index++) {
long userId = cachedUserIDs[index];
long itemId = cachedItemIDs[index];
float rating = dataModel.getPreferenceValue(userId, itemId);
updateParameters(userId, itemId, rating, currentLearningRate);
}
currentLearningRate *= learningRateDecay;
}
return createFactorization(userVectors, itemVectors);
}
double getAveragePreference() throws TasteException {
RunningAverage average = new FullRunningAverage();
LongPrimitiveIterator it = dataModel.getUserIDs();
while (it.hasNext()) {
for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
average.addDatum(pref.getValue());
}
}
return average.getAverage();
}
protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) {
int userIndex = userIndex(userID);
int itemIndex = itemIndex(itemID);
double[] userVector = userVectors[userIndex];
double[] itemVector = itemVectors[itemIndex];
double prediction = predictRating(userIndex, itemIndex);
double err = rating - prediction;
// adjust user bias
userVector[USER_BIAS_INDEX] +=
biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]);
// adjust item bias
itemVector[ITEM_BIAS_INDEX] +=
biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]);
// adjust features
for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
double userFeature = userVector[feature];
double itemFeature = itemVector[feature];
double deltaUserFeature = err * itemFeature - preventOverfitting * userFeature;
userVector[feature] += currentLearningRate * deltaUserFeature;
double deltaItemFeature = err * userFeature - preventOverfitting * itemFeature;
itemVector[feature] += currentLearningRate * deltaItemFeature;
}
}
private double predictRating(int userID, int itemID) {
double sum = 0;
for (int feature = 0; feature < numFeatures; feature++) {
sum += userVectors[userID][feature] * itemVectors[itemID][feature];
}
return sum;
}
}