Package org.grouplens.lenskit.scored

Source Code of org.grouplens.lenskit.scored.ScoredIdListBuilder$ChannelStorage

/*
* LensKit, an open source recommender systems toolkit.
* Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
* Work on LensKit has been funded by the National Science Foundation under
* grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.grouplens.lenskit.scored;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import it.unimi.dsi.fastutil.Swapper;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.AbstractIntComparator;
import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.LongList;
import it.unimi.dsi.fastutil.objects.Reference2ObjectArrayMap;
import org.apache.commons.lang3.builder.Builder;
import org.grouplens.lenskit.collections.CompactableLongArrayList;
import org.grouplens.lenskit.symbols.DoubleSymbolValue;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.symbols.SymbolValue;
import org.grouplens.lenskit.symbols.TypedSymbol;
import org.grouplens.lenskit.vectors.ImmutableSparseVector;
import org.grouplens.lenskit.vectors.MutableSparseVector;

import java.util.*;

import static it.unimi.dsi.fastutil.Arrays.quickSort;

/**
* Builder for packed lists of scored ids.  All ids in the resulting list will have the same set
* of side channels.
*
* @since 1.4
* @compat Public
* @author <a href="http://www.grouplens.org">GroupLens Research</a>
*/
public class ScoredIdListBuilder implements Builder<PackedScoredIdList> {
    // INVARIANT: all arrays (including channel arrays) have same size, which is capacity
    // INVARIANT: all arrays are non-null unless finish() has been called
    private CompactableLongArrayList ids;
    private DoubleArrayList scores;
    private boolean ignoreUnknown = false;
    private Map<Symbol,ChannelStorage> channels;
    private Map<TypedSymbol<?>,TypedChannelStorage<?>> typedChannels;

    public ScoredIdListBuilder() {
        this(10);
    }

    public ScoredIdListBuilder(int cap) {
        initialize(cap);
    }

    private void initialize(int cap) {
        ids = new CompactableLongArrayList(cap);
        scores = new DoubleArrayList(cap);
        channels = new Reference2ObjectArrayMap<Symbol, ChannelStorage>();
        typedChannels = new Reference2ObjectArrayMap<TypedSymbol<?>, TypedChannelStorage<?>>();
    }

    @Override
    public PackedScoredIdList build() {
        return finish(false);
    }

    /**
     * Destructive version of {@link #build()}, re-using storage if possible.  Future use of the
     * builder is impossible, and all memory used by it is released.
     *
     * @return The scored ID list.
     */
    public PackedScoredIdList finish() {
        return finish(true);
    }

    /**
     * Implementation of {@link #build()} and {@link #finish()}.
     * @param reuse Whether we should try to reuse the builder's storage for the packed list.
     *                 If {@code true}, the builder will be invalid after finishing and the packed
     *                 list will use the same arrays as the builder if they are full.
     * @return The packed ID list.
     */
    private PackedScoredIdList finish(boolean reuse) {
        Preconditions.checkState(ids != null, "builder has been finished");
        Map<Symbol, DoubleList> chans;
        Map<TypedSymbol<?>, List<?>> typedChans;
        if (size() > 0) {
            ImmutableMap.Builder<Symbol, DoubleList> cbld = ImmutableMap.builder();
            for (ChannelStorage chan: channels.values()) {
                DoubleArrayList built;
                if (reuse) {
                    built = chan.values;
                    built.trim();
                } else {
                    built = new DoubleArrayList(chan.values);
                }
                cbld.put(chan.symbol, built);
            }
            chans = cbld.build();
            ImmutableMap.Builder<TypedSymbol<?>, List<?>> tcbld = ImmutableMap.builder();
            for (TypedChannelStorage<?> chan: typedChannels.values()) {
                List<?> built;
                if (reuse) {
                    chan.values.trimToSize();
                    built = chan.values;
                } else {
                    built = new ArrayList<Object>(chan.values);
                }
                tcbld.put(chan.symbol, built);
            }
            typedChans = tcbld.build();
        } else {
            chans = Collections.emptyMap();
            typedChans = Collections.emptyMap();
        }
        LongList builtIds;
        DoubleList builtScores;
        if (reuse) {
            ids.trim();
            builtIds = ids;
            scores.trim();
            builtScores = scores;
            clear();
        } else {
            builtIds = new CompactableLongArrayList(ids);
            builtScores = new DoubleArrayList(scores);
        }
        return new PackedScoredIdList(builtIds, builtScores, typedChans, chans);
    }

