Package com.guokr.simbase.store

Source Code of com.guokr.simbase.store.SparseVectorSet

package com.guokr.simbase.store;

import gnu.trove.iterator.TIntIntIterator;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.TFloatList;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.map.hash.TIntIntHashMap;

import java.util.ArrayList;
import java.util.List;

import com.guokr.simbase.events.BasisListener;
import com.guokr.simbase.events.VectorSetListener;

public class SparseVectorSet implements VectorSet, BasisListener {

    public static final String      TYPE    = "sparse";

    String                          key;

    TFloatList                      data    = new TFloatArrayList();
    TIntIntMap                      lengths = new TIntIntHashMap();
    TIntIntMap                      indexer = new TIntIntHashMap();

    float                           accumuFactor;
    int                             sparseFactor;

    Basis                           base;

    private boolean                 listening;
    private List<VectorSetListener> listeners;

    private int[]                   iReuseList;
    private float[]                 fReuseList;

    public SparseVectorSet(String key, Basis base) {
        this(key, base, 0.01f, 4096);
    }

    public SparseVectorSet(String key, Basis base, float accumuFactor, int sparseFactor) {
        this.key = key;
        this.base = base;
        this.accumuFactor = accumuFactor;
        this.sparseFactor = sparseFactor;
        this.listening = true;
        this.listeners = new ArrayList<VectorSetListener>();

        this.fReuseList = new float[this.base.size()];
        this.iReuseList = new int[this.base.size() * 2];
        this.base.addListener(this);
    }

    private void validateParams(int vecid, int[] pairs) {
        if (pairs.length % 2 != 0) {
            throw new IllegalArgumentException("the size of the input array must be a even number!");
        }
    }

    @Override
    public String type() {
        return "" + TYPE;
    }

    @Override
    public String key() {
        return key;
    }

    @Override
    public int size() {
        return this.indexer.size();
    }

    @Override
    public boolean contains(int vecid) {
        return this.indexer.containsKey(vecid);
    }

    @Override
    public void clean() {
        TFloatList olddata = data;
        TIntIntMap oldindexer = indexer;
        data = new TFloatArrayList(olddata.size());
        indexer = new TIntIntHashMap(oldindexer.size());

        int pos = 0;
        TIntIntIterator iter = oldindexer.iterator();
        while (iter.hasNext()) {
            iter.advance();
            int vecid = iter.key();
            int start = iter.value();
            int length = lengths.get(vecid);

            int cursor = 0;
            indexer.put(vecid, pos);
            while (cursor < length) {
                data.add(olddata.get(start + cursor));
                pos++;
                cursor++;
            }
        }
    }

    @Override
    public int[] ids() {
        return indexer.keys();
    }

    @Override
    public void remove(int vecid) {
        if (indexer.containsKey(vecid)) {
            indexer.remove(vecid);
            lengths.remove(vecid);

            if (listening) {
                for (VectorSetListener l : listeners) {
                    l.onVectorRemoved(this, vecid);
                }
            }
        }
    }

    @Override
    public int length(int vecid) {
        return lengths.get(vecid);
    }

    public float[] get(int vecid, int[] input, float[] result) {
        _get(vecid, input);
        Basis.densify(base.size(), sparseFactor, input, result);
        return result;
    }

    @Override
    public float[] get(int vecid) {
        float[] result = new float[base.size()];
        Basis.densify(base.size(), sparseFactor, _get(vecid), result);
        return result;
    }

    @Override
    public void add(int vecid, float[] vector) {
        _add(vecid, Basis.sparsify(sparseFactor, vector));
    }

    @Override
    public void set(int vecid, float[] vector) {
        _set(vecid, Basis.sparsify(sparseFactor, vector));
    }

    @Override
    public void accumulate(int vecid, float[] vector) {
        _accumulate(vecid, Basis.sparsify(sparseFactor, vector));
    }

    protected void _get(int vecid, int[] result) {
        int length = lengths.get(vecid);
        int cursor = indexer.get(vecid), i = 0;
        while (length > 0) {
            int pos = (int) data.get(cursor++);
            int val = Math.round(data.get(cursor++));
            result[i++] = pos;
            result[i++] = val;
            length -= 2;
        }
    }

    @Override
    public int[] _get(int vecid) {
        int[] result = new int[lengths.get(vecid)];
        _get(vecid, result);
        return result;
    }

    @Override
    public void _add(int vecid, int[] pairs) {
        validateParams(vecid, pairs);
        if (!indexer.containsKey(vecid)) {
            int start = data.size();
            indexer.put(vecid, start);
            lengths.put(vecid, pairs.length);
            for (int val : pairs) {
                data.add(val);
            }

            if (listening) {
                for (VectorSetListener l : listeners) {
                    l.onVectorAdded(this, vecid, pairs);
                }
            }
        }
    }

