package joshua.discriminative.ranker;
import java.io.BufferedWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import joshua.corpus.vocab.BuildinSymbol;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.hypergraph.DiskHyperGraph;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.decoder.hypergraph.ViterbiExtractor;
import joshua.discriminative.DiscriminativeSupport;
import joshua.discriminative.FileUtilityOld;
import joshua.discriminative.feature_related.feature_template.EdgeBigramFT;
import joshua.discriminative.feature_related.feature_template.FeatureTemplate;
import joshua.discriminative.feature_related.feature_template.NgramFT;
import joshua.discriminative.feature_related.feature_template.TMFT;
import joshua.discriminative.feature_related.feature_template.TableBasedBaselineFT;
/*This class rescore the hypergraph, and directly change the one-best pointer and logPs
* (Note: we should not change the order of dependency, that is, all features are local)
* */
@Deprecated
public class RescorerHGSimple {
private HashMap processedNodesTtbl = new HashMap();//help to tranverse a hypergraph
private HashMap correctiveModel = null; //corrective model: this should not have the baseline feature
private HashSet<String> restrictedFeatSet =null; //feature set
private List<FeatureTemplate> featTemplates=null;
private HashMap<HyperEdge, Double> hyperEdgeBaselineLogPTbl = new HashMap<HyperEdge, Double>();
private int numChanges=0;
public RescorerHGSimple(){
//do nothing
}
// ########### public interfaces
// the baseline scale must be explicitly indicated in baseline_scale, instead of implicitly indicated in corrective_model
//return the reranked onebest hg
// corrective model: this should not have the baseline feature
public HyperGraph rerankHGAndGet1best(HyperGraph hg, HashMap correctiveModel, HashSet<String> restrictedFeatSet,
List<FeatureTemplate> featTemplates, boolean isValueAVector){
numChanges=0;
adjustHGLogP( hg, correctiveModel, restrictedFeatSet, featTemplates);
System.out.println("numChanges="+numChanges);
return ViterbiExtractor.getViterbiTreeHG(hg);
}
//==========================================
public HashMap<HyperEdge, Double> collectTransitionLogPs(HyperGraph hg){
hyperEdgeBaselineLogPTbl.clear();
processedNodesTtbl.clear();
numChanges=0;
collectTransitionLogPs(hg.goalNode);
processedNodesTtbl.clear();
return hyperEdgeBaselineLogPTbl;
}
private void collectTransitionLogPs(HGNode it ){
if(processedNodesTtbl.containsKey(it)) return;
processedNodesTtbl.put(it,1);
for(HyperEdge dt : it.hyperedges){
collectTransitionLogPs(it, dt);
}
}
private void collectTransitionLogPs(HGNode parentNode, HyperEdge dt){
hyperEdgeBaselineLogPTbl.put(dt, dt.getTransitionLogP(false));//get baseline score
if(dt.getAntNodes()!=null){
for(HGNode antNode : dt.getAntNodes()){
collectTransitionLogPs(antNode);
}
}
}
//==========================================
private void adjustHGLogP(HyperGraph hg, HashMap correctiveModel_, HashSet<String> restrictedFeatSet_, List<FeatureTemplate> featTemplates_){
processedNodesTtbl.clear();
correctiveModel=correctiveModel_;
restrictedFeatSet = restrictedFeatSet_;
featTemplates = featTemplates_;
adjustNodeLogP(hg.goalNode);
processedNodesTtbl.clear();
}
//########### end public interfaces
//item: recursively call my children edges, change pointer for bestHyperedge
private void adjustNodeLogP(HGNode it ){
if(processedNodesTtbl.containsKey(it))
return;
processedNodesTtbl.put(it,1);
//==== recursively call my children edges, change pointer for bestHyperedge
HyperEdge oldEdge =it.bestHyperedge;
it.bestHyperedge=null;
for(HyperEdge dt : it.hyperedges){
adjustHyperedgeLogP(it, dt );//deduction-specifc feature
it.semiringPlus(dt);
}
if(it.bestHyperedge!=oldEdge)
numChanges++;
}
//adjust best_cost, and recursively call my ant items
private void adjustHyperedgeLogP(HGNode parentNode, HyperEdge dt){
dt.bestDerivationLogP =0;
if(dt.getAntNodes()!=null){
for(HGNode antNode : dt.getAntNodes()){
adjustNodeLogP(antNode);
dt.bestDerivationLogP += antNode.bestHyperedge.bestDerivationLogP;
}
}
double res = getTransitionLogP(parentNode, dt);
dt.setTransitionLogP(res) ;
dt.bestDerivationLogP += res;
}
//give a dt and pointers to ant items, find all features that apply, non-recursive
private double getTransitionLogP(HGNode parentNode, HyperEdge dt ){
double res =0;
HashMap featTbl = new HashMap();
for(FeatureTemplate template : featTemplates){
template.getFeatureCounts(dt, featTbl, restrictedFeatSet, 1);//scale is one: hard count
}
return DiscriminativeSupport.computeLinearCombinationLogP(featTbl, correctiveModel);
}
//================== example main funciton======================
public static void main(String[] args) throws IOException{
if(args.length<11){
System.out.println("wrong command, correct command should be: java Perceptron_TEST is_avg_model is_nbest f_test_items f_test_rules num_sent f_perceptron_model baseline_scale use_tm_feat use_lm_feat use_edge_bigram_feat_only freranked_1best");
System.out.println("num of args is "+ args.length);
for(int i=0; i <args.length; i++)
System.out.println("arg is: " + args[i]);
System.exit(0);
}
long start_time = System.currentTimeMillis();
boolean isAvgModel = new Boolean(args[0].trim());
boolean isNbest = new Boolean(args[1].trim());
String testNodesFile=args[2].trim();
String testRulesFile=args[3].trim();
int numSent = new Integer(args[4].trim());
String modelFile = args[5].trim();
double baselineScale = new Double(args[6].trim());
boolean useTMFeat = new Boolean(args[7].trim());
boolean useLMFeat = new Boolean(args[8].trim());
boolean useEdgeNgramOnly = new Boolean(args[9].trim());
String reranked1bestFile = args[10].trim();
boolean saveModelCosts = true;
String featureFile = null;
if(args.length>11)
featureFile = args[11].trim();
// ????????????????????????????????????????????????????
int ngramStateID = 0;
//??????????????????????????????????????
SymbolTable symbolTbl = new BuildinSymbol(null);
boolean useIntegerString = false;
boolean useRuleIDName = false;
//####### nbest decoding
/*if(is_nbest){
//TODO
}else{//#####hg decoding*/
//======== setup feature templates list
List<FeatureTemplate> featTemplates = new ArrayList<FeatureTemplate>();
String baselineName = "baseline_lzf";//TODO
FeatureTemplate baselineFeature = new TableBasedBaselineFT(baselineName, baselineScale);
featTemplates.add(baselineFeature);
if(useTMFeat==true){
FeatureTemplate ft = new TMFT(symbolTbl, useIntegerString, useRuleIDName);
featTemplates.add(ft);
}
int baselineLMOrder = 5;//TODO??????????????????
if(useLMFeat==true){
FeatureTemplate ft = new NgramFT(symbolTbl, false, ngramStateID, baselineLMOrder, 1, 2);//TODO: unigram and bi gram
featTemplates.add(ft);
}else if(useEdgeNgramOnly){//exclusive with use_lm_feat
FeatureTemplate ft = new EdgeBigramFT(symbolTbl, ngramStateID, baselineLMOrder, useIntegerString);
featTemplates.add(ft);
}
System.out.println("templates are: " + featTemplates);
//============= restricted feature set : normally this is not used as the model itself is a restriction
HashSet<String> restrictedFeatureSet = null;
if(featureFile!=null){
restrictedFeatureSet = new HashSet<String>();
DiscriminativeSupport.loadFeatureSet(featureFile, restrictedFeatureSet);
//restricted_feature_set.put(HGDiscriminativeLearner.g_baseline_feat_name, 1.0); //should not add the baseline feature
System.out.println("============use restricted feature set========================");
}
//================ model
HashMap<String, Double> modelTbl = new HashMap<String, Double>();
DiscriminativeSupport.loadModel(modelFile, modelTbl, null);
BufferedWriter out1best = FileUtilityOld.getWriteFileStream(reranked1bestFile);
RescorerHGSimple reranker = new RescorerHGSimple();
DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null);
diskHG.initRead(testNodesFile, testRulesFile, null);
for(int sent_id=0; sent_id < numSent; sent_id ++){
System.out.println("#Process sentence " + sent_id);
HyperGraph testHG = diskHG.readHyperGraph();
((TableBasedBaselineFT) baselineFeature).setBaselineScoreTbl( reranker.collectTransitionLogPs(testHG) );
HyperGraph rerankedOnebestHG = reranker.rerankHGAndGet1best(testHG, modelTbl, restrictedFeatureSet, featTemplates, isAvgModel);
System.out.println("bestScore=" + rerankedOnebestHG.goalNode.bestHyperedge.bestDerivationLogP );
String reranked_1best = ViterbiExtractor.extractViterbiString(symbolTbl, rerankedOnebestHG.goalNode);
FileUtilityOld.writeLzf(out1best, reranked_1best + "\n");
}
FileUtilityOld.closeWriteFile(out1best);
//}
System.out.println("Time cost: " + ((System.currentTimeMillis()-start_time)/1000));
}
}