package joshua.discriminative.feature_related.feature_function;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.BLEU;
import joshua.decoder.chart_parser.SourcePath;
import joshua.decoder.ff.DefaultStatefulFF;
import joshua.decoder.ff.lm.NgramExtractor;
import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;
import joshua.util.Regex;
import joshua.util.io.LineReader;
import joshua.util.io.Reader;
import joshua.util.io.UncheckedIOException;
public class BLEUOracleModel extends DefaultStatefulFF {
private int startNgramOrder = 1;
private int endNgramOrder = 4;
private SymbolTable symbolTbl = null;
private NgramExtractor ngramExtractor;
private Reader<String>[] referenceReaders;
private boolean useIntegerNgram = true;
private static Logger logger = Logger.getLogger(BLEUOracleModel.class.getName());
Map<Integer, Map<String,Integer>> tblOfReferenceNgramTbls;
private int maxSentIDSoFar=-1;//TODO: assume valid sentID start from zero
/*
private static double unigramPrecision = 0.85;
private static double precisionDecayRatio = 0.7;
private static int numUnigramTokens = 10;
private static double[] linearCorpusGainThetas =BLEU.computeLinearCorpusThetas(
numUnigramTokens, unigramPrecision, precisionDecayRatio);
*/
private double[] linearCorpusGainThetas = null;
public BLEUOracleModel(int ngramStateID, int baselineLMOrder, int featID, SymbolTable psymbol, double weight, String referenceFile, double[] linearCorpusGainThetas) {
super(ngramStateID, weight, featID);
if(baselineLMOrder<endNgramOrder)
logger.severe("baselineLMOrder is too small; baselineLMOrder="+baselineLMOrder);
this.symbolTbl = psymbol;
this.ngramExtractor = new NgramExtractor(symbolTbl, ngramStateID, useIntegerNgram, baselineLMOrder);
this.linearCorpusGainThetas = linearCorpusGainThetas;
logger.info("linearCorpusGainThetas=" + this.linearCorpusGainThetas);
//setup reference reader
this.referenceReaders = new LineReader[1];
LineReader reader = openOneFile(referenceFile);
this.referenceReaders[0] = reader;
this.tblOfReferenceNgramTbls = new HashMap<Integer, Map<String,Integer>>();
logger.info("number of references used is " + referenceReaders.length);
}
public BLEUOracleModel(int ngramStateID, int baselineLMOrder, int featID, SymbolTable psymbol, double weight, String[] referenceFiles, double[] linearCorpusGainThetas) {
super(ngramStateID, weight, featID);
this.symbolTbl = psymbol;
this.ngramExtractor = new NgramExtractor(symbolTbl, ngramStateID, useIntegerNgram, baselineLMOrder);
this.linearCorpusGainThetas = linearCorpusGainThetas;
logger.info("linearCorpusGainThetas=" + this.linearCorpusGainThetas);
//setup reference readers
this.referenceReaders = new LineReader[referenceFiles.length];
for(int i=0; i<referenceFiles.length; i++){
LineReader reader = openOneFile(referenceFiles[i]);
this.referenceReaders[i] = reader;
}
this.tblOfReferenceNgramTbls = new HashMap<Integer, Map<String,Integer>>();
logger.info("number of references used is " + referenceReaders.length);
}
/**we do not have sentence-specific estimation
* */
public double estimateLogP(Rule rule, int sentID) {
return 0;
}
public double estimateFutureLogP(Rule rule, DPState curDPState, int sentID) {
// TODO Auto-generated method stub
return 0;
}
public double finalTransitionLogP(HGNode antNode, int spanStart, int spanEnd, SourcePath srcPath, int sentID) {
//TODO Auto-generated method stub
return 0;
}
public double transitionLogP(Rule rule, List<HGNode> antNodes, int spanStart, int spanEnd, SourcePath srcPath, int sentID) {
return computeTransitionBleu(rule, antNodes, setupReferenceNgramTable(sentID));
}
//========================= risk related =============
synchronized private Map<String,Integer> setupReferenceNgramTable(int sentID){
while(this.maxSentIDSoFar<sentID){
this.maxSentIDSoFar++;
try {
logger.info("open a new sentence with id " + this.maxSentIDSoFar);
String[] referenceSentences = new String[referenceReaders.length];
for(int i=0; i<referenceReaders.length; i++){
referenceSentences[i] = referenceReaders[i].readLine();
}
if(this.useIntegerNgram){
referenceSentences = convertToIntegerString(referenceSentences);
}
Map<String,Integer> ngramTable = BLEU.constructMaxRefCountTable(referenceSentences, endNgramOrder);
this.tblOfReferenceNgramTbls.put(this.maxSentIDSoFar, ngramTable);
} catch (IOException ioe) {
logger.severe("read references error");
System.exit(0);
throw new UncheckedIOException(ioe);
}
}
return this.tblOfReferenceNgramTbls.get(sentID);
}
private double computeTransitionBleu(Rule rule, List<HGNode> antNodes, Map<String,Integer> refNgramTable){
double transitionBLEU = 0;
if(rule != null){
int hypLength = rule.getEnglish().length-rule.getArity();
/**this statement is most time-consuming
**/
HashMap<String, Integer> hyperedgeNgramTable = ngramExtractor.getTransitionNgrams(rule, antNodes, startNgramOrder, endNgramOrder);
transitionBLEU = BLEU.computeLinearCorpusGain(linearCorpusGainThetas, hypLength, hyperedgeNgramTable, refNgramTable);
}else{
//note: hyperedges under goal item does not contribute BLEU, do nothing
}
return transitionBLEU;
}
private LineReader openOneFile(String file){
try{
return new LineReader(file);
}catch (IOException ioe) {
throw new UncheckedIOException(ioe);
}
}
/*
private void compareTbl(HashMap<String, Integer> tem1, HashMap<String, Integer> tem2){
if(tem1.size()==tem2.size()){
for(Map.Entry<String, Integer> entry : tem1.entrySet()){
String intNgram = entry.getKey();
String[] words = Regex.spaces.split(intNgram);
StringBuffer strNgram = new StringBuffer();
for(int i=0; i<words.length; i++){
String wrd = this.symbolTbl.getWord(new Integer(words[i]));
strNgram.append(wrd);
if(i<words.length-1)
strNgram.append(" ");
}
if(tem2.get(strNgram.toString())!=entry.getValue()){
System.out.println("different tbl");
System.out.println("tbl1" + entry.getValue());
System.out.println("tbl2" + tem2.get(entry.getKey()));
System.exit(0);
}
}
}else{
System.out.println("different size");
System.exit(0);
}
}
*/
private String[] convertToIntegerString(String[] strSentences){
String[] intSentences = new String[strSentences.length];
int j=0;
for(String str : strSentences){
String[] wrds = Regex.spaces.split(str);
int[] ids = this.symbolTbl.addTerminals(wrds);
StringBuffer intSent = new StringBuffer();;
for(int i=0; i<ids.length; i++){
intSent.append(ids[i]);
if(i<ids.length-1)
intSent.append(" ");
}
intSentences[j] = intSent.toString();
//System.out.println("str: " + strSentences[j]);
//System.out.println("int: " + intSentences[j]);
j++;
}
//convertToStrString(intSentences);
return intSentences;
}
/*
private String[] convertToStrString(String[] intSentences){
String[] strSentences = new String[intSentences.length];
int j=0;
for(String intSent : intSentences){
String[] wrds = Regex.spaces.split(intSent);
StringBuffer strSent= new StringBuffer();
for(int i=0; i<wrds.length; i++){
strSent.append(this.symbolTbl.getWord(new Integer(wrds[i])));
if(i<wrds.length-1)
strSent.append(" ");
}
strSentences[j] = strSent.toString();
System.out.println("str: " + strSentences[j]);
System.out.println("int: " + intSentences[j]);
j++;
}
return strSentences;
}
*/
}