    /**
     * Clear the builder. After it is cleared, it can no longer be used.
     */
    public void clear() {
        ids = null;
        scores = null;
        channels = null;
        typedChannels = null;
    }

    /**
     * Build a sparse vector directly from the list of IDs. This allows a scored ID list builder to
     * be used to efficiently accumulate a sparse vector.  If the same ID is added multiple times,
     * the first instance is used.
     *
     * @return A sparse vector containing the data accumulated.
     */
    @SuppressWarnings({"rawtypes", "unchecked"})
    public ImmutableSparseVector buildVector() {
        MutableSparseVector msv = MutableSparseVector.create(ids);
        final int size = size();
        for (int i = 0; i < size; i++) {
            msv.set(ids.get(i), scores.get(i));
        }

        for (ChannelStorage chan: channels.values()) {
            MutableSparseVector vchan = msv.getOrAddChannelVector(chan.symbol);
            for (int i = 0; i < size; i++) {
                vchan.set(ids.get(i), chan.values.get(i));
            }
        }

        for (TypedChannelStorage<?> chan: typedChannels.values()) {
            Long2ObjectMap vchan = msv.getOrAddChannel(chan.symbol);
            for (int i = 0; i < size; i++) {
                vchan.put(ids.get(i), chan.values.get(i));
            }
        }

        return msv.freeze();
    }

    /**
     * Get the number of items currently in the builder.
     * @return The number of items in the builder.
     */
    public int size() {
        assert ids.size() == scores.size();
        return ids.size();
    }

    /**
     * Add a scored ID without boxing.  The default value will be used for each channel.
     * @param id The ID to add.
     * @param score The score for the ID.
     * @return The builder (for chaining).
     */
    @SuppressWarnings({"rawtypes", "unchecked"})
    public ScoredIdListBuilder add(long id, double score) {
        Preconditions.checkState(ids != null, "builder has been finished");
        ids.add(id);
        scores.add(score);
        for (ChannelStorage chan: channels.values()) {
            assert chan.values.size() == ids.size() - 1;
            chan.values.add(chan.defaultValue);
        }
        for (TypedChannelStorage chan: typedChannels.values()) {
            assert chan.values.size() == ids.size() - 1;
            chan.values.add(chan.defaultValue);
        }
        return this;
    }

    /**
     * Add a scored ID.  The ID is copied into the builder, not referenced.  All side channels on
     * the ID must have already been added with one of the {@code addChannel} methods.
     * @param id The ID.
     * @return The builder (for chaining).
     */
    @SuppressWarnings({"rawtypes", "unchecked"})
    public ScoredIdListBuilder add(ScoredId id) {
        Preconditions.checkState(ids != null, "builder has been finished");
        // check whether all symbols are valid
        Collection<SymbolValue<?>> chans = id.getChannels();
        if (!ignoreUnknown) {
            for (SymbolValue<?> chan: chans) {
                TypedSymbol<?> sym = chan.getSymbol();
                boolean good = sym.getType().equals(Double.class)
                        ? channels.containsKey(sym.getRawSymbol())
                        : typedChannels.containsKey(sym);
                if (!good) {
                    throw new IllegalArgumentException("channel " + sym + " not known");
                }
            }
        }

        // now we're ready to add
        int idx = ids.size();
        add(id.getId(), id.getScore());
        assert ids.size() == idx + 1;
        assert scores.size() == idx + 1;
        for (SymbolValue<?> sv: chans) {
            TypedSymbol<?> sym = sv.getSymbol();
            if (sym.getType().equals(Double.class) && channels.containsKey(sym.getRawSymbol())) {
                ChannelStorage chan = channels.get(sym.getRawSymbol());
                assert chan.values.size() == idx + 1;
                if (sv instanceof DoubleSymbolValue) {
                    chan.values.set(idx, ((DoubleSymbolValue) sv).getDoubleValue());
                } else {
                    Object v = sv.getValue();
                    chan.values.set(idx, (Double) v);
                }
            } else {
                TypedChannelStorage chan = typedChannels.get(sv.getSymbol());
                if (chan != null) {
                    assert chan.values.size() == idx + 1;
                    chan.values.set(idx, sv.getValue());
                }
            }
        }
        return this;
    }

