Package cc.mallet.classify.evaluate

Source Code of cc.mallet.classify.evaluate.AccuracyCoverage

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */


/**
   @author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a>
*/

package cc.mallet.classify.evaluate;


import java.awt.*;
import java.awt.event.*;
import javax.swing.*;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.Trial;
import cc.mallet.classify.evaluate.GraphItem;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.PrintUtilities;

import java.util.*;
import java.util.logging.*;
import java.text.DecimalFormat;

/**
* Methods for calculating and displaying the accuracy v.
* coverage data for a Trial
*/
public class AccuracyCoverage implements ActionListener
{
  private static Logger logger = MalletLogger.getLogger(AccuracyCoverage.class.getName());
  static final int DEFAULT_NUM_BUCKETS = 20;
  static final int DEFAULT_MAX_X = 100;
  private ArrayList classifications;
  private double [] accuracyValues;
  private int numBuckets;
  private double step;
  private Graph2 graph;
  private JFrame frame;
 
  /**
   * Constructs object, sorts classifications, and creates
   * accuracyValues array
     * @param t trial to get data from
     * @param numBuckets number of x-axis measurements to find accuracy
     */
  public AccuracyCoverage(Trial t, int numBuckets, String title, String dataName)
  {
    this.classifications = t;
    this.numBuckets = numBuckets;
    this.step = (double)DEFAULT_MAX_X/numBuckets;
    this.accuracyValues = new double[numBuckets];
    this.frame = null;
    logger.info("Constructing AccCov with " +
                       this.classifications.size());
    sortClassifications();
/*    for(int i=0; i<classifications.size(); i++)
    {
      Classification c = (Classification)this.classifications.get(i);
      LabelVector distr = c.getLabelVector();
      System.out.println(distr.getBestValue());
    }
*/
    createAccuracyArray();
    this.graph = new Graph2(
      title, 0, 100,
      "Coverage", "Accuracy");
    addDataToGraph(this.accuracyValues, numBuckets, dataName);
  }

  public AccuracyCoverage(Trial t, String title, String name)
  {
    this(t, DEFAULT_NUM_BUCKETS, title, name);
  }
  public AccuracyCoverage(Trial t, String title)
  {
    this(t, DEFAULT_NUM_BUCKETS, title, "unnamed");
  }
 
  public AccuracyCoverage(Classifier C, InstanceList ilist, String title)
  {
    this(new Trial(C, ilist), DEFAULT_NUM_BUCKETS, title, "unnamed");
  }
 
  public AccuracyCoverage(Classifier C, InstanceList ilist, int numBuckets, String title)
  {
    this(new Trial(C, ilist), numBuckets, title, "unnamed");
  }
 
  /**
   * Finds the "area under the acc/cov curve"
   * steps by one percentage point and calcs area
   * of trapezoid
   */
  public double cumulativeAccuracy()
  {
    double area = 0.0;
    for(int i=1; i<100; i++)
    {
      double leftAccuracy = accuracyAtCoverage((double)i/100);
      double rightAccuracy = accuracyAtCoverage((double)(i+1)/100);
      area += .5*(leftAccuracy + rightAccuracy);
    }
    return area;   
  }
 
  /**
   * Creates array of accuracy values for coverage
   * at each step as defined by numBuckets.
  
   */
  public void createAccuracyArray()
  {
//    System.out.println("Creating accuracyArray. Step= "+step);
    for(int i=0 ; i<numBuckets; i++)
    {
      accuracyValues[i] =
        accuracyAtCoverage(step
                           *(double)(i+1)/100.0);
    }
  }
 
  /**
   * accuracy at a given coverage percentage
   * @param cov coverage percentage
   * @return accuracy value
   */
  public double accuracyAtCoverage(double cov)
  {
    assert(cov <= 1 && cov > 0);
    int numTrials = (int)(Math.round((double)classifications.size()*cov));
    int numCorrect = 0;
//    System.out.println("NumTrials="+numTrials);
    for(int i= classifications.size()-1;
        i >= classifications.size()-numTrials; i--)
    {
      Classification temp = (Classification)classifications.get(i);
      if(temp.bestLabelIsCorrect())
        numCorrect++;
    }
//    System.out.println("Accuracy at cov "+cov+" is "+
    //(double)numCorrect/numTrials);
    return((double)numCorrect/numTrials);
  }
 
  /**
   * Sort classifications ArrayList
   * by winner's value
   */
  public void sortClassifications()
  {
    Collections.sort(classifications, new  ClassificationComparator());
  }
 
 
  public void addDataToGraph(double [] accValues, int nBuckets, String name)
  {
    Vector values = new Vector(nBuckets);
    for(int i=0; i<nBuckets; i++)
    {
      GraphItem temp = new GraphItem("",
                                     (int)(accValues[i]*100),
                                     Color.black);
      values.add(temp);
    }
    logger.info("Sending "+values.size()+" elements to graph");
    this.graph.addItemVector(values, name);
  }
 
/**
* Displays the accuracy v. coverage graph
*/
  public void displayGraph()
  {
    Vector values = new Vector(this.numBuckets);
    JButton printButton = new JButton("Print");
    frame = new JFrame("Graph");
    DecimalFormat df = new DecimalFormat();

    printButton.addActionListener(this);
   
    frame.addWindowListener
      (new WindowAdapter()
        {
          public void windowClosing(WindowEvent e)
          {
            System.exit(0);
          }
        }
        );

    // Get content pane
    Container pane = frame.getContentPane();
   
    // Set layout manager
    pane.setLayout( new FlowLayout() );

    assert(graph!= null); // make sure we've got data in the graph
    // Add to pane
    pane.add( graph );
    pane.add( printButton );
    frame.pack();
   
    // Center the frame
    Toolkit toolkit = Toolkit.getDefaultToolkit();
   
    // Get the current screen size
    Dimension scrnsize = toolkit.getScreenSize();
   
    // Get the frame size
    Dimension framesize= frame.getSize();
   
    // Set X,Y location
    frame.setLocation ( (int) (scrnsize.getWidth()
                               - frame.getWidth() ) / 2 ,
                        (int) (scrnsize.getHeight()
                               - frame.getHeight()) / 2);
   
    frame.setVisible(true);
  }
 
 
  public void actionPerformed(ActionEvent event)
  {
    PrintUtilities.printComponent(graph);
  }

  public void addTrial(Trial t, String name)
  {
    addTrial(t, DEFAULT_NUM_BUCKETS, name);
  }
 
  public void addTrial(Trial t, int nBuckets, String name)
  {
    AccuracyCoverage newData = new AccuracyCoverage(t, nBuckets, "untitled", name);
    double [] accValues = newData.accuracyValues();
    addDataToGraph(accValues, nBuckets, name);
  }

  public double[] accuracyValues()
  {
    return this.accuracyValues;
  }
  public class ClassificationComparator implements Comparator
  {
    public final int compare (Object a, Object b)
    {
      LabelVector x = (LabelVector) (((Classification)a).getLabelVector());
      LabelVector y = (LabelVector) (((Classification)b).getLabelVector());
      double difference = x.getBestValue() - y.getBestValue();
      int toReturn = 0;
      if(difference > 0)
        toReturn = 1;
      else if (difference < 0)
        toReturn = -1;
      return(toReturn);   
    }
   
  }
 
}
TOP

Related Classes of cc.mallet.classify.evaluate.AccuracyCoverage

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.