Package joshua.discriminative.training.oracle

Source Code of joshua.discriminative.training.oracle.ApproximateFilterHGByOneString

package joshua.discriminative.training.oracle;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.state_maintenance.NgramDPState;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;


/* Given a general hypergraph and a string (known to be contained in the hypergraph), this
* class returns a filtered hypergraph that contains only the derivations yielding that string
* The method is not exact, in other words, it may actually contains derivations that do not yield the given string
* Or, the recall (of the derivations we are seeking) is perfect, but the precision is not
* */


/**This programm assume the transition_cost in the input hypergraph is properly set!!!!!!!!!!!!!!!!!
* */

public class ApproximateFilterHGByOneString {
 
  HashMap<HGNode, HGNode> tbl_processed_items =  new HashMap<HGNode, HGNode>();//key: item; value: true if the item should be filtered out
  HashMap<Integer, ArrayList<Integer>> tbl_unigram_start_pos = new HashMap<Integer, ArrayList<Integer>>();//ref word position is indexed from zero
  int[] ref_sent_wrds = null;
  //String ref_sent_string = null;
  SymbolTable p_symbol;
  int ngramStateID = 0;//TODO
  int baseline_lm_order = 3;//TODO
 
  public ApproximateFilterHGByOneString(SymbolTable symbol_, int lm_feat_id_, int baseline_lm_order_ ){
    p_symbol = symbol_;
    ngramStateID = lm_feat_id_;
    baseline_lm_order = baseline_lm_order_;
  }
 
 
//############### filter hg ##### 
  public HyperGraph approximate_filter_hg(HyperGraph hg, int[] ref_sent_wrds_in){  
    tbl_processed_items.clear();
    tbl_unigram_start_pos.clear();
    ref_sent_wrds = ref_sent_wrds_in;
   
    //debug
    /*StringBuffer tem = new StringBuffer();
    for(int i=0; i<ref_sent_wrds.length; i++){
       tem.append(ref_sent_wrds[i]);
       if(i!=ref_sent_wrds.length-1) tem.append(" ");
    }
    ref_sent_string = tem.toString();*/
    //end
   
    create_start_pos_index(ref_sent_wrds);
   
    System.out.println("goal_item has " + hg.goalNode.hyperedges.size() );
    HGNode new_goal_node = filter_item(hg.goalNode);
    if(new_goal_node==null){
      System.out.println("The hypergraph has been fully filterd out, must be wrong");
      System.exit(0);
    }
    System.out.println("Finished filtering the hypergraph");
    HyperGraph filtered_hg =  new HyperGraph(new_goal_node, -1, -1, hg.sentID, hg.sentLen);
    return filtered_hg;//filtered hg
 
 
  public void clear_state(){
    tbl_processed_items.clear();
    tbl_unigram_start_pos.clear();
  }
 
 
 
  //return true if the node should be filtered out
  private HGNode filter_item(HGNode it){
    if(tbl_processed_items.containsKey(it))//processed before
      return tbl_processed_items.get(it);
    else{
      boolean should_filter = true;//default: filter out
     
      ArrayList<HyperEdge> l_survive_deductions= new   ArrayList<HyperEdge>();
      HyperEdge best_survive_deduction = null;
     
      //### recursive call on each deduction, and update l_deductions and best_deduction
      for(HyperEdge dt : it.hyperedges){
        HyperEdge new_edge = filter_deduction(dt);//deduction-specifc operation
        if(new_edge!=null){//survive
          should_filter = false;//the item survives as long as at least one hyperedge survives
          l_survive_deductions.add(new_edge);
          if(best_survive_deduction ==null || best_survive_deduction.bestDerivationLogP>new_edge.bestDerivationLogP)
            best_survive_deduction = new_edge;
        }
      }
     
      HGNode new_node = null;
      if(should_filter==false){//survive
        new_node = new HGNode(it.i, it.j, it.lhs, l_survive_deductions, best_survive_deduction, it.getDPStates());
      }

      tbl_processed_items.put(it,new_node);
      //System.out.println("res=" + res+ "; old_n_deducts=" + old_num_deductions + "; new="+ it.l_deductions.size());
      return new_node;
    }
  } 
 
  //return true if the hyperedge should be filtered out
  private HyperEdge filter_deduction(HyperEdge dt){
    //### first see if the combinations of rule with the items yield valid string
    if(filter_deduction_helper(dt)==true){//filter out
      return null;
    }else{
      boolean should_filter = false;//survive
      //### recursive call on each ant item, to see if each item yield valid string
      //make sure if the transition_cost is properly set in the disk hypergraph; Also, we cannot force the function to compute the transition cost!!!!!
      double bestLogP = dt.getTransitionLogP(false);
      ArrayList<HGNode> l_ant_items = null;
      if(dt.getAntNodes()!=null){
        l_ant_items = new ArrayList<HGNode>();
        for(HGNode ant_it : dt.getAntNodes()){
          HGNode new_node = filter_item(ant_it);
          if(new_node==null){//the edge should be filtered out as long as one item is filtered out
            should_filter = true;
            break;
          }else{
            bestLogP += new_node.bestHyperedge.bestDerivationLogP;
            l_ant_items.add(new_node);
          }
        }
      }
     
      HyperEdge new_edge = null;
      if(should_filter==false){
        new_edge = new HyperEdge(dt.getRule(), bestLogP, dt.getTransitionLogP(false), l_ant_items, null);
      }
      //System.out.println("hyperEdge filter="+res);
      return new_edge;
    }
  }
 
 
//   return true if the hyperedge should be filtered out
  // non-recursive
  private boolean filter_deduction_helper(HyperEdge dt){
    Rule rl = dt.getRule();
    if(rl!=null){//not hyperedges under goal item
      int[] en_words = dt.getRule().getEnglish();
      ArrayList<Integer> words= new ArrayList<Integer>();//a continous sequence of words due to combination; the sequence stops whenever the right-lm-state jumps in (i.e., having eclipsed words)       
      for(int c=0; c<en_words.length; c++){
          int c_id = en_words[c];
          if(p_symbol.isNonterminal(c_id)==true){
            //## get the left and right context
            int index= p_symbol.getTargetNonterminalIndex(c_id);
            HGNode ant_item = (HGNode) dt.getAntNodes().get(index);
            NgramDPState state     = (NgramDPState) ant_item.getDPState(this.ngramStateID);
          List<Integer>   l_context = state.getLeftLMStateWords();
          List<Integer>   r_context = state.getRightLMStateWords();
          if (l_context.size() != r_context.size()) {
            System.out.println("LMModel>>lookup_words1_equv_state: left and right contexts have unequal lengths");
            System.exit(1);
          }
         
          for(int t : l_context)//always have l_context
              words.add(t);               
           
            if(r_context.size()>=baseline_lm_order-1){//the right and left are NOT overlapping
              if(match_a_span(words)==false)//no match
                return true;//filter out
                   
              words.clear();//start a new chunk; the sequence stops whenever the right-lm-state jumps in (i.e., having eclipsed words) 
              for(int t : r_context)
                words.add(t);           
            }
          }else{
            words.add(c_id);
          }
        }
      if(words.size()>0){
        if(match_a_span(words)==false)//no match
          return true;//filter out
      }
    }else{//hyperedges under goal item
      if(dt.getAntNodes().size()!=1){System.out.println("error deduction under goal item have more than one item"); System.exit(0);}
      HGNode ant_item = (HGNode) dt.getAntNodes().get(0);
      NgramDPState state     = (NgramDPState) ant_item.getDPState(this.ngramStateID);
      List<Integer>   l_context = state.getLeftLMStateWords();
      List<Integer>   r_context = state.getRightLMStateWords();
      if(matchLeftOrRightMostSpan(l_context, true)==false ||
         matchLeftOrRightMostSpan(r_context, false)==false )//the left-most or right-most word does not match
        return true;
    }
    return false;
  }
 
  protected  boolean matchLeftOrRightMostSpan(List<Integer> words, boolean is_left_most){
    boolean res =true;
    if(words.size()>ref_sent_wrds.length)
      res = false;
    else{
      for(int j=0; j<words.size(); j++){
        if(is_left_most){
          if(words.get(j)!=ref_sent_wrds[j]) {res=false; break;}
        }else{//right most
          if(words.get(j)!=ref_sent_wrds[ref_sent_wrds.length-words.size()+j]) {res=false; break;}
        }
      }
    }
    //System.out.println("string: " + p_symbol.getWords(words) + "; res=" + res + "; left=" + is_left_most);
   
    /*if(match_left_or_right_most_span2(words,is_left_most)!=res){
      System.out.println("match is not correct");
      System.exit(0);
    }*/
    return res;
  }
 
 
 
  //This procedure check wheter l_words matches span in references
  protected  boolean match_a_span(ArrayList<Integer> l_words){
    boolean res = false; //no match
    ArrayList<Integer> start_positions = tbl_unigram_start_pos.get(l_words.get(0));
    if(start_positions!=null){
      for(int start_pos : start_positions){//for each possible span
        boolean is_match = true;
        if(start_pos+l_words.size()>ref_sent_wrds.length){
          is_match = false;
        }else{
          for(int j=0; j<l_words.size(); j++){     
            if(l_words.get(j)!=ref_sent_wrds[start_pos+j]){
              is_match = false;
              break;
            }
          }
        }
       
        if(is_match == true){//at least matching one span
          res = true;
          break;
        }
      } 
    }
    //System.out.println("string: " + p_symbol.getWords(l_words) + "; res=" + res);
    /*if(match_a_span2(l_words)!=res){
      System.out.println("match is not correct");
      System.exit(0);
    }*/
    return res;
  }
 
  /*
  protected  boolean match_a_span2(ArrayList l_words){
    StringBuffer tem = new StringBuffer();
    for(int i=0; i<l_words.size(); i++){
       tem.append(l_words.get(i));
       if(i!=l_words.size()-1) tem.append(" ");
    }
    if(ref_sent_string.indexOf(tem.toString())==-1)
      return false;
    else
      return true;   
  }
  */
 
  /*
  //this function is not right; the length in a string is in termos of number of unicode-16, while a chinese character may costs 32bit
  protected  boolean match_left_or_right_most_span2(int[] words, boolean is_left_most){
    StringBuffer tem = new StringBuffer();
    for(int i=0; i<words.length; i++){
       tem.append(words[i]);
       if(i!=words.length-1) tem.append(" ");
    }
    if(is_left_most==true && ref_sent_string.indexOf(tem.toString())!=0)
      return false;
   
    if(is_left_most==false && ref_sent_string.lastIndexOf(tem.toString())!=(ref_sent_string.length()-words.length)){
      System.out.println("val0="+ ref_sent_string.length() +"; va11=" + ref_sent_string.lastIndexOf(tem.toString()) + "; val2="+ (ref_sent_string.length()-words.length));
      return false;
    }
    return true;   
  }
  */
  private  void create_start_pos_index(int[] ref_sent_wrds_in){
    tbl_unigram_start_pos.clear();//ref word position is indexed from zero
    for(int i=0; i<ref_sent_wrds_in.length; i++){
      ArrayList<Integer> l_start_pos = tbl_unigram_start_pos.get(ref_sent_wrds_in[i]);
      if(l_start_pos==null){
        l_start_pos = new ArrayList<Integer>();
        l_start_pos.add(i);
        tbl_unigram_start_pos.put(ref_sent_wrds_in[i], l_start_pos);
      }else
        l_start_pos.add(i);
    }
    //System.out.println("tbl is: " + tbl_unigram_start_pos.toString());
  }
   
}
TOP

Related Classes of joshua.discriminative.training.oracle.ApproximateFilterHGByOneString

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.