Package edu.cmu.graphchi.walks.distributions

Source Code of edu.cmu.graphchi.walks.distributions.DiscreteDistribution

package edu.cmu.graphchi.walks.distributions;

import edu.cmu.graphchi.util.IdCount;

import java.util.*;

/**
* Presents a map from integers to frequencies.
* Special distributions for avoidance can be used to exclude
* certain ids from the distributions in merges.
* @author Aapo Kyrola, akyrola@cs.cmu.edu
*/
public class DiscreteDistribution {

    private int[] ids;
    private short[] counts;
    private int uniqueCount;


    private DiscreteDistribution(int initialCapacity) {
        ids = new int[initialCapacity];
        counts = new short[initialCapacity];
        uniqueCount = initialCapacity;
    }

    /**
     * Create an empty distribution
     */
    public DiscreteDistribution() {
        this(0);
    }


    public DiscreteDistribution forceToSize(int size) {
        if (size > this.ids.length) return this;
        int[] newIds = new int[size];
        short[] newCounts = new short[size];
        for(int i=0; i < size; i++) {
            newIds[i] = ids[i];
            newCounts[i] = counts[i];
        }
        DiscreteDistribution d = new DiscreteDistribution();
        d.ids = newIds;
        d.counts = newCounts;
        d.uniqueCount = size;
        return d;
    }

    public int totalCount() {
        int tot = 0;
        for(int c=0; c<counts.length; c++) {
           if (counts[c] > 0) tot += counts[c];
        }
        return tot;
    }


    public void print() {
        for(int i=0; i<ids.length; i++) {
            System.out.println("D " + ids[i] + ", " + counts[i]);
        }
    }

    public IdCount[] getTop(int topN) {
        final TreeSet<IdCount> topList = new TreeSet<IdCount>(new Comparator<IdCount>() {
            public int compare(IdCount a, IdCount b) {
                return (a.count > b.count ? -1 : (a.count == b.count ? (a.id < b.id ? -1 : (a.id == b.id ? 0 : 1)) : 1)); // Descending order
            }
        });
        IdCount least = null;
        for(int i=0; i<uniqueCount; i++) {
            if (topList.size() < topN) {
                topList.add(new IdCount(ids[i], counts[i]));
                least = topList.last();
            } else {
                if (counts[i] > least.count) {
                    topList.remove(least);
                    topList.add(new IdCount(ids[i], counts[i]));
                    least = topList.last();
                }
            }
            if (topList.size() > topN) throw new RuntimeException("Top list too big: " + topList.size());

        }

        IdCount[] topArr = new IdCount[topList.size()];
        topList.toArray(topArr);
        return topArr;
    }

