Package joshua.discriminative.training.risk_annealer.hypergraph

Source Code of joshua.discriminative.training.risk_annealer.hypergraph.RiskAndFeatureAnnotationOnLMHG

package joshua.discriminative.training.risk_annealer.hypergraph;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.BLEU;
import joshua.decoder.ff.lm.NgramExtractor;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.discriminative.feature_related.feature_template.FeatureTemplate;


/**The way we extract features (stored in featureTbl) for each edge is as following:
* (1) Feature template will return feature-name (i.e. string) and feature value.
* (2) This class will convert the feature-name to feature-id (i.e. integer) by using featureStringToIntegerMap
* (3) The featureStringToIntegerMap may be also used for feature filtering
* */

public class RiskAndFeatureAnnotationOnLMHG {
 
  //private SymbolTable symbolTbl;
 
  private HashSet<HGNode> processedHGNodesTbl =  new HashSet<HGNode>();
 
//  == variables related to BLEU risk
  private boolean doRiskAnnotation = true;//TODO
  private double[] linearCorpusGainThetas; //weights in the Goolge linear corpus gain function
 
  private NgramExtractor ngramExtractor;
 
  private int baselineLMOrder;
  private static int startNgramOrder=1;
  private static int endNgramOrder=4;
 
 
//  == variables related to feature annotation
  private boolean doFeatureAnnotation = true;
  private HashMap<String, Integer> featureStringToIntegerMap; //this can also be used as feature filter
  private HashSet<String> restrictedFeatureSet;
  private List<FeatureTemplate> featTemplates;
  double scale = 1.0; //TODO
 
  Logger logger = Logger.getLogger(RiskAndFeatureAnnotationOnLMHG.class.getSimpleName());
 
  public RiskAndFeatureAnnotationOnLMHG(int baselineLMOrder,  int ngramStateID,  double[] linearCorpusGainThetas,  SymbolTable symbolTbl,
      HashMap<String, Integer> featureStringToIntegerMap,  List<FeatureTemplate> featTemplates, boolean doRiskAnnotation){
   
    this.baselineLMOrder = baselineLMOrder;
    this.doRiskAnnotation = doRiskAnnotation;
   
    if(this.baselineLMOrder<endNgramOrder){
      System.out.println("Baseline LM n-gram order is too small");
      System.exit(1);
    }
    //this.symbolTbl = symbolTbl;
   

    this.linearCorpusGainThetas = linearCorpusGainThetas;
    this.ngramExtractor = new NgramExtractor(symbolTbl, ngramStateID, false, baselineLMOrder);
 
    //=== feature related
    this.featureStringToIntegerMap = featureStringToIntegerMap;
    this.restrictedFeatureSet = new HashSet<String>( featureStringToIntegerMap.keySet() );
    this.featTemplates = featTemplates;
     
    //System.out.println("use riskAnnotatorNoEquiv====");
  }
 
  /**Input a hypergraph, return
   * a FeatureForest
   * Note that the input hypergraph has been changed*/ 
 
  public FeatureForest riskAnnotationOnHG(HyperGraph hg, String[] referenceSentences){
   
    processedHGNodesTbl.clear();
       
    if(doRiskAnnotation){
      HashMap<String, Integer> refereceNgramTable = BLEU.constructMaxRefCountTable(referenceSentences, endNgramOrder);
      annotateNode(hg.goalNode, refereceNgramTable);
    }else{
      annotateNode(hg.goalNode, null);
    }
    releaseNodesStateMemroy();
    return new FeatureForest(hg);   
  }
 

  private void annotateNode(HGNode node, HashMap<String, Integer> refereceNgramTable){
   
    if(processedHGNodesTbl.contains(node))
      return;
    processedHGNodesTbl.add(node);
   
    //=== recursive call on each edge
    for(int i=0; i<node.hyperedges.size(); i++){
      HyperEdge oldEdge = node.hyperedges.get(i);
      HyperEdge newEdge = annotateHyperEdge(oldEdge, refereceNgramTable);
      node.hyperedges.set(i, newEdge);
    }
   
    //===@todo: release the memory consumed by the state of node, but we have to make sure all parent hyperedges have been processed
  }
 
