Package edu.stanford.nlp.parser.lexparser

Source Code of edu.stanford.nlp.parser.lexparser.StringBinaryRule

package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.fsm.*;
import edu.stanford.nlp.io.NumberRangeFileFilter;
import edu.stanford.nlp.trees.*;
import edu.stanford.nlp.util.*;
import edu.stanford.nlp.stats.ClassicCounter;
import java.util.*;


/**
* @author Teg Grenager (grenager@cs.stanford.edu)
*/
public class GrammarCompactionTester {

  // for debugging
  //  public static MergeableGraph debugGraph = null;

  ExhaustivePCFGParser parser = null;
  ExhaustiveDependencyParser dparser = null;
  BiLexPCFGParser bparser = null;
  Scorer scorer = null;
  Options op;

  //  TreebankLangParserParams tlpParams = new EnglishTreebankParserParams();
  // tlpParams may be changed to something else later, so don't use it till
  // after options are parsed.


  GrammarCompactor compactor = null;

  Map<String, List<List<String>>> allTestPaths = Generics.newHashMap();
  Map<String, List<List<String>>> allTrainPaths = Generics.newHashMap();

  String asciiOutputPath = null;
  String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj";
  int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219;

  String suffixOrderString = null;
  String minArcNumString = null;
  String maxMergeCostString = null;
  String sizeCutoffString = null;
  String minPortionArcsString = null;
  String ignoreUnsupportedSuffixesString = "false";
  String splitParamString = null;
  String costModelString = null;
  String verboseString = null;
  String minArcCostString = null;
  String trainThresholdString = null;
  String heldoutThresholdString = null;
  int markovOrder = -1;
  String smoothParamString = null;
  String scoringData = null;
  String allowEpsilonsString = null;
  boolean saveGraphs = false;
  private int indexRangeLow;
  private int indexRangeHigh;
  private String outputFile = null;
  private String inputFile = null;
  private boolean toy = false;

  /**
   */
  public Map<String,List<List<String>>> extractPaths(String path, int low, int high, boolean annotate) {

    // setup tree transforms
    Treebank trainTreebank = op.tlpParams.memoryTreebank(); // this is a new one
    TreebankLanguagePack tlp = op.langpack();

    trainTreebank.loadPath(path, new NumberRangeFileFilter(low, high, true));

    if (op.trainOptions.selectiveSplit) {
      op.trainOptions.splitters = ParentAnnotationStats.getSplitCategories(trainTreebank, op.trainOptions.selectiveSplitCutOff, op.tlpParams.treebankLanguagePack());
    }
    if (op.trainOptions.selectivePostSplit) {
      TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams, op);
      Treebank annotatedTB = trainTreebank.transform(myTransformer);
      op.trainOptions.postSplitters = ParentAnnotationStats.getSplitCategories(annotatedTB, op.trainOptions.selectivePostSplitCutOff, op.tlpParams.treebankLanguagePack());
    }

