Package com.guokr.simbase.score

Source Code of com.guokr.simbase.score.CosineSquareSimilarity

package com.guokr.simbase.score;

import gnu.trove.map.TIntFloatMap;
import gnu.trove.map.hash.TIntFloatHashMap;

import java.util.HashMap;
import java.util.Map;

import com.guokr.simbase.SimScore;
import com.guokr.simbase.store.VectorSet;

public class CosineSquareSimilarity implements SimScore {

    private static String                    name   = "cosinesq";
    private static Map<String, TIntFloatMap> caches = new HashMap<String, TIntFloatMap>();

    private float flengthsq(float[] vector) {
        float result = 0f;
        int len = vector.length;
        for (int i = 0; i < len; i++) {
            result += vector[i] * vector[i];
        }
        return result;
    }

    private float ilengthsq(int[] vector) {
        int result = 0;
        int len = vector.length;
        for (int i = 0; i < len;) {
            result += vector[i + 1] * vector[i + 1];
            i += 2;
        }
        return result;
    }

    @Override
    public String name() {
        return name;
    }

    @Override
    public SortOrder order() {
        return SortOrder.Desc;
    }

    @Override
    public float score(String srcVKey, int srcId, float[] source, String tgtVKey, int tgtId, float[] target) {
        TIntFloatMap sourceCache = caches.get(srcVKey);
        TIntFloatMap targetCache = caches.get(tgtVKey);

        float scoring = 0;
        int len = source.length;
        for (int i = 0; i < len; i++) {
            scoring += source[i] * target[i];
        }

        scoring = scoring * scoring / sourceCache.get(srcId) / targetCache.get(tgtId);

        return scoring;
    }

    @Override
    public float score(String srcVKey, int srcId, int[] source, int srclen, String tgtVKey, int tgtId, int[] target,
            int tgtlen) {
        TIntFloatMap sourceCache = caches.get(srcVKey);
        TIntFloatMap targetCache = caches.get(tgtVKey);

        float scoring = 0f;
        int idx1 = 0, idx2 = 0;
        if (idx1 < srclen && idx2 < tgtlen) {
            while (true) {
                if (source[idx1] < target[idx2]) {
                    idx1 += 2;
                    if (idx1 >= srclen)
                        break;
                } else if (source[idx1] > target[idx2]) {
                    idx2 += 2;
                    if (idx2 >= tgtlen)
                        break;
                } else {
                    scoring += source[idx1 + 1] * target[idx2 + 1];
                    idx1 += 2;
                    idx2 += 2;
                    if (idx1 >= srclen || idx2 >= tgtlen)
                        break;
                }
            }
        }

        scoring = scoring * scoring / sourceCache.get(srcId) / targetCache.get(tgtId);

        return scoring;
    }

    public void onAttached(String vkey) {
        caches.put(vkey, new TIntFloatHashMap());
    }

    public void onUpdated(String vkey, int vid, float[] vector) {
        caches.get(vkey).put(vid, flengthsq(vector));
    }

    public void onUpdated(String vkey, int vid, int[] vector) {
        caches.get(vkey).put(vid, ilengthsq(vector));
    }

    public void onRemoved(String vkey, int vid) {
        caches.get(vkey).remove(vid);
    }

    @Override
    public void onVectorAdded(VectorSet evtSrc, int vecid, float[] vector) {
        onUpdated(evtSrc.key(), vecid, vector);
    }

    @Override
    public void onVectorAdded(VectorSet evtSrc, int vecid, int[] vector) {
        onUpdated(evtSrc.key(), vecid, vector);
    }

    @Override
    public void onVectorSetted(VectorSet evtSrc, int vecid, float[] old, float[] vector) {
        onUpdated(evtSrc.key(), vecid, vector);
    }

    @Override
    public void onVectorSetted(VectorSet evtSrc, int vecid, int[] old, int[] vector) {
        onUpdated(evtSrc.key(), vecid, vector);
    }

    @Override
    public void onVectorAccumulated(VectorSet evtSrc, int vecid, float[] vector, float[] accumulated) {
        onUpdated(evtSrc.key(), vecid, accumulated);
    }

    @Override
    public void onVectorAccumulated(VectorSet evtSrc, int vecid, int[] vector, int[] accumulated) {
        onUpdated(evtSrc.key(), vecid, accumulated);
    }

    @Override
    public void onVectorRemoved(VectorSet evtSrc, int vecid) {
        onRemoved(evtSrc.key(), vecid);
    }

}
TOP

Related Classes of com.guokr.simbase.score.CosineSquareSimilarity

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.