Package cc.mallet.classify

Source Code of cc.mallet.classify.ConfidencePredictingClassifier

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */




/**
   @author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/

package cc.mallet.classify;

import java.util.ArrayList;

import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;
import cc.mallet.types.LabelVector;

public class ConfidencePredictingClassifier extends Classifier
{
  Classifier underlyingClassifier;
  Classifier confidencePredictingClassifier;
  double totalCorrect;
  double totalIncorrect;
  double  totalIncorrectIncorrect;
  double  totalIncorrectCorrect;
  int numCorrectInstances;
  int numIncorrectInstances;
  int numConfidenceCorrect;
  int numFalsePositive;
  int numFalseNegative;
 
  public ConfidencePredictingClassifier (Classifier underlyingClassifier, Classifier confidencePredictingClassifier)
  {
    super (underlyingClassifier.getInstancePipe());
    this.underlyingClassifier = underlyingClassifier;
    this.confidencePredictingClassifier = confidencePredictingClassifier;
    // for testing confidence accuracy
    totalCorrect = 0.0;
    totalIncorrect = 0.0;
    totalIncorrectIncorrect = 0.0;
    totalIncorrectCorrect = 0.0;
    numCorrectInstances = 0;
    numIncorrectInstances = 0;
    numConfidenceCorrect = 0;
    numFalsePositive = 0;
     numFalseNegative = 0;

  }

  public Classification classify (Instance instance)
  {
    Classification c = underlyingClassifier.classify (instance);
    Classification cpc = confidencePredictingClassifier.classify (c);
    LabelVector lv = c.getLabelVector();
    int bestIndex = lv.getBestIndex();
    double [] values = new double[lv.numLocations()];
    //// Put score of "correct" into score of the winning class...
    // xxx Can't set lv - it's immutable.
    //     Must create copy and new classification object
    // lv.set (bestIndex, cpc.getLabelVector().value("correct"));
    //for (int i = 0; i < lv.numLocations(); i++)
    //  if (i != bestIndex)
    //    lv.set (i, 0.0);

    // Put score of "correct" in winning class and
    // set rest to 0
    for (int i = 0; i < lv.numLocations(); i++) {
      if (i != bestIndex)
        values[i] = 0.0;
      else values[i] = cpc.getLabelVector().value("correct");
    }
    //return c;
   
    if(c.bestLabelIsCorrect()){
      numCorrectInstances++;
      totalCorrect+=cpc.getLabelVector().value("correct");
      totalIncorrectCorrect+=cpc.getLabelVector().value("incorrect");
      String correct = new String("correct");
      if(correct.equals(cpc.getLabelVector().getBestLabel().toString()))
        numConfidenceCorrect++;
      else numFalseNegative++;
    }
   
    else{
      numIncorrectInstances++;
      totalIncorrect+=cpc.getLabelVector().value("correct");
      totalIncorrectIncorrect+=cpc.getLabelVector().value("incorrect");
      if((new String("incorrect")).equals(cpc.getLabelVector().getBestLabel().toString()))
        numConfidenceCorrect++;
      else numFalsePositive++;
    }
   
    return new Classification(instance, this, new LabelVector(lv.getLabelAlphabet(), values));
//    return cpc;
  }
 
  public void printAverageScores() {
      System.out.println("Mean score of correct for correct instances = " + meanCorrect());
      System.out.println("Mean score of correct for incorrect instances = " + meanIncorrect());
      System.out.println("Mean score of incorrect for correct instances = " +
                         this.totalIncorrectCorrect/this.numCorrectInstances);
      System.out.println("Mean score of incorrect for incorrect instances = " +
                         this.totalIncorrectIncorrect/this.numIncorrectInstances);
  }

  public void printConfidenceAccuracy() {
    System.out.println("Confidence predicting accuracy = " +
                       ((double)numConfidenceCorrect/(numIncorrectInstances + numCorrectInstances))+ " false negatives: "+ numFalseNegative + "/"+numCorrectInstances + " false positives: "+ numFalsePositive +" / " +numIncorrectInstances);
  }
  public double meanCorrect()
  {
    if(this.numCorrectInstances==0)
      return 0.0;
    return (this.totalCorrect/(double)this.numCorrectInstances);
  }

  public double meanIncorrect()
  {
    if(this.numIncorrectInstances==0)
      return 0.0;
    return (this.totalIncorrect/(double)this.numIncorrectInstances);
  }

}
TOP

Related Classes of cc.mallet.classify.ConfidencePredictingClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.