    @Override
    public void _set(int vecid, int[] pairs) {
        validateParams(vecid, pairs);
        if (indexer.containsKey(vecid)) {
            int[] old = _get(vecid);

            listening = false;
            remove(vecid);
            _add(vecid, pairs);
            listening = true;

            if (listening) {
                for (VectorSetListener l : listeners) {
                    l.onVectorSetted(this, vecid, old, pairs);
                }
            }
        } else {
            _add(vecid, pairs);
        }
    }

    @Override
    public void _accumulate(int vecid, int[] pairs) {
        validateParams(vecid, pairs);
        if (!indexer.containsKey(vecid)) {
            _add(vecid, pairs);
        } else {
            TIntList indexes = new TIntArrayList();
            TIntFloatMap results = new TIntFloatHashMap();

            float max = Float.NEGATIVE_INFINITY;
            int cursor = indexer.get(vecid);
            int length = lengths.get(vecid);
            while (length > 0) {
                int pos = (int) data.get(cursor++);
                float val = data.get(cursor++);
                results.put(pos, val);
                if (val > max) {
                    max = val;
                }
                indexes.add(pos);
                length -= 2;
            }

            cursor = 0;
            while (cursor < pairs.length) {
                int pos = pairs[cursor++];
                float val = (float) pairs[cursor++];
                if (results.containsKey(pos)) {
                    val = results.get(pos) + val;
                    results.put(pos, val);
                    if (val > max) {
                        max = val;
                    }
                } else {
                    results.put(pos, val);
                    indexes.add(pos);
                }
            }
            indexes.sort();

            int start = data.size();
            indexer.put(vecid, start);
            lengths.put(vecid, indexes.size() * 2);
            TIntIterator iter = indexes.iterator();
            if (max < accumuFactor * sparseFactor) {
                while (iter.hasNext()) {
                    int key = iter.next();
                    float value = results.get(key);
                    data.add(key);
                    data.add(value);
                }
            } else {
                while (iter.hasNext()) {
                    int key = iter.next();
                    float value = results.get(key) * accumuFactor / max;
                    data.add(key);
                    data.add(value);
                }
            }

            if (listening) {
                int[] accumulated = _get(vecid);
                for (VectorSetListener l : listeners) {
                    l.onVectorAccumulated(this, vecid, pairs, accumulated);
                }
            }
        }
    }

    @Override
    public void addListener(VectorSetListener listener) {
        listeners.add(listener);
    }

    @Override
    public void rescore(String key, int vecid, float[] vector, Recommendation rec) {
        rec.create(vecid);
        TIntIntIterator iter = indexer.iterator();
        int[] input = new int[this.base.size() * 2];
        if (this == rec.source) {
            while (iter.hasNext()) {
                iter.advance();
                int tgtId = iter.key();
                get(tgtId, input, fReuseList);
                float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, fReuseList);
                rec.add(vecid, tgtId, score);
                rec.add(tgtId, vecid, score);
            }
            rec.remove(vecid, vecid);
        } else {
            while (iter.hasNext()) {
                iter.advance();
                int tgtId = iter.key();
                get(tgtId, input, fReuseList);
                float score = rec.scoring.score(key, vecid, vector, this.key, tgtId, fReuseList);
                rec.add(vecid, tgtId, score);
            }
        }
    }

    @Override
    public void rescore(String key, int vecid, int[] vector, Recommendation rec) {
        rec.create(vecid);
        TIntIntIterator iter = indexer.iterator();
        if (this == rec.source) {
            while (iter.hasNext()) {
                iter.advance();
                int tgtId = iter.key();
                _get(tgtId, iReuseList);
                float score = rec.scoring.score(key, vecid, vector, vector.length, this.key, tgtId, iReuseList,
                        length(tgtId));
                rec.add(vecid, tgtId, score);
                rec.add(tgtId, vecid, score);
            }
            rec.remove(vecid, vecid);
        } else {
            while (iter.hasNext()) {
                iter.advance();
                int tgtId = iter.key();
                _get(tgtId, iReuseList);
                float score = rec.scoring.score(key, vecid, vector, vector.length, this.key, tgtId, iReuseList,
                        length(tgtId));
                rec.add(vecid, tgtId, score);
            }
        }
    }

    @Override
    public void onBasisRevised(Basis evtSrc, String[] oldSchema, String[] newSchema) {
        fReuseList = new float[this.base.size()];
        iReuseList = new int[this.base.size() * 2];
    }
}
TOP

Related Classes of com.guokr.simbase.store.SparseVectorSet

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.