Package edu.stanford.nlp.sentiment

Source Code of edu.stanford.nlp.sentiment.Evaluate

package edu.stanford.nlp.sentiment;

import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.List;
import java.util.Set;

import edu.stanford.nlp.neural.rnn.TopNGramRecord;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.ConfusionMatrix;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.StringUtils;

public class Evaluate {
  final SentimentCostAndGradient cag;
  final SentimentModel model;

  final int[][] equivalenceClasses;
  final String[] equivalenceClassNames;

  int labelsCorrect;
  int labelsIncorrect;

  // the matrix will be [gold][predicted]
  int[][] labelConfusion;

  int rootLabelsCorrect;
  int rootLabelsIncorrect;

  int[][] rootLabelConfusion;

  IntCounter<Integer> lengthLabelsCorrect;
  IntCounter<Integer> lengthLabelsIncorrect;

  TopNGramRecord ngrams;

  // TODO: make this an option
  static final int NUM_NGRAMS = 5;

  private static final NumberFormat NF = new DecimalFormat("0.000000");

  public Evaluate(SentimentModel model) {
    this.model = model;
    this.cag = new SentimentCostAndGradient(model, null);
    this.equivalenceClasses = model.op.equivalenceClasses;
    this.equivalenceClassNames = model.op.equivalenceClassNames;

    reset();
  }

  public void reset() {
    labelsCorrect = 0;
    labelsIncorrect = 0;
    labelConfusion = new int[model.op.numClasses][model.op.numClasses];

    rootLabelsCorrect = 0;
    rootLabelsIncorrect = 0;
    rootLabelConfusion = new int[model.op.numClasses][model.op.numClasses];

    lengthLabelsCorrect = new IntCounter<Integer>();
    lengthLabelsIncorrect = new IntCounter<Integer>();

    if (model.op.testOptions.ngramRecordSize > 0) {
      ngrams = new TopNGramRecord(model.op.numClasses, model.op.testOptions.ngramRecordSize, model.op.testOptions.ngramRecordMaximumLength);
    } else {
      ngrams = null;
    }
  }

  public void eval(List<Tree> trees) {
    for (Tree tree : trees) {
      eval(tree);
    }
  }

  public void eval(Tree tree) {
    cag.forwardPropagateTree(tree);

    countTree(tree);
    countRoot(tree);
    countLengthAccuracy(tree);
    if (ngrams != null) {
      ngrams.countTree(tree);
    }
  }

  private int countLengthAccuracy(Tree tree) {
    if (tree.isLeaf()) {
      return 0;
    }
    Integer gold = RNNCoreAnnotations.getGoldClass(tree);
    Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
    int length;
    if (tree.isPreTerminal()) {
      length = 1;
    } else {
      length = 0;
      for (Tree child : tree.children()) {
        length += countLengthAccuracy(child);
      }
    }
    if (gold >= 0) {
      if (gold.equals(predicted)) {
        lengthLabelsCorrect.incrementCount(length);
      } else {
        lengthLabelsIncorrect.incrementCount(length);
      }
    }
    return length;
  }

  private void countTree(Tree tree) {
    if (tree.isLeaf()) {
      return;
    }
    for (Tree child : tree.children()) {
      countTree(child);
    }
    Integer gold = RNNCoreAnnotations.getGoldClass(tree);
    Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
    if (gold >= 0) {
      if (gold.equals(predicted)) {
        labelsCorrect++;
      } else {
        labelsIncorrect++;
      }
      labelConfusion[gold][predicted]++;
    }
  }

  private void countRoot(Tree tree) {
    Integer gold = RNNCoreAnnotations.getGoldClass(tree);
    Integer predicted = RNNCoreAnnotations.getPredictedClass(tree);
    if (gold >= 0) {
      if (gold.equals(predicted)) {
        rootLabelsCorrect++;
      } else {
        rootLabelsIncorrect++;
      }
      rootLabelConfusion[gold][predicted]++;
    }
  }

  public double exactNodeAccuracy() {
    return (double) labelsCorrect / ((double) (labelsCorrect + labelsIncorrect));
  }

  public double exactRootAccuracy() {
    return (double) rootLabelsCorrect / ((double) (rootLabelsCorrect + rootLabelsIncorrect));
  }

  public Counter<Integer> lengthAccuracies() {
    Set<Integer> keys = Generics.newHashSet();
    keys.addAll(lengthLabelsCorrect.keySet());
    keys.addAll(lengthLabelsIncorrect.keySet());

    Counter<Integer> results = new ClassicCounter<Integer>();
    for (Integer key : keys) {
      results.setCount(key, lengthLabelsCorrect.getCount(key) / (lengthLabelsCorrect.getCount(key) + lengthLabelsIncorrect.getCount(key)));
    }
    return results;
  }

  public void printLengthAccuracies() {
    Counter<Integer> accuracies = lengthAccuracies();
    Set<Integer> keys = Generics.newTreeSet();
    keys.addAll(accuracies.keySet());
    System.err.println("Label accuracy at various lengths:");
    for (Integer key : keys) {
      System.err.println(StringUtils.padLeft(Integer.toString(key), 4) + ": " + NF.format(accuracies.getCount(key)));
    }
  }

  private static void printConfusionMatrix(String name, int[][] confusion) {
    System.err.println(name + " confusion matrix");
    ConfusionMatrix<Integer> confusionMatrix = new ConfusionMatrix<Integer>();
    confusionMatrix.setUseRealLabels(true);
    for (int i = 0; i < confusion.length; ++i) {
      for (int j = 0; j < confusion[i].length; ++j) {
        confusionMatrix.add(j, i, confusion[i][j]);
      }
    }
    System.err.println(confusionMatrix);
  }