    /**
     * Add a collection of IDs. The IDs are copied into the builder, not referenced.
     * @param ids The IDs to add.
     * @return The builder (for chaining)
     */
    public ScoredIdListBuilder addAll(Iterable<ScoredId> ids) {
        Preconditions.checkState(ids != null, "builder has been finished");
        for (ScoredId id: ids) {
            add(id);
        }
        return this;
    }

    /**
     * Add a side channel to the list builder with a default value of 0.  It is an error
     * to add the same symbol multiple times.  All side channels that will be used must be added
     * prior to calling {@link #add(ScoredId)}.
     *
     * @param sym The symbol to add.
     * @return The builder (for chaining).
     * @see #addChannel(Symbol, double)
     */
    public ScoredIdListBuilder addChannel(Symbol sym) {
        return addChannel(sym, 0);
    }

    /**
     * Add a side channel to the list builder.  It is an error to add the same symbol multiple times.
     * All side channels that will be used must be added prior to calling {@link #add(ScoredId)}.
     *
     * @param sym The symbol to add.
     * @param dft The default value when adding IDs that lack this channel.
     * @return The builder (for chaining).
     */
    public ScoredIdListBuilder addChannel(Symbol sym, double dft) {
        Preconditions.checkState(ids != null, "builder has been finished");
        if (channels.containsKey(sym)) {
            throw new IllegalArgumentException(sym + " already in the builder");
        } else {
            channels.put(sym, new ChannelStorage(sym, dft));
        }
        return this;
    }

    /**
     * Add multiple unboxed channels with a default value of 0.
     * @param channels The channels to add.
     * @return The builder (for chaining).
     */
    public ScoredIdListBuilder addChannels(Iterable<Symbol> channels) {
        for (Symbol sym: channels) {
            addChannel(sym);
        }
        return this;
    }

    /**
     * Add a side channel to the list builder.  It is an error
     * to add the same symbol multiple times.  All side channels that will be used must be added
     * prior to calling {@link #add(ScoredId)}.
     *
     * @param sym The symbol to add.
     * @return The builder (for chaining).
     * @see #addChannel(TypedSymbol, Object)
     */
    public ScoredIdListBuilder addChannel(TypedSymbol<?> sym) {
        return addChannel(sym, null);
    }

    /**
     * Add a typed side channel to the list builder.  It is an error to add the same symbol multiple
     * times. All side channels that will be used must be added prior to calling {@link
     * #add(ScoredId)}.
     *
     * @param sym The symbol to add.
     * @param dft The default value when adding ids that lack this channel.  If {@code null},
     *            it will be omitted from such ids.
     * @return The builder (for chaining).
     */
    public <T> ScoredIdListBuilder addChannel(TypedSymbol<T> sym, T dft) {
        Preconditions.checkState(ids != null, "builder has been finished");
        if (typedChannels.containsKey(sym)) {
            throw new IllegalArgumentException(sym + " already in the builder");
        } else {
            typedChannels.put(sym, new TypedChannelStorage<T>(sym, dft));
        }
        return this;
    }

    /**
     * Add multiple channels with a default value of {@code null}.
     * @param channels The channels to add.
     * @return The builder (for chaining).
     */
    public ScoredIdListBuilder addTypedChannels(Iterable<? extends TypedSymbol<?>> channels) {
        for (TypedSymbol<?> sym: channels) {
            addChannel(sym);
        }
        return this;
    }