    List<Tree> trainTrees = new ArrayList<Tree>();
    HeadFinder hf = null;
    if (op.trainOptions.leftToRight) {
      hf = new LeftHeadFinder();
    } else {
      hf = op.tlpParams.headFinder();
    }
    TreeTransformer annotator = new TreeAnnotator(hf, op.tlpParams, op);
    for (Tree tree : trainTreebank) {
      if (annotate) {
        tree = annotator.transformTree(tree);
      }
      trainTrees.add(tree);
    }
    Extractor<Map<String,List<List<String>>>> pExtractor = new PathExtractor(hf, op);
    Map<String,List<List<String>>> allPaths = pExtractor.extract(trainTrees);
    return allPaths;
  }


  public static void main(String[] args) {
    new GrammarCompactionTester().runTest(args);
  }

  public void runTest(String[] args) {
    System.out.println("Currently " + new Date());
    System.out.print("Invoked with arguments:");
    for (int i = 0; i < args.length; i++) {
      System.out.print(" " + args[i]);
    }

    System.out.println();

    int i = 0;
    while (i < args.length && args[i].startsWith("-")) {
      if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) {
        path = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-saveToAscii") && (i + 1 < args.length)) {
        asciiOutputPath = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) {
        trainLow = Integer.parseInt(args[i + 1]);
        trainHigh = Integer.parseInt(args[i + 2]);
        i += 3;
      } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) {
        testLow = Integer.parseInt(args[i + 1]);
        testHigh = Integer.parseInt(args[i + 2]);
        i += 3;
      } else if (args[i].equalsIgnoreCase("-index") && (i + 2 < args.length)) {
        indexRangeLow = Integer.parseInt(args[i + 1]);
        indexRangeHigh = Integer.parseInt(args[i + 2]);
        i += 3;
      } else if (args[i].equalsIgnoreCase("-outputFile")) {
        outputFile = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-inputFile")) {
        inputFile = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-suffixOrder")) {
        suffixOrderString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-minArcNum")) {
        minArcNumString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-maxMergeCost")) {
        maxMergeCostString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-sizeCutoff")) {
        sizeCutoffString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-minPortionArcs")) {
        minPortionArcsString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-ignoreUnsupportedSuffixes")) {
        ignoreUnsupportedSuffixesString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-trainThreshold")) {
        trainThresholdString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-heldoutThreshold")) {
        heldoutThresholdString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-minArcCost")) {
        minArcCostString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-splitParam")) {
        splitParamString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-costModel")) {
        costModelString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-scoringData")) {
        scoringData = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-verbose")) {
        verboseString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-allowEpsilons")) {
        allowEpsilonsString = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-saveGraphs")) {
        saveGraphs = true;
        i++;
      } else if (args[i].equalsIgnoreCase("-toy")) {
        toy = true;
        i++;
      } else if (args[i].equalsIgnoreCase("-markovOrder")) {
        markovOrder = Integer.parseInt(args[i + 1]);
        i += 2;
      } else if (args[i].equalsIgnoreCase("-smoothParam")) {
        smoothParamString = args[i + 1];
        i += 2;
      } else {
        i = op.setOptionOrWarn(args, i);
      }
    }

    op.trainOptions.sisterSplitters = Generics.newHashSet(Arrays.asList(op.tlpParams.sisterSplitters()));
    if (op.trainOptions.compactGrammar() == 4) {
      System.out.println("Instantiating fsm.LossyGrammarCompactor");
      try {
        Class[] argTypes = new Class[13];
        Class strClass = String.class;
        for (int j = 0; j < argTypes.length; j++) {
          argTypes[j] = strClass;
        }
        Object[] cArgs = new Object[13];
        cArgs[0] = suffixOrderString;
        cArgs[1] = minArcNumString;
        cArgs[2] = trainThresholdString;
        cArgs[3] = heldoutThresholdString;
        cArgs[4] = sizeCutoffString;
        cArgs[5] = minPortionArcsString;
        cArgs[6] = splitParamString;
        cArgs[7] = ignoreUnsupportedSuffixesString;
        cArgs[8] = minArcCostString;
        cArgs[9] = smoothParamString;
        cArgs[10] = costModelString;
        cArgs[11] = scoringData;
        cArgs[12] = verboseString;
        compactor = (GrammarCompactor) Class.forName("fsm.LossyGrammarCompactor").getConstructor(argTypes).newInstance(cArgs);
      } catch (Exception e) {
        System.err.println("Couldn't instantiate GrammarCompactor: " + e);
        e.printStackTrace();
      }
    } else if (op.trainOptions.compactGrammar() == 5) {
      System.out.println("Instantiating fsm.CategoryMergingGrammarCompactor");
      try {
        Class[] argTypes = new Class[6];
        Class strClass = String.class;
        for (int j = 0; j < argTypes.length; j++) {
          argTypes[j] = strClass;
        }
        Object[] cArgs = new Object[6];
        cArgs[0] = splitParamString;
        cArgs[1] = trainThresholdString;
        cArgs[2] = heldoutThresholdString;
        cArgs[3] = minArcCostString;
        cArgs[4] = ignoreUnsupportedSuffixesString;
        cArgs[5] = smoothParamString;
        compactor = (GrammarCompactor) Class.forName("fsm.CategoryMergingGrammarCompactor").getConstructor(argTypes).newInstance(cArgs);
      } catch (Exception e) {
        throw new RuntimeException("Couldn't instantiate CategoryMergingGrammarCompactor." + e);
      }
    } else if (op.trainOptions.compactGrammar() == 3) {
      System.out.println("Instantiating fsm.ExactGrammarCompactor");
      compactor = new ExactGrammarCompactor(op, saveGraphs, true);
    } else if (op.trainOptions.compactGrammar() > 0) {
    }

    if (markovOrder >= 0) {
      op.trainOptions.markovOrder = markovOrder;
      op.trainOptions.hSelSplit = false;
    }
    if (toy) {
      buildAndCompactToyGrammars();
    } else {
      testGrammarCompaction();
    }
  }

  /*
  private static void testOneAtATimeMerging() {

    // use the parser constructor to extract the grammars from the treebank once
    LexicalizedParser lp = new LexicalizedParser(path, new NumberRangeFileFilter(trainLow, trainHigh, true), tlpParams);

    ParserData pd = lp.parserData();
    Pair originalGrammar = new Pair(pd.ug, pd.bg);

    // extract a bunch of paths
    Timing.startTime();
    System.out.print("Extracting other paths...");
    allTrainPaths = extractPaths(path, trainLow, trainHigh, true);
    allTestPaths = extractPaths(path, testLow, testHigh, true);
    Timing.tick("done");

    List mergePairs = null;
    if (inputFile != null) {
      // read merge pairs from file and do them and parse
      System.out.println("getting pairs from file: " + inputFile);
      mergePairs = getMergePairsFromFile(inputFile);
    }
    // try one merge at a time and parse afterwards
    Numberer originalNumberer = Numberer.getGlobalNumberer("states");
    String header = "index\tmergePair\tmergeCost\tparseF1\n";
    StringUtils.printToFile(outputFile, header, true);

    for (int i = indexRangeLow; i < indexRangeHigh; i++) {

      Timing.startTime();
      Numberer.getNumberers().put("states", originalNumberer);
      if (mergePairs != null)
        System.out.println("passing merge pairs to compactor: " + mergePairs);
      CategoryMergingGrammarCompactor compactor = new CategoryMergingGrammarCompactor(mergePairs, i);
      System.out.println("Compacting grammars with index " + i);
      Pair compactedGrammar = compactor.compactGrammar(originalGrammar, allTrainPaths, allTestPaths);
      Pair mergePair = null;
      double mergeCosts = Double.NEGATIVE_INFINITY;
      List mergeList = compactor.getCompletedMergeList();
      if (mergeList != null && mergeList.size() > 0) {
        mergePair = (Pair) mergeList.get(0);
        mergeCosts = compactor.getActualScores().getCount(mergePair);
      }


      ParserData newPd = new ParserData(pd.lex,
                                        (BinaryGrammar) compactedGrammar.second, (UnaryGrammar) compactedGrammar.first,
                                        pd.dg, pd.numbs, pd.pt);

      lp = new LexicalizedParser(newPd);
      Timing.tick("done.");

      Treebank testTreebank = tlpParams.testMemoryTreebank();
      testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
      System.out.println("Currently " + new Date());
      double f1 = lp.testOnTreebank(testTreebank);
      System.out.println("Currently " + new Date());

      String resultString = i + "\t" + mergePair + "\t" + mergeCosts + "\t" + f1 + "\n";
      StringUtils.printToFile(outputFile, resultString, true);
    }
  }

  private static List getMergePairsFromFile(String filename) {
    List result = new ArrayList();
    try {
      String fileString = StringUtils.slurpFile(new File(filename));
      StringTokenizer st = new StringTokenizer(fileString);
      while (st.hasMoreTokens()) {
        String token1 = st.nextToken();
        if (st.hasMoreTokens()) {
          String token2 = st.nextToken();
          UnorderedPair pair = new UnorderedPair(token1, token2);
          result.add(pair);
        }
      }
    } catch (Exception e) {
      throw new RuntimeException("couldn't access file: " + filename);
    }
    return result;
  }
*/
  /*
//    System.out.println(MergeableGraph.areIsomorphic(graphs[0], graphs[1], graphs[0].getStartNode(), graphs[1].getStartNode()));
//    System.out.println(MergeableGraph.areIsomorphic(graphs[1], graphs[2], graphs[1].getStartNode(), graphs[2].getStartNode()));
//    System.out.println(MergeableGraph.areIsomorphic(graphs[2], graphs[0], graphs[2].getStartNode(), graphs[0].getStartNode()));

  // now go through the grammars themselves and see if they are equal
  System.out.println("UR 0 and 1: " + equalsUnary(((UnaryGrammar)grammars[0].first).rules(),((UnaryGrammar)grammars[1].first).rules()));
  System.out.println("UR 1 and 2: "  + equalsUnary(((UnaryGrammar)grammars[1].first).rules(),((UnaryGrammar)grammars[2].first).rules()));
  System.out.println("UR 2 and 0: "  + equalsUnary(((UnaryGrammar)grammars[2].first).rules(),((UnaryGrammar)grammars[0].first).rules()));

  System.out.println("BR 0 and 1: "  + equalsBinary(((BinaryGrammar)grammars[0].second).rules(),((BinaryGrammar)grammars[1].second).rules()));
  System.out.println("BR 1 and 2: " + equalsBinary(((BinaryGrammar)grammars[1].second).rules(),((BinaryGrammar)grammars[2].second).rules()));
  System.out.println("BR 2 and 0: " + equalsBinary(((BinaryGrammar)grammars[2].second).rules(),((BinaryGrammar)grammars[0].second).rules()));

    System.exit(0);

  // now go through the grammars we made and see if they are equal!
  Set[] unaryRules = new Set[3];
  Set[] binaryRules = new Set[3];
  for (int i=0; i<grammars.length; i++) {
    unaryRules[i] = new HashSet();
    System.out.println(i + " size: " + ((UnaryGrammar)grammars[i].first()).numRules());
    for (Iterator unRuleI = ((UnaryGrammar)grammars[i].first()).iterator(); unRuleI.hasNext();) {
      UnaryRule ur = (UnaryRule) unRuleI.next();
      String parent = (String) stateNumberers[i].object(ur.parent);
      String child = (String) stateNumberers[i].object(ur.child);
      unaryRules[i].add(new StringUnaryRule(parent, child, ur.score));
    }
    binaryRules[i] = new HashSet();
    System.out.println(i + " size: " + ((BinaryGrammar)grammars[i].second()).numRules());
    for (Iterator binRuleI = ((BinaryGrammar)grammars[i].second()).iterator(); binRuleI.hasNext();) {
      BinaryRule br = (BinaryRule) binRuleI.next();
      String parent = (String) stateNumberers[i].object(br.parent);
      String leftChild = (String) stateNumberers[i].object(br.leftChild);
      String rightChild = (String) stateNumberers[i].object(br.rightChild);
      binaryRules[i].add(new StringBinaryRule(parent, leftChild, rightChild, br.score));
    }
  }

  System.out.println("uR 0 and 1: " + equals(unaryRules[0],unaryRules[1]));
  System.out.println("uR 1 and 2: " + equals(unaryRules[1],unaryRules[2]));
  System.out.println("uR 2 and 0: " + equals(unaryRules[2],unaryRules[0]));

  System.out.println("bR 0 and 1: " + equals(binaryRules[0],binaryRules[1]));
  System.out.println("bR 1 and 2: " + equals(binaryRules[1],binaryRules[2]));
  System.out.println("bR 2 and 0: " + equals(binaryRules[2],binaryRules[0]));

}
*/

  /*
    public static void testCategoryMergingProblem() {
      LexicalizedParser lp = new LexicalizedParser(path, new NumberRangeFileFilter(trainLow, trainHigh, true), tlpParams);

      // test it without the change
      Treebank testTreebank = tlpParams.testMemoryTreebank();
      testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
      System.out.println("Currently " + new Date());
      lp.testOnTreebank(testTreebank);
      System.out.println("Currently " + new Date());

      // pull out the rules and consistently change the name of one of the states
      ParserData pd = lp.parserData();
      BinaryGrammar bg = pd.bg;
      UnaryGrammar ug = pd.ug;
      Numberer stateNumberer = Numberer.getGlobalNumberer("states");
      UnaryGrammar newUG = new UnaryGrammar(stateNumberer.total()+1);
      for (Iterator urIter = ug.iterator(); urIter.hasNext();) {
        UnaryRule rule = (UnaryRule) urIter.next();
        rule.parent = changeIfNecessary(rule.parent, stateNumberer);
        rule.child = changeIfNecessary(rule.child, stateNumberer);
        newUG.addRule(rule);
      }
      BinaryGrammar newBG = new BinaryGrammar(stateNumberer.total()+1);
      for (Iterator urIter = bg.iterator(); urIter.hasNext();) {
        BinaryRule rule = (BinaryRule) urIter.next();
        rule.parent = changeIfNecessary(rule.parent, stateNumberer);
        rule.leftChild = changeIfNecessary(rule.leftChild, stateNumberer);
        rule.rightChild = changeIfNecessary(rule.rightChild, stateNumberer);
        newBG.addRule(rule);
      }
      newUG.purgeRules();
      newBG.splitRules();
      pd.ug = newUG;
      pd.bg = newBG;
      lp = new LexicalizedParser(pd);

      // test it with the change
      testTreebank = tlpParams.testMemoryTreebank();
      testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
      System.out.println("Currently " + new Date());
      lp.testOnTreebank(testTreebank);
      System.out.println("Currently " + new Date());
    }
  */

  public Pair<UnaryGrammar, BinaryGrammar> translateAndSort(Pair<UnaryGrammar, BinaryGrammar> grammar, Index<String> oldIndex, Index<String> newIndex) {
    System.out.println("oldIndex.size()" + oldIndex.size() + " newIndex.size()" + newIndex.size());
    UnaryGrammar ug = grammar.first;
    List<UnaryRule> unaryRules = new ArrayList<UnaryRule>();
    for (UnaryRule rule : ug.rules()) {
      rule.parent = translate(rule.parent, oldIndex, newIndex);
      rule.child = translate(rule.child, oldIndex, newIndex);
      unaryRules.add(rule);
    }
    Collections.sort(unaryRules);

    UnaryGrammar newUG = new UnaryGrammar(newIndex);
    for (UnaryRule unaryRule : unaryRules) {
      newUG.addRule(unaryRule);
    }
    newUG.purgeRules();

    BinaryGrammar bg = grammar.second;
    List<BinaryRule> binaryRules = new ArrayList<BinaryRule>();
    for (BinaryRule rule : bg.rules()) {
      rule.parent = translate(rule.parent, oldIndex, newIndex);
      rule.leftChild = translate(rule.leftChild, oldIndex, newIndex);
      rule.rightChild = translate(rule.rightChild, oldIndex, newIndex);
      binaryRules.add(rule);
    }
    Collections.sort(unaryRules);

    BinaryGrammar newBG = new BinaryGrammar(newIndex);
    for (BinaryRule binaryRule : binaryRules) {
      newBG.addRule(binaryRule);
    }
    newBG.splitRules();

    return Generics.newPair(newUG, newBG);
  }

  private static int translate(int i, Index<String> oldIndex, Index<String> newIndex) {
    return newIndex.addToIndex(oldIndex.get(i));
  }

  // WTF is this?
  public int changeIfNecessary(int i, Index<String> n) {
    String s = n.get(i);
    if (s.equals("NP^PP")) {
      System.out.println("changed");
      return n.addToIndex("NP-987928374");
    }
    return i;
  }

  public boolean equalsBinary(List<BinaryRule> l1, List<BinaryRule> l2) {
    // put each into a map to itself
    Map<BinaryRule, BinaryRule> map1 = Generics.newHashMap();
    for (BinaryRule o : l1) {
      map1.put(o, o);
    }
    Map<BinaryRule, BinaryRule> map2 = Generics.newHashMap();
    for (BinaryRule o : l2) {
      map2.put(o, o);
    }
    boolean isEqual = true;
    for (BinaryRule rule1 : map1.keySet()) {
      BinaryRule rule2 = map2.get(rule1);
      if (rule2 == null) {
        System.out.println("no rule for " + rule1);
        isEqual = false;
      } else {
        map2.remove(rule2);
        if (rule1.score != rule2.score) {
          System.out.println(rule1 + " and " + rule2 + " have diff scores");
          isEqual = false;
        }
      }
    }
    System.out.println("left over: " + map2.keySet());
    return isEqual;
  }

  public boolean equalsUnary(List<UnaryRule> l1, List<UnaryRule> l2) {
    // put each into a map to itself
    Map<UnaryRule, UnaryRule> map1 = Generics.newHashMap();
    for (UnaryRule o : l1) {
      map1.put(o, o);
    }
    Map<UnaryRule, UnaryRule> map2 = Generics.newHashMap();
    for (UnaryRule o : l2) {
      map2.put(o, o);
    }
    boolean isEqual = true;
    for (UnaryRule rule1 : map1.keySet()) {
      UnaryRule rule2 = map2.get(rule1);
      if (rule2 == null) {
        System.out.println("no rule for " + rule1);
        isEqual = false;
      } else {
        map2.remove(rule2);
        if (rule1.score != rule2.score) {
          System.out.println(rule1 + " and " + rule2 + " have diff scores");
          isEqual = false;
        }
      }
    }
    System.out.println("left over: " + map2.keySet());
    return isEqual;
  }

  private static <T> boolean equalSets(Set<T> set1, Set<T> set2) {
    boolean isEqual = true;
    if (set1.size() != set2.size()) {
      System.out.println("sizes different: " + set1.size() + " vs. " + set2.size());
      isEqual = false;
    }
    Set<T> newSet1 = (Set<T>) ((HashSet<T>) set1).clone();
    newSet1.removeAll(set2);
    if (newSet1.size() > 0) {
      isEqual = false;
      System.out.println("set1 left with: " + newSet1);
    }
    Set<T> newSet2 = (Set<T>) ((HashSet<T>) set2).clone();
    newSet2.removeAll(set1);
    if (newSet2.size() > 0) {
      isEqual = false;
      System.out.println("set2 left with: " + newSet2);
    }
    return isEqual;
  }

  /*
  public static void testAutomatonCompaction() {
    // make our LossyAutomatonCompactor from the parameters passed at command line
    // now set up the compactor2 constructor args
    // extract a bunch of paths
    Timing.startTime();
    System.out.print("Extracting paths from treebank...");
    allTrainPaths = extractPaths(path, trainLow, trainHigh, false);
    allTestPaths = extractPaths(path, testLow, testHigh, false);
    Timing.tick("done");

    // for each category, construct an automaton and then compact it
    for (Iterator catIter = allTrainPaths.keySet().iterator(); catIter.hasNext();) {
      // construct an automaton from the paths
      String category = (String) catIter.next();
      List trainPaths = (List) allTrainPaths.get(category);
      List testPaths = (List) allTestPaths.get(category);
      if (testPaths == null) testPaths = new ArrayList();
      // now make the graph with the training paths (the LossyAutomatonCompactor will reestimate the weights anyway)
      TransducerGraph graph = TransducerGraph.createGraphFromPaths(trainPaths, 3);
      System.out.println("Created graph for: " + category);

      System.out.println();
      int numArcs1 = graph.getArcs().size();

      LossyAutomatonCompactor compactor = new LossyAutomatonCompactor(3, // horizonOrder, 1 means that only exactly compatible merges are considered
                      0, // min nmber of arcs
                      10000000.0, // maxMergeCost
                      0.5, // splitParam
                      false, //  ignoreUnsupportedSuffixes
                      -1000, // minArcCost
                      trainPaths,
                      testPaths,
                      LossyAutomatonCompactor.DATA_LIKELIHOOD_COST, // costModel
                      false); // verbose

      TransducerGraph result = compactor.compactFA(graph);
      //do we need this?      result = new TransducerGraph(result, ntsp);  // pull out strings from sets returned by minimizer
      int numArcs2 = result.getArcs().size();
      System.out.println("LossyGrammarCompactor compacted "+category+" from " + numArcs1 + " to " + numArcs2 + " arcs");

    }


  }
*/
  private static <T> int numTokens(List<List<T>> paths) {
    int result = 0;
    for (List<T> path : paths) {
      result += path.size();
    }
    return result;
  }

  public void buildAndCompactToyGrammars() {
    // extract a bunch of paths
    System.out.print("Extracting other paths...");
    allTrainPaths = extractPaths(path, trainLow, trainHigh, true);
    TransducerGraph.NodeProcessor ntsp = new TransducerGraph.SetToStringNodeProcessor(new PennTreebankLanguagePack());
    TransducerGraph.NodeProcessor otsp = new TransducerGraph.ObjectToSetNodeProcessor();
    TransducerGraph.ArcProcessor isp = new TransducerGraph.InputSplittingProcessor();
    TransducerGraph.ArcProcessor ocp = new TransducerGraph.OutputCombiningProcessor();
    TransducerGraph.GraphProcessor normalizer = new TransducerGraph.NormalizingGraphProcessor(false);
    TransducerGraph.GraphProcessor quasiDeterminizer = new QuasiDeterminizer();
    AutomatonMinimizer exactMinimizer = new FastExactAutomatonMinimizer();
    for (String key : allTrainPaths.keySet()) {
      System.out.println("creating graph for " + key);
      List<List<String>> paths = allTrainPaths.get(key);
      ClassicCounter<List<String>> pathCounter = new ClassicCounter<List<String>>();
      for (List<String> o : paths) {
        pathCounter.incrementCount(o);
      }
      ClassicCounter<List<String>> newPathCounter = removeLowCountPaths(pathCounter, 2);
      paths.retainAll(newPathCounter.keySet()); // get rid of the low count ones
      TransducerGraph result = TransducerGraph.createGraphFromPaths(newPathCounter, 1000);
      // exact compaction
      int numArcs = result.getArcs().size();
      int numNodes = result.getNodes().size();
      if (numArcs == 0) {
        continue;
      }
      System.out.println("initial graph has " + numArcs + " arcs and " + numNodes + " nodes.");
      GrammarCompactor.writeFile(result, "unminimized", key);
      // do exact minimization
      result = normalizer.processGraph(result); // normalize it so that exact minimization works properly
      result = quasiDeterminizer.processGraph(result); // push probabilities left or down
      result = new TransducerGraph(result, ocp); // combine outputs into inputs
      result = exactMinimizer.minimizeFA(result); // minimize the thing
      result = new TransducerGraph(result, ntsp)// pull out strings from sets returned by minimizer
      result = new TransducerGraph(result, isp); // split outputs from inputs
      numArcs = result.getArcs().size();
      numNodes = result.getNodes().size();

      System.out.println("after exact minimization graph has " + numArcs + " arcs and " + numNodes + " nodes.");
      GrammarCompactor.writeFile(result, "exactminimized", key);

      // do additional lossy minimization
      /*
      NewLossyAutomatonCompactor compactor2 = new NewLossyAutomatonCompactor(paths, true);
      result = compactor2.compactFA(result);
      result = new TransducerGraph(result, ntsp);  // pull out strings from sets returned by minimizer
      numArcs = result.getArcs().size();
      numNodes = result.getNodes().size();

      System.out.println("after lossy minimization graph has " + numArcs + " arcs and " + numNodes + " nodes.");
      GrammarCompactor.writeFile(result, "lossyminimized", key);
      */
    }
  }

  private static ClassicCounter<List<String>> removeLowCountPaths(ClassicCounter<List<String>> paths, double thresh) {
    ClassicCounter<List<String>> result = new ClassicCounter<List<String>>();
    int numRetained = 0;
    for (List<String> path : paths.keySet()) {
      double count = paths.getCount(path);
      if (count >= thresh) {
        result.setCount(path, count);
        numRetained++;
      }
    }
    System.out.println("retained " + numRetained);
    return result;
  }

  public void testGrammarCompaction() {

    // these for testing against the markov 3rd order baseline

    // use the parser constructor to extract the grammars from the treebank
    op = new Options();
    LexicalizedParser lp = LexicalizedParser.trainFromTreebank(path, new NumberRangeFileFilter(trainLow, trainHigh, true), op);

    // compact grammars
    if (compactor != null) {

      // extract a bunch of paths
      Timing.startTime();
      System.out.print("Extracting other paths...");
      allTrainPaths = extractPaths(path, trainLow, trainHigh, true);
      allTestPaths = extractPaths(path, testLow, testHigh, true);
      Timing.tick("done");

      // compact grammars
      Timing.startTime();
      System.out.print("Compacting grammars...");
      Pair<UnaryGrammar, BinaryGrammar> grammar = Generics.newPair(lp.ug, lp.bg);
      Triple<Index<String>, UnaryGrammar, BinaryGrammar> compactedGrammar = compactor.compactGrammar(grammar, allTrainPaths, allTestPaths, lp.stateIndex);
      lp.stateIndex = compactedGrammar.first();
      lp.ug = compactedGrammar.second();
      lp.bg = compactedGrammar.third();

      Timing.tick("done.");
    }

    if (asciiOutputPath != null) {
      lp.saveParserToTextFile(asciiOutputPath);
    }

    // test it
    Treebank testTreebank = op.tlpParams.testMemoryTreebank();
    testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
    System.out.println("Currently " + new Date());
    EvaluateTreebank evaluator = new EvaluateTreebank(lp);
    evaluator.testOnTreebank(testTreebank);
    System.out.println("Currently " + new Date());
  }

}

