Package joshua.decoder.chart_parser

Source Code of joshua.decoder.chart_parser.CubePruneCombiner$CubePruneState

package joshua.decoder.chart_parser;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.PriorityQueue;

import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.state_maintenance.StateComputer;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;

public class CubePruneCombiner implements Combiner{
 
  private List<FeatureFunction> featureFunctions;
  private List<StateComputer> stateComputers;
 
  public CubePruneCombiner(List<FeatureFunction> featureFunctions, List<StateComputer> stateComputers){
    this.featureFunctions = featureFunctions;
    this.stateComputers = stateComputers;
  }
 
  //BUG:???????????????????? CubePrune will depend on relativeThresholdPruning, but  cell.beamPruner can be null ????????????????
 
 
 

  public void addAxioms(Chart chart, Cell cell, int i, int j, List<Rule> rules, SourcePath srcPath) {
    for (Rule rule : rules) {
      addAxiom(chart, cell, i, j, rule, srcPath);
    }
  }



  public void addAxiom(Chart chart, Cell cell, int i, int j, Rule rule, SourcePath srcPath) {
    cell.addHyperEdgeInCell(
        new ComputeNodeResult(this.featureFunctions, rule, null, i, j, srcPath, stateComputers, chart.segmentID),
        rule, i, j, null, srcPath, false);
  }

 
  /** Add complete Items in Chart pruning inside this function */
  // TODO: our implementation do the prunining for each DotItem
  //       under each grammar, not aggregated as in the python
  //       version
  // TODO: the implementation is little bit different from
  //       the description in Liang'2007 ACL paper
  public void combine(Chart chart, Cell cell, int i, int j, List<SuperNode> superNodes, List<Rule> rules, int arity, SourcePath srcPath) {
   
    //combinations: rules, antecent nodes
    //in the paper, combinationHeap is called cand[v]
    PriorityQueue<CubePruneState> combinationHeap =  new PriorityQueue<CubePruneState>();
   
    // rememeber which state has been explored
    HashMap<String,Integer> cubeStateTbl = new HashMap<String,Integer>();
   
    if (null == rules || rules.size() <= 0) {
      return;
    }
   
    //===== seed the heap with best node
    Rule currentRule = rules.get(0);
    List<HGNode> currentAntNodes = new ArrayList<HGNode>();
    for (SuperNode si : superNodes) {
      // TODO: si.nodes must be sorted
      currentAntNodes.add(si.nodes.get(0));
    }
    ComputeNodeResult result =  new ComputeNodeResult(featureFunctions, currentRule, currentAntNodes, i, j, srcPath, stateComputers, chart.segmentID);
   
    int[] ranks = new int[1+superNodes.size()]; // rule, ant items
    for (int d = 0; d < ranks.length; d++) {
      ranks[d] = 1;
    }
   
    CubePruneState bestState =  new CubePruneState(result, ranks, currentRule, currentAntNodes);
    combinationHeap.add(bestState);
    cubeStateTbl.put(bestState.getSignature(),1);
    // cube_state_tbl.put(best_state,1);
   
    //====== extend the heap
    Rule   oldRule = null;
    HGNode oldItem = null;
    int    tem_c   = 0;
    while (combinationHeap.size() > 0) {
     
      //========== decide if the top in the heap should be pruned
      tem_c++;
      CubePruneState curState = combinationHeap.poll();
      currentRule = curState.rule;
      currentAntNodes = new ArrayList<HGNode>(curState.antNodes); // critical to create a new list
      //cube_state_tbl.remove(cur_state.get_signature()); // TODO, repeat
      cell.addHyperEdgeInCell(curState.nodeStatesTbl, curState.rule, i, j, curState.antNodes, srcPath, false); // pre-pruning inside this function
     
      //if the best state is pruned, then all the remaining states should be pruned away
      if (curState.nodeStatesTbl.getExpectedTotalLogP() < cell.beamPruner.getCutoffLogP() - JoshuaConfiguration.fuzz1) {
        //n_prepruned += heap_cands.size();
        chart.nPreprunedFuzz1 += combinationHeap.size();
        break;
      }
     
      //========== extend the curState, and add the candidates into the heap
      for (int k = 0; k < curState.ranks.length; k++) {
       
        //GET new_ranks
        int[] newRanks = new int[curState.ranks.length];
        for (int d = 0; d < curState.ranks.length; d++) {
          newRanks[d] = curState.ranks[d];
        }
        newRanks[k] = curState.ranks[k] + 1;
       
        String new_sig = CubePruneState.getSignature(newRanks);
       
        if (cubeStateTbl.containsKey(new_sig) // explored before
        || (k == 0 && newRanks[k] > rules.size())
        || (k != 0 && newRanks[k] > superNodes.get(k-1).nodes.size())
        ) {
          continue;
        }
       
        if (k == 0) { // slide rule
          oldRule = currentRule;
          currentRule = rules.get(newRanks[k]-1);
        } else { // slide ant
          oldItem = currentAntNodes.get(k-1); // conside k == 0 is rule
          currentAntNodes.set(k-1,
            superNodes.get(k-1).nodes.get(newRanks[k]-1));
        }
       
        CubePruneState tState = new CubePruneState(
            new ComputeNodeResult(featureFunctions, currentRule,
                currentAntNodes, i, j, srcPath, stateComputers, chart.segmentID),
          newRanks, currentRule, currentAntNodes);
       
        // add state into heap
        cubeStateTbl.put(new_sig,1);       
        if (result.getExpectedTotalLogP() > cell.beamPruner.getCutoffLogP() - JoshuaConfiguration.fuzz2) {
          combinationHeap.add(tState);
        } else {
          //n_prepruned += 1;
          chart.nPreprunedFuzz2 += 1;
        }
       
        // recover
        if (k == 0) { // rule
          currentRule = oldRule;
        } else { // ant
          currentAntNodes.set(k-1, oldItem);
        }
      }
    }
   
  }
 
 

 
//  ===============================================================
//   CubePruneState class
//  ===============================================================
    private static class CubePruneState implements Comparable<CubePruneState> {
      int[]             ranks;
      ComputeNodeResult nodeStatesTbl;
      Rule              rule;
      List<HGNode> antNodes;
     
      public CubePruneState(ComputeNodeResult state, int[] ranks, Rule rule,
          List<HGNode> antecedents)
      {
        this.nodeStatesTbl = state;
        this.ranks           = ranks;
        this.rule            = rule;
        // create a new vector is critical, because
        // currentAntecedents will change later
        this.antNodes = new ArrayList<HGNode>(antecedents);
      }
     
     
      private static String getSignature(int[] ranks2) {
        StringBuffer sb = new StringBuffer();
        if (null != ranks2) {
          for (int i = 0; i < ranks2.length; i++) {
            sb.append(' ').append(ranks2[i]);
          }
        }
        return sb.toString();
      }
     
     
      private String getSignature() {
        return getSignature(ranks);
      }
     
     
      /**
       * Compares states by ExpectedTotalLogP, allowing states
       * to be sorted according to their inverse order (high-prob first).
       */
      public int compareTo(CubePruneState another) {
        if (this.nodeStatesTbl.getExpectedTotalLogP() < another.nodeStatesTbl.getExpectedTotalLogP()) {
          return 1;
        } else if (this.nodeStatesTbl.getExpectedTotalLogP() == another.nodeStatesTbl.getExpectedTotalLogP()) {
          return 0;
        } else {
          return -1;
        }
      }
    }
   

}
TOP

Related Classes of joshua.decoder.chart_parser.CubePruneCombiner$CubePruneState

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.