    /**
     * Set the builder to ignore unknown channels on IDs passed to {@link #add(ScoredId)}.
     * @return The builder (for chaining).
     */
    public ScoredIdListBuilder ignoreUnknownChannels() {
        ignoreUnknown = true;
        return this;
    }

    /**
     * Set the builder to fail on unknown channels.  This is the default response to unknown
     * channels.
     * @return The builder (for chaining).
     */
    public ScoredIdListBuilder failOnUnknownChannels() {
        ignoreUnknown = false;
        return this;
    }

    /**
     * Sort the list-in-progress by the specified comparator.
     * @param order The comparator.
     * @return The buidler (for chaining).
     */
    public ScoredIdListBuilder sort(Comparator<ScoredId> order) {
        Preconditions.checkState(ids != null, "builder has been finished");
        quickSort(0, size(), new SortComp(order), new SortSwap());
        return this;
    }

    /**
     * Comparator for sorting the list.  This comparator internally uses a packed list over the
     * entire capacity of the builder to provide ids for the real comparator to use.
     */
    private class SortComp extends AbstractIntComparator {
        private final Comparator<ScoredId> order;

        private PackedScoredIdList.IndirectScoredId id1;
        private PackedScoredIdList.IndirectScoredId id2;

        public SortComp(Comparator<ScoredId> o) {
            order = o;

            // make an internal list
            Map<Symbol,DoubleList> chanMap = Maps.newHashMap();
            for (ChannelStorage chan: channels.values()) {
                chanMap.put(chan.symbol, chan.values);
            }
            Map<TypedSymbol<?>,List<?>> typedMap = Maps.newHashMap();
            for (TypedChannelStorage<?> chan: typedChannels.values()) {
                typedMap.put(chan.symbol, chan.values);
            }
            PackedScoredIdList list = new PackedScoredIdList(ids, scores, typedMap, chanMap);

            id1 = list.getFlyweight(0);
            id2 = list.getFlyweight(0);
        }

        @Override
        public int compare(int i1, int i2) {
            id1.setIndex(i1);
            id2.setIndex(i2);
            return order.compare(id1, id2);
        }
    }

    /**
     * Swapper for sorting the list.
     */
    private class SortSwap implements Swapper {
        @Override
        public void swap(int i, int j) {
            doSwap(ids, i, j);
            doSwap(scores, i, j);
            for (ChannelStorage chan: channels.values()) {
                doSwap(chan.values, i, j);
            }
            for (TypedChannelStorage<?> chan: typedChannels.values()) {
                doSwap(chan.values, i, j);
            }
        }
    }

    private static void doSwap(LongList longs, int i, int j) {
        longs.set(i, longs.set(j, longs.get(i)));
    }
    private static void doSwap(DoubleList doubles, int i, int j) {
        doubles.set(i, doubles.set(j, doubles.get(i)));
    }
    private static <T> void doSwap(List<T> objs, int i, int j) {
        objs.set(i, objs .set(j, objs .get(i)));
    }

    /**
     * Storage for a side channel.
     */
    private class ChannelStorage {
        private final Symbol symbol;
        private final double defaultValue;
        private DoubleArrayList values;

        public ChannelStorage(Symbol sym, double dft) {
            symbol = sym;
            defaultValue = dft;
            values = new DoubleArrayList(scores.elements().length);
            for (int i = size() - 1; i >= 0; i--) {
                values.add(defaultValue);
            }
        }
    }

    /**
     * Storage for a typed side channel.
     */
    private class TypedChannelStorage<T> {
        private final TypedSymbol<T> symbol;
        private final T defaultValue;
        private ArrayList<T> values;

        @SuppressWarnings("unchecked")
        private TypedChannelStorage(TypedSymbol<T> sym, T dft) {
            symbol = sym;
            defaultValue = dft;
            values = new ArrayList<T>(scores.elements().length);
            for (int i = size() - 1; i >= 0; i--) {
                values.add(defaultValue);
            }
        }
    }
}
TOP

Related Classes of org.grouplens.lenskit.scored.ScoredIdListBuilder$ChannelStorage

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.