Package joshua.discriminative.ranker

Source Code of joshua.discriminative.ranker.RescorerHGSimple

package joshua.discriminative.ranker;

import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.ViterbiExtractor;
import joshua.discriminative.DiscriminativeSupport;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.feature_related.feature_template.EdgeBigramFT;
import joshua.discriminative.feature_related.feature_template.FeatureTemplate;
import joshua.discriminative.feature_related.feature_template.NgramFT;
import joshua.discriminative.feature_related.feature_template.TMFT;
import joshua.discriminative.feature_related.feature_template.TableBasedBaselineFT;




/*This class rescore the hypergraph, and directly change the one-best pointer and logPs
* (Note: we should not change the order of dependency, that is, all features are local)
* */

@Deprecated
public class RescorerHGSimple {
  
  private HashMap processedNodesTtbl =  new HashMap();//help to tranverse a hypergraph
 
  private HashMap correctiveModel = null; //corrective model: this should not have the baseline feature

  private HashSet<String> restrictedFeatSet =null; //feature set
  private List<FeatureTemplate> featTemplates=null;
 
  private HashMap<HyperEdge, Double> hyperEdgeBaselineLogPTbl = new  HashMap<HyperEdge, Double>();
 

  private int numChanges=0;

 
 
  public RescorerHGSimple(){
    //do nothing
  }
 
//  ###########  public interfaces
//  the baseline scale must be explicitly indicated in baseline_scale, instead of implicitly indicated in corrective_model
 

  //return the reranked onebest hg
  //  corrective model: this should not have the baseline feature
  public HyperGraph rerankHGAndGet1best(HyperGraph hg, HashMap correctiveModel, HashSet<String> restrictedFeatSet,
      List<FeatureTemplate> featTemplates, boolean isValueAVector){
    numChanges=0;
    adjustHGLogP( hg, correctiveModel, restrictedFeatSet, featTemplates);
    System.out.println("numChanges="+numChanges);   
    return   ViterbiExtractor.getViterbiTreeHG(hg);     
  }
 
 
  //==========================================
  public HashMap<HyperEdge, Double>  collectTransitionLogPs(HyperGraph hg){
    hyperEdgeBaselineLogPTbl.clear();
    processedNodesTtbl.clear()
    numChanges=0;
    collectTransitionLogPs(hg.goalNode);
    processedNodesTtbl.clear();
   
    return hyperEdgeBaselineLogPTbl;
  }

 
  private  void collectTransitionLogPs(HGNode it ){
    if(processedNodesTtbl.containsKey(it))  return;
    processedNodesTtbl.put(it,1);   
    for(HyperEdge dt : it.hyperedges){         
      collectTransitionLogPs(it, dt);
    }   
  }

  private void collectTransitionLogPs(HGNode parentNode, HyperEdge dt){
    hyperEdgeBaselineLogPTbl.put(dt, dt.getTransitionLogP(false));//get baseline score 
   
    if(dt.getAntNodes()!=null){
      for(HGNode antNode : dt.getAntNodes()){
        collectTransitionLogPs(antNode);
      }
    }
  } 
  //==========================================
 
 
 
  private void adjustHGLogP(HyperGraph hg, HashMap correctiveModel_, HashSet<String> restrictedFeatSet_, List<FeatureTemplate> featTemplates_){
    processedNodesTtbl.clear();
    correctiveModel=correctiveModel_;
    restrictedFeatSet = restrictedFeatSet_;
    featTemplates = featTemplates_;
    adjustNodeLogP(hg.goalNode);
    processedNodesTtbl.clear();
  }
//########### end public interfaces 
 
