Package cc.mallet.types

Source Code of cc.mallet.types.ROCData

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




package cc.mallet.types;

import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Arrays;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.Trial;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;

/**
* Tracks ROC data for instances in {@link Trial} results.
*
* @see Trial
* @see InstanceList
* @see Classifier
* @see Classification
*
* @author Michael Bond <a href="mailto:mikejbond@gmail.com">mikejbond@gmail.com</a>
*/
public class ROCData implements AlphabetCarrying, Serializable {

    private static final long serialVersionUID  = -2060194953037720640L;
    public static final int TRUE_POSITIVE       = 0;
    public static final int FALSE_POSITIVE      = 1;
    public static final int FALSE_NEGATIVE      = 2;
    public static final int TRUE_NEGATIVE       = 3;
   
    private final LabelAlphabet labelAlphabet;

    /** Matrix of class, threshold, [tp, fp, fn, tn] */
    private final int[][][] counts;
    private final double[] thresholds;
   
    /**
     * Constructs a new object
     *
     * @param thresholds        Array of thresholds to track counts for
     * @param labelAlphabet     Label alphabet for instances in {@link Trial}
     */
    public ROCData(double[] thresholds, final LabelAlphabet labelAlphabet) {
        // ensure that thresholds are sorted
        Arrays.sort(thresholds);
        this.counts = new int[labelAlphabet.size()][thresholds.length][4];
        this.labelAlphabet = labelAlphabet;
        this.thresholds = thresholds;
    }
   
    /**
     * Adds classification results to the ROC data
     *
     * @param trial Trial results to add to ROC data
     */
    public void add(Classification classification) {
        int correctIndex = classification.getInstance().getLabeling().getBestIndex();
        LabelVector lv = classification.getLabelVector();
        double[] values = lv.getValues();
       
        if (!Alphabet.alphabetsMatch(this, lv)) {
            throw new IllegalArgumentException ("Alphabets do not match");
        }
       
        int numLabels = this.labelAlphabet.size();
        for (int label = 0; label < numLabels; label++) {
            double labelValue = values[label];
            int[][] thresholdCounts = this.counts[label];
            int threshold = 0;
           
            // add the trial to all the thresholds it would be positive for
            for (; threshold < this.thresholds.length && labelValue >= this.thresholds[threshold]; threshold++) {
                if (correctIndex == label) {
                    thresholdCounts[threshold][TRUE_POSITIVE]++;
                } else {
                    thresholdCounts[threshold][FALSE_POSITIVE]++;
                }
            }
           
            // add the trial to the thresholds it would be negative for
            for (; threshold < this.thresholds.length; threshold++) {
                if (correctIndex == label) {
                    thresholdCounts[threshold][FALSE_NEGATIVE]++;
                } else {
                    thresholdCounts[threshold][TRUE_NEGATIVE]++;
                }
            }
        }
    }

    /**
     * Adds trial results to the ROC data
     *
     * @param trial Trial results to add to ROC data
     */
    public void add(Trial trial) {
        for (Classification classification : trial) {
            add(classification);
        }
    }

    /**
     * Adds existing ROC data to this ROC data
     *
     * @param rocData ROC data to add
     */
    public void add(ROCData rocData) {
        if (!Alphabet.alphabetsMatch(this, rocData)) {
            throw new IllegalArgumentException ("Alphabets do not match");
        }

        if (!Arrays.equals(this.thresholds, rocData.thresholds)) {
            throw new IllegalArgumentException ("Thresholds do not match");
        }

        int countsLength = this.counts.length;
        for (int c = 0; c < countsLength; c++) {
            int[][] thisClassCounts = this.counts[c];
            int[][] otherClassCounts = rocData.counts[c];
            int classLength = thisClassCounts.length;
            for (int t = 0; t < classLength; t++) {
                int[] thisThrCounts = thisClassCounts[t];
                int[] otherThrCounts = otherClassCounts[t];
                int thrLength = thisThrCounts.length;
                for (int s = 0; s < thrLength; s++) {
                    thisThrCounts[s] += otherThrCounts[s];
                }
            }
        }
    }

    //@Override
    public Alphabet getAlphabet() {
        return this.labelAlphabet;
    }

    //@Override
    public Alphabet[] getAlphabets() {
        return new Alphabet[] { this.labelAlphabet };
    }
   
    /**
     * Gets the raw counts for a specified label.
     *
     * @param label     Label to get counts for
     * @see #TRUE_POSITIVE
     * @see #FALSE_POSITIVE
     * @see #FALSE_NEGATIVE
     * @see #TRUE_NEGATIVE
     * @return Array of raw counts for specified label
     */
    public int[][] getCounts(Label label) {
        return this.counts[label.getIndex()];
    }
   
    /**
     * Gets the raw counts for a specified label and threshold.
     *
     * If data was not collected for the exact threshold specified, then results
     * for the highest threshold <= the specified threshold will be returned.
     *
     * @param label     Label to get counts for
     * @param threshold Threshold to get counts for
     * @see #TRUE_POSITIVE
     * @see #FALSE_POSITIVE
     * @see #FALSE_NEGATIVE
     * @see #TRUE_NEGATIVE
     * @return Array of raw counts for specified label and threshold
     */
    public int[] getCounts(Label label, double threshold) {
        int index = Arrays.binarySearch(this.thresholds, threshold);
        if (index < 0) {
            index = (-index) - 2;
        }
        return this.counts[label.getIndex()][index];
    }

    /**
     * Gets the label alphabet
     */
    public LabelAlphabet getLabelAlphabet() {
        return this.labelAlphabet;
    }
   
