Package org.grouplens.lenskit.knn.user

Source Code of org.grouplens.lenskit.knn.user.UserUserItemScorer

/*
* LensKit, an open source recommender systems toolkit.
* Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
* Work on LensKit has been funded by the National Science Foundation under
* grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.grouplens.lenskit.knn.user;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.grouplens.lenskit.basic.AbstractItemScorer;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.history.History;
import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.knn.MinNeighbors;
import org.grouplens.lenskit.knn.NeighborhoodSize;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.transform.normalize.UserVectorNormalizer;
import org.grouplens.lenskit.transform.normalize.VectorTransformation;
import org.grouplens.lenskit.transform.threshold.Threshold;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nonnull;
import javax.inject.Inject;
import java.util.Collection;
import java.util.PriorityQueue;

import static java.lang.Math.abs;

/**
* @author <a href="http://www.grouplens.org">GroupLens Research</a>
*/
public class UserUserItemScorer extends AbstractItemScorer {
    private static final Logger logger = LoggerFactory.getLogger(UserUserItemScorer.class);

    public static final Symbol NEIGHBORHOOD_SIZE_SYMBOL =
            Symbol.of("org.grouplens.lenskit.knn.user.NeighborhoodSize");
    public static final Symbol NEIGHBORHOOD_WEIGHT_SYMBOL =
            Symbol.of("org.grouplens.lenskit.knn.user.NeighborhoodWeight");

    private final UserEventDAO dao;
    protected final NeighborFinder neighborFinder;
    protected final UserVectorNormalizer normalizer;
    private final int neighborhoodSize;
    private final int minNeighborCount;
    private final Threshold userThreshold;

    @Inject
    public UserUserItemScorer(UserEventDAO dao, NeighborFinder nf,
                              UserVectorNormalizer norm,
                              @NeighborhoodSize int nnbrs,
                              @MinNeighbors int minNbrs,
                              @UserSimilarityThreshold Threshold thresh) {
        this.dao = dao;
        neighborFinder = nf;
        normalizer = norm;
        neighborhoodSize = nnbrs;
        minNeighborCount = minNbrs;
        userThreshold = thresh;
    }

    /**
     * Normalize all neighbor rating vectors, taking care to normalize each one
     * only once.
     *
     * FIXME: MDE does not like this method.
     *
     * @param neighborhoods
     */
    protected Long2ObjectMap<SparseVector> normalizeNeighborRatings(Collection<? extends Collection<Neighbor>> neighborhoods) {
        Long2ObjectMap<SparseVector> normedVectors =
                new Long2ObjectOpenHashMap<SparseVector>();
        for (Neighbor n : Iterables.concat(neighborhoods)) {
            if (!normedVectors.containsKey(n.user)) {
                normedVectors.put(n.user, normalizer.normalize(n.user, n.vector, null));
            }
        }
        return normedVectors;
    }

    @Override
    public void score(long user, @Nonnull MutableSparseVector scores) {
        UserHistory<Event> history = dao.getEventsForUser(user);
        if (history == null) {
            history = History.forUser(user);
        }
        logger.debug("Predicting for {} items for user {} with {} events",
                     scores.size(), user, history.size());

        Long2ObjectMap<? extends Collection<Neighbor>> neighborhoods =
                findNeighbors(history, scores.keyDomain());
        Long2ObjectMap<SparseVector> normedUsers =
                normalizeNeighborRatings(neighborhoods.values());

        MutableSparseVector sizeChan = scores.addChannelVector(NEIGHBORHOOD_SIZE_SYMBOL);
        MutableSparseVector weightChan = scores.addChannelVector(NEIGHBORHOOD_WEIGHT_SYMBOL);
        for (VectorEntry e : scores.view(VectorEntry.State.EITHER)) {
            final long item = e.getKey();
            double sum = 0;
            double weight = 0;
            int count = 0;
            Collection<Neighbor> nbrs = neighborhoods.get(item);
            if (nbrs != null) {
                for (Neighbor n : nbrs) {
                    weight += abs(n.similarity);
                    sum += n.similarity * normedUsers.get(n.user).get(item);
                    count += 1;
                }
            }
           
            if (count >= minNeighborCount) {
                if (logger.isTraceEnabled()) {
                    logger.trace("Total neighbor weight for item {} is {} from {} neighbors",
                                 item, weight, count);
                }
                scores.set(e, sum / weight);
            } else {
                scores.unset(e);
            }
            sizeChan.set(e, count);
            weightChan.set(e,weight);
        }

        // Denormalize and return the results
        SparseVector urv = RatingVectorUserHistorySummarizer.makeRatingVector(history);
        VectorTransformation vo = normalizer.makeTransformation(history.getUserId(), urv);
        vo.unapply(scores);
    }

    /**
     * Find the neighbors for a user with respect to a collection of items.
     * For each item, the {@var neighborhoodSize} users closest to the
     * provided user are returned.
     *
     * @param user  The user's rating vector.
     * @param items The items for which neighborhoods are requested.
     * @return A mapping of item IDs to neighborhoods.
     */
    protected Long2ObjectMap<? extends Collection<Neighbor>>
    findNeighbors(@Nonnull UserHistory<? extends Event> user, @Nonnull LongSet items) {
        Preconditions.checkNotNull(user, "user profile");
        Preconditions.checkNotNull(user, "item set");

        Long2ObjectMap<PriorityQueue<Neighbor>> heaps = new Long2ObjectOpenHashMap<PriorityQueue<Neighbor>>(items.size());
        for (LongIterator iter = items.iterator(); iter.hasNext();) {
            long item = iter.nextLong();
            heaps.put(item, new PriorityQueue<Neighbor>(neighborhoodSize + 1,
                                                        Neighbor.SIMILARITY_COMPARATOR));
        }

        int neighborsUsed = 0;
        for (Neighbor nbr: neighborFinder.getCandidateNeighbors(user, items)) {
            for (VectorEntry e: nbr.vector) {
                final long item = e.getKey();
                PriorityQueue<Neighbor> heap = heaps.get(item);
                if (heap != null) {
                    heap.add(nbr);
                    if (heap.size() > neighborhoodSize) {
                        assert heap.size() == neighborhoodSize + 1;
                        heap.remove();
                    } else {
                        neighborsUsed += 1;
                    }
                }
            }
        }
        logger.debug("using {} neighbors across {} items",
                     neighborsUsed, items.size());
        return heaps;
    }
}
TOP

Related Classes of org.grouplens.lenskit.knn.user.UserUserItemScorer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.