Package joshua.discriminative.training.risk_annealer.hypergraph.deprecated

Source Code of joshua.discriminative.training.risk_annealer.hypergraph.deprecated.RiskAndFeatureAnnotation

package joshua.discriminative.training.risk_annealer.hypergraph.deprecated;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.BLEU;
import joshua.decoder.ff.state_maintenance.NgramStateComputer;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.discriminative.feature_related.feature_template.FeatureTemplate;
import joshua.discriminative.training.oracle.DPStateOracle;
import joshua.discriminative.training.oracle.EquivLMState;
import joshua.discriminative.training.oracle.RefineHG;
import joshua.discriminative.training.oracle.RefineHG.RefinedNode;
import joshua.discriminative.training.risk_annealer.hypergraph.FeatureForest;
import joshua.discriminative.training.risk_annealer.hypergraph.FeatureHyperEdge;
import joshua.util.Ngram;
import joshua.util.Regex;



/**The way we extract features (stored in featureTbl) for each edge is as following:
* (1) Feature template will return feature-name (i.e. string) and feature value.
* (2) This class will convert the feature-name to feature-id (i.e. integer) by using featureStringToIntegerMap
* (3) The featureStringToIntegerMap may be also used for feature filtering
* */

@Deprecated
public class RiskAndFeatureAnnotation extends  RefineHG<DPStateOracle> {

  SymbolTable symbolTable;
 
  //== variables related to BLEU risk 
  boolean doRiskAnnotation = true;//TODO
  double[] linearCorpusGainThetas; //weights in the Goolge linear corpus gain function
 
  int ngramOrder;
  private static int bleuOrder = 4;
 
  EquivLMState equi; 
  NgramStateComputer ngramStateComputer;
  int ngramStateID = 0;
 
  //== sentence-specific
  protected HashMap<String, Integer> refNgramsTbl = new HashMap<String, Integer>();
 
  EquivLMState equiv;
 
 
  //== variables related to features
  boolean doFeatureAnnotation = true;//TODO
  HashSet<String> restrictedFeatureSet;//this can also be used as feature filter
  private HashMap<String, Integer> featureStringToIntegerMap;
  private List<FeatureTemplate> featTemplates;
 
 
  /**
   * @param symbolTable_
   * @param nGramOrder_
   * @param linearCorpusGainThetas_
   * @param featureStringToIntegerMap_
   * @param doFeatureFiltering_
   * @param featTemplates_
   */
  public RiskAndFeatureAnnotation(SymbolTable symbolTable_, int nGramOrder_, double[] linearCorpusGainThetas_,
      HashMap<String, Integer> featureStringToIntegerMap_,  List<FeatureTemplate> featTemplates_) {
   
   
    this.symbolTable = symbolTable_;
    this.ngramOrder = nGramOrder_;
    this.linearCorpusGainThetas = linearCorpusGainThetas_;
   
    this.featureStringToIntegerMap = featureStringToIntegerMap_;
    restrictedFeatureSet = new HashSet<String>(featureStringToIntegerMap.keySet());
   
    this.featTemplates = featTemplates_;
   
    this.equi = new EquivLMState(symbolTable_, bleuOrder);
    this.ngramStateComputer = new NgramStateComputer(symbolTable_, bleuOrder, ngramStateID);
   
    this.equiv = new EquivLMState(this.symbolTable, bleuOrder);
    System.out.println("use RiskAndFeatureAnnotation====");
  }

 
  public FeatureForest riskAnnotationOnHG(HyperGraph hg, String refSentStr){

    setupRefAndPrefixAndSurfixTbl(refSentStr);//TODO: should enforce this in parent class   
    return new FeatureForestsplitHG(hg) );
  }
 
  public FeatureForest riskAnnotationOnHG(HyperGraph hg, String[] refSentStrs){

    setupRefAndPrefixAndSurfixTbl(refSentStrs);//TODO: should enforce this in parent class   
    return new FeatureForestsplitHG(hg) );
  }