    /**
     * Gets the precision for a specified label and threshold.
     *
     * If data was not collected for the exact threshold specified, then results
     * will for the highest threshold <= the specified threshold will be
     * returned.
     *
     * @param label     Label to get precision for
     * @param threshold Threshold to get precision for
     * @return Precision for specified label and threshold
     */
    public double getPrecision(Label label, double threshold) {
        int[] counts = getCounts(label, threshold);
        return (double) counts[TRUE_POSITIVE] / (double) (counts[TRUE_POSITIVE] + counts[FALSE_POSITIVE]);
    }
   
    /**
     * Gets the precision for a specified label and score. This differs from
     * {@link ROCData.getPrecision(Label, double)} in that it is the precision
     * for only scores falling in the one score value, not for all scores
     * above the threshold.
     *
     * If data was not collected for the exact threshold specified, then results
     * will for the highest threshold <= the specified threshold will be
     * returned.
     *
     * @param label     Label to get precision for
     * @param threshold Threshold to get precision for
     * @return Precision for specified label and score
     */
    public double getPrecisionForScore(Label label, double score) {
        final int[][] buckets = this.counts[label.getIndex()];

        int index = Arrays.binarySearch(this.thresholds, score);
        if (index < 0) {
            index = (-index) - 2;
        }

        final double tp;
        final double fp;
        if (index == this.thresholds.length - 1) {
            tp = buckets[index][TRUE_POSITIVE];
            fp = buckets[index][FALSE_POSITIVE];
        } else {
            tp = buckets[index][TRUE_POSITIVE] - buckets[index + 1][TRUE_POSITIVE];
            fp = buckets[index][FALSE_POSITIVE] - buckets[index + 1][FALSE_POSITIVE];
        }

        return (double) tp / (double) (tp + fp);
    }
   
    /**
     * Gets the estimated percentage of training events that exceed the
     * threshold.
     *
     * @param label     Label to get precision for
     * @param threshold Threshold to get precision for
     * @return Estimated percentage of events exceeding threshold
     */
    public double getPositivePercent(Label label, double threshold) {
        final int[] counts = getCounts(label, threshold);
        final int positive = counts[TRUE_POSITIVE] + counts[FALSE_POSITIVE];
        return ((double) positive / (double) (positive + counts[FALSE_NEGATIVE] + counts[TRUE_NEGATIVE])) * 100.0;
    }
   
    /**
     * Gets the recall rate for a specified label and threshold.
     *
     * If data was not collected for the exact threshold specified, then results
     * will for the highest threshold <= the specified threshold will be
     * returned.
     *
     * @param label     Label to get recall for
     * @param threshold Threshold to get recall for
     * @return Recall rate for specified label and threshold
     */
    public double getRecall(Label label, double threshold) {
        int[] counts = getCounts(label, threshold);
        return (double) counts[TRUE_POSITIVE] / (double) (counts[TRUE_POSITIVE] + counts[FALSE_NEGATIVE]);
    }

    /**
     * Gets the thresholds being tracked
     *
     * @return Array of thresholds
     */
    public double[] getThresholds() {
        return this.thresholds;
    }
   
    /**
     * Sets the raw counts for a specified label and threshold.
     *
     * If data is not collected for the exact threshold specified, then counts
     * for the highest threshold <= the specified threshold will be set.
     *
     * @param label     Label to get counts for
     * @param threshold Threshold to get counts for
     * @param newCounts New count values for the label and threshold
     * @see #TRUE_POSITIVE
     * @see #FALSE_POSITIVE
     * @see #FALSE_NEGATIVE
     * @see #TRUE_NEGATIVE
     */
    public void setCounts(Label label, double threshold, int[] newCounts) {
        int index = Arrays.binarySearch(this.thresholds, threshold);
        if (index < 0) {
            index = (-index) - 2;
        }
       
        final int[] oldCounts = this.counts[label.getIndex()][index];
        if (newCounts.length != oldCounts.length) {
            throw new IllegalArgumentException ("Array of counts must contain " + oldCounts.length + " elements.");
        }

        for (int i = 0; i < oldCounts.length; i++) {
            oldCounts[i] = newCounts[i];
        }
    }
   
    //@Override
    public String toString() {
        final StringBuilder buf = new StringBuilder();
        final NumberFormat format = new DecimalFormat("0.####");
       
        for (int i = 0; i < this.labelAlphabet.size(); i++) {
            int[][] labelData = this.counts[i];
           
            buf.append("ROC data for ");
            buf.append(this.labelAlphabet.lookupObject(i).toString());
            buf.append('\n');
            buf.append("THR\tTP\tFP\tFN\tTN\tPrecis\tRecall\n");
           
            // add one row for each threshold
            for (int t = 0; t < this.thresholds.length; t++) {
                buf.append(this.thresholds[t]);
                for (int res : labelData[t]) {
                    buf.append('\t').append(res);
                }

                double tp = labelData[t][TRUE_POSITIVE];
                double sum = tp + labelData[t][FALSE_POSITIVE];
                double precision = 0.0;
                if (sum != 0) {
                    precision = tp / sum;
                }
               
                sum = tp + labelData[t][FALSE_NEGATIVE];
                double recall = 0.0;
                if (sum != 0) {
                    recall = tp / sum;
                }
               
                buf.append('\t').append(format.format(precision));
                buf.append('\t').append(format.format(recall));
                buf.append('\n');
            }
           
            buf.append('\n');
        }
       
        return buf.toString();
    }
   
}
TOP

Related Classes of cc.mallet.types.ROCData

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.