class StringUnaryRule {
  public String parent;
  public String child;
  public double score;

  @Override
  public boolean equals(Object o) {
    if (this == o) {
      return true;
    }
    if (!(o instanceof StringUnaryRule)) {
      return false;
    }

    final StringUnaryRule stringUnaryRule = (StringUnaryRule) o;

    if (score != stringUnaryRule.score) {
      return false;
    }
    if (child != null ? !child.equals(stringUnaryRule.child) : stringUnaryRule.child != null) {
      return false;
    }
    if (parent != null ? !parent.equals(stringUnaryRule.parent) : stringUnaryRule.parent != null) {
      return false;
    }

    return true;
  }

  @Override
  public int hashCode() {
    int result;
    long temp;
    result = (parent != null ? parent.hashCode() : 0);
    result = 29 * result + (child != null ? child.hashCode() : 0);
    temp = Double.doubleToLongBits(score);
    result = 29 * result + (int) (temp ^ (temp >>> 32));
    return result;
  }

  @Override
  public String toString() {
    return "UR:::::" + parent + ":::::" + child + ":::::" + score;
  }

  public StringUnaryRule(String parent, String child, double score) {
    this.parent = parent;
    this.child = child;
    this.score = score;
  }
}

class StringBinaryRule {
  public String parent;
  public String leftChild;
  public String rightChild;
  public double score;

