Package edu.stanford.nlp.parser.lexparser

Source Code of edu.stanford.nlp.parser.lexparser.FactoredLexicon

package edu.stanford.nlp.parser.lexparser;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

import edu.stanford.nlp.international.Languages;
import edu.stanford.nlp.international.Languages.Language;
import edu.stanford.nlp.international.arabic.ArabicMorphoFeatureSpecification;
import edu.stanford.nlp.international.french.FrenchMorphoFeatureSpecification;
import edu.stanford.nlp.international.morph.MorphoFeatureSpecification;
import edu.stanford.nlp.international.morph.MorphoFeatureSpecification.MorphoFeatureType;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.TwoDimensionalIntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;

/**
*
* @author Spence Green
*
*/
public class FactoredLexicon extends BaseLexicon {

  private static final long serialVersionUID = -744693222804176489L;

  private static final boolean DEBUG = false;

  private MorphoFeatureSpecification morphoSpec;

  private static final String NO_MORPH_ANALYSIS = "xXxNONExXx";

  private Index<String> morphIndex = new HashIndex<String>();

  private TwoDimensionalIntCounter<Integer,Integer> wordTag = new TwoDimensionalIntCounter<Integer,Integer>(40000);
  private Counter<Integer> wordTagUnseen = new ClassicCounter<Integer>(500);

  private TwoDimensionalIntCounter<Integer,Integer> lemmaTag = new TwoDimensionalIntCounter<Integer,Integer>(40000);
  private Counter<Integer> lemmaTagUnseen = new ClassicCounter<Integer>(500);

  private TwoDimensionalIntCounter<Integer,Integer> morphTag = new TwoDimensionalIntCounter<Integer,Integer>(500);
  private Counter<Integer> morphTagUnseen = new ClassicCounter<Integer>(500);

  private Counter<Integer> tagCounter = new ClassicCounter<Integer>(300);

  public FactoredLexicon(MorphoFeatureSpecification morphoSpec, Index<String> wordIndex, Index<String> tagIndex) {
    super(wordIndex, tagIndex);
    this.morphoSpec = morphoSpec;
  }

  public FactoredLexicon(Options op, MorphoFeatureSpecification morphoSpec, Index<String> wordIndex, Index<String> tagIndex) {
    super(op, wordIndex, tagIndex);
    this.morphoSpec = morphoSpec;
  }