    /**
     * Constructs a distribution from a sorted list of entries.
     * @param sortedIdList
     */
    public DiscreteDistribution(int[] sortedIdList) {
        if (sortedIdList.length == 0) {
            ids = new int[0];
            counts = new short[0];
            return;
        }
        if (sortedIdList.length == 1) {
            ids = new int[] {sortedIdList[0]};
            counts = new short[] {1};
            return;
        }

        /* First count unique */
        uniqueCount = 1;
        for(int i=1; i < sortedIdList.length; i++) {
            if (sortedIdList[i] != sortedIdList[i - 1]) {
                uniqueCount++;
                if (sortedIdList[i] < sortedIdList[i - 1]) {
                    throw new RuntimeException("Not ordered! " + sortedIdList[i] + " < " + sortedIdList[i - 1]);
                }
            }
        }

        ids = new int[uniqueCount];
        counts = new short[uniqueCount];

        /* Create counts */
        int idx = 0;
        short curCount = 1;

        for(int i=1; i < sortedIdList.length; i++) {
            if (sortedIdList[i] != sortedIdList[i - 1]) {
                ids[idx] = sortedIdList[i - 1];
                counts[idx] = curCount;
                idx++;
                curCount = 1;
            else curCount++;
        }

        ids[idx] = sortedIdList[sortedIdList.length - 1];
        counts[idx] = curCount;
    }


    public DiscreteDistribution filteredAndShift(int minimumCount) {
         return filteredAndShift((short)minimumCount);
    }

    /**
     * Creates a new distribution with all entries with count less than
     * minimumCount removed, and rest changed by - minimumCount. Does not remove avoids
     */
    public DiscreteDistribution filteredAndShift(short minimumCount) {
        if (minimumCount <= 1) {
            return this;
        }
        int toRemove = 0;
        for(int i=0; i < uniqueCount; i++) {
            toRemove += (counts[i] < minimumCount &&  counts[i] > 0 ? 1 : 0);
        }

        if (toRemove == 0) {
            return this;   // We can safely return same instance, as this is immutable
        }

        DiscreteDistribution filteredDist = new DiscreteDistribution(uniqueCount - toRemove);
        int idx = 0;
        for(int i=0; i < uniqueCount; i++) {
            if (counts[i] >= minimumCount || counts[i] == (-1)) {
                filteredDist.ids[idx] = ids[i];
                if ( counts[i] != (-1)) {
                    filteredDist.counts[idx] = (short) (counts[i] - minimumCount + 1);
                } else {
                    filteredDist.counts[idx] = -1;
                }
                idx++;
            }
        }
        return filteredDist;
    }

    /**
     * Create a special avoidance distribution, where each count is -1.
     * @param avoids  sorted list of ids to avoid
     * @return
     */
    public static DiscreteDistribution createAvoidanceDistribution(int[] avoids) {
        DiscreteDistribution avoidDistr = new DiscreteDistribution(avoids);
        Arrays.fill(avoidDistr.counts, (short) -1);
        return avoidDistr;
    }



    public int getCount(int id) {
        int idx = Arrays.binarySearch(ids, id);
        if (idx >= 0) {
            return counts[idx];
        } else {
            return 0;
        }
    }

    public int size() {
        return uniqueCount;
    }

    public int avoidCount() {
        int a = 0;
        for(int i=0; i < counts.length; i++) a += counts[i] >= 0 ? 0 : 1;
        return a;
    }

    public int memorySizeEst() {
        int OVERHEAD = 64; // Estimate of the java memory overhead
        return 6 * counts.length + 4 + OVERHEAD;
    }




    public int max() {
        int mx = Integer.MIN_VALUE;
        for(int i=0; i < counts.length; i++) {
            if (counts[i] > mx) mx = counts[i];
        }
        return mx;
    }

    /**
     * Merge too distribution. If either has a negative entry for some id, it will
     * remain negative in the merged distribution (see createAvoidDistribution).
     * @param d1
     * @param d2
     * @return
     */
    public static DiscreteDistribution merge(DiscreteDistribution d1, DiscreteDistribution d2) {

        if (d1.uniqueCount == 0) return d2;
        if (d2.uniqueCount == 0) return d1;

        /* Merge style algorithm */

        /* 1. first count number of different pairs */
        int combinedUniqueIndices = 0;
        int leftIdx = 0;
        int rightIdx = 0;
        final int leftCount = d1.size();
        final int rightCount = d2.size();

        while(leftIdx < leftCount && rightIdx < rightCount) {
            if (d1.ids[leftIdx] == d2.ids[rightIdx]) {
                leftIdx++;
                rightIdx++;
                combinedUniqueIndices++;
            } else if (d1.ids[leftIdx] < d2.ids[rightIdx]) {
                combinedUniqueIndices++;
                leftIdx++;
            } else {
                combinedUniqueIndices++;
                rightIdx++;
            }
        }
        combinedUniqueIndices += (rightCount - rightIdx);
        combinedUniqueIndices += (leftCount - leftIdx);

        /* Create new merged distribution by doing the actual merge */
        DiscreteDistribution merged = new DiscreteDistribution(combinedUniqueIndices);
        leftIdx = 0;
        rightIdx = 0;
        int idx = 0;
        while(leftIdx < leftCount && rightIdx < rightCount) {
            if (d1.ids[leftIdx] == d2.ids[rightIdx]) {
                merged.ids[idx] = d1.ids[leftIdx];
                merged.counts[idx] = (short) (d1.counts[leftIdx] + d2.counts[rightIdx]);

                // Force to retain negativity
                if (d1.counts[leftIdx] < 0 || d2.counts[rightIdx] < 0) {
                    merged.counts[idx] = -1;
                }
                leftIdx++;
                rightIdx++;
                idx++;

            } else if (d1.ids[leftIdx] < d2.ids[rightIdx]) {
                merged.ids[idx] = d1.ids[leftIdx];
                merged.counts[idx] = d1.counts[leftIdx];   // Note, retains negativity
                idx++;
                leftIdx++;
            } else {
                merged.ids[idx] = d2.ids[rightIdx];
                merged.counts[idx] = d2.counts[rightIdx];   // Note, retains negativity
                idx++;
                rightIdx++;
            }
        }

        while(leftIdx < leftCount) {
            merged.ids[idx] = d1.ids[leftIdx];
            merged.counts[idx] = d1.counts[leftIdx];
            leftIdx++;
            idx++;
        }

        while(rightIdx < rightCount) {
            merged.ids[idx] = d2.ids[rightIdx];
            merged.counts[idx] = d2.counts[rightIdx];
            rightIdx++;
            idx++;
        }
        return merged;
    }

    public int sizeExcludingAvoids() {
        int j = 0;
        for(int i=0; i < uniqueCount; i++) {
            if (counts[i] > 0) j++;
        }
        return j;
    }
}
TOP

Related Classes of edu.cmu.graphchi.walks.distributions.DiscreteDistribution

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.