  @Override
  public boolean equals(Object o) {
    if (this == o) {
      return true;
    }
    if (!(o instanceof StringBinaryRule)) {
      return false;
    }

    final StringBinaryRule stringBinaryRule = (StringBinaryRule) o;

    if (score != stringBinaryRule.score) {
      return false;
    }
    if (leftChild != null ? !leftChild.equals(stringBinaryRule.leftChild) : stringBinaryRule.leftChild != null) {
      return false;
    }
    if (parent != null ? !parent.equals(stringBinaryRule.parent) : stringBinaryRule.parent != null) {
      return false;
    }
    if (rightChild != null ? !rightChild.equals(stringBinaryRule.rightChild) : stringBinaryRule.rightChild != null) {
      return false;
    }

    return true;
  }

  @Override
  public int hashCode() {
    int result;
    long temp;
    result = (parent != null ? parent.hashCode() : 0);
    result = 29 * result + (leftChild != null ? leftChild.hashCode() : 0);
    result = 29 * result + (rightChild != null ? rightChild.hashCode() : 0);
    temp = Double.doubleToLongBits(score);
    result = 29 * result + (int) (temp ^ (temp >>> 32));
    return result;
  }

  @Override
  public String toString() {
    return "BR:::::" + parent + ":::::" + leftChild + ":::::" + rightChild + ":::::" + score;
  }

  public StringBinaryRule(String parent, String leftChild, String rightChild, double score) {
    this.parent = parent;
    this.leftChild = leftChild;
    this.rightChild = rightChild;
    this.score = score;
  }
}
TOP

Related Classes of edu.stanford.nlp.parser.lexparser.StringBinaryRule

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.