  /**
   * Rule table is lemmas. So isKnown() is slightly trickier.
   */
  @Override
  public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) {

    if (word == wordIndex.indexOf(BOUNDARY)) {
      // Deterministic tagging of the boundary symbol
      return rulesWithWord[word].iterator();

    } else if (isKnown(word)) {
      // Strict lexical tagging for seen *lemma* types
      // We need to copy the word form into the rules, which currently have lemmas in them
      return rulesWithWord[word].iterator();

    } else {
      if (DEBUG) System.err.println("UNKNOWN WORD");
      // Unknown word signatures
      Set<IntTaggedWord> lexRules = Generics.newHashSet(10);
      List<IntTaggedWord> uwRules = rulesWithWord[wordIndex.indexOf(UNKNOWN_WORD)];
      // Inject the word into these rules instead of the UW signature
      for (IntTaggedWord iTW : uwRules) {
        lexRules.add(new IntTaggedWord(word, iTW.tag));
      }
      return lexRules.iterator();
    }
  }

  @Override
  public float score(IntTaggedWord iTW, int loc, String word, String featureSpec) {
    final int wordId = iTW.word();
    final int tagId = iTW.tag();

    // Force 1-best path to go through the boundary symbol
    // (deterministic tagging)
    final int boundaryId = wordIndex.indexOf(BOUNDARY);
    final int boundaryTagId = tagIndex.indexOf(BOUNDARY_TAG);
    if (wordId == boundaryId && tagId == boundaryTagId) {
      return 0.0f;
    }

    // Morphological features
    String tag = tagIndex.get(iTW.tag());
    Pair<String,String> lemmaMorph = MorphoFeatureSpecification.splitMorphString(word, featureSpec);
    String lemma = lemmaMorph.first();
    int lemmaId = wordIndex.indexOf(lemma);
    String richMorphTag = lemmaMorph.second();
    String reducedMorphTag = morphoSpec.strToFeatures(richMorphTag).toString().trim();
    reducedMorphTag = reducedMorphTag.length() == 0 ? NO_MORPH_ANALYSIS : reducedMorphTag;
    int morphId = morphIndex.addToIndex(reducedMorphTag);

    // Score the factors and create the rule score p_W_T
    double p_W_Tf = Math.log(probWordTag(word, loc, wordId, tagId));
//    double p_L_T = Math.log(probLemmaTag(word, loc, tagId, lemmaId));
    double p_L_T = 0.0;
    double p_M_T = Math.log(probMorphTag(tagId, morphId));
    double p_W_T = p_W_Tf + p_L_T + p_M_T;

    if (DEBUG) {
//      String tag = tagIndex.get(tagId);
      System.err.printf("WSGDEBUG: %s --> %s %s %s ||  %.10f (%.5f / %.5f / %.5f)%n", tag, word, lemma,
          reducedMorphTag, p_W_T, p_W_Tf, p_L_T, p_M_T);
    }

    // Filter low probability taggings
    return p_W_T > -100.0 ? (float) p_W_T : Float.NEGATIVE_INFINITY;
  }

  private double probWordTag(String word, int loc, int wordId, int tagId) {
    double cW = wordTag.totalCount(wordId);
    double cWT = wordTag.getCount(wordId, tagId);

    // p_L
    double p_W = cW / wordTag.totalCount();

    // p_T
    double cTseen = tagCounter.getCount(tagId);
    double p_T = cTseen / tagCounter.totalCount();

    // p_T_L
    double p_W_T = 0.0;
    if (cW > 0.0) { // Seen lemma
      double p_T_W = 0.0;
      if (cW > 100.0 && cWT > 0.0) {
        p_T_W = cWT / cW;
      } else {
        double cTunseen = wordTagUnseen.getCount(tagId);
        // TODO p_T_U is 0?
        double p_T_U = cTunseen / wordTagUnseen.totalCount();
        p_T_W = (cWT + smooth[1]*p_T_U) / (cW + smooth[1]);
      }
      p_W_T = p_T_W * p_W / p_T;

    } else { // Unseen word. Score based on the word signature (of the surface form)
      IntTaggedWord iTW = new IntTaggedWord(wordId, tagId);
      double c_T = tagCounter.getCount(tagId);
      p_W_T = Math.exp(getUnknownWordModel().score(iTW, loc, c_T, tagCounter.totalCount(), smooth[0], word));
    }

    return p_W_T;
  }

  /**
   * This method should never return 0!!
   */
  private double probLemmaTag(String word, int loc, int tagId, int lemmaId) {
    double cL = lemmaTag.totalCount(lemmaId);
    double cLT = lemmaTag.getCount(lemmaId, tagId);

    // p_L
    double p_L = cL / lemmaTag.totalCount();

    // p_T
    double cTseen = tagCounter.getCount(tagId);
    double p_T = cTseen / tagCounter.totalCount();

    // p_T_L
    double p_L_T = 0.0;
    if (cL > 0.0) { // Seen lemma
      double p_T_L = 0.0;
      if (cL > 100.0 && cLT > 0.0) {
        p_T_L = cLT / cL;
      } else {
        double cTunseen = lemmaTagUnseen.getCount(tagId);
        // TODO(spenceg): p_T_U is 0??
        double p_T_U = cTunseen / lemmaTagUnseen.totalCount();
        p_T_L = (cLT + smooth[1]*p_T_U) / (cL + smooth[1]);
      }
      p_L_T = p_T_L * p_L / p_T;

    } else { // Unseen lemma. Score based on the word signature (of the surface form)
      // Hack
      double cTunseen = lemmaTagUnseen.getCount(tagId);
      p_L_T = cTunseen / tagCounter.totalCount();

      //      int wordId = wordIndex.indexOf(word);
//      IntTaggedWord iTW = new IntTaggedWord(wordId, tagId);
//      double c_T = tagCounter.getCount(tagId);
//      p_L_T = Math.exp(getUnknownWordModel().score(iTW, loc, c_T, tagCounter.totalCount(), smooth[0], word));
    }

    return p_L_T;
  }

  /**
   * This method should never return 0!
   */
  private double probMorphTag(int tagId, int morphId) {
    double cM = morphTag.totalCount(morphId);
    double cMT = morphTag.getCount(morphId, tagId);

    // p_M
    double p_M = cM / morphTag.totalCount();

    // p_T
    double cTseen = tagCounter.getCount(tagId);
    double p_T = cTseen / tagCounter.totalCount();

    double p_M_T = 0.0;
    if (cM > 100.0 && cMT > 0.0) {
      double p_T_M = cMT / cM;

//      else {
//        double cTunseen = morphTagUnseen.getCount(tagId);
//        double p_T_U = cTunseen / morphTagUnseen.totalCount();
//        p_T_M = (cMT + smooth[1]*p_T_U) / (cM + smooth[1]);
//      }
      p_M_T = p_T_M * p_M / p_T;

    } else { // Unseen morphological analysis
      // Hack....unseen morph tags are extremely rare
      // Add+1 smoothing
      p_M_T = 1.0 / (morphTag.totalCount() + tagIndex.size() + 1.0);
    }

    return p_M_T;
  }

  /**
   * This method should populate wordIndex, tagIndex, and morphIndex.
   */
  @Override
  public void train(Collection<Tree> trees, Collection<Tree> rawTrees) {
    double weight = 1.0;
    // Train uw model on words
    uwModelTrainer.train(trees, weight);

    final double numTrees = trees.size();
    Iterator<Tree> rawTreesItr = rawTrees == null ? null : rawTrees.iterator();
    Iterator<Tree> treeItr = trees.iterator();

    // Train factored lexicon on lemmas and morph tags
    int treeId = 0;
    while (treeItr.hasNext()) {
      Tree tree = treeItr.next();
      // CoreLabels, with morph analysis in the originalText annotation
      List<Label> yield = rawTrees == null ? tree.yield() : rawTreesItr.next().yield();
      // Annotated, binarized tree for the tags (labels are usually CategoryWordTag)
      List<Label> pretermYield = tree.preTerminalYield();

      int yieldLen = yield.size();
      for (int i = 0; i < yieldLen; ++i) {
        String word = yield.get(i).value();
        int wordId = wordIndex.addToIndex(word); // Don't do anything with words
        String tag = pretermYield.get(i).value();
        int tagId = tagIndex.addToIndex(tag);

        // Use the word as backup if there is no lemma
        String featureStr = ((CoreLabel) yield.get(i)).originalText();
        Pair<String,String> lemmaMorph = MorphoFeatureSpecification.splitMorphString(word, featureStr);
        String lemma = lemmaMorph.first();
        int lemmaId = wordIndex.addToIndex(lemma);
        String richMorphTag = lemmaMorph.second();
        String reducedMorphTag = morphoSpec.strToFeatures(richMorphTag).toString().trim();
        reducedMorphTag = reducedMorphTag.isEmpty() ? NO_MORPH_ANALYSIS : reducedMorphTag;
        int morphId = morphIndex.addToIndex(reducedMorphTag);

        // Seen event counts
        wordTag.incrementCount(wordId, tagId);
        lemmaTag.incrementCount(lemmaId, tagId);
        morphTag.incrementCount(morphId, tagId);
        tagCounter.incrementCount(tagId);

        // Unseen event counts
        if (treeId > op.trainOptions.fractionBeforeUnseenCounting*numTrees) {
          if (! wordTag.firstKeySet().contains(wordId) || wordTag.getCounter(wordId).totalCount() < 2) {
            wordTagUnseen.incrementCount(tagId);
          }
          if (! lemmaTag.firstKeySet().contains(lemmaId) || lemmaTag.getCounter(lemmaId).totalCount() < 2) {
            lemmaTagUnseen.incrementCount(tagId);
          }
          if (! morphTag.firstKeySet().contains(morphId) || morphTag.getCounter(morphId).totalCount() < 2) {
            morphTagUnseen.incrementCount(tagId);
          }
        }
      }
      ++treeId;

      if (DEBUG && (treeId % 100) == 0) {
        System.err.printf("[%d]",treeId);
      }
      if (DEBUG && (treeId % 10000) == 0) {
        System.err.println();
      }
    }
  }

  /**
   * Rule table is lemmas!
   */
  @Override
  protected void initRulesWithWord() {
    // Add synthetic symbols to the indices
    int unkWord = wordIndex.addToIndex(UNKNOWN_WORD);
    int boundaryWordId = wordIndex.addToIndex(BOUNDARY);
    int boundaryTagId = tagIndex.addToIndex(BOUNDARY_TAG);

    // Initialize rules table
    final int numWords = wordIndex.size();
    rulesWithWord = new List[numWords];
    for (int w = 0; w < numWords; w++) {
      rulesWithWord[w] = new ArrayList<IntTaggedWord>(1);
    }

    // Collect rules, indexed by word
    Set<IntTaggedWord> lexRules = Generics.newHashSet(40000);
    for (int wordId : wordTag.firstKeySet()) {
      for (int tagId : wordTag.getCounter(wordId).keySet()) {
        lexRules.add(new IntTaggedWord(wordId, tagId));
        lexRules.add(new IntTaggedWord(nullWord, tagId));
      }
    }

    // Known words and signatures
    for (IntTaggedWord iTW : lexRules) {
      if (iTW.word() == nullWord) {
        // Mix in UW signature rules for open class types
        double types = uwModel.unSeenCounter().getCount(iTW);
        if (types > trainOptions.openClassTypesThreshold) {
          IntTaggedWord iTU = new IntTaggedWord(unkWord, iTW.tag);
          if (!rulesWithWord[unkWord].contains(iTU)) {
            rulesWithWord[unkWord].add(iTU);
          }
        }
      } else {
        // Known word
        rulesWithWord[iTW.word].add(iTW);
      }
    }

    System.err.print("The " + rulesWithWord[unkWord].size() + " open class tags are: [");
    for (IntTaggedWord item : rulesWithWord[unkWord]) {
      System.err.print(" " + tagIndex.get(item.tag()));
    }
    System.err.println(" ] ");

    // Boundary symbol has one tagging
    rulesWithWord[boundaryWordId].add(new IntTaggedWord(boundaryWordId, boundaryTagId));
  }

  /**
   * Convert a treebank to factored lexicon events for fast iteration in the
   * optimizer.
   */
  private static List<FactoredLexiconEvent> treebankToLexiconEvents(List<Tree> treebank,
      FactoredLexicon lexicon) {
    List<FactoredLexiconEvent> events = new ArrayList<FactoredLexiconEvent>(70000);
    for (Tree tree : treebank) {
      List<Label> yield = tree.yield();
      List<Label> preterm = tree.preTerminalYield();
      assert yield.size() == preterm.size();
      int yieldLen = yield.size();
      for (int i = 0; i < yieldLen; ++i) {
        String tag = preterm.get(i).value();
        int tagId = lexicon.tagIndex.indexOf(tag);
        String word = yield.get(i).value();
        int wordId = lexicon.wordIndex.indexOf(word);
        // Two checks to see if we keep this example
        if (tagId < 0) {
          System.err.println("Discarding training example: " + word + " " + tag);
          continue;
        }
//        if (counts.probWordTag(wordId, tagId) == 0.0) {
//          System.err.println("Discarding low counts <w,t> pair: " + word + " " + tag);
//          continue;
//        }

        String featureStr = ((CoreLabel) yield.get(i)).originalText();
        Pair<String,String> lemmaMorph = MorphoFeatureSpecification.splitMorphString(word, featureStr);
        String lemma = lemmaMorph.first();
        String richTag = lemmaMorph.second();
        String reducedTag = lexicon.morphoSpec.strToFeatures(richTag).toString();
        reducedTag = reducedTag.length() == 0 ? NO_MORPH_ANALYSIS : reducedTag;

        int lemmaId = lexicon.wordIndex.indexOf(lemma);
        int morphId = lexicon.morphIndex.indexOf(reducedTag);
        FactoredLexiconEvent event = new FactoredLexiconEvent(wordId, tagId, lemmaId, morphId, i, word, featureStr);
        events.add(event);
      }
    }
    return events;
  }

  private static List<FactoredLexiconEvent> getTuningSet(Treebank devTreebank,
      FactoredLexicon lexicon, TreebankLangParserParams tlpp) {
    List<Tree> devTrees = new ArrayList<Tree>(3000);
    for (Tree tree : devTreebank) {
      for (Tree subTree : tree) {
        if (!subTree.isLeaf()) {
          tlpp.transformTree(subTree, tree);
        }
      }
      devTrees.add(tree);
    }
    List<FactoredLexiconEvent> tuningSet = treebankToLexiconEvents(devTrees, lexicon);
    return tuningSet;
  }


  private static Options getOptions(Language language) {
    Options options = new Options();
    if (language.equals(Language.Arabic)) {
      options.lexOptions.useUnknownWordSignatures = 9;
      options.lexOptions.unknownPrefixSize = 1;
      options.lexOptions.unknownSuffixSize = 1;
      options.lexOptions.uwModelTrainer = "edu.stanford.nlp.parser.lexparser.ArabicUnknownWordModelTrainer";
    } else if (language.equals(Language.French)) {
      options.lexOptions.useUnknownWordSignatures = 1;
      options.lexOptions.unknownPrefixSize = 1;
      options.lexOptions.unknownSuffixSize = 2;
      options.lexOptions.uwModelTrainer = "edu.stanford.nlp.parser.lexparser.FrenchUnknownWordModelTrainer";
    } else {
      throw new UnsupportedOperationException();
    }
    return options;
  }

  /**
   * @param args
   */
  public static void main(String[] args) {
    if (args.length != 4) {
      System.err.printf("Usage: java %s language features train_file dev_file%n", FactoredLexicon.class.getName());
      System.exit(-1);
    }
    // Command line options
    Language language = Language.valueOf(args[0]);
    TreebankLangParserParams tlpp = Languages.getLanguageParams(language);
    Treebank trainTreebank = tlpp.diskTreebank();
    trainTreebank.loadPath(args[2]);
    Treebank devTreebank = tlpp.diskTreebank();
    devTreebank.loadPath(args[3]);
    MorphoFeatureSpecification morphoSpec;
    Options options = getOptions(language);
    if (language.equals(Language.Arabic)) {
      morphoSpec = new ArabicMorphoFeatureSpecification();
      String[] languageOptions = {"-arabicFactored"};
      tlpp.setOptionFlag(languageOptions, 0);
    } else if (language.equals(Language.French)) {
      morphoSpec = new FrenchMorphoFeatureSpecification();
      String[] languageOptions = {"-frenchFactored"};
      tlpp.setOptionFlag(languageOptions, 0);
    } else {
      throw new UnsupportedOperationException();
    }
    String featureList = args[1];
    String[] features = featureList.trim().split(",");
    for (String feature : features) {
      morphoSpec.activate(MorphoFeatureType.valueOf(feature));
    }
    System.out.println("Language: " + language.toString());
    System.out.println("Features: " + args[1]);

    // Create word and tag indices
    // Save trees in a collection since the interface requires that....
    System.out.print("Loading training trees...");
    List<Tree> trainTrees = new ArrayList<Tree>(19000);
    Index<String> wordIndex = new HashIndex<String>();
    Index<String> tagIndex = new HashIndex<String>();
    for (Tree tree : trainTreebank) {
      for (Tree subTree : tree) {
        if (!subTree.isLeaf()) {
          tlpp.transformTree(subTree, tree);
        }
      }
      trainTrees.add(tree);
    }
    System.out.printf("Done! (%d trees)%n", trainTrees.size());

    // Setup and train the lexicon.
    System.out.print("Collecting sufficient statistics for lexicon...");
    FactoredLexicon lexicon = new FactoredLexicon(options, morphoSpec, wordIndex, tagIndex);
    lexicon.initializeTraining(trainTrees.size());
    lexicon.train(trainTrees, null);
    lexicon.finishTraining();
    System.out.println("Done!");
    trainTrees = null;

    // Load the tuning set
    System.out.print("Loading tuning set...");
    List<FactoredLexiconEvent> tuningSet = getTuningSet(devTreebank, lexicon, tlpp);
    System.out.printf("...Done! (%d events)%n", tuningSet.size());

    // Print the probabilities that we obtain
    // TODO(spenceg): Implement tagging accuracy with FactLex
    int nCorrect = 0;
    Counter<String> errors = new ClassicCounter<String>();
    for (FactoredLexiconEvent event : tuningSet) {
      Iterator<IntTaggedWord> itr = lexicon.ruleIteratorByWord(event.word(), event.getLoc(), event.featureStr());
      Counter<Integer> logScores = new ClassicCounter<Integer>();
      boolean noRules = true;
      int goldTagId = -1;
      while (itr.hasNext()) {
        noRules = false;
        IntTaggedWord iTW = itr.next();
        if (iTW.tag() == event.tagId()) {
          System.err.print("GOLD-");
          goldTagId = iTW.tag();
        }
        float tagScore = lexicon.score(iTW, event.getLoc(), event.word(), event.featureStr());
        logScores.incrementCount(iTW.tag(), tagScore);
      }
      if (noRules) {
        System.err.printf("NO TAGGINGS: %s %s%n", event.word(), event.featureStr());
      } else {
        // Score the tagging
        int hypTagId = Counters.argmax(logScores);
        if (hypTagId == goldTagId) {
          ++nCorrect;
        } else {
          String goldTag = goldTagId < 0 ? "UNSEEN" : lexicon.tagIndex.get(goldTagId);
          errors.incrementCount(goldTag);
        }
      }
      System.err.println();
    }

    // Output accuracy
    double acc = (double) nCorrect / (double) tuningSet.size();
    System.err.printf("%n%nACCURACY: %.2f%n%n", acc*100.0);
    System.err.println("% of errors by type:");
    List<String> biggestKeys = new ArrayList<String>(errors.keySet());
    Collections.sort(biggestKeys, Counters.toComparator(errors, false, true));
    Counters.normalize(errors);
    for (String key : biggestKeys) {
      System.err.printf("%s\t%.2f%n", key, errors.getCount(key)*100.0);
    }
  }

}
TOP

Related Classes of edu.stanford.nlp.parser.lexparser.FactoredLexicon

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.