  private static double[] approxAccuracy(int[][] confusion, int[][] classes) {
    int[] correct = new int[classes.length];
    int[] total = new int[classes.length];
    double[] results = new double[classes.length];
    for (int i = 0; i < classes.length; ++i) {
      for (int j = 0; j < classes[i].length; ++j) {
        for (int k = 0; k < classes[i].length; ++k) {
          correct[i] += confusion[classes[i][j]][classes[i][k]];
        }
        for (int k = 0; k < confusion[classes[i][j]].length; ++k) {
          total[i] += confusion[classes[i][j]][k];
        }
      }
      results[i] = ((double) correct[i]) / ((double) (total[i]));
    }
    return results;
  }

  private static double approxCombinedAccuracy(int[][] confusion, int[][] classes) {
    int correct = 0;
    int total = 0;
    for (int i = 0; i < classes.length; ++i) {
      for (int j = 0; j < classes[i].length; ++j) {
        for (int k = 0; k < classes[i].length; ++k) {
          correct += confusion[classes[i][j]][classes[i][k]];
        }
        for (int k = 0; k < confusion[classes[i][j]].length; ++k) {
          total += confusion[classes[i][j]][k];
        }
      }
    }

    return ((double) correct) / ((double) (total));
  }

  public void printSummary() {
    System.err.println("EVALUATION SUMMARY");
    System.err.println("Tested " + (labelsCorrect + labelsIncorrect) + " labels");
    System.err.println("  " + labelsCorrect + " correct");
    System.err.println("  " + labelsIncorrect + " incorrect");
    System.err.println("  " + NF.format(exactNodeAccuracy()) + " accuracy");
    System.err.println("Tested " + (rootLabelsCorrect + rootLabelsIncorrect) + " roots");
    System.err.println("  " + rootLabelsCorrect + " correct");
    System.err.println("  " + rootLabelsIncorrect + " incorrect");
    System.err.println("  " + NF.format(exactRootAccuracy()) + " accuracy");

    printConfusionMatrix("Label", labelConfusion);
    printConfusionMatrix("Root label", rootLabelConfusion);

    if (equivalenceClasses != null && equivalenceClassNames != null) {
      double[] approxLabelAccuracy = approxAccuracy(labelConfusion, equivalenceClasses);
      for (int i = 0; i < equivalenceClassNames.length; ++i) {
        System.err.println("Approximate " + equivalenceClassNames[i] + " label accuracy: " + NF.format(approxLabelAccuracy[i]));
      }
      System.err.println("Combined approximate label accuracy: " + NF.format(approxCombinedAccuracy(labelConfusion, equivalenceClasses)));

      double[] approxRootLabelAccuracy = approxAccuracy(rootLabelConfusion, equivalenceClasses);
      for (int i = 0; i < equivalenceClassNames.length; ++i) {
        System.err.println("Approximate " + equivalenceClassNames[i] + " root label accuracy: " + NF.format(approxRootLabelAccuracy[i]));
      }
      System.err.println("Combined approximate root label accuracy: " + NF.format(approxCombinedAccuracy(rootLabelConfusion, equivalenceClasses)));
      System.err.println();
    }

    if (model.op.testOptions.ngramRecordSize > 0) {
      System.err.println(ngrams);
    }

    if (model.op.testOptions.printLengthAccuracies) {
      printLengthAccuracies();
    }
  }

  /**
   * Expected arguments are <code> -model model -treebank treebank </code> <br>
   *
   * For example <br>
   * <code>
   *  java edu.stanford.nlp.sentiment.Evaluate
   *   edu/stanford/nlp/models/sentiment/sentiment.ser.gz
   *   /u/nlp/data/sentiment/trees/dev.txt
   * </code>
   */
  public static void main(String[] args) {
    String modelPath = null;
    String treePath = null;
    boolean filterUnknown = false;

    List<String> remainingArgs = Generics.newArrayList();

    for (int argIndex = 0; argIndex < args.length; ) {
      if (args[argIndex].equalsIgnoreCase("-model")) {
        modelPath = args[argIndex + 1];
        argIndex += 2;
      } else if (args[argIndex].equalsIgnoreCase("-treebank")) {
        treePath = args[argIndex + 1];
        argIndex += 2;
      } else if (args[argIndex].equalsIgnoreCase("-filterUnknown")) {
        filterUnknown = true;
        argIndex++;
      } else {
        remainingArgs.add(args[argIndex]);
        argIndex++;
      }
    }

    String[] newArgs = new String[remainingArgs.size()];
    remainingArgs.toArray(newArgs);

    SentimentModel model = SentimentModel.loadSerialized(modelPath);
    for (int argIndex = 0; argIndex < newArgs.length; ) {
      int newIndex = model.op.setOption(newArgs, argIndex);
      if (argIndex == newIndex) {
        System.err.println("Unknown argument " + newArgs[argIndex]);
        throw new IllegalArgumentException("Unknown argument " + newArgs[argIndex]);
      }
      argIndex = newIndex;
    }

    List<Tree> trees = SentimentUtils.readTreesWithGoldLabels(treePath);
    if (filterUnknown) {
      trees = SentimentUtils.filterUnknownRoots(trees);
    }

    Evaluate eval = new Evaluate(model);
    eval.eval(trees);
    eval.printSummary();
  }
}
TOP

Related Classes of edu.stanford.nlp.sentiment.Evaluate

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.