Package edu.umd.hooka.alignment.model1

Source Code of edu.umd.hooka.alignment.model1.Model1

package edu.umd.hooka.alignment.model1;

import java.util.Arrays;

import org.apache.hadoop.mapred.Reporter;

import edu.umd.hooka.Alignment;
import edu.umd.hooka.AlignmentPosteriorGrid;
import edu.umd.hooka.PhrasePair;
import edu.umd.hooka.alignment.CrossEntropyCounters;
import edu.umd.hooka.alignment.PerplexityReporter;
import edu.umd.hooka.ttables.TTable;

public class Model1 extends Model1Base {

  float[] totals = null;
  protected TTable tmodel = null;
 
  public Model1(TTable ttable, boolean useNullWord) {
    super(useNullWord); // include Null word
    tmodel = ttable;
  }

  public void clearModel() {
    tmodel = null;
    totals = null;
  }

  @Override
  public void processTrainingInstance(PhrasePair pp, Reporter reporter) {
    initializeCountTableForSentencePair(pp);
    int fw[] = pp.getF().getWords();
    int ew[] = pp.getE().getWords();
    if (totals == null) {
      totals = new float[maxF];
    } else {
      Arrays.fill(totals, 0.0f);
    }
    if (_includeEnglishNullWord) {
      // handle NULL
      for (int fj:fw) {  
        totals[fj] += tmodel.get(0, fj);
      }
    }
    // handle normal e_i's
    for (int ei:ew) {
      for (int fj:fw) {  
        totals[fj] += tmodel.get(ei, fj);
      }
    }
   
    float totalLogProb = 0.0f;
    for (int j=0; j<fw.length; j++) {
      int fj = fw[j];
      float totalProb = totals[fj];
      totalLogProb += Math.log(totalProb);
      for (int i=0; i<=ew.length; i++) {
        int ei = (i > 0) ? ew[i-1] : 0;
        addTranslationCount(i, j, tmodel.get(ei, fj) / totalProb);
      }
    }
    if (reporter != null) {
      totalLogProb -= ((float)fw.length) * Math.log(1.0f + (float)ew.length);
      reporter.incrCounter(CrossEntropyCounters.LOGPROB, (long)(-totalLogProb));
      reporter.incrCounter(CrossEntropyCounters.WORDCOUNT, fw.length);
      reporter.progress();
    }
  }
 
  public AlignmentPosteriorGrid computeAlignmentPosteriors(PhrasePair pp) {
    AlignmentPosteriorGrid res = new AlignmentPosteriorGrid(pp);
    int fw[] = pp.getF().getWords();
    int ew[] = pp.getE().getWords();
    if (totals == null) {
      totals = new float[maxF];
    } else {
      Arrays.fill(totals, 0.0f);
    }
    if (_includeEnglishNullWord) {
      // handle NULL
      for (int fj:fw) {  
        totals[fj] += tmodel.get(0, fj);
      }
    }
    // handle normal e_i's
    for (int ei:ew) {
      for (int fj:fw) {  
        totals[fj] += tmodel.get(ei, fj);
      }
    }
   
    float totalLogProb = 0.0f;
    for (int j=0; j<fw.length; j++) {
      int fj = fw[j];
      float totalProb = totals[fj];
      totalLogProb += Math.log(totalProb);
      int start = 1;
      if (_includeEnglishNullWord)
        start = 0;
      for (int i=start; i<=ew.length; i++) {
        int ei = (i > 0) ? ew[i-1] : 0;
        float post = tmodel.get(ei, fj) / totalProb;
        res.setAlignmentPointPosterior(j, i, post);
      }
    }
    return res;
  }

  @Override
  public Alignment viterbiAlign(PhrasePair sentence, PerplexityReporter viterbiPerp) {
    int[] es = sentence.getE().getWords();
    int[] fs = sentence.getF().getWords();
    float threshold = 0.27f;
    Alignment res = new Alignment(fs.length, es.length);
    AlignmentPosteriorGrid g = computeAlignmentPosteriors(sentence);
    for (int j=0; j<fs.length; j++) {
      for (int i=1; i<es.length; i++) {
        float post = g.getAlignmentPointPosterior(j, i);
        if (post > threshold) { res.align(j, i-1); }
      }
    }
    return res;
  }

  public Alignment realViterbiAlign(PhrasePair sentence, PerplexityReporter viterbiPerp) {
    int[] es = sentence.getE().getWords();
    int[] fs = sentence.getF().getWords();
    Alignment res = new Alignment(fs.length, es.length);
    float viterbiScore = 0.0f;
    for (int j=0; j<fs.length; j++) {
      float bestProb = -1.0f;
      int besti = -1;
      int starti = 0; // TODO - should use NULL?
      if (_includeEnglishNullWord) starti = -1;
      for (int i=starti; i<es.length; i++) {
        float curProb = 0.0f;
        if (i == -1)
          curProb = tmodel.get(0, fs[j]);
        else
          curProb = tmodel.get(es[i], fs[j]);
        if (curProb > bestProb) { bestProb = curProb; besti = i; }
      }
      if (besti < 0) {
        ; //throw new RuntimeException("Implement or ignore!");
      } else {
        res.align(j, besti);
      }
      viterbiScore += Math.log(bestProb);
    }
    viterbiPerp.addFactor(viterbiScore - (fs.length * Math.log(es.length + 1.0)), fs.length);
    return res;
  }

}
TOP

Related Classes of edu.umd.hooka.alignment.model1.Model1

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.