Package opennlp.tools.cmdline.postag

Source Code of opennlp.tools.cmdline.postag.POSTaggerFineGrainedReportListener$Counter

/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package opennlp.tools.cmdline.postag;

import java.io.OutputStream;
import java.io.PrintStream;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import opennlp.tools.postag.POSSample;
import opennlp.tools.postag.POSTaggerEvaluationMonitor;
import opennlp.tools.util.Span;
import opennlp.tools.util.eval.FMeasure;
import opennlp.tools.util.eval.Mean;

/**
* Generates a detailed report for the POS Tagger.
* <p>
* It is possible to use it from an API and access the statistics using the
* provided getters
*
*/
public class POSTaggerFineGrainedReportListener implements
    POSTaggerEvaluationMonitor {

  private final PrintStream printStream;
  private final Stats stats = new Stats();

  private static final char[] alpha = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h',
      'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
      'w', 'x', 'y', 'z' };

  /**
   * Creates a listener that will print to {@link System#err}
   */
  public POSTaggerFineGrainedReportListener() {
    this(System.err);
  }

  /**
   * Creates a listener that prints to a given {@link OutputStream}
   */
  public POSTaggerFineGrainedReportListener(OutputStream outputStream) {
    this.printStream = new PrintStream(outputStream);
  }

  // methods inherited from EvaluationMonitor

  public void missclassified(POSSample reference, POSSample prediction) {
    stats.add(reference, prediction);
  }

  public void correctlyClassified(POSSample reference, POSSample prediction) {
    stats.add(reference, prediction);
  }

  /**
   * Writes the report to the {@link OutputStream}. Should be called only after
   * the evaluation process
   */
  public void writeReport() {
    printGeneralStatistics();
    // token stats
    printTokenErrorRank();
    printTokenOcurrenciesRank();
    // tag stats
    printTagsErrorRank();
    // confusion tables
    printGeneralConfusionTable();
    printDetailedConfusionMatrix();
  }

  // api methods
  // general stats

  public long getNumberOfSentences() {
    return stats.getNumberOfSentences();
  }

  public double getAverageSentenceSize() {
    return stats.getAverageSentenceSize();
  }

  public int getMinSentenceSize() {
    return stats.getMinSentenceSize();
  }

  public int getMaxSentenceSize() {
    return stats.getMaxSentenceSize();
  }

  public int getNumberOfTags() {
    return stats.getNumberOfTags();
  }

  public double getAccuracy() {
    return stats.getAccuracy();
  }

  // token stats

  public double getTokenAccuracy(String token) {
    return stats.getTokenAccuracy(token);
  }

  public SortedSet<String> getTokensOrderedByFrequency() {
    return stats.getTokensOrderedByFrequency();
  }

  public int getTokenFrequency(String token) {
    return stats.getTokenFrequency(token);
  }

  public int getTokenErrors(String token) {
    return stats.getTokenErrors(token);
  }

  public SortedSet<String> getTokensOrderedByNumberOfErrors() {
    return stats.getTokensOrderedByNumberOfErrors();
  }

  public SortedSet<String> getTagsOrderedByErrors() {
    return stats.getTagsOrderedByErrors();
  }

  public int getTagFrequency(String tag) {
    return stats.getTagFrequency(tag);
  }

  public int getTagErrors(String tag) {
    return stats.getTagErrors(tag);
  }

  public double getTagPrecision(String tag) {
    return stats.getTagPrecision(tag);
  }

  public double getTagRecall(String tag) {
    return stats.getTagRecall(tag);
  }

  public double getTagFMeasure(String tag) {
    return stats.getTagFMeasure(tag);
  }

  public SortedSet<String> getConfusionMatrixTagset() {
    return stats.getConfusionMatrixTagset();
  }

  public SortedSet<String> getConfusionMatrixTagset(String token) {
    return stats.getConfusionMatrixTagset(token);
  }

  public double[][] getConfusionMatrix() {
    return stats.getConfusionMatrix();
  }

  public double[][] getConfusionMatrix(String token) {
    return stats.getConfusionMatrix(token);
  }

  private String matrixToString(SortedSet<String> tagset, double[][] data,
      boolean filter) {
    // we dont want to print trivial cases (acc=1)
    int initialIndex = 0;
    String[] tags = tagset.toArray(new String[tagset.size()]);
    StringBuilder sb = new StringBuilder();
    int minColumnSize = Integer.MIN_VALUE;
    String[][] matrix = new String[data.length][data[0].length];
    for (int i = 0; i < data.length; i++) {
      int j = 0;
      for (; j < data[i].length - 1; j++) {
        matrix[i][j] = data[i][j] > 0 ? Integer.toString((int) data[i][j])
            : ".";
        if (minColumnSize < matrix[i][j].length()) {
          minColumnSize = matrix[i][j].length();
        }
      }
      matrix[i][j] = MessageFormat.format("{0,number,#.##%}", data[i][j]);
      if (data[i][j] == 1 && filter) {
        initialIndex = i + 1;
      }
    }

    final String headerFormat = "%" + (minColumnSize + 2) + "s "; // | 1234567 |
    final String cellFormat = "%" + (minColumnSize + 2) + "s "; // | 12345 |
    final String diagFormat = " %" + (minColumnSize + 2) + "s";
    for (int i = initialIndex; i < tagset.size(); i++) {
      sb.append(String.format(headerFormat,
          generateAlphaLabel(i - initialIndex).trim()));
    }
    sb.append("| Accuracy | <-- classified as\n");
    for (int i = initialIndex; i < data.length; i++) {
      int j = initialIndex;
      for (; j < data[i].length - 1; j++) {
        if (i == j) {
          String val = "<" + matrix[i][j] + ">";
          sb.append(String.format(diagFormat, val));
        } else {
          sb.append(String.format(cellFormat, matrix[i][j]));
        }
      }
      sb.append(
          String.format("|   %-6s |   %3s = ", matrix[i][j],
              generateAlphaLabel(i - initialIndex))).append(tags[i]);
      sb.append("\n");
    }
    return sb.toString();
  }

  private void printGeneralStatistics() {
    printHeader("Evaluation summary");
    printStream.append(
        String.format("%21s: %6s", "Number of sentences",
            Long.toString(getNumberOfSentences()))).append("\n");
    printStream.append(
        String.format("%21s: %6s", "Min sentence size", getMinSentenceSize()))
        .append("\n");
    printStream.append(
        String.format("%21s: %6s", "Max sentence size", getMaxSentenceSize()))
        .append("\n");
    printStream.append(
        String.format("%21s: %6s", "Average sentence size",
            MessageFormat.format("{0,number,#.##}", getAverageSentenceSize())))
        .append("\n");
    printStream.append(
        String.format("%21s: %6s", "Tags count", getNumberOfTags())).append(
        "\n");
    printStream.append(
        String.format("%21s: %6s", "Accuracy",
            MessageFormat.format("{0,number,#.##%}", getAccuracy()))).append(
        "\n");
    printFooter("Evaluation Corpus Statistics");
  }

  private void printTokenOcurrenciesRank() {
    printHeader("Most frequent tokens");

    SortedSet<String> toks = getTokensOrderedByFrequency();
    final int maxLines = 20;

    int maxTokSize = 5;

    int count = 0;
    Iterator<String> tokIterator = toks.iterator();
    while (tokIterator.hasNext() && count++ < maxLines) {
      String tok = tokIterator.next();
      if (tok.length() > maxTokSize) {
        maxTokSize = tok.length();
      }
    }

    int tableSize = maxTokSize + 19;
    String format = "| %3s | %6s | %" + maxTokSize + "s |";

    printLine(tableSize);
    printStream.append(String.format(format, "Pos", "Count", "Token")).append(
        "\n");
    printLine(tableSize);

    // get the first 20 errors
    count = 0;
    tokIterator = toks.iterator();
    while (tokIterator.hasNext() && count++ < maxLines) {
      String tok = tokIterator.next();
      int ocurrencies = getTokenFrequency(tok);

      printStream.append(String.format(format, count, ocurrencies, tok)

      ).append("\n");
    }
    printLine(tableSize);
    printFooter("Most frequent tokens");
  }

  private void printTokenErrorRank() {
    printHeader("Tokens with the highest number of errors");
    printStream.append("\n");

    SortedSet<String> toks = getTokensOrderedByNumberOfErrors();
    int maxTokenSize = 5;

    int count = 0;
    Iterator<String> tokIterator = toks.iterator();
    while (tokIterator.hasNext() && count++ < 20) {
      String tok = tokIterator.next();
      if (tok.length() > maxTokenSize) {
        maxTokenSize = tok.length();
      }
    }

    int tableSize = 31 + maxTokenSize;

    String format = "| %" + maxTokenSize + "s | %6s | %5s | %7s |\n";

    printLine(tableSize);
    printStream.append(String.format(format, "Token", "Errors", "Count",
        "% Err"));
    printLine(tableSize);

    // get the first 20 errors
    count = 0;
    tokIterator = toks.iterator();
    while (tokIterator.hasNext() && count++ < 20) {
      String tok = tokIterator.next();
      int ocurrencies = getTokenFrequency(tok);
      int errors = getTokenErrors(tok);
      String rate = MessageFormat.format("{0,number,#.##%}", (double) errors
          / ocurrencies);

      printStream.append(String.format(format, tok, errors, ocurrencies, rate)

      );
    }
    printLine(tableSize);
    printFooter("Tokens with the highest number of errors");
  }

  private void printTagsErrorRank() {
    printHeader("Detailed Accuracy By Tag");
    SortedSet<String> tags = getTagsOrderedByErrors();
    printStream.append("\n");

    int maxTagSize = 3;

    for (String t : tags) {
      if (t.length() > maxTagSize) {
        maxTagSize = t.length();
      }
    }

    int tableSize = 65 + maxTagSize;

    String headerFormat = "| %" + maxTagSize
        + "s | %6s | %6s | %7s | %9s | %6s | %9s |\n";
    String format = "| %" + maxTagSize
        + "s | %6s | %6s | %-7s | %-9s | %-6s | %-9s |\n";

    printLine(tableSize);
    printStream.append(String.format(headerFormat, "Tag", "Errors", "Count",
        "% Err", "Precision", "Recall", "F-Measure"));
    printLine(tableSize);

    Iterator<String> tagIterator = tags.iterator();
    while (tagIterator.hasNext()) {
      String tag = tagIterator.next();
      int ocurrencies = getTagFrequency(tag);
      int errors = getTagErrors(tag);
      String rate = MessageFormat.format("{0,number,#.###}", (double) errors
          / ocurrencies);

      double p = getTagPrecision(tag);
      double r = getTagRecall(tag);
      double f = getTagFMeasure(tag);

      printStream.append(String.format(format, tag, errors, ocurrencies, rate,
          MessageFormat.format("{0,number,#.###}", p > 0 ? p : 0),
          MessageFormat.format("{0,number,#.###}", r > 0 ? r : 0),
          MessageFormat.format("{0,number,#.###}", f > 0 ? f : 0))

      );
    }
    printLine(tableSize);

    printFooter("Tags with the highest number of errors");
  }

  private void printGeneralConfusionTable() {
    printHeader("Confusion matrix");

    SortedSet<String> labels = getConfusionMatrixTagset();

    double[][] confusionMatrix = getConfusionMatrix();

    printStream.append("\nTags with 100% accuracy: ");
    int line = 0;
    for (String label : labels) {
      if (confusionMatrix[line][confusionMatrix[0].length - 1] == 1) {
        printStream.append(label).append(" (")
            .append(Integer.toString((int) confusionMatrix[line][line]))
            .append(") ");
      }
      line++;
    }

    printStream.append("\n\n");

    printStream.append(matrixToString(labels, confusionMatrix, true));

    printFooter("Confusion matrix");
  }

  private void printDetailedConfusionMatrix() {
    printHeader("Confusion matrix for tokens");
    printStream.append("  sorted by number of errors\n");
    SortedSet<String> toks = getTokensOrderedByNumberOfErrors();

    for (String t : toks) {
      double acc = getTokenAccuracy(t);
      if (acc < 1) {
        printStream
            .append("\n[")
            .append(t)
            .append("]\n")
            .append(
                String.format("%12s: %-8s", "Accuracy",
                    MessageFormat.format("{0,number,#.##%}", acc)))
            .append("\n");
        printStream.append(
            String.format("%12s: %-8s", "Ocurrencies",
                Integer.toString(getTokenFrequency(t)))).append("\n");
        printStream.append(
            String.format("%12s: %-8s", "Errors",
                Integer.toString(getTokenErrors(t)))).append("\n");

        SortedSet<String> labels = getConfusionMatrixTagset(t);

        double[][] confusionMatrix = getConfusionMatrix(t);

        printStream.append(matrixToString(labels, confusionMatrix, false));
      }
    }
    printFooter("Confusion matrix for tokens");
  }

  /** Auxiliary method that prints a emphasised report header */
  private void printHeader(String text) {
    printStream.append("=== ").append(text).append(" ===\n");
  }

  /** Auxiliary method that prints a marker to the end of a report */
  private void printFooter(String text) {
    printStream.append("\n<-end> ").append(text).append("\n\n");
  }

  /** Auxiliary method that prints a horizontal line of a given size */
  private void printLine(int size) {
    for (int i = 0; i < size; i++) {
      printStream.append("-");
    }
    printStream.append("\n");
  }

  private static final String generateAlphaLabel(int index) {

    char labelChars[] = new char[3];
    int i;

    for (i = 2; i >= 0; i--) {
      labelChars[i] = alpha[index % alpha.length];
      index = index / alpha.length - 1;
      if (index < 0) {
        break;
      }
    }

    return new String(labelChars);
  }

  private class Stats {

    // general statistics
    private final Mean accuracy = new Mean();
    private final Mean averageSentenceLength = new Mean();
    private int minimalSentenceLength = Integer.MAX_VALUE;
    private int maximumSentenceLength = Integer.MIN_VALUE;

    // token statistics
    private final Map<String, Mean> tokAccuracies = new HashMap<String, Mean>();
    private final Map<String, Counter> tokOcurrencies = new HashMap<String, Counter>();
    private final Map<String, Counter> tokErrors = new HashMap<String, Counter>();

    // tag statistics
    private final Map<String, Counter> tagOcurrencies = new HashMap<String, Counter>();
    private final Map<String, Counter> tagErrors = new HashMap<String, Counter>();
    private final Map<String, FMeasure> tagFMeasure = new HashMap<String, FMeasure>();

    // represents a Confusion Matrix that aggregates all tokens
    private final Map<String, ConfusionMatrixLine> generalConfusionMatrix = new HashMap<String, ConfusionMatrixLine>();

    // represents a set of Confusion Matrix for each token
    private final Map<String, Map<String, ConfusionMatrixLine>> tokenConfusionMatrix = new HashMap<String, Map<String, ConfusionMatrixLine>>();

    public void add(POSSample reference, POSSample prediction) {
      int length = reference.getSentence().length;
      averageSentenceLength.add(length);

      if (minimalSentenceLength > length) {
        minimalSentenceLength = length;
      }
      if (maximumSentenceLength < length) {
        maximumSentenceLength = length;
      }

      String[] toks = reference.getSentence();
      String[] refs = reference.getTags();
      String[] preds = prediction.getTags();

      updateTagFMeasure(refs, preds);

      for (int i = 0; i < toks.length; i++) {
        add(toks[i], refs[i], preds[i]);
      }
    }

    /**
     * Includes a new evaluation data
     *
     * @param tok
     *          the evaluated token
     * @param ref
     *          the reference pos tag
     * @param pred
     *          the predicted pos tag
     */
    private void add(String tok, String ref, String pred) {
      // token stats
      if (!tokAccuracies.containsKey(tok)) {
        tokAccuracies.put(tok, new Mean());
        tokOcurrencies.put(tok, new Counter());
        tokErrors.put(tok, new Counter());
      }
      tokOcurrencies.get(tok).increment();

      // tag stats
      if (!tagOcurrencies.containsKey(ref)) {
        tagOcurrencies.put(ref, new Counter());
        tagErrors.put(ref, new Counter());
      }
      tagOcurrencies.get(ref).increment();

      // updates general, token and tag error stats
      if (ref.equals(pred)) {
        tokAccuracies.get(tok).add(1);
        accuracy.add(1);
      } else {
        tokAccuracies.get(tok).add(0);
        tokErrors.get(tok).increment();
        tagErrors.get(ref).increment();
        accuracy.add(0);
      }

      // populate confusion matrixes
      if (!generalConfusionMatrix.containsKey(ref)) {
        generalConfusionMatrix.put(ref, new ConfusionMatrixLine(ref));
      }
      generalConfusionMatrix.get(ref).increment(pred);

      if (!tokenConfusionMatrix.containsKey(tok)) {
        tokenConfusionMatrix.put(tok,
            new HashMap<String, ConfusionMatrixLine>());
      }
      if (!tokenConfusionMatrix.get(tok).containsKey(ref)) {
        tokenConfusionMatrix.get(tok).put(ref, new ConfusionMatrixLine(ref));
      }
      tokenConfusionMatrix.get(tok).get(ref).increment(pred);
    }

    private void updateTagFMeasure(String[] refs, String[] preds) {
      // create a set with all tags
      Set<String> tags = new HashSet<String>(Arrays.asList(refs));
      tags.addAll(Arrays.asList(preds));

      // create samples for each tag
      for (String tag : tags) {
        List<Span> reference = new ArrayList<Span>();
        List<Span> prediction = new ArrayList<Span>();
        for (int i = 0; i < refs.length; i++) {
          if (refs[i].equals(tag)) {
            reference.add(new Span(i, i + 1));
          }
          if (preds[i].equals(tag)) {
            prediction.add(new Span(i, i + 1));
          }
        }
        if (!this.tagFMeasure.containsKey(tag)) {
          this.tagFMeasure.put(tag, new FMeasure());
        }
        // populate the fmeasure
        this.tagFMeasure.get(tag).updateScores(
            reference.toArray(new Span[reference.size()]),
            prediction.toArray(new Span[prediction.size()]));
      }
    }

    public double getAccuracy() {
      return accuracy.mean();
    }

    public int getNumberOfTags() {
      return this.tagOcurrencies.keySet().size();
    }

    public long getNumberOfSentences() {
      return this.averageSentenceLength.count();
    }

    public double getAverageSentenceSize() {
      return this.averageSentenceLength.mean();
    }

    public int getMinSentenceSize() {
      return this.minimalSentenceLength;
    }

    public int getMaxSentenceSize() {
      return this.maximumSentenceLength;
    }

    public double getTokenAccuracy(String token) {
      return tokAccuracies.get(token).mean();
    }

    public int getTokenErrors(String token) {
      return tokErrors.get(token).value();
    }

    public int getTokenFrequency(String token) {
      return tokOcurrencies.get(token).value();
    }

    public SortedSet<String> getTokensOrderedByFrequency() {
      SortedSet<String> toks = new TreeSet<String>(new Comparator<String>() {
        public int compare(String o1, String o2) {
          if (o1.equals(o2)) {
            return 0;
          }
          int e1 = 0, e2 = 0;
          if (tokOcurrencies.containsKey(o1))
            e1 = tokOcurrencies.get(o1).value();
          if (tokOcurrencies.containsKey(o2))
            e2 = tokOcurrencies.get(o2).value();
          if (e1 == e2) {
            return o1.compareTo(o2);
          }
          return e2 - e1;
        }
      });

      toks.addAll(tokOcurrencies.keySet());

      return Collections.unmodifiableSortedSet(toks);
    }

    public SortedSet<String> getTokensOrderedByNumberOfErrors() {
      SortedSet<String> toks = new TreeSet<String>(new Comparator<String>() {
        public int compare(String o1, String o2) {
          if (o1.equals(o2)) {
            return 0;
          }
          int e1 = 0, e2 = 0;
          if (tokErrors.containsKey(o1))
            e1 = tokErrors.get(o1).value();
          if (tokErrors.containsKey(o2))
            e2 = tokErrors.get(o2).value();
          if (e1 == e2) {
            return o1.compareTo(o2);
          }
          return e2 - e1;
        }
      });
      toks.addAll(tokErrors.keySet());
      return toks;
    }

    public int getTagFrequency(String tag) {
      return tagOcurrencies.get(tag).value();
    }

    public int getTagErrors(String tag) {
      return tagErrors.get(tag).value();
    }

    public double getTagFMeasure(String tag) {
      return tagFMeasure.get(tag).getFMeasure();
    }

    public double getTagRecall(String tag) {
      return tagFMeasure.get(tag).getRecallScore();
    }

    public double getTagPrecision(String tag) {
      return tagFMeasure.get(tag).getPrecisionScore();
    }

    public SortedSet<String> getTagsOrderedByErrors() {
      SortedSet<String> tags = new TreeSet<String>(new Comparator<String>() {
        public int compare(String o1, String o2) {
          if (o1.equals(o2)) {
            return 0;
          }
          int e1 = 0, e2 = 0;
          if (tagErrors.containsKey(o1))
            e1 = tagErrors.get(o1).value();
          if (tagErrors.containsKey(o2))
            e2 = tagErrors.get(o2).value();
          if (e1 == e2) {
            return o1.compareTo(o2);
          }
          return e2 - e1;
        }
      });
      tags.addAll(tagErrors.keySet());
      return Collections.unmodifiableSortedSet(tags);
    }

    public SortedSet<String> getConfusionMatrixTagset() {
      return getConfusionMatrixTagset(generalConfusionMatrix);
    }

    public double[][] getConfusionMatrix() {
      return createConfusionMatrix(getConfusionMatrixTagset(),
          generalConfusionMatrix);
    }

    public SortedSet<String> getConfusionMatrixTagset(String token) {
      return getConfusionMatrixTagset(tokenConfusionMatrix.get(token));
    }

    public double[][] getConfusionMatrix(String token) {
      return createConfusionMatrix(getConfusionMatrixTagset(token),
          tokenConfusionMatrix.get(token));
    }

    /**
     * Creates a matrix with N lines and N + 1 columns with the data from
     * confusion matrix. The last column is the accuracy.
     */
    private double[][] createConfusionMatrix(SortedSet<String> tagset,
        Map<String, ConfusionMatrixLine> data) {
      int size = tagset.size();
      double[][] matrix = new double[size][size + 1];
      int line = 0;
      for (String ref : tagset) {
        int column = 0;
        for (String pred : tagset) {
          matrix[line][column] = (double) (data.get(ref) != null ? data
              .get(ref).getValue(pred) : 0);
          column++;
        }
        // set accuracy
        matrix[line][column] = (double) (data.get(ref) != null ? data.get(ref)
            .getAccuracy() : 0);
        line++;
      }

      return matrix;
    }

    private SortedSet<String> getConfusionMatrixTagset(
        Map<String, ConfusionMatrixLine> data) {
      SortedSet<String> tags = new TreeSet<String>(new CategoryComparator(data));
      tags.addAll(data.keySet());
      List<String> col = new LinkedList<String>();
      for (String t : tags) {
        col.addAll(data.get(t).line.keySet());
      }
      tags.addAll(col);
      return Collections.unmodifiableSortedSet(tags);
    }
  }

  /**
   * A comparator that sorts the confusion matrix labels according to the
   * accuracy of each line
   */
  private static class CategoryComparator implements Comparator<String> {

    private Map<String, ConfusionMatrixLine> confusionMatrix;

    public CategoryComparator(Map<String, ConfusionMatrixLine> confusionMatrix) {
      this.confusionMatrix = confusionMatrix;
    }

    public int compare(String o1, String o2) {
      if (o1.equals(o2)) {
        return 0;
      }
      ConfusionMatrixLine t1 = confusionMatrix.get(o1);
      ConfusionMatrixLine t2 = confusionMatrix.get(o2);
      if (t1 == null || t2 == null) {
        if (t1 == null) {
          return 1;
        } else if (t2 == null) {
          return -1;
        }
        return 0;
      }
      double r1 = t1.getAccuracy();
      double r2 = t2.getAccuracy();
      if (r1 == r2) {
        return o1.compareTo(o2);
      }
      if (r2 > r1) {
        return 1;
      }
      return -1;
    }

  }

  /**
   * Represents a line in the confusion table.
   */
  private static class ConfusionMatrixLine {

    private Map<String, Counter> line = new HashMap<String, Counter>();
    private String ref;
    private int total = 0;
    private int correct = 0;
    private double acc = -1;

    /**
     * Creates a new {@link ConfusionMatrixLine}
     *
     * @param ref
     *          the reference column
     */
    public ConfusionMatrixLine(String ref) {
      this.ref = ref;
    }

    /**
     * Increments the counter for the given column and updates the statistics.
     *
     * @param column
     *          the column to be incremented
     */
    public void increment(String column) {
      total++;
      if (column.equals(ref))
        correct++;
      if (!line.containsKey(column)) {
        line.put(column, new Counter());
      }
      line.get(column).increment();
    }

    /**
     * Gets the calculated accuracy of this element
     *
     * @return the accuracy
     */
    public double getAccuracy() {
      // we save the accuracy because it is frequently used by the comparator
      if (acc == -1) {
        if (total == 0)
          acc = 0;
        acc = (double) correct / (double) total;
      }
      return acc;
    }

    /**
     * Gets the value given a column
     *
     * @param column
     *          the column
     * @return the counter value
     */
    public int getValue(String column) {
      Counter c = line.get(column);
      if (c == null)
        return 0;
      return c.value();
    }
  }

  /**
   * Implements a simple counter
   */
  private static class Counter {
    private int c = 0;

    public void increment() {
      c++;
    }

    public int value() {
      return c;
    }
  }

}
TOP

Related Classes of opennlp.tools.cmdline.postag.POSTaggerFineGrainedReportListener$Counter

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.