Package com.guokr.simbase.store

Source Code of com.guokr.simbase.store.Recommendation

package com.guokr.simbase.store;

import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;

import java.util.ArrayList;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.guokr.simbase.SimScore;
import com.guokr.simbase.events.RecommendationListener;
import com.guokr.simbase.events.VectorSetListener;
import com.guokr.simbase.score.CosineSquareSimilarity;

public class Recommendation implements VectorSetListener {

    private static final Logger          logger         = LoggerFactory.getLogger(Recommendation.class);

    public VectorSet                     source;
    public VectorSet                     target;
    public SimScore                      scoring;

    int                                  limit;

    TIntList                             sorterKeys     = new TIntArrayList();
    TIntObjectMap<Sorter>                sorters        = new TIntObjectHashMap<Sorter>();

    private TIntObjectMap<TIntSet>       reverseIndexer = new TIntObjectHashMap<TIntSet>();
    private List<RecommendationListener> listeners      = new ArrayList<RecommendationListener>();

    public Recommendation(VectorSet source, VectorSet target) {
        this(source, target, new CosineSquareSimilarity(), 20);
    }

    public Recommendation(VectorSet source, VectorSet target, SimScore scoring) {
        this(source, target, scoring, 20);
    }

    public Recommendation(VectorSet source, VectorSet target, SimScore scoring, int limits) {
        this.source = source;
        this.target = target;
        this.limit = limits;
        this.scoring = scoring;

        scoring.onAttached(source.key());
        source.addListener(scoring);

        if (source.type().equals("dense")) {
            for (int id : source.ids()) {
                scoring.onVectorAdded(source, id, source.get(id));
            }
        } else {
            for (int id : source.ids()) {
                scoring.onVectorAdded(source, id, source._get(id));
            }
        }

        if (source != target) {
            scoring.onAttached(target.key());
            target.addListener(scoring);

            if (target.type().equals("dense")) {
                for (int id : target.ids()) {
                    scoring.onVectorAdded(target, id, target.get(id));
                }
            } else {
                for (int id : target.ids()) {
                    scoring.onVectorAdded(target, id, target._get(id));
                }
            }
        }
    }

    public void clean() {
        TIntObjectMap<Sorter> oldSorters = sorters;
        TIntList oldSorterKeys = sorterKeys;
        sorterKeys = new TIntArrayList();
        sorters = new TIntObjectHashMap<Sorter>();
        reverseIndexer = new TIntObjectHashMap<TIntSet>();

        TIntIterator iter = oldSorterKeys.iterator();
        while (iter.hasNext()) {
            int srcId = iter.next();
            if (this.source.contains(srcId)) {
                sorterKeys.add(srcId);
                Sorter oldSorter = oldSorters.get(srcId);
                Sorter sorter = new Sorter(this, srcId, oldSorter.order, oldSorter.limits);
                sorters.put(srcId, sorter);
                for (int tgtId : oldSorter.vecids()) {
                    if (this.target.contains(tgtId)) {
                        sorter.add(tgtId, oldSorter.get(tgtId));
                        if (!reverseIndexer.containsKey(tgtId)) {
                            reverseIndexer.put(tgtId, new TIntHashSet());
                        }
                        reverseIndexer.get(tgtId).add(srcId);
                    } else {
                        logger.warn(String.format("target vector[%d] not in sorter[%d] of rec[%s][%s]", tgtId, srcId,
                                this.source.key(), this.target.key()));
                    }
                }
            } else {
                logger.warn(String.format("source vector[%d] not in rec[%s][%s]", srcId, this.source.key(),
                        this.target.key()));
            }
        }
    }

    public void addReverseIndex(int srcVecId, int tgtVecId) {
        if (!this.reverseIndexer.containsKey(tgtVecId)) {
            this.reverseIndexer.put(tgtVecId, new TIntHashSet());
        }
        reverseIndexer.get(tgtVecId).add(srcVecId);
    }

    public void deleteReverseIndex(int srcVecId, int tgtVecId) {
        reverseIndexer.get(tgtVecId).remove(srcVecId);
    }

    private void processDenseChangedEvt(VectorSet evtSrc, int vecid, float[] vector) {
        if (evtSrc == this.source) {
            target.rescore(source.key(), vecid, vector, this);
        } else if (evtSrc == this.target) {
            int tgtVecId = vecid;
            TIntIterator iter = sorterKeys.iterator();
            while (iter.hasNext()) {
                int srcVecId = iter.next();
                float score = scoring.score(source.key(), srcVecId, source.get(srcVecId), target.key(), tgtVecId,
                        vector);
                add(srcVecId, tgtVecId, score);
            }
        }
    }