  //item: recursively call my children edges, change pointer for bestHyperedge
  private  void adjustNodeLogP(HGNode it ){
    if(processedNodesTtbl.containsKey(it)) 
      return;
    processedNodesTtbl.put(it,1);
   
    //==== recursively call my children edges, change pointer for bestHyperedge
    HyperEdge oldEdge =it.bestHyperedge;
    it.bestHyperedge=null;
    for(HyperEdge dt : it.hyperedges){         
      adjustHyperedgeLogP(it, dt );//deduction-specifc feature
      it.semiringPlus(dt);
    }
    if(it.bestHyperedge!=oldEdge)
      numChanges++;
  }
 
 
  //adjust best_cost, and recursively call my ant items
  private void adjustHyperedgeLogP(HGNode parentNode, HyperEdge dt){
    dt.bestDerivationLogP =0;
    if(dt.getAntNodes()!=null){
      for(HGNode antNode : dt.getAntNodes()){
        adjustNodeLogP(antNode);
        dt.bestDerivationLogP += antNode.bestHyperedge.bestDerivationLogP;
      }
    }
    double res = getTransitionLogP(parentNode, dt);
    dt.setTransitionLogP(res) ;
    dt.bestDerivationLogP += res;
   
 
 
 
  //give a dt and pointers to ant items, find all features that apply, non-recursive
  private double getTransitionLogP(HGNode parentNode, HyperEdge dt ){
    double res =0;
   
    HashMap featTbl = new HashMap();
    for(FeatureTemplate template : featTemplates){     
      template.getFeatureCounts(dt,  featTbl,  restrictedFeatSet, 1);//scale is one: hard count     
   
   
    return DiscriminativeSupport.computeLinearCombinationLogP(featTbl, correctiveModel);
   
  }
 
 
 
 
 
 
 
 
//================== example main funciton======================
 
  public static void main(String[] args)   throws IOException{
   
    if(args.length<11){
      System.out.println("wrong command, correct command should be: java Perceptron_TEST is_avg_model is_nbest f_test_items f_test_rules num_sent f_perceptron_model baseline_scale use_tm_feat use_lm_feat use_edge_bigram_feat_only freranked_1best");
      System.out.println("num of args is "+ args.length);
      for(int i=0; i <args.length; i++)
        System.out.println("arg is: " + args[i])
      System.exit(0);   
    }
   
    long start_time = System.currentTimeMillis();

    boolean isAvgModel = new Boolean(args[0].trim());
    boolean isNbest = new Boolean(args[1].trim());
    String testNodesFile=args[2].trim();
    String testRulesFile=args[3].trim();
    int numSent = new Integer(args[4].trim());
    String modelFile = args[5].trim();
    double baselineScale = new Double(args[6].trim());
    boolean useTMFeat = new Boolean(args[7].trim());
    boolean useLMFeat = new Boolean(args[8].trim());
    boolean useEdgeNgramOnly = new Boolean(args[9].trim());
    String reranked1bestFile = args[10].trim();

    boolean saveModelCosts = true;
   
    String featureFile = null;
    if(args.length>11)
      featureFile = args[11].trim();   
   
//    ????????????????????????????????????????????????????
    int ngramStateID = 0;
    //??????????????????????????????????????
   
    SymbolTable symbolTbl = new BuildinSymbol(null);
   
    boolean useIntegerString = false;
    boolean useRuleIDName = false;
   
   
    //####### nbest decoding
    /*if(is_nbest){
      //TODO
    }else{//#####hg decoding*/
   
     
      //======== setup feature templates list
      List<FeatureTemplate> featTemplates =  new ArrayList<FeatureTemplate>()
     
      String baselineName = "baseline_lzf";//TODO
      FeatureTemplate baselineFeature = new TableBasedBaselineFT(baselineName, baselineScale);
      featTemplates.add(baselineFeature)
     
      if(useTMFeat==true){
        FeatureTemplate ft = new TMFT(symbolTbl, useIntegerString, useRuleIDName);
        featTemplates.add(ft);
      }
       
      int baselineLMOrder = 5;//TODO??????????????????
      if(useLMFeat==true){ 
        FeatureTemplate ft = new NgramFT(symbolTbl, false, ngramStateID, baselineLMOrder, 1, 2);//TODO: unigram and bi gram
        featTemplates.add(ft);
      }else if(useEdgeNgramOnly){//exclusive with use_lm_feat
        FeatureTemplate ft = new EdgeBigramFT(symbolTbl, ngramStateID, baselineLMOrder, useIntegerString);
        featTemplates.add(ft);
      }   
      System.out.println("templates are: " + featTemplates);
     
      //============= restricted feature set : normally this is not used as the model itself is a restriction
      HashSet<String> restrictedFeatureSet = null;
      if(featureFile!=null){
        restrictedFeatureSet = new HashSet<String>();
        DiscriminativeSupport.loadFeatureSet(featureFile, restrictedFeatureSet);
        //restricted_feature_set.put(HGDiscriminativeLearner.g_baseline_feat_name, 1.0); //should not add the baseline feature
        System.out.println("============use  restricted feature set========================");
      }
     
      //================ model
      HashMap<String, Double> modelTbl =  new HashMap<String, Double>();
           
      DiscriminativeSupport.loadModel(modelFile, modelTbl, null);
     
      BufferedWriter out1best = FileUtilityOld.getWriteFileStream(reranked1bestFile);
     
      RescorerHGSimple reranker = new RescorerHGSimple();
     
      DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
      diskHG.initRead(testNodesFile, testRulesFile, null);
      for(int sent_id=0; sent_id < numSent; sent_id ++){
        System.out.println("#Process sentence " + sent_id);
        HyperGraph testHG = diskHG.readHyperGraph();
        ((TableBasedBaselineFT) baselineFeature).setBaselineScoreTbl( reranker.collectTransitionLogPs(testHG) );
        HyperGraph rerankedOnebestHG = reranker.rerankHGAndGet1best(testHG, modelTbl, restrictedFeatureSet, featTemplates, isAvgModel);
        System.out.println("bestScore=" + rerankedOnebestHG.goalNode.bestHyperedge.bestDerivationLogP );
        String reranked_1best = ViterbiExtractor.extractViterbiString(symbolTbl, rerankedOnebestHG.goalNode);
       
        FileUtilityOld.writeLzf(out1best, reranked_1best + "\n");       
      }
      FileUtilityOld.closeWriteFile(out1best);
    //}
     
    System.out.println("Time cost: " + ((System.currentTimeMillis()-start_time)/1000));
  }

 
}
TOP

Related Classes of joshua.discriminative.ranker.RescorerHGSimple

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.