package joshua.discriminative.training.risk_annealer.hypergraph.deprecated;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.BLEU;
import joshua.decoder.ff.state_maintenance.NgramStateComputer;
import joshua.decoder.hypergraph.HGNode;
import joshua.decoder.hypergraph.HyperEdge;
import joshua.decoder.hypergraph.HyperGraph;
import joshua.discriminative.feature_related.feature_template.FeatureTemplate;
import joshua.discriminative.training.oracle.DPStateOracle;
import joshua.discriminative.training.oracle.EquivLMState;
import joshua.discriminative.training.oracle.RefineHG;
import joshua.discriminative.training.oracle.RefineHG.RefinedNode;
import joshua.discriminative.training.risk_annealer.hypergraph.FeatureForest;
import joshua.discriminative.training.risk_annealer.hypergraph.FeatureHyperEdge;
import joshua.util.Ngram;
import joshua.util.Regex;
/**The way we extract features (stored in featureTbl) for each edge is as following:
* (1) Feature template will return feature-name (i.e. string) and feature value.
* (2) This class will convert the feature-name to feature-id (i.e. integer) by using featureStringToIntegerMap
* (3) The featureStringToIntegerMap may be also used for feature filtering
* */
@Deprecated
public class RiskAndFeatureAnnotation extends RefineHG<DPStateOracle> {
SymbolTable symbolTable;
//== variables related to BLEU risk
boolean doRiskAnnotation = true;//TODO
double[] linearCorpusGainThetas; //weights in the Goolge linear corpus gain function
int ngramOrder;
private static int bleuOrder = 4;
EquivLMState equi;
NgramStateComputer ngramStateComputer;
int ngramStateID = 0;
//== sentence-specific
protected HashMap<String, Integer> refNgramsTbl = new HashMap<String, Integer>();
EquivLMState equiv;
//== variables related to features
boolean doFeatureAnnotation = true;//TODO
HashSet<String> restrictedFeatureSet;//this can also be used as feature filter
private HashMap<String, Integer> featureStringToIntegerMap;
private List<FeatureTemplate> featTemplates;
/**
* @param symbolTable_
* @param nGramOrder_
* @param linearCorpusGainThetas_
* @param featureStringToIntegerMap_
* @param doFeatureFiltering_
* @param featTemplates_
*/
public RiskAndFeatureAnnotation(SymbolTable symbolTable_, int nGramOrder_, double[] linearCorpusGainThetas_,
HashMap<String, Integer> featureStringToIntegerMap_, List<FeatureTemplate> featTemplates_) {
this.symbolTable = symbolTable_;
this.ngramOrder = nGramOrder_;
this.linearCorpusGainThetas = linearCorpusGainThetas_;
this.featureStringToIntegerMap = featureStringToIntegerMap_;
restrictedFeatureSet = new HashSet<String>(featureStringToIntegerMap.keySet());
this.featTemplates = featTemplates_;
this.equi = new EquivLMState(symbolTable_, bleuOrder);
this.ngramStateComputer = new NgramStateComputer(symbolTable_, bleuOrder, ngramStateID);
this.equiv = new EquivLMState(this.symbolTable, bleuOrder);
System.out.println("use RiskAndFeatureAnnotation====");
}
public FeatureForest riskAnnotationOnHG(HyperGraph hg, String refSentStr){
setupRefAndPrefixAndSurfixTbl(refSentStr);//TODO: should enforce this in parent class
return new FeatureForest( splitHG(hg) );
}
public FeatureForest riskAnnotationOnHG(HyperGraph hg, String[] refSentStrs){
setupRefAndPrefixAndSurfixTbl(refSentStrs);//TODO: should enforce this in parent class
return new FeatureForest( splitHG(hg) );
}
// TODO: should enforce this in parent class
private void setupRefAndPrefixAndSurfixTbl(String refSentStr){
//== ref tbL and effective ref len
int[] refWords = this.symbolTable.addTerminals(refSentStr.split("\\s+"));
refNgramsTbl.clear();
Ngram.getNgrams(refNgramsTbl, 1, bleuOrder, refWords);
//== prefix and suffix tbl
equi.setupPrefixAndSurfixTbl(refNgramsTbl);
}
private void setupRefAndPrefixAndSurfixTbl(String[] refSentStrs){
//== ref tbL and effective ref len
int[] refLens = new int[refSentStrs.length];
ArrayList<HashMap<String, Integer>> listRefNgramTbl = new ArrayList<HashMap<String, Integer>>();
for(int i =0; i<refSentStrs.length; i++){
int[] refWords = this.symbolTable.addTerminals(refSentStrs[i].split("\\s+"));
refLens[i] = refWords.length;
HashMap<String, Integer> tRefNgramsTbl = new HashMap<String, Integer>();
Ngram.getNgrams(tRefNgramsTbl, 1, bleuOrder, refWords);
listRefNgramTbl.add(tRefNgramsTbl);
}
refNgramsTbl = BLEU.computeMaxRefCountTbl(listRefNgramTbl);
//== prefix and suffix tbl
equi.setupPrefixAndSurfixTbl(refNgramsTbl);
}
@Override
protected HyperEdge createNewHyperEdge(HyperEdge originalEdge, List<HGNode> antVirtualItems, DPStateOracle dps) {
//== risk annotation
double riskTransitionCost = 0;
if(doRiskAnnotation)
riskTransitionCost = getRiskTransitionCost(originalEdge, antVirtualItems, dps);//TODO
//System.out.println("tran2=" + riskTransitionCost);
//== feature annotation
HashMap<Integer, Double> featureTbl= null;
if(doFeatureAnnotation)
featureTbl = featureExtraction(originalEdge, null);//TODO: originalEdge? null parentNode
/**compared wit the original edge, three changes:
* (1) change the list of ant nodes
* (2) add risk cost at edge (but does not change the orignal model cost)
* (3) add feature tbl
* */
return new FeatureHyperEdge(originalEdge.getRule(), originalEdge.bestDerivationLogP, originalEdge.getTransitionLogP(false), antVirtualItems,
originalEdge.getSourcePath(), featureTbl, riskTransitionCost);
}
private double getRiskTransitionCost(HyperEdge originalEdge, List<HGNode> antVirtualItems, DPStateOracle dps){//note: transition_cost is already linearly interpolated
double riskTransitionCost = dps.bestDerivationLogP;
if(antVirtualItems!=null)
for(HGNode ant_it :antVirtualItems ){
RefinedNode it2 = (RefinedNode) ant_it;//TODO
riskTransitionCost -= ((DPStateOracle)it2.dpState).bestDerivationLogP;
}
return -riskTransitionCost;
}
protected DPStateOracle computeState(HGNode originalParentItem, HyperEdge originalEdge, List<HGNode> antVirtualItems){
/*
double refLen = computeAvgLen(originalParentItem.j-originalParentItem.i, srcSentLen, refSentLen);
//=== hypereges under "goal item" does not have rule
if(originalEdge.getRule()==null){
if(antVirtualItems.size()!=1){
System.out.println("error deduction under goal item have more than one item");
System.exit(0);
}
double bleuCost = antVirtualItems.get(0).bestHyperedge.bestDerivationCost;
return new DPStateOracle(0, null, null, null, bleuCost);//no DPState at all
}
ComputeOracleStateResult lmState = computeLMState(originalEdge, antVirtualItems);
int hypLen = lmState.numNewWordsAtEdge;
int[] numNgramMatches = new int[bleuOrder];
Iterator iter = lmState.newNgramsTbl.keySet().iterator();
while(iter.hasNext()){
String ngram = (String)iter.next();
int finalCount = lmState.newNgramsTbl.get(ngram);
if(doLocalNgramClip)
numNgramMatches[ngram.split("\\s+").length-1] += Support.findMin(finalCount, refNgramsTbl.get(ngram)) ;
else
numNgramMatches[ngram.split("\\s+").length-1] += finalCount; //do not do any cliping
}
if(antVirtualItems!=null){
for(int i=0; i<antVirtualItems.size(); i++){
DPStateOracle antDPState = (DPStateOracle)((RefinedNode)antVirtualItems.get(i)).dpState;
hypLen += antDPState.bestLen;
for(int t=0; t<bleuOrder; t++)
numNgramMatches[t] += antDPState.ngramMatches[t];
}
}
double bleuCost = - computeBleu(hypLen, refLen, numNgramMatches, bleuOrder);
return new DPStateOracle(hypLen, numNgramMatches, lmState.leftEdgeWords, lmState.rightEdgeWords, bleuCost);*/
return null;
}
//========================================== BLEU realted =====================================
//TODO: merge with joshua.decoder.BLEU
/**
* speed consideration: assume hypNgramTable has a smaller
* size than referenceNgramTable does
*/
public static double computeLinearCorpusGain(double[] linearCorpusGainThetas, int hypLength, HashMap<String,Integer> hypNgramTable, HashMap<String,Integer> referenceNgramTable) {
double res = 0;
int[] numMatches = new int[5];
res += linearCorpusGainThetas[0] * hypLength;
numMatches[0] = hypLength;
for (Entry<String,Integer> entry : hypNgramTable.entrySet()) {
String key = entry.getKey();
Integer refNgramCount = referenceNgramTable.get(key);
//System.out.println("key is " + key); System.exit(1);
if(refNgramCount!=null){//delta function
int ngramOrder = Regex.spaces.split(key).length;
res += entry.getValue() * linearCorpusGainThetas[ngramOrder];
numMatches[ngramOrder] += entry.getValue();
}
}
/*
System.out.print("Google BLEU stats are: ");
for(int i=0; i<5; i++)
System.out.print(numMatches[i]+ " ");
System.out.print(" ; BLUE is " + res);
System.out.println();
*/
return res;
}
public static double[] computeLinearCorpusThetas(int numUnigramTokens, double unigramPrecision, double decayRatio){
double[] res = new double[5];
res[0] = -1.0/numUnigramTokens;
for(int i=1; i<5; i++)
res[i] = 1.0/(4.0*numUnigramTokens*unigramPrecision*Math.pow(decayRatio, i-1));
System.out.print("Thetas are: ");
for(int i=0; i<5; i++)
System.out.print(res[i] + " ");
System.out.print("\n");
return res;
}
//============================================================================================
//==================================== feature extraction function ======================================
//============================================================================================
/**The way we extract features (stored in featureTbl) for each edge is as following:
* (1) Feature template will return feature-name (i.e. string) and feature value.
* (2) This method will convert the feature-name to feature-id (i.e. integer) by using featureStringToIntegerMap
* (3) The featureStringToIntegerMap may be also used for feature filtering
* */
private final HashMap<Integer, Double> featureExtraction(HyperEdge dt, HGNode parentItem){
//=== extract feature counts
HashMap<String, Double> activeFeaturesHelper = new HashMap<String, Double>();
double tScale = 1.0;//TODO
for(FeatureTemplate template : featTemplates){
template.getFeatureCounts(dt, activeFeaturesHelper, restrictedFeatureSet, tScale);
}
//=== convert the featureString to featureInteger
HashMap<Integer, Double> res = new HashMap<Integer, Double>();
for(Map.Entry<String, Double> feature : activeFeaturesHelper.entrySet()){
Integer featureID = featureStringToIntegerMap.get(feature.getKey());
if(featureID==null){
System.out.println("Null feature ID");
System.exit(1);
}
res.put(featureID, feature.getValue());
//System.out.println("Str1=" + feature.getKey() + "; ID=" + featureID + "; val=" +feature.getValue());
}
//System.out.println("Feature extraction res: " + res);
return res;
}
}