  private HyperEdge annotateHyperEdge(HyperEdge oldEdge, HashMap<String, Integer> refereceNgramTable){
   
    //=== recursive call on each ant item
    if(oldEdge.getAntNodes()!=null)
      for(HGNode antNode : oldEdge.getAntNodes())
        annotateNode(antNode, refereceNgramTable);
   
    //=== HyperEdge-specific operation       
    return createNewHyperEdge(oldEdge, refereceNgramTable);
  }
 
 
  protected HyperEdge createNewHyperEdge(HyperEdge oldEdge, HashMap<String, Integer> refereceNgramTable) {
 
    //======== risk annotation
    double transitionRisk = 0;
    if(doRiskAnnotation)
      transitionRisk = getTransitionRisk(oldEdge, refereceNgramTable);
   
    //System.out.println("tran2=" + riskTransitionCost);
   
    //======== feature annotation
    HashMap<Integer, Double> featureTbl= null;
    if(doFeatureAnnotation)
      featureTbl = featureExtraction(oldEdge, null, scale);
   
 
    /**compared wit the original edge, two changes:
     * (1) add risk at edge (but does not change the orignal model score)
     * (2) add feature tbl
     * */
    return new FeatureHyperEdge(oldEdge, featureTbl, transitionRisk);
  }
 
 
  private void releaseNodesStateMemroy(){
    for(HGNode node : processedHGNodesTbl){
      //System.out.println("releaseNodesStateMemroy");
      node.releaseDPStatesMemory();
    }
  }
 
 
  private double getTransitionRisk(HyperEdge dt, HashMap<String, Integer> refereceNgramTable){
   
    double transitionRisk = 0;
    if(dt.getRule() != null){//note: hyperedges under goal item does not contribute BLEU
      int hypLength = dt.getRule().getEnglish().length-dt.getRule().getArity();
      HashMap<String, Integer> hyperedgeNgramTable = ngramExtractor.getTransitionNgrams(dt, startNgramOrder, endNgramOrder);     
      transitionRisk = - BLEU.computeLinearCorpusGain(linearCorpusGainThetas, hypLength, hyperedgeNgramTable, refereceNgramTable);
     
      /*
      System.out.println("hyp tbl: " + hyperedgeNgramTable);
      System.out.println("ref tbl: " + refereceNgramTable);
      System.out.println("hypLength: " + hypLength);
      System.out.println("risk is " + transitionRisk); 
      System.exit(1);*/
    }
   
    return transitionRisk;
  }
 
 
 
 
  //TODO: copied from RiskAndFeatureAnnotation, consider merge
 
  //============================================================================================
  //==================================== feature extraction function ======================================
  //============================================================================================
 

  /**The way we extract features (stored in featureTbl) for each edge is as following:
   * (1) Feature template will return feature-name (i.e. string) and feature value.
   * (2) This class will convert the feature-name to feature-id (i.e. integer) by using featureStringToIntegerMap
   * (3) The featureStringToIntegerMap (derive restrictedFeatureSet) will be also used for feature filtering
   * */
 
  private final  HashMap<Integer, Double> featureExtraction(HyperEdge dt, HGNode parentItem, double scale){   
 
    //=== extract feature counts
    HashMap<String, Double> activeFeaturesHelper = new HashMap<String, Double>();

    for(FeatureTemplate template : featTemplates){     
      template.getFeatureCounts(dt,  activeFeaturesHelper,  restrictedFeatureSet, scale);
     
    }
   
    //=== convert the featureString to featureInteger
    HashMap<Integer, Double> res = new HashMap<Integer, Double>();
    for(Map.Entry<String, Double> feature : activeFeaturesHelper.entrySet()){
      Integer featureID = featureStringToIntegerMap.get(feature.getKey());
      if(featureID==null){
        logger.severe("Null feature ID, featureID="+feature.getKey());
        System.exit(1);
      }
      res.put(featureID, feature.getValue());
    }
    //System.out.println("Feature extraction res: " + res);
    return res;     
  }
 
}
TOP

Related Classes of joshua.discriminative.training.risk_annealer.hypergraph.RiskAndFeatureAnnotationOnLMHG

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.