Package edu.stanford.nlp.neural.rnn

Source Code of edu.stanford.nlp.neural.rnn.TopNGramRecord

package edu.stanford.nlp.neural.rnn;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Sentence;
import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;

/**
* This class stores the best K ngrams for each class for a model.  It
* does so by keeping priority queues of the best K trees seen,
* eliminating duplicates.
* <br>
* The interface is not too advanced at the moment.  To use, first
* create it with the number of classes and the counts of the ngrams,
* and then add trees one at a time with countTree.  The results can
* be extracted with toString().  There is no way to directly access
* the internal results.
*
* @author John Bauer
*/
public class TopNGramRecord {
  /** how many ngrams to store for each class */
  final private int ngramCount;

  /** How many classes we are storing */
  final private int numClasses;

  /** Longest ngram to keep */
  final private int maximumLength;

  Map<Integer, Map<Integer, PriorityQueue<Tree>>> classToNGrams = Generics.newHashMap();

  public TopNGramRecord(int numClasses, int ngramCount, int maximumLength) {
    this.numClasses = numClasses;
    this.ngramCount = ngramCount;
    this.maximumLength = maximumLength;
    for (int i = 0; i < numClasses; ++i) {
      Map<Integer, PriorityQueue<Tree>> innerMap = Generics.newHashMap();
      classToNGrams.put(i, innerMap);
    }
  }

  /**
   * Adds the tree and all its subtrees to the appropriate
   * PriorityQueues for each predicted class.
   */
  public void countTree(Tree tree) {
    Tree simplified = simplifyTree(tree);
    for (int i = 0; i < numClasses; ++i) {
      countTreeHelper(simplified, i, classToNGrams.get(i));
    }
  }

  /**
   * Remove everything but the skeleton, the predictions, and the labels
   */
  private Tree simplifyTree(Tree tree) {
    CoreLabel newLabel = new CoreLabel();
    newLabel.set(RNNCoreAnnotations.Predictions.class, RNNCoreAnnotations.getPredictions(tree));
    newLabel.setValue(tree.label().value());
    if (tree.isLeaf()) {
      return tree.treeFactory().newLeaf(newLabel);
    }

    List<Tree> children = Generics.newArrayList(tree.children().length);
    for (int i = 0; i < tree.children().length; ++i) {
      children.add(simplifyTree(tree.children()[i]));
    }
    return tree.treeFactory().newTreeNode(newLabel, children);
  }

  private int countTreeHelper(Tree tree, int prediction, Map<Integer, PriorityQueue<Tree>> ngrams) {
    if (tree.isLeaf()) {
      return 1;
    }
    int treeSize = 0;
    for (Tree child : tree.children()) {
      treeSize += countTreeHelper(child, prediction, ngrams);
    }
    if (maximumLength > 0 && treeSize > maximumLength) {
      return treeSize;
    }
    PriorityQueue<Tree> queue = getPriorityQueue(treeSize, prediction, ngrams);
    // TODO: should we allow classes which aren't the best possible
    // class for this tree to be included in the results?
    if (!queue.contains(tree)) {
      queue.add(tree);
    }
    if (queue.size() > ngramCount) {
      queue.poll();
    }
    return treeSize;
  }

  private PriorityQueue<Tree> getPriorityQueue(int size, int prediction, Map<Integer, PriorityQueue<Tree>> ngrams) {
    PriorityQueue<Tree> queue = ngrams.get(size);
    if (queue != null) {
      return queue;
    }
    queue = new PriorityQueue<Tree>(ngramCount + 1, scoreComparator(prediction));
    ngrams.put(size, queue);
    return queue;
  }

  private Comparator<Tree> scoreComparator(final int prediction) {
    return (tree1, tree2) -> {
      double score1 = RNNCoreAnnotations.getPredictions(tree1).get(prediction);
      double score2 = RNNCoreAnnotations.getPredictions(tree2).get(prediction);
      if (score1 < score2) {
        return -1;
      } else if (score1 > score2) {
        return 1;
      } else {
        return 0;
      }
    };
  }

  public String toString() {
    StringBuilder result = new StringBuilder();
    for (int prediction = 0; prediction < numClasses; ++prediction) {
      result.append("Best scores for class " + prediction + "\n");
      Map<Integer, PriorityQueue<Tree>> ngrams = classToNGrams.get(prediction);
      for (Map.Entry<Integer, PriorityQueue<Tree>> entry : ngrams.entrySet()) {
        List<Tree> trees = Generics.newArrayList(entry.getValue());
        Collections.sort(trees, scoreComparator(prediction));
        result.append("  Len " + entry.getKey() + "\n");
        for (int i = trees.size() - 1; i >= 0; i--) {
          Tree tree = trees.get(i);
          result.append("    " + Sentence.listToString(tree.yield()) + "  [" + RNNCoreAnnotations.getPredictions(tree).get(prediction) + "]\n");
        }
      }
    }
    return result.toString();
  }
}
TOP

Related Classes of edu.stanford.nlp.neural.rnn.TopNGramRecord

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.