Package joshua.discriminative.ranker

Source Code of joshua.discriminative.ranker.HGRanker

package joshua.discriminative.ranker;


import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.JoshuaConfiguration;
import joshua.decoder.chart_parser.ComputeNodeResult;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.KBestExtractor;
import joshua.decoder.hypergraph.ViterbiExtractor;
import joshua.discriminative.DiscriminativeSupport;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.feature_related.feature_function.EdgeTblBasedBaselineFF;


/**This class implements functions to rank HG based on a bunch of feature functions
*It does not change the topology of the HG, but it changes the
*bestHyperedge and bestLogP in hypergraph.
**/


public class HGRanker {
 
  private HashSet<HGNode> processedNodesTbl =  new HashSet<HGNode>()
  private List<FeatureFunction> featFunctions;
 
  private int numChangedBestHyperedge = 0;

 
  private static Logger logger = Logger.getLogger(HGRanker.class.getName());
 
  public HGRanker(List<FeatureFunction> featFunctions){
    this.featFunctions = featFunctions;
  }
 
  /**Change the input hypergraph based on featFunctions
   */
  public void rankHG(HyperGraph hg){
    resetState();
    rankHGNode(hg.goalNode);
    //logger.info("number of nodes whose best hyperedge changes is " + numChangedBestHyperedge
    //    + " among total number of nodes " + processedNodesTbl.size() );
    resetState();   
  }
 
  //get 1best HG
  public HyperGraph rerankHGAndGet1best(HyperGraph hg){
    rankHG(hg);
    return  ViterbiExtractor.getViterbiTreeHG(hg);     
  }
 
 
  public void resetState(){
    processedNodesTbl.clear();
    numChangedBestHyperedge = 0;
  }
 
 
 
  private  void rankHGNode(HGNode it ){
    if(processedNodesTbl.contains(it)) 
      return;
    processedNodesTbl.add(it);
   
    //==== recursively call my children deductions, change pointer for bestHyperedge
    HyperEdge oldBestHyperedge = it.bestHyperedge;
   
    it.bestHyperedge=null;
    for(HyperEdge dt : it.hyperedges){         
      rankHyperEdge(it, dt );
      it.semiringPlus(dt);
   
   
    /**Due to diskHG precision, the behavior may not be precise
     **/
    if(it.bestHyperedge!=oldBestHyperedge){
      numChangedBestHyperedge++;
    }
  }

  private void rankHyperEdge(HGNode parentNode, HyperEdge dt){
   
    dt.bestDerivationLogP = 0;
    if(dt.getAntNodes()!=null){
      for(HGNode antNode : dt.getAntNodes()){
        rankHGNode(antNode);
        dt.bestDerivationLogP += antNode.bestHyperedge.bestDerivationLogP;//semiring times
      }
    }
    double transLogP = getTransitionLogP(parentNode, dt);   
    dt.setTransitionLogP(transLogP);   
    dt.bestDerivationLogP += transLogP;
   
 
 
 
  private double getTransitionLogP(HGNode parentNode, HyperEdge dt ){
    return ComputeNodeResult.computeCombinedTransitionLogP(
        this.featFunctions, dt, parentNode.i, parentNode.j, -1);
  }
 
 
 
 
//  ================== example main funciton======================
 
  public static void main(String[] args)   throws IOException{
   
    if(args.length<10){
      System.out.println("wrong command, correct command");
      System.out.println("num of args is "+ args.length);
      for(int i=0; i <args.length; i++)
        System.out.println("arg is: " + args[i])
      System.exit(0);   
    }
   
    long startTime = System.currentTimeMillis();

    String testNodesFile=args[0].trim();
    String testRulesFile=args[1].trim();
    int numSent = new Integer(args[2].trim());
    String modelFile = args[3].trim();
    double baselineWeight = new Double(args[4].trim());
    boolean useTMFeat = new Boolean(args[5].trim());
    boolean useLMFeat = new Boolean(args[6].trim());
    boolean useEdgeNgramOnly = new Boolean(args[7].trim());
    boolean useTMTargetFeat = new Boolean(args[8].trim());
    String reranked1bestFile = args[9].trim();

    boolean saveModelCosts = true;
   
    String featureFile = null;
    if(args.length>9)
      featureFile = args[9].trim();   
   
   
    SymbolTable symbolTbl = new BuildinSymbol(null);     
    List<FeatureFunction> features =  new ArrayList<FeatureFunction>();
   
   
   
    //=== baseline feature ====
    //TODO: ????????????????????????????????????????????????????     
    int baselineFeatID = 99;
    //??????????????????????????????????????
   
    EdgeTblBasedBaselineFF baselineFeature = new EdgeTblBasedBaselineFF(baselineFeatID, baselineWeight);
    features.add(baselineFeature);
   

    //=== reranking feature ===
    //TODO: ??????????????
    int ngramStateID = 0;
    int baselineLMOrder = 5;
    int startNgramOrder = 1;
    int endNgramOrder = 2;
    int featID = 100;
    double weight = 1.0;
    //????????
   
    Map<String,Integer> rulesIDTable = null; //TODO??
 
    //TODO
    FeatureFunction rerankFF = DiscriminativeSupport.setupRerankingFeature(featID, weight, symbolTbl, useTMFeat, useLMFeat, useEdgeNgramOnly, useTMTargetFeat,
        JoshuaConfiguration.useMicroTMFeat, JoshuaConfiguration.wordMapFile,
        ngramStateID,
        baselineLMOrder, startNgramOrder, endNgramOrder, featureFile, modelFile, rulesIDTable);
   
    features.add(rerankFF);
   
    //=== reranker using the feature functions
    HGRanker reranker = new HGRanker(features);
   
    BufferedWriter out1best = FileUtilityOld.getWriteFileStream(reranked1bestFile);
   
         
    int topN=3;
    boolean useUniqueNbest =true;
    boolean useTreeNbest = false;
    boolean addCombinedCost = true
    KBestExtractor kbestExtractor = new KBestExtractor(symbolTbl, useUniqueNbest, useTreeNbest, false, addCombinedCost, false, true);
   
   
    DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
    diskHG.initRead(testNodesFile, testRulesFile, null);
    for(int sentID=0; sentID < numSent; sentID ++){
      System.out.println("#Process sentence " + sentID);
      HyperGraph testHG = diskHG.readHyperGraph();
      baselineFeature.collectTransitionLogPs(testHG);
      reranker.rankHG(testHG);
   
      try{
        kbestExtractor.lazyKBestExtractOnHG(testHG, features, topN, sentID, out1best);
      } catch (IOException e) {
        e.printStackTrace();
      }
             
    }
    FileUtilityOld.closeWriteFile(out1best);
   
     
    System.out.println("Time cost: " + ((System.currentTimeMillis()-startTime)/1000));
  }

 

 
}
TOP

Related Classes of joshua.discriminative.ranker.HGRanker

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.