//  TODO: should enforce this in parent class
  private void setupRefAndPrefixAndSurfixTbl(String refSentStr){ 
   
    //== ref tbL and effective ref len
    int[] refWords = this.symbolTable.addTerminals(refSentStr.split("\\s+"));
         
    refNgramsTbl.clear();
    Ngram.getNgrams(refNgramsTbl, 1, bleuOrder, refWords);
   
    //== prefix and suffix tbl
    equi.setupPrefixAndSurfixTbl(refNgramsTbl);
   
  }
 
 
  private void setupRefAndPrefixAndSurfixTbl(String[] refSentStrs){   
   
    //== ref tbL and effective ref len
    int[] refLens = new int[refSentStrs.length];
    ArrayList<HashMap<String, Integer>> listRefNgramTbl = new ArrayList<HashMap<String, Integer>>();   
    for(int i =0; i<refSentStrs.length; i++){
      int[] refWords = this.symbolTable.addTerminals(refSentStrs[i].split("\\s+"));
      refLens[i] = refWords.length;   
      HashMap<String, Integer> tRefNgramsTbl = new HashMap<String, Integer>();
      Ngram.getNgrams(tRefNgramsTbl, 1, bleuOrder, refWords)
      listRefNgramTbl.add(tRefNgramsTbl);     
    }
   
    refNgramsTbl =  BLEU.computeMaxRefCountTbl(listRefNgramTbl);
   
    //== prefix and suffix tbl
    equi.setupPrefixAndSurfixTbl(refNgramsTbl);
  }
   


  @Override
  protected HyperEdge createNewHyperEdge(HyperEdge originalEdge, List<HGNode> antVirtualItems, DPStateOracle dps) {
 
    //== risk annotation
    double riskTransitionCost = 0;
    if(doRiskAnnotation)
      riskTransitionCost = getRiskTransitionCost(originalEdge, antVirtualItems, dps);//TODO
   
    //System.out.println("tran2=" + riskTransitionCost);
   
    //== feature annotation
    HashMap<Integer, Double> featureTbl= null;
    if(doFeatureAnnotation)
      featureTbl = featureExtraction(originalEdge, null);//TODO: originalEdge? null parentNode
   
    /**compared wit the original edge, three changes:
     * (1) change the list of ant nodes
     * (2) add risk cost at edge (but does not change the orignal model cost)
     * (3) add feature tbl
     * */
    return new FeatureHyperEdge(originalEdge.getRule(), originalEdge.bestDerivationLogP, originalEdge.getTransitionLogP(false), antVirtualItems,
        originalEdge.getSourcePath(), featureTbl, riskTransitionCost);
  }
 
 
  private double getRiskTransitionCost(HyperEdge originalEdge, List<HGNode> antVirtualItems, DPStateOracle dps){//note: transition_cost is already linearly interpolated
   
    double riskTransitionCost = dps.bestDerivationLogP;
    if(antVirtualItems!=null
      for(HGNode ant_it :antVirtualItems ){
        RefinedNode it2 = (RefinedNode) ant_it;//TODO
        riskTransitionCost -= ((DPStateOracle)it2.dpState).bestDerivationLogP;
      }     
    return -riskTransitionCost;
  }
 
  protected  DPStateOracle computeState(HGNode originalParentItem, HyperEdge originalEdge, List<HGNode> antVirtualItems){
    /*
    double refLen = computeAvgLen(originalParentItem.j-originalParentItem.i, srcSentLen, refSentLen);
   
    //=== hypereges under "goal item" does not have rule
    if(originalEdge.getRule()==null){
      if(antVirtualItems.size()!=1){
        System.out.println("error deduction under goal item have more than one item");
        System.exit(0);
      }
      double bleuCost = antVirtualItems.get(0).bestHyperedge.bestDerivationCost;
      return  new DPStateOracle(0, null, null, null, bleuCost);//no DPState at all
      
    }
   
   
    ComputeOracleStateResult lmState = computeLMState(originalEdge, antVirtualItems); 
     
    int hypLen = lmState.numNewWordsAtEdge;
   
    int[] numNgramMatches = new int[bleuOrder];
   
    Iterator iter = lmState.newNgramsTbl.keySet().iterator();
      while(iter.hasNext()){
        String ngram = (String)iter.next();
       
      int finalCount =  lmState.newNgramsTbl.get(ngram);
      if(doLocalNgramClip)
        numNgramMatches[ngram.split("\\s+").length-1] += Support.findMin(finalCount, refNgramsTbl.get(ngram)) ;
      else
        numNgramMatches[ngram.split("\\s+").length-1] += finalCount; //do not do any cliping         
      }
   
      if(antVirtualItems!=null){
      for(int i=0; i<antVirtualItems.size(); i++){
        DPStateOracle antDPState = (DPStateOracle)((RefinedNode)antVirtualItems.get(i)).dpState;         
        hypLen += antDPState.bestLen;
        for(int t=0; t<bleuOrder; t++)
          numNgramMatches[t] += antDPState.ngramMatches[t];
      }
      }
    double bleuCost = - computeBleu(hypLen, refLen, numNgramMatches, bleuOrder);

    return  new DPStateOracle(hypLen, numNgramMatches, lmState.leftEdgeWords, lmState.rightEdgeWords, bleuCost);*/
    return null;
  }
 

 
 
 
 
  //========================================== BLEU realted =====================================
  //TODO: merge with joshua.decoder.BLEU
 
  /**
   * speed consideration: assume hypNgramTable has a smaller
   * size than referenceNgramTable does
   */
  public static double computeLinearCorpusGain(double[] linearCorpusGainThetas, int hypLength, HashMap<String,Integer> hypNgramTable,  HashMap<String,Integer> referenceNgramTable) {
    double res = 0;
    int[] numMatches = new int[5];
    res += linearCorpusGainThetas[0] * hypLength;
    numMatches[0] = hypLength;
    for (Entry<String,Integer> entry : hypNgramTable.entrySet()) {
      String   key = entry.getKey();
      Integer refNgramCount = referenceNgramTable.get(key);
      //System.out.println("key is " + key); System.exit(1);
      if(refNgramCount!=null){//delta function
        int ngramOrder = Regex.spaces.split(key).length;
        res += entry.getValue() * linearCorpusGainThetas[ngramOrder];
        numMatches[ngramOrder] += entry.getValue();
      }
    }
    /*
    System.out.print("Google BLEU stats are: ");
    for(int i=0; i<5; i++)
      System.out.print(numMatches[i]+ " ");
    System.out.print(" ; BLUE is " + res);
    System.out.println();
    */
    return res;
  }
 
 
 
  public static double[] computeLinearCorpusThetas(int numUnigramTokens, double unigramPrecision, double decayRatio){
    double[] res = new double[5];
    res[0] = -1.0/numUnigramTokens;
    for(int i=1; i<5; i++)
      res[i] = 1.0/(4.0*numUnigramTokens*unigramPrecision*Math.pow(decayRatio, i-1));
    System.out.print("Thetas are: ");
    for(int i=0; i<5; i++)
      System.out.print(res[i] + " ");
    System.out.print("\n");
    return res;
  }



 
  //============================================================================================
  //==================================== feature extraction function ======================================
  //============================================================================================
 
  /**The way we extract features (stored in featureTbl) for each edge is as following:
   * (1) Feature template will return feature-name (i.e. string) and feature value.
   * (2) This method will convert the feature-name to feature-id (i.e. integer) by using featureStringToIntegerMap
   * (3) The featureStringToIntegerMap may be also used for feature filtering
   * */
 
  private final  HashMap<Integer, Double> featureExtraction(HyperEdge dt, HGNode parentItem){   
 
    //=== extract feature counts
    HashMap<String, Double> activeFeaturesHelper = new HashMap<String, Double>();

    double tScale = 1.0;//TODO
    for(FeatureTemplate template : featTemplates){   
      template.getFeatureCounts(dt,  activeFeaturesHelper,  restrictedFeatureSet, tScale);   
    }
   
    //=== convert the featureString to featureInteger
    HashMap<Integer, Double> res = new HashMap<Integer, Double>();
    for(Map.Entry<String, Double> feature : activeFeaturesHelper.entrySet()){
      Integer featureID = featureStringToIntegerMap.get(feature.getKey());
      if(featureID==null){
        System.out.println("Null feature ID");
        System.exit(1);
      }
      res.put(featureID, feature.getValue());     
      //System.out.println("Str1=" + feature.getKey() + "; ID=" + featureID + "; val=" +feature.getValue());
    }
    //System.out.println("Feature extraction res: " + res);
    return res;
     
  }


 
 
}
TOP

Related Classes of joshua.discriminative.training.risk_annealer.hypergraph.deprecated.RiskAndFeatureAnnotation

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.