Package joshua.decoder.ff.lm

Source Code of joshua.decoder.ff.lm.LanguageModelFF

/* This file is part of the Joshua Machine Translation System.
*
* Joshua is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1
* of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*/
package joshua.decoder.ff.lm;

import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;

import joshua.corpus.vocab.SymbolTable;
import joshua.decoder.chart_parser.SourcePath;
import joshua.decoder.ff.DefaultStatefulFF;
import joshua.decoder.ff.state_maintenance.DPState;
import joshua.decoder.ff.state_maintenance.NgramDPState;
import joshua.decoder.ff.tm.Rule;
import joshua.decoder.hypergraph.HGNode;


/**
* This class performs the following:
* <ol>
* <li> Gets the additional LM score due to combinations of small
*      items into larger ones by using rules
* <li> Gets the LM state
* <li> Gets the left-side LM state estimation score
* </ol>
*
*
* @author Zhifei Li, <zhifei.work@gmail.com>
* @version $LastChangedDate: 2010-01-14 19:15:28 -0600 (Thu, 14 Jan 2010) $
*/
public class LanguageModelFF extends DefaultStatefulFF {
 
  /** Logger for this class. */
  private static final Logger logger = Logger.getLogger(LanguageModelFF.class.getName());
 
  private final String START_SYM="<s>";
  private final int START_SYM_ID;
  private final String STOP_SYM="</s>";
  private final int STOP_SYM_ID;
 
 
  /* These must be static (for now) for LMGrammar, but they shouldn't be! in case of multiple LM features */
  static String BACKOFF_LEFT_LM_STATE_SYM="<lzfbo>";
  static public int BACKOFF_LEFT_LM_STATE_SYM_ID;//used for equivelant state
  static String NULL_RIGHT_LM_STATE_SYM="<lzfrnull>";
  static public int NULL_RIGHT_LM_STATE_SYM_ID;//used for equivelant state
 
  private final boolean addStartAndEndSymbol = true;
 
  /**
   * N-gram language model. We assume the language model is
   * in ARPA format for equivalent state:
   *
   * <ol>
   * <li>We assume it is a backoff lm, and high-order ngram
   *     implies low-order ngram; absense of low-order ngram
   *     implies high-order ngram</li>
   * <li>For a ngram, existence of backoffweight => existence
   *     a probability Two ways of dealing with low counts:
   *     <ul>
   *       <li>SRILM: don't multiply zeros in for unknown
   *           words</li>
   *       <li>Pharaoh: cap at a minimum score exp(-10),
   *           including unknown words</li>
   *     </ul>
   * </li>
   */
  private final NGramLanguageModel lmGrammar;
 
  /**
   * We always use this order of ngram, though the LMGrammar
   * may provide higher order probability.
   */
  private final int ngramOrder;// = 3;
  //boolean add_boundary=false; //this is needed unless the text already has <s> and </s>
 
  /** Symbol table that maps between Strings and integers. */
  private final SymbolTable symbolTable;
 
 
  /** stateID is any integer exept -1
   **/
  public LanguageModelFF(int stateID, int featID, int ngramOrder, SymbolTable psymbol, NGramLanguageModel lmGrammar, double weight) {
   
    super(stateID, weight, featID);
    this.ngramOrder = ngramOrder;
    this.lmGrammar  = lmGrammar;
    this.symbolTable = psymbol;
    this.START_SYM_ID = psymbol.addTerminal(START_SYM);
    this.STOP_SYM_ID = psymbol.addTerminal(STOP_SYM);
   
    LanguageModelFF.BACKOFF_LEFT_LM_STATE_SYM_ID = symbolTable.addTerminal(BACKOFF_LEFT_LM_STATE_SYM);
    LanguageModelFF.NULL_RIGHT_LM_STATE_SYM_ID = symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM);
   
    logger.info("LM feature, with an order=" + ngramOrder);
  }
 


  public double transitionLogP(Rule rule, List<HGNode> antNodes, int spanStart, int spanEnd, SourcePath srcPath, int sentID) {
    return computeTransition(rule.getEnglish(), antNodes);
  }

 
  public double finalTransitionLogP(HGNode antNode, int spanStart, int spanEnd, SourcePath srcPath, int sentID) {
    return computeFinalTransitionLogP((NgramDPState)antNode.getDPState(this.getStateID()));
  }
 
 

  /**will consider all the complete ngrams,
   * and all the incomplete-ngrams that will have sth fit into its left side*/
  public double estimateLogP(Rule rule, int sentID) {
    return estimateRuleLogProb(rule.getEnglish());
  }
 


  public double estimateFutureLogP(Rule rule, DPState curDPState, int sentID) {
    //TODO: do not consider <s> and </s>
    boolean addStart = false;
    boolean addEnd = false;
    
    return estimateStateLogProb((NgramDPState)curDPState, addStart, addEnd);
  }



 
  /**when calculate transition prob: when saw a <bo>, then need to add backoff weights, start from non-state words
   * */
  private double computeTransition(int[] enWords,  List<HGNode> antNodes) {
       
    List<Integer> currentNgram   = new ArrayList<Integer>();
    double             transitionLogP = 0.0;
   
    for (int c = 0; c < enWords.length; c++) {
      int curID = enWords[c];
      if (symbolTable.isNonterminal(curID)) {       
        int index = symbolTable.getTargetNonterminalIndex(curID);
     
        NgramDPState state = (NgramDPState) antNodes.get(index).getDPState(this.getStateID());
        List<Integer> leftContext = state.getLeftLMStateWords();
        List<Integer> rightContext = state.getRightLMStateWords();
        if (leftContext.size() != rightContext.size() ) {
          throw new RuntimeException("computeTransition: left and right contexts have unequal lengths");
        }
       
        //================ left context
        for (int i = 0; i < leftContext.size(); i++) {
          int t = leftContext.get(i);
          currentNgram.add(t);
         
          //always calculate logP for <bo>: additional backoff weight
          if (t == BACKOFF_LEFT_LM_STATE_SYM_ID) {
            int numAdditionalBackoffWeight = currentNgram.size() - (i+1);//number of non-state words
           
            //compute additional backoff weight
            transitionLogP  += this.lmGrammar.logProbOfBackoffState(currentNgram, currentNgram.size(), numAdditionalBackoffWeight);
           
            if (currentNgram.size() == this.ngramOrder) {
              currentNgram.remove(0);
            }
          } else if (currentNgram.size() == this.ngramOrder) {
            // compute the current word probablity, and remove it
            transitionLogP += this.lmGrammar.ngramLogProbability(currentNgram, this.ngramOrder);
           
            currentNgram.remove(0);
          }
         
        }
       
        //================  right context
        //note: left_state_org_wrds will never take words from right context because it is either duplicate or out of range
        //also, we will never score the right context probablity because they are either duplicate or partional ngram
        int tSize = currentNgram.size();
        for (int i = 0; i < rightContext.size(); i++) {
          // replace context
          currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i) );
        }
     
      } else {//terminal words
        currentNgram.add(curID);
        if (currentNgram.size() == this.ngramOrder) {
          // compute the current word probablity, and remove it
          transitionLogP += this.lmGrammar.ngramLogProbability(currentNgram, this.ngramOrder);
         
          currentNgram.remove(0);
        }
      }
    }
    //===== create tabl
   
    //===== get left euquiv state
    //double[] lmLeftCost = new double[2];
    //int[] equivLeftState = this.lmGrammar.leftEquivalentState(Support.subIntArray(leftLMStateWrds, 0, leftLMStateWrds.size()),  this.ngramOrder, lmLeftCost);
   
    //transitionCost += lmLeftCost[0];//add finalized cost for the left state words
    return transitionLogP;
  }

  private double computeFinalTransitionLogP(NgramDPState state) {
   
    double res = 0.0;
    List<Integer> currentNgram = new ArrayList<Integer>();
    List<Integer>   leftContext = state.getLeftLMStateWords();   
    List<Integer>   rightContext = state.getRightLMStateWords();
   
    if (leftContext.size() != rightContext.size()) {
      throw new RuntimeException(
        "LMModel.compute_equiv_state_final_transition: left and right contexts have unequal lengths");
    }
   
    //================ left context
    if (addStartAndEndSymbol)
      currentNgram.add(START_SYM_ID);
   
    for (int i = 0; i < leftContext.size(); i++) {
      int t = leftContext.get(i);
      currentNgram.add(t);
     
      if (t == BACKOFF_LEFT_LM_STATE_SYM_ID) {//calculate logP for <bo>: additional backoff weight
        int additionalBackoffWeight = currentNgram.size() - (i+1);
        //compute additional backoff weight
        //TOTO: may not work with the case that add_start_and_end_symbol=false
        res += this.lmGrammar.logProbOfBackoffState(
          currentNgram, currentNgram.size(), additionalBackoffWeight);
       
      } else { // partial ngram
        //compute the current word probablity
        if (currentNgram.size() >= 2) { // start from bigram
          res += this.lmGrammar.ngramLogProbability(
            currentNgram, currentNgram.size());
        }
      }
      if (currentNgram.size() == this.ngramOrder) {
        currentNgram.remove(0);
      }
    }
   
    //================ right context
    //switch context, we will never score the right context probablity because they are either duplicate or partional ngram
    if(addStartAndEndSymbol){
      int tSize = currentNgram.size();
      for (int i = 0; i < rightContext.size(); i++) {//replace context
        currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i));
      }
     
      currentNgram.add(STOP_SYM_ID);
      res += this.lmGrammar.ngramLogProbability(currentNgram, currentNgram.size());
    }
    return res;
  }

 
  /*in general: consider all the complete ngrams, and all the incomplete-ngrams that WILL have sth fit into its left side, so
  *if the left side of incomplete-ngrams is a ECLIPS, then ignore the incomplete-ngrams
  *if the left side of incomplete-ngrams is a Non-Terminal, then consider the incomplete-ngrams 
  *if the left side of incomplete-ngrams is boundary of a rule, then consider the incomplete-ngrams*/
  private double estimateRuleLogProb(int[] enWords) {
    double    estimate   = 0.0;
    boolean   considerIncompleteNgrams = true;
    List<Integer> words      = new ArrayList<Integer>();
    boolean   skipStart = (enWords[0] == START_SYM_ID);
   
    for (int c = 0; c < enWords.length; c++) {
      int curWrd = enWords[c];
      /*if (c_wrd == Symbol.ECLIPS_SYM_ID) {
        estimate += score_chunk(
          words, consider_incomplete_ngrams, skip_start);
        consider_incomplete_ngrams = false;
        //for the LM bonus function: this simply means the right state will not be considered at all because all the ngrams in right-context will be incomplete
        words.clear();
        skip_start = false;
      } else*/ if( symbolTable.isNonterminal(curWrd) ) {
        estimate += scoreChunkLogPwords, considerIncompleteNgrams, skipStart);
        considerIncompleteNgrams = true;
        words.clear();
        skipStart = false;
      } else {
        words.add(curWrd);
      }
    }
    estimate += scoreChunkLogP( words, considerIncompleteNgrams, skipStart );
    return estimate;
  }
 
 
  /**TODO:
   * This does not work when addStart == true or addEnd == true
   **/
  private double estimateStateLogProb(NgramDPState state, boolean addStart, boolean addEnd) {
   
    double res = 0.0;   
    List<Integer>   leftContext = state.getLeftLMStateWords();
   
    if (null != leftContext) {
      List<Integer> words = new ArrayList<Integer>();;
      if (addStart == true)
        words.add(START_SYM_ID);
      words.addAll(leftContext);
     
      boolean considerIncompleteNgrams = true;
      boolean skipStart = true;
      if (words.size() >0 && words.get(0) != START_SYM_ID) {
        skipStart = false;
      }
      res += scoreChunkLogP(words, considerIncompleteNgrams, skipStart);
    }
    /*if (add_start == true) {
      System.out.println("left context: " +Symbol.get_string(l_context) + ";prob "+res);
    }*/
    if (addEnd == true) {//only when add_end is true, we get a complete ngram, otherwise, all ngrams in r_state are incomplete and we should do nothing
      List<Integer>    rightContext = state.getRightLMStateWords();
      List<Integer> list = new ArrayList<Integer>(rightContext);
      list.add(STOP_SYM_ID);
      double tem = scoreChunkLogP(list, false, false);
      res += tem;
      //System.out.println("right context:"+ Symbol.get_string(r_context) + "; score: "  + tem);
    }
    return res;
  }
 


  private double scoreChunkLogP(List<Integer> words, boolean considerIncompleteNgrams, boolean skipStart) {
    if (words.size() <= 0) {
      return 0.0;
    } else {
      int startIndex;
      if (! considerIncompleteNgrams) {
        startIndex = this.ngramOrder;
      } else if (skipStart) {
        startIndex = 2;
      } else {
        startIndex = 1;
      }
     
      return this.lmGrammar.sentenceLogProbability(
        words, this.ngramOrder, startIndex);
    }
  }
 
}

TOP

Related Classes of joshua.decoder.ff.lm.LanguageModelFF

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.