    private void processSparseChangedEvt(VectorSet evtSrc, int vecid, int[] vector) {
        if (evtSrc == this.source) {
            target.rescore(source.key(), vecid, vector, this);
        } else if (evtSrc == this.target) {
            int tgtVecId = vecid;
            TIntIterator iter = sorterKeys.iterator();
            while (iter.hasNext()) {
                int srcVecId = iter.next();
                float score = scoring.score(source.key(), srcVecId, source._get(srcVecId), source.length(srcVecId),
                        target.key(), tgtVecId, vector, vector.length);
                add(srcVecId, tgtVecId, score);
            }
        }
    }

    private void processDeletedEvt(int tgtVecId) {
        if (this.reverseIndexer.containsKey(tgtVecId)) {
            TIntSet range = this.reverseIndexer.get(tgtVecId);
            TIntIterator iter = range.iterator();
            while (iter.hasNext()) {
                int srcVecId = iter.next();
                this.sorters.get(srcVecId).remove(tgtVecId);
            }
        }
        reverseIndexer.remove(tgtVecId);
    }

    public Sorter create(int srcVecId) {
        Sorter sorter = sorters.get(srcVecId);
        if (sorter == null) {
            sorter = new Sorter(this, srcVecId, scoring.order(), this.limit);
            this.sorters.put(srcVecId, sorter);
            this.sorterKeys.add(srcVecId);
        }
        return sorter;
    }

    public void add(int srcVecId, int tgtVecId, float score) {
        create(srcVecId).add(tgtVecId, score);
    }

    public String[] get(int vecid) {
        if (this.sorters.containsKey(vecid)) {
            return this.sorters.get(vecid).pickle();
        } else
            return new String[0];
    }

    public int[] rec(int vecid) {
        if (this.sorters.containsKey(vecid)) {
            return this.sorters.get(vecid).vecids();
        } else
            return new int[0];
    }

    public void remove(int srcVecId, int tgtVecId) {
        this.sorters.get(srcVecId).remove(tgtVecId);
        if (reverseIndexer.containsKey(tgtVecId)) {
            reverseIndexer.get(tgtVecId).remove(srcVecId);
        }
    }

    public void addListener(RecommendationListener listener) {
        listeners.add(listener);
    }

    @Override
    public void onVectorAdded(VectorSet evtSrc, int vecid, float[] inputed) {
        processDenseChangedEvt(evtSrc, vecid, inputed);
    }

    @Override
    public void onVectorAdded(VectorSet evtSrc, int vecid, int[] vector) {
        processSparseChangedEvt(evtSrc, vecid, vector);
    }

    @Override
    public void onVectorSetted(VectorSet evtSrc, int vecid, float[] old, float[] vector) {
        if (evtSrc == this.source && sorters.containsKey(vecid)) {
            sorters.get(vecid).reset();
        }
        processDenseChangedEvt(evtSrc, vecid, vector);
    }

    @Override
    public void onVectorSetted(VectorSet evtSrc, int vecid, int[] old, int[] vector) {
        if (evtSrc == this.source && sorters.containsKey(vecid)) {
            sorters.get(vecid).reset();
        }
        processSparseChangedEvt(evtSrc, vecid, vector);
    }

    @Override
    public void onVectorAccumulated(VectorSet evtSrc, int vecid, float[] vector, float[] accumulated) {
        if (evtSrc == this.source && sorters.containsKey(vecid)) {
            sorters.get(vecid).reset();
        }
        processDenseChangedEvt(evtSrc, vecid, accumulated);
    }

    @Override
    public void onVectorAccumulated(VectorSet evtSrc, int vecid, int[] vector, int[] accumulated) {
        if (evtSrc == this.source && sorters.containsKey(vecid)) {
            sorters.get(vecid).reset();
        }
        processSparseChangedEvt(evtSrc, vecid, accumulated);
    }

    @Override
    public void onVectorRemoved(VectorSet evtSrc, int vecid) {
        if (evtSrc == this.target) {
            processDeletedEvt(vecid);
        }
        if (evtSrc == this.source) {
            for (int tgtId : this.sorters.get(vecid).vecids()) {
                if (this.reverseIndexer.containsKey(tgtId)) {
                    this.reverseIndexer.get(tgtId).remove(vecid);
                }
            }
            this.sorters.remove(vecid);
            this.sorterKeys.remove(vecid);
        }
    }
}
TOP

Related Classes of com.guokr.simbase.store.Recommendation

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.