Package joshua.discriminative.training.oracle

Source Code of joshua.discriminative.training.oracle.OracleExtractionOnHGV3

package joshua.discriminative.training.oracle;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;

import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.BLEU;
import joshua.decoder.Support;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;


public class OracleExtractionOnHGV3 extends RefineHG<DPStateOracle> {
 
  SymbolTable symbolTable;
 
  protected  int srcSentLen =0;
  EquivLMState equiv;
 
  static protected boolean doLocalNgramClip =false;
  static protected boolean maitainLengthState = false;
  static protected  int bleuOrder = 4;
 
 
  //== the way to compute effective reference length
  boolean useShortestRef = false; //TODO
 
  //derived from reference sentence
  protected HashMap<String, Integer> refNgramsTbl = new HashMap<String, Integer>();
  protected double refSentLen =0;
 

 
  public OracleExtractionOnHGV3(SymbolTable symbolTable ) {
    this.symbolTable = symbolTable;
   
    this.equiv = new EquivLMState(this.symbolTable, bleuOrder);
  }

 
//  find the oracle hypothesis in the nbest list
  public Object[] oracleExtractOnNbest(KBestExtractor kbest_extractor, HyperGraph hg, int n, boolean do_ngram_clip, String ref_sent){
   
    if(hg.goalNode==null)
      return null;
    kbest_extractor.resetState();       
    int next_n=0;
    double orc_bleu=-1;
    String orc_sent=null;
    while(true){
      String hyp_sent = kbest_extractor.getKthHyp(hg.goalNode, ++next_n, -1, null, null);//?????????
      //System.out.println(hyp_sent);
      if(hyp_sent==null || next_n > n) break;
      double t_bleu = computeSentenceBleu(this.symbolTable, ref_sent, hyp_sent, do_ngram_clip, 4);
      if(t_bleu>orc_bleu){
        orc_bleu = t_bleu;
        orc_sent = hyp_sent;
      }     
    }
    System.out.println("Oracle sent in nbest: " + orc_sent);
    System.out.println("Oracle bleu in nbest: " + orc_bleu);
    Object[] res = new Object[2];
    res[0]=orc_sent;
    res[1]=orc_bleu;
    return res;
  }
 
 
  public HyperGraph oracleExtractOnHG(HyperGraph hg, int srcSentLenIn,  int baselineLMOrder, String refSentStr){
//TODO: baselineLMOrder
   
    srcSentLen = srcSentLenIn;
   
    //== ref tbL and effective ref len
    int[] refWords = this.symbolTable.addTerminals(refSentStr.split("\\s+"));
   
    refSentLen = refWords.length;       
    refNgramsTbl.clear();
    getNgrams(refNgramsTbl, bleuOrder, refWords, false);
   
    equiv.setupPrefixAndSurfixTbl(refNgramsTbl);
   
    HyperGraph res= splitHG(hg);
   
    return res;
  }
 
  public HyperGraph oracleExtractOnHG(HyperGraph hg, int srcSentLenIn,  int baselineLMOrder, String[] refSentStrs){
//    TODO: baselineLMOrder
       
    srcSentLen = srcSentLenIn;
    //== ref tbL and effective ref len
    int[] refLens = new int[refSentStrs.length];
    ArrayList<HashMap<String, Integer>> listRefNgramTbl = new ArrayList<HashMap<String, Integer>>();
   
    for(int i =0; i<refSentStrs.length; i++){
      int[] refWords = this.symbolTable.addTerminals(refSentStrs[i].split("\\s+"));
      refLens[i] = refWords.length;   
      HashMap<String, Integer> tRefNgramsTbl = new HashMap<String, Integer>();
      getNgrams(tRefNgramsTbl, bleuOrder, refWords, false)
      listRefNgramTbl.add(tRefNgramsTbl);     
    }
    refSentLen = BLEU.computeEffectiveLen(refLens, useShortestRef);
    refNgramsTbl =  BLEU.computeMaxRefCountTbl(listRefNgramTbl);
   
    equiv.setupPrefixAndSurfixTbl(refNgramsTbl);
   
    HyperGraph res= splitHG(hg);
   
    return res;
  }
     
 
   
 
  private double computeAvgLen(int spanLen, int srcSentLen, double refSentLen){
    return (spanLen>=srcSentLen) ? refSentLen :  spanLen*refSentLen*1.0/srcSentLen;//avg len?
  }


 

  @Override
  protected HyperEdge createNewHyperEdge(HyperEdge originalEdge, List<HGNode> antVirtualItems, DPStateOracle dps) {
    /**compared wit the original edge, two changes:
     * (1) change the list of ant nodes
     * (2) change the transition logProb to BLEU from Model logProb
     * */
    return new HyperEdge(originalEdge.getRule(), dps.bestDerivationLogP, null, antVirtualItems, originalEdge.getSourcePath());
  }

 


  /*This procedure does
   * (1) create a new hyperedge (based on curEdge and ant_virtual_item)
   * (2) find whether an Item can contain this hyperedge (based on virtualItemSigs which is a hashmap specific to a parent_item)
   *   (2.1) if yes, add the hyperedge,
   *  (2.2) otherwise
   *    (2.2.1) create a new item
   *    (2.2.2) and add the item into virtualItemSigs
   **/

  protected  DPStateOracle computeState(HGNode originalParentItem, HyperEdge originalEdge, List<HGNode> antVirtualItems){
   
    double refLen = computeAvgLen(originalParentItem.j-originalParentItem.i, srcSentLen, refSentLen);
   
    //=== hypereges under "goal item" does not have rule
    if(originalEdge.getRule()==null){
      if(antVirtualItems.size()!=1){
        System.out.println("error deduction under goal item have more than one item");
        System.exit(0);
      }
      double bleu = antVirtualItems.get(0).bestHyperedge.bestDerivationLogP;
      return  new DPStateOracle(0, null, null, null, bleu);//no DPState at all
      
    }
   
   
    ComputeOracleStateResult lmState = computeLMState(originalEdge, antVirtualItems)
     
    int hypLen = lmState.numNewWordsAtEdge;
   
    int[] numNgramMatches = new int[bleuOrder];
   
    Iterator iter = lmState.newNgramsTbl.keySet().iterator();
      while(iter.hasNext()){
        String ngram = (String)iter.next();
       
      int finalCount =  lmState.newNgramsTbl.get(ngram);
      if(doLocalNgramClip)
        numNgramMatches[ngram.split("\\s+").length-1] += Support.findMin(finalCount, refNgramsTbl.get(ngram)) ;
      else
        numNgramMatches[ngram.split("\\s+").length-1] += finalCount; //do not do any cliping         
      }
   
      if(antVirtualItems!=null){
      for(int i=0; i<antVirtualItems.size(); i++){
        DPStateOracle antDPState = (DPStateOracle)((RefinedNode)antVirtualItems.get(i)).dpState;         
        hypLen += antDPState.bestLen;
        for(int t=0; t<bleuOrder; t++)
          numNgramMatches[t] += antDPState.ngramMatches[t];
      }
      }
    double bleu = computeBleu(hypLen, refLen, numNgramMatches, bleuOrder);

    return  new DPStateOracle(hypLen, numNgramMatches, lmState.leftEdgeWords, lmState.rightEdgeWords, bleu);
  }
 

  protected  ComputeOracleStateResult computeLMState(HyperEdge dt, List<HGNode> antVirtualItems){ 
   
    //=== hypereges under "goal item" does not have rule
    if(dt.getRule()==null){
      /**TODO: we did not consider <s> and </s> here
       * */
      HashMap<String, Integer> finalNewGramCounts = null;
      int hypLen = 0;
      return  new ComputeOracleStateResult(finalNewGramCounts, null, null, hypLen);
    }
   
    //======== hypereges *not* under "goal item"   
    HashMap<String, Integer> newGramCounts = new HashMap<String, Integer>();//new ngrams created due to the combination
    HashMap<String, Integer> oldNgramCounts = new HashMap<String, Integer>();//the ngram that has already been computed
    int hypLen =0;
    int[] enWords = dt.getRule().getEnglish();
   
   
      List<Integer> words= new ArrayList<Integer>();
      List<Integer> leftStateSequence =  new ArrayList<Integer>();
      List<Integer> rightStateSequence = new ArrayList<Integer>();
 
     
      //==== get left_state_sequence, right_state_sequence, total_hyp_len, num_ngram_match
      for(int c=0; c<enWords.length; c++){
       
        int word = enWords[c];
        if(symbolTable.isNonterminal(word)==true){
         
          int index = this.symbolTable.getTargetNonterminalIndex(word);
         
          DPStateOracle antState = (DPStateOracle)((RefinedNode)antVirtualItems.get(index)).dpState;//TODO         
                      
          List<Integer> leftContext = antState.leftLMState;
          List<Integer> rightContext = antState.rightLMState;
         
          for(int t : leftContext){//always have l_context
            words.add(t);
            if(leftStateSequence!=null && leftStateSequence.size()<bleuOrder-1)
              leftStateSequence.add(t);
          }
          getNgrams(oldNgramCounts, bleuOrder, leftContext, true);  
         
          if(rightContext.size()>=bleuOrder-1){//the right and left are NOT overlapping       
            getNgrams(newGramCounts, bleuOrder, words, true);
            getNgrams(oldNgramCounts, bleuOrder, rightContext, true);
            words.clear();//start a new chunk   
            if(rightStateSequence!=null)
              rightStateSequence.clear();
            for(int t : rightContext)
              words.add(t);           
          }
          if(rightStateSequence!=null)
            for(int t : rightContext)
              rightStateSequence.add(t);
        }else{
          words.add(word);
          hypLen++;
          if(leftStateSequence!=null && leftStateSequence.size()<bleuOrder-1)
            leftStateSequence.add(word);
          if(rightStateSequence!=null)
            rightStateSequence.add(word);
        }
      }
      getNgrams(newGramCounts, bleuOrder, words, true);
   
      //=== now deduct ngram counts
      HashMap<String, Integer> finalNewGramCounts = new HashMap<String, Integer>();
      Iterator iter = newGramCounts.keySet().iterator();
      //System.out.println("new size= " + newGramCounts.size());
      //System.out.println("old size= " + oldNgramCounts.size());
      while(iter.hasNext()){
     
        String ngram = (String)iter.next();
        if(refNgramsTbl.containsKey(ngram)){//TODO
          int finalCount = newGramCounts.get(ngram);
          if(oldNgramCounts.containsKey(ngram)){
            finalCount -= oldNgramCounts.get(ngram);
            if(finalCount<0){
              System.out.println("error: negative count for ngram: "+ this.symbolTable.getWord(11844) + "; new: " + newGramCounts.get(ngram) +"; old: " +oldNgramCounts.get(ngram) );
              System.exit(0);
            }
          }
          if(finalCount>0){
            finalNewGramCounts.put(ngram, finalCount);             
          }
        }
      }
     
      List<Integer> leftLMState = equiv.getLeftEquivState(leftStateSequence);
      List<Integer> rightLMState =  equiv.getRightEquivState(rightStateSequence);    
   
      ComputeOracleStateResult res = new  ComputeOracleStateResult(finalNewGramCounts, leftLMState, rightLMState, hypLen);
      //res.printInfo();
     
    return  res;
  }
 
 
 
 
 
 
//  =================================================================================================
  //==================== BLEU-related functions ==========================================
  //=================================================================================================
  //TODO: consider merge with joshua.decoder.BLEU
 
  //do_ngram_clip: consider global n-gram clip
  public  double computeSentenceBleu(SymbolTable p_symbol, String ref_sent, String hyp_sent, boolean do_ngram_clip, int bleu_order){
    int[] numeric_ref_sent = p_symbol.addTerminals(ref_sent.split("\\s+"));
    int[] numeric_hyp_sent = p_symbol.addTerminals(hyp_sent.split("\\s+"));
    return computeSentenceBleu(numeric_ref_sent, numeric_hyp_sent, do_ngram_clip, bleu_order);   
  }
 
 
  public  double computeSentenceBleu( int[] ref_sent, int[] hyp_sent, boolean do_ngram_clip, int bleu_order){
    double res_bleu = 0;
    int order =4;
    HashMap<String, Integer>  ref_ngram_tbl = new HashMap<String, Integer> ();
    getNgrams(ref_ngram_tbl, order, ref_sent,false);
    HashMap<String, Integer>  hyp_ngram_tbl = new HashMap<String, Integer> ();
    getNgrams(hyp_ngram_tbl, order, hyp_sent,false);
   
    int[] num_ngram_match = new int[order];
    for(Iterator it = hyp_ngram_tbl.keySet().iterator(); it.hasNext();){
      String ngram = (String) it.next();
      if(ref_ngram_tbl.containsKey(ngram)){
        if(do_ngram_clip)
          num_ngram_match[ngram.split("\\s+").length-1] += Support.findMin(ref_ngram_tbl.get(ngram),hyp_ngram_tbl.get(ngram)); //ngram clip
        else
          num_ngram_match[ngram.split("\\s+").length-1] += hyp_ngram_tbl.get(ngram);//without ngram count clipping         
        }
    }
    res_bleu = computeBleu(hyp_sent.length, ref_sent.length, num_ngram_match, bleu_order);
    //System.out.println("hyp_len: " + hyp_sent.length + "; ref_len:" + ref_sent.length + "; bleu: " + res_bleu +" num_ngram_matches: " + num_ngram_match[0] + " " +num_ngram_match[1]+
    //    " " + num_ngram_match[2] + " " +num_ngram_match[3]);

    return res_bleu;
  }
   
 
//  sentence-bleu: BLEU= bp * prec; where prec = exp (sum 1/4 * log(prec[order]))
  public static double computeBleu(int hypLen, double refLen, int[] numNgramMatches, int bleuOrder){
    if(hypLen<=0 || refLen<=0){
      System.out.println("error: ref or hyp is zero len");
      System.exit(0);
    }
   
    double res=0;   
    double wt = 1.0/bleuOrder;
    double prec = 0;
    double smoothFactor=1.0;
    for(int t=0; t<bleuOrder && t<hypLen; t++){
      if(numNgramMatches[t]>0)
        prec += wt*Math.log(numNgramMatches[t]*1.0/(hypLen-t));
      else{
        smoothFactor *= 0.5;//TODO
        prec += wt*Math.log(smoothFactor/(hypLen-t));
      }
    }
   
    double bp = (hypLen>=refLen) ? 1.0 : Math.exp(1-refLen/hypLen)
   
    res = bp*Math.exp(prec);
    //System.out.println("hyp_len: " + hyp_len + "; ref_len:" + ref_len + "prec: " + Math.exp(prec) + "; bp: " + bp + "; bleu: " + res);
    return res;
  }




  //=================================================================================================
  //==================== ngram extraction functions ==========================================
  //=================================================================================================
  protected void getNgrams(HashMap<String, Integer> tbl, int order, int[] wrds, boolean ignoreNullEquivSymbol){
   
    for(int i=0; i<wrds.length; i++)
      for(int j=0; j<order && j+i<wrds.length; j++){//ngram: [i,i+j]
        boolean contain_null=false;
        StringBuffer ngram = new StringBuffer();
        for(int k=i; k<=i+j; k++){
         
          ngram.append(wrds[k]);
          if(k<i+j) ngram.append(" ");
        }
        if(ignoreNullEquivSymbol && contain_null)
          continue;//skip this ngram
        String ngram_str = ngram.toString();
        if(tbl.containsKey(ngram_str))
          tbl.put(ngram_str,  tbl.get(ngram_str)+1);
        else
          tbl.put(ngram_str, 1);
      }
  }
 
//  accumulate ngram counts into tbl
  protected void getNgrams(HashMap<String, Integer>  tbl, int order, List<Integer> wrds, boolean ignoreNullEquivSymbol){
   
    for(int i=0; i<wrds.size(); i++)
      for(int j=0; j<order && j+i<wrds.size(); j++){//ngram: [i,i+j]
        boolean contain_null=false;
        StringBuffer ngram = new StringBuffer();
        for(int k=i; k<=i+j; k++){
          int t_wrd = wrds.get(k);
         
          ngram.append(t_wrd);
          if(k<i+j) ngram.append(" ");
        }
        if(ignoreNullEquivSymbol && contain_null)
          continue;//skip this ngram
        String ngram_str = ngram.toString();
        if(tbl.containsKey(ngram_str))
          tbl.put(ngram_str, tbl.get(ngram_str)+1);
        else
          tbl.put(ngram_str, 1);
      }
  }
 
 

}
TOP

Related Classes of joshua.discriminative.training.oracle.OracleExtractionOnHGV3

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.