Package edu.stanford.nlp.international.arabic.parsesegment

Source Code of edu.stanford.nlp.international.arabic.parsesegment.JointParsingModel$GenericLatticeScorer

package edu.stanford.nlp.international.arabic.parsesegment;

import java.io.*;
import java.util.*;

import edu.stanford.nlp.ling.CategoryWordTagFactory;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.parser.lexparser.*;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreePrint;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Timing;

public class JointParsingModel {

  private boolean VERBOSE = false;

  //Factored parsing models (Klein and Manning, 2002)
  private static ExhaustivePCFGParser pparser;
  private static ExhaustiveDependencyParser dparser;
  private BiLexPCFGParser bparser;

  //Parser objects
  private Options op;
  private LexicalizedParser lp;
  private TreeTransformer debinarizer;
  private TreeTransformer subcategoryStripper;
  private TreePrint treePrint;
  private static List<CoreLabel> bestSegmentationB;

  private boolean serInput = false;
  private int maxSentLen = 5000;

  private static final int trainLengthLimit = 100000;

  public void setVerbose(boolean b) {
    VERBOSE = b;
    op.testOptions.verbose = b;
    op.trainOptions.printAnnotatedStateCounts = b;
    op.trainOptions.printAnnotatedRuleCounts = b;
  }

  public void setSerInput(boolean ser_input) { serInput = ser_input; }

  public void setMaxEvalSentLen(int maxSentLen) { this.maxSentLen = maxSentLen; }


  private void removeDeleteSplittersFromSplitters(TreebankLanguagePack tlp) {
    if (op.trainOptions.deleteSplitters != null) {
      List<String> deleted = new ArrayList<String>();
      for (String del : op.trainOptions.deleteSplitters) {
        String baseDel = tlp.basicCategory(del);
        boolean checkBasic = del.equals(baseDel);
        for (Iterator<String> it = op.trainOptions.splitters.iterator(); it.hasNext(); ) {
          String elem = it.next();
          String baseElem = tlp.basicCategory(elem);
          boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del);
          if (delStr) {
            it.remove();
            deleted.add(elem);
          }
        }
      }
      if (op.testOptions.verbose) {
        System.err.println("Removed from vertical splitters: " + deleted);
      }
    }
  }

  public List<Tree> getAnnotatedBinaryTreebankFromTreebank(Treebank trainTreebank) {
    TreebankLangParserParams tlpParams = op.tlpParams;
    TreebankLanguagePack tlp = tlpParams.treebankLanguagePack();

    if (VERBOSE) System.err.println("\n\n" + trainTreebank.textualSummary(tlp));

    System.err.print("Binarizing trees...");
    TreeAnnotatorAndBinarizer binarizer = new TreeAnnotatorAndBinarizer(tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op);
    Timing.tick("done.");

    if (op.trainOptions.selectiveSplit) {
      op.trainOptions.splitters = ParentAnnotationStats.getSplitCategories(trainTreebank, op.trainOptions.tagSelectiveSplit, 0, op.trainOptions.selectiveSplitCutOff, op.trainOptions.tagSelectiveSplitCutOff, tlp);
      removeDeleteSplittersFromSplitters(tlp);
      if (op.testOptions.verbose) {
        List<String> list = new ArrayList<String>(op.trainOptions.splitters);
        Collections.sort(list);
        System.err.println("Parent split categories: " + list);
      }
    }
    //    if (op.trainOptions.selectivePostSplit) {
    //      // Do all the transformations once just to learn selective splits on annotated categories
    //      TreeTransformer myTransformer = new TreeAnnotator(tlpParams.headFinder(), tlpParams);
    //      Treebank annotatedTB = trainTreebank.transform(myTransformer);
    //      op.trainOptions.postSplitters = ParentAnnotationStats.getSplitCategories(annotatedTB, true, 0, op.trainOptions.selectivePostSplitCutOff, op.trainOptions.tagSelectivePostSplitCutOff, tlp);
    //      if (op.testOptions.verbose) {
    //        System.err.println("Parent post annotation split categories: " + op.trainOptions.postSplitters);
    //      }
    //    }
    if (op.trainOptions.hSelSplit) {
      // We run through all the trees once just to gather counts for hSelSplit!
      int ptt = op.trainOptions.printTreeTransformations;
      op.trainOptions.printTreeTransformations = 0;
      binarizer.setDoSelectiveSplit(false);
      for (Tree tree : trainTreebank) {
        binarizer.transformTree(tree);
      }
      binarizer.setDoSelectiveSplit(true);
      op.trainOptions.printTreeTransformations = ptt;
    }

    //Tree transformation
    //
    List<Tree> binaryTrainTrees = new ArrayList<Tree>();
    for (Tree tree : trainTreebank) {
      tree = binarizer.transformTree(tree);
      if (tree.yield().size() - 1 <= trainLengthLimit) {
        binaryTrainTrees.add(tree);
      }
    }

    // WSGDEBUG: Lot's of stuff on the grammar
    //    if(VERBOSE) {
    //      binarizer.printStateCounts();
    //      binarizer.printRuleCounts();
    //    binarizer.dumpStats();
    //    }

    return binaryTrainTrees;
  }


  public LexicalizedParser getParserDataFromTreebank(Treebank trainTreebank) {

    System.err.print("Binarizing training trees...");
    List<Tree> binaryTrainTrees = getAnnotatedBinaryTreebankFromTreebank(trainTreebank);
    Timing.tick("done.");

    Index<String> stateIndex = new HashIndex<String>();

    System.err.print("Extracting PCFG...");
    Extractor<Pair<UnaryGrammar,BinaryGrammar>> bgExtractor = new BinaryGrammarExtractor(op, stateIndex);
    Pair<UnaryGrammar,BinaryGrammar> bgug = bgExtractor.extract(binaryTrainTrees);

    BinaryGrammar bg = bgug.second;
    bg.splitRules();

    UnaryGrammar ug = bgug.first;
    ug.purgeRules();
    Timing.tick("done.");

    System.err.print("Extracting Lexicon...");
    Index<String> wordIndex = new HashIndex<String>();
    Index<String> tagIndex = new HashIndex<String>();
    Lexicon lex = op.tlpParams.lex(op, wordIndex, tagIndex);
    lex.initializeTraining(binaryTrainTrees.size());
    lex.train(binaryTrainTrees);
    lex.finishTraining();
    Timing.tick("done.");

    Extractor<DependencyGrammar> dgExtractor = op.tlpParams.dependencyGrammarExtractor(op, wordIndex, tagIndex);
    DependencyGrammar dg = null;
    if (op.doDep) {
      System.err.print("Extracting Dependencies...");
      dg = dgExtractor.extract(binaryTrainTrees);
      dg.setLexicon(lex);
      Timing.tick("done.");
    }

    System.err.println("Done extracting grammars and lexicon.");

    return new LexicalizedParser(lex, bg, ug, dg, stateIndex, wordIndex, tagIndex, op);
  }

  private void makeParsers() {
    if (lp == null)
      throw new RuntimeException(this.getClass().getName() + ": Parser grammar does not exist");

    //a la (Klein and Manning, 2002)
    pparser = new ExhaustivePCFGParser(lp.bg, lp.ug, lp.lex, op, lp.stateIndex, lp.wordIndex, lp.tagIndex);
    dparser = new ExhaustiveDependencyParser(lp.dg, lp.lex, op, lp.wordIndex, lp.tagIndex);
    bparser = new BiLexPCFGParser(new GenericLatticeScorer(), pparser, dparser, lp.bg, lp.ug, lp.dg, lp.lex, op, lp.stateIndex, lp.wordIndex, lp.tagIndex);
  }

  private boolean parse(InputStream inputStream) {
    final LatticeXMLReader reader = new LatticeXMLReader();

    if(!reader.load(inputStream,serInput)) {
      System.err.printf("%s: Error loading input lattice xml from stdin\n", this.getClass().getName());
      return false;
    }

    System.err.printf("%s: Entering main parsing loop...\n", this.getClass().getName());

    int latticeNum = 0;
    int parseable = 0;
    int successes = 0;
    int fParseSucceeded = 0;
    for(final Lattice lattice : reader) {

      if (lattice.getNumNodes() > op.testOptions.maxLength + 1) {  // + 1 for boundary symbol
        System.err.printf("%s: Lattice %d too big! (%d nodes)\n",this.getClass().getName(),latticeNum,lattice.getNumNodes());
        latticeNum++;
        continue;
      }

      parseable++;

      //TODO This doesn't work for what we want. Check the implementation in ExhaustivePCFG parser
      //op.testOptions.constraints = lattice.getConstraints();

      try {
        Tree rawTree = null;
        if(op.doPCFG && pparser.parse(lattice)) {
          rawTree = pparser.getBestParse(); //1best segmentation
          bestSegmentationB = rawTree.yield(new ArrayList<CoreLabel>()); //has boundary symbol

          if(op.doDep && dparser.parse(bestSegmentationB)) {
            System.err.printf("%s: Dependency parse succeeded!\n", this.getClass().getName());
            if(bparser.parse(bestSegmentationB)) {
              System.err.printf("%s: Factored parse succeeded!\n", this.getClass().getName());
              rawTree = bparser.getBestParse();
              fParseSucceeded++;
            }

          } else {
            System.out.printf("%s: Dependency parse failed. Backing off to PCFG...\n", this.getClass().getName());
          }

        } else {
          System.out.printf("%s: WARNING: parsing failed for lattice %d\n", this.getClass().getName(), latticeNum);
        }

        //Post-process the tree
        if (rawTree == null) {
          System.out.printf("%s: WARNING: Could not extract best parse for lattice %d\n", this.getClass().getName(), latticeNum);

        } else {
          Tree t = debinarizer.transformTree(rawTree);
          t = subcategoryStripper.transformTree(t);
          treePrint.printTree(t);

          successes++;
        }

        //When a best parse can't be extracted
      } catch (Exception e) {
        System.out.printf("%s: WARNING: Could not extract best parse for lattice %d\n", this.getClass().getName(), latticeNum);
        e.printStackTrace();
      }

      latticeNum++;
    }

    System.err.println("===================================================================");
    System.err.println("===================================================================");
    System.err.println("Post mortem:");
    System.err.println("  Input:     " + latticeNum);
    System.err.println("  Parseable: " + parseable);
    System.err.println("  Parsed:    " + successes);
    System.err.println("  f_Parsed:  " + fParseSucceeded);
    System.err.println("  String %:  " + (int)((double) successes * 10000.0 / (double) parseable) / 100.0);

    return true;
  }


  public boolean run(File trainTreebankFile, File testTreebankFile, InputStream inputStream) {
    op = new Options();
    op.tlpParams = new ArabicTreebankParserParams();
    op.setOptions("-arabicFactored");
    op.testOptions.maxLength = maxSentLen;
    op.testOptions.MAX_ITEMS = 5000000; //500000 is the default for Arabic, but we have substantially more edges now
    op.testOptions.outputFormatOptions = "removeTopBracket,includePunctuationDependencies";

    // WSG: Just set this to some high value so that extractBestParse()
    // actually calls the lattice reader (e.g., this says that we can't have a word longer than
    // 80 characters...seems sensible for Arabic
    op.testOptions.maxSpanForTags = 80;

    treePrint = op.testOptions.treePrint(op.tlpParams);
    debinarizer = new Debinarizer(op.forceCNF, new CategoryWordTagFactory());
    subcategoryStripper = op.tlpParams.subcategoryStripper();

    Timing.startTime();

    final Treebank trainTreebank = op.tlpParams.diskTreebank();
    trainTreebank.loadPath(trainTreebankFile);

    lp = getParserDataFromTreebank(trainTreebank);

    makeParsers();

    if (VERBOSE) {
      op.display();
      String lexNumRules = (pparser != null) ? Integer.toString(lp.lex.numRules()): "";
      System.err.println("Grammar\tStates\tTags\tWords\tUnaryR\tBinaryR\tTaggings");
      System.err.println("Grammar\t" +
                         lp.stateIndex.size() + '\t' +
                         lp.tagIndex.size() + '\t' +
                         lp.wordIndex.size() + '\t' +
                         (pparser != null ? lp.ug.numRules(): "") + '\t' +
                         (pparser != null ? lp.bg.numRules(): "") + '\t' +
                         lexNumRules);
      System.err.println("ParserPack is " + op.tlpParams.getClass().getName());
      System.err.println("Lexicon is " + lp.lex.getClass().getName());
    }

    return parse(inputStream);
  }


  /*
   * pparser chart uses segmentation interstices; dparser uses 1best word
   * interstices. Convert between the two here for bparser.
   */
  private static class GenericLatticeScorer implements LatticeScorer {

    public Item convertItemSpan(Item item) {
      if(bestSegmentationB == null || bestSegmentationB.size() == 0)
        throw new RuntimeException(this.getClass().getName() + ": No 1best segmentation available");

      item.start = bestSegmentationB.get(item.start).beginPosition();
      item.end = bestSegmentationB.get(item.end - 1).endPosition();
      return item;
    }

    public double oScore(Edge edge) {
      final Edge latticeEdge = (Edge) convertItemSpan(new Edge(edge));

      double pOscore = pparser.oScore(latticeEdge);
      double dOscore = dparser.oScore(edge);

      return pOscore + dOscore;
    }

    public double iScore(Edge edge) {
      final Edge latticeEdge = (Edge) convertItemSpan(new Edge(edge));

      double pIscore = pparser.iScore(latticeEdge);
      double dIscore = dparser.iScore(edge);

      return pIscore + dIscore;
    }

    public boolean oPossible(Hook hook) {
      final Hook latticeHook = (Hook) convertItemSpan(new Hook(hook));

      return pparser.oPossible(latticeHook) && dparser.oPossible(hook);
    }

    public boolean iPossible(Hook hook) {
      final Hook latticeHook = (Hook) convertItemSpan(new Hook(hook));

      return pparser.iPossible(latticeHook) && dparser.iPossible(hook);
    }

    public boolean parse(List<? extends HasWord> words) {
      throw new UnsupportedOperationException(this.getClass().getName() + ": Does not support parse operation.");
    }
  }

}
TOP

Related Classes of edu.stanford.nlp.international.arabic.parsesegment.JointParsingModel$GenericLatticeScorer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.