Package joshua.discriminative.monolingual_parser

Source Code of joshua.discriminative.monolingual_parser.EMDecoderFactory

package joshua.discriminative.monolingual_parser;

import java.io.IOException;
import java.util.ArrayList;


import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.ff.FeatureFunction;
import joshua.decoder.ff.tm.Grammar;
import joshua.decoder.ff.tm.GrammarFactory;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.ff.tm.RuleCollection;
import joshua.decoder.ff.tm.Trie;


/** This class should have a way to store the accumulated count, normalize the count,
*  and assign values to the paramemters
* */
public class EMDecoderFactory extends MonolingualDecoderFactory {
 
  private float normConstant = 0;//real semiring; the norm constant for LHS
  private double dataLogProb=0;
  private String  outGrammarFile;
 
 
 
  public EMDecoderFactory(GrammarFactory[] grammar_facories, boolean have_lm_model_, ArrayList<FeatureFunction> l_feat_functions,
      ArrayList<Integer> l_default_nonterminals_, SymbolTable symbolTable, String outGrammarFile_) {
    super(grammar_facories, have_lm_model_, l_feat_functions, l_default_nonterminals_, symbolTable);
    outGrammarFile = outGrammarFile_;   
  }

  @Override
  public MonolingualDecoderThread constructThread(int decoderID, String cur_test_file, int start_sent_id) throws IOException {
    MonolingualDecoderThread pdecoder = new EMDecoderThread(
        this,
        this.p_grammar_factories,
        this.have_lm_model,
        this.p_l_feat_functions,
        this.l_default_nonterminals,
        this.symbolTable,
        cur_test_file,
        start_sent_id
        );
    return pdecoder;
  }
 

  @Override
  public void mergeParallelDecodingResults() throws IOException {
    //do nothing
  }

  @Override
  public void postProcess() throws IOException {
    //== run M step here
    reEstmateGrammars();
    System.out.println("======== Data log prob is " + dataLogProb);
  }

 
  public void accumulateDataLogProb(double exampleLogProb){   
    //data likilihood is a product of each example, so it is a sum in log domain
    dataLogProb += exampleLogProb;   
  }
 
  public void incrementRulePosteriorProb(Rule rl, double prob){
    MonolingualGrammar.incrementRulePosteriorProb(rl, prob);
  }
 
 
  private void reEstmateGrammars(){
    for (GrammarFactory grammarFactory : this.p_grammar_factories) {
      Grammar bathGrammar = grammarFactory.getGrammarForSentence(null);
      accumulatePosteriorCountInGrammar(bathGrammar);
      normalizePosteriorCountInGrammar(bathGrammar);
     
      //TODO: this will *correctly* write the regular grammar, instead of the GLUE grammar, but we should avoid this
      ((MonolingualGrammar) bathGrammar).writeGrammarOnDisk(outGrammarFile, this.symbolTable);
     
      bathGrammar.sortGrammar(this.p_l_feat_functions);
    }
  }
 
 
  private void accumulatePosteriorCountInGrammar(Grammar grammar) {
    normConstant =0;
    accumulatePosteriorCountInTrie(grammar.getTrieRoot());
    System.out.println("normConstant for a grammar is " + normConstant);
  }
 
  private void accumulatePosteriorCountInTrie(Trie trie) {
    if(trie.hasRules()){
      RuleCollection rlCollection = trie.getRules();
      for(Rule rl : rlCollection.getSortedRules()){
        //TODO: LHS specific
        normConstant += MonolingualGrammar.getRulePosteriorProb(rl);
      }
    }
   
    if (trie.hasExtensions()) {
      Object[] tem =  trie.getExtensions().toArray();
      for (int i = 0; i < tem.length; i++) {
        accumulatePosteriorCountInTrie((Trie)tem[i]);
      }
    }
  }
 
 
  private void normalizePosteriorCountInGrammar(Grammar grammar) {
    normalizePosteriorCountInTrie(grammar, grammar.getTrieRoot());
  }
 
  private void normalizePosteriorCountInTrie(Grammar grammar, Trie trie) {
    RuleCollection rlCollection = trie.getRules();
    if(trie.hasRules()){
      for(Rule rl : rlCollection.getSortedRules()){
        //TODO: LHS specific
        double oldVal = MonolingualGrammar.getRuleNormalizedCost(rl);
       
        //==add-lambda smoothing
        float smoothingConstant = 0.1f;
        float prob = (MonolingualGrammar.getRulePosteriorProb(rl)+smoothingConstant)/(normConstant+smoothingConstant*grammar.getNumRules());       
        MonolingualGrammar.setRuleNormalizedCost(rl, prob);
       
        double newVal = MonolingualGrammar.getRuleNormalizedCost(rl);
        if( symbolTable.getWord(rl.getLHS()).compareTo("S")==0 ){
          System.out.println("count: "+ MonolingualGrammar.getRulePosteriorProb(rl) + "; norm: " + normConstant + "; old: " + oldVal + "; new: " + newVal);
          System.out.println(rl.toString(symbolTable));
        }
        MonolingualGrammar.resetRulePosteriorProb(rl);//reset to zero
      }
    }
   
    if (trie.hasExtensions()) {
      Object[] tem = trie.getExtensions().toArray();
      for (int i = 0; i < tem.length; i++) {
        normalizePosteriorCountInTrie(grammar, (Trie)tem[i]);
      }
    }
  }
}
TOP

Related Classes of joshua.discriminative.monolingual_parser.EMDecoderFactory

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.