Package edu.umd.hooka.alignment.hmm

Source Code of edu.umd.hooka.alignment.hmm.HMM

package edu.umd.hooka.alignment.hmm;

import java.io.IOException;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;

import edu.umd.hooka.Alignment;
import edu.umd.hooka.AlignmentPosteriorGrid;
import edu.umd.hooka.Array2D;
import edu.umd.hooka.PhrasePair;
import edu.umd.hooka.alignment.PartialCountContainer;
import edu.umd.hooka.alignment.PerplexityReporter;
import edu.umd.hooka.alignment.CrossEntropyCounters;
import edu.umd.hooka.alignment.ZeroProbabilityException;
import edu.umd.hooka.alignment.model1.Model1;
import edu.umd.hooka.ttables.TTable;

/**
* Represents an HMM that applies to a single sentence pair, which is
* derived from the parameters stored in a TTable and an ATable object.
*
* @author redpony
*
*/
public class HMM extends Model1 {
  public static final IntWritable ACOUNT_VOC_ID = new IntWritable(999999);
  static final int MAX_LENGTH = 500;
  static final float THRESH =0.5f;
 
  /**
   * (s,j) = p(f_j|e(s))
   */
  Array2D emission = new Array2D(MAX_LENGTH * MAX_LENGTH);

  /**
   * (s,j) = i s.t. e(s) = e_i or -1 if n.a.
   */
  IntArray2D e_coords = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
 
  /**
   * (s,j) = the english word corresponding to state s
   */
  IntArray2D e_words = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
 
  /**
   * (i',i) = p(i-i')
   */
  Array2D transition = new Array2D(MAX_LENGTH * MAX_LENGTH);
  IntArray2D transition_coords = new IntArray2D(MAX_LENGTH * MAX_LENGTH);

  Array2D alphas = new Array2D(MAX_LENGTH * MAX_LENGTH);
  Array2D betas  = new Array2D(MAX_LENGTH * MAX_LENGTH);

  Array2D viterbi = new Array2D(MAX_LENGTH * MAX_LENGTH);
  IntArray2D backtrace = new IntArray2D(MAX_LENGTH * MAX_LENGTH);
 
  ATable amodel;
  ATable acounts;
 
  int l = -1;
  int m = -1;
  AlignmentPosteriorGrid m1_post = null;
 
  public void setModel1Posteriors(AlignmentPosteriorGrid m1pg) {
    m1_post = m1pg;
  }

  protected HMM(TTable ttable, ATable atable, boolean useNull) {
    super(ttable, useNull);
    amodel = atable;
    acounts = (ATable)amodel.clone(); acounts.clear();   
  }
 
  public HMM(TTable ttable, ATable atable) {
    super(ttable, false);
    amodel = atable;
    acounts = (ATable)amodel.clone(); acounts.clear();
  }
 
  public void writePartialCounts(OutputCollector<IntWritable,PartialCountContainer> output) throws IOException
  {
    super.writePartialCounts(output);
    PartialCountContainer pcc = new PartialCountContainer();
    pcc.setContent(acounts);
    output.collect(ACOUNT_VOC_ID, pcc);
    acounts.clear();
  }
 
  public void buildHMMTables(PhrasePair pp) {
    int[] es = pp.getE().getWords();
    int[] fs = pp.getF().getWords();
    l = es.length;
    m = fs.length;
    emission.resize(m + 1, l + 1);
    e_coords.resize(m + 1, l + 1);
    e_words.resize(m + 1, l + 1);
    e_words.fill(-1);
    e_coords.fill(-1);
    for (int i = 1; i <= l; i++) {
      int ei = es[i-1];
      for (int j = 1; j <= m; j++) {
        int fj = fs[j-1];
        e_coords.set(j, i, i);
        emission.set(j, i, tmodel.get(ei, fj));
        e_words.set(j, i, i - 1);
      }
    }
    //System.out.println("b:\n"+emission);

    transition.resize(l+1, l+1);
    transition_coords.resize(l+1, l+1);
    transition_coords.fill(-1);
    for (int i_prev = 0; i_prev <= l; i_prev++) {
      for (int i = 1; i <= l; i++) {
        transition_coords.set(i_prev, i, amodel.getCoord(i - i_prev, (char)l));
        transition.set(i_prev, i, amodel.get(i - i_prev, (char)l));
      }
    }
   
    //System.out.println("a:\n"+transition);
  }
 
  public final int getNumStates() {
    return transition.getSize2();
  }
 
  public final float getTransitionProb(int s_prev, int s) {
    return transition.get(s_prev, s);
  }
 
  public final float getEmissionProb(int j, int s) {
    return emission.get(j, s);
  }
 
  public final void addPartialJumpCountsToATable(ATable ac) {
    ac.plusEquals(acounts);
  }

  @Override
  public void processTrainingInstance(PhrasePair pp, Reporter r) {
    if (pp.getE().size() >= amodel.getMaxDist()-1) return;
    if (pp.getF().size() >= amodel.getMaxDist()-1) return;
    if (pp.getE().size() == 0) return;
    if (pp.getF().size() == 0) return;

    this.buildHMMTables(pp);
    float totalLogProb = this.baumWelch(pp, null);
    if (r != null) {
      r.incrCounter(CrossEntropyCounters.LOGPROB, (long)(-totalLogProb));
      r.incrCounter(CrossEntropyCounters.WORDCOUNT, pp.getF().size());
    }
  }

  /**
   * @return negative log probability of sentence
   */
  public final float baumWelch(PhrasePair pp, AlignmentPosteriorGrid pg) {

    initializeCountTableForSentencePair(pp);

    int[] obs = pp.getF().getWords();
    int J = obs.length + 1;
    int numStates = getNumStates();
    int l = pp.getE().getWords().length;
    float[] anorms = new float[J];
    alphas.resize(J + 1, getNumStates());
    betas.resize(J + 1, getNumStates());
    alphas.set(0, 0, 1.0f); anorms[0]=1.0f;
    Alignment m1a = null;
    if (m1_post != null)
      m1a = m1_post.alignPosteriorThreshold(THRESH);
    for (int j = 1; j < J; j++) {
      //System.out.println("J="+j);
      for (int s = 0; s < numStates; s++) {
        float alpha = 0.0f;
        float m1boost = 1.0f;
        float m1penalty = 0.0f;
        boolean use_m1 = false;
        if (m1a != null && m1a.isFAligned(j-1)) {
          float m1post = 0.0f;
          use_m1 = true;
          for (int i=0; i<l; i++)
            if (m1a.aligned(j-1, i))
              m1post = m1_post.getAlignmentPointPosterior(j-1, i+1);
          //System.out.println(m1post);
          m1boost = (float)(Math.sqrt(m1post));
          m1penalty = 1.0f - m1boost;
        }
        for (int s_prev = 0; s_prev < numStates; s_prev++) {
          float trans = getTransitionProb(s_prev, s);
          if (use_m1) {
            if (s <= l && s > 0 && m1a.aligned(j-1, s-1))
              trans = m1boost;
            else
              trans *= m1penalty;
          }
          alpha += alphas.get(j - 1, s_prev) * trans;
        }
        alpha *= getEmissionProb(j, s);
        //System.out.println(" ep:" + hmm.getEmissionProb(s, j));
        alphas.set(j, s, alpha);
      }
      //anorms[j] = 1.0f;
      try {
        anorms[j] = alphas.normalizeColumn(j);
      } catch (ZeroProbabilityException ex) {
        this.notifyUnalignablePair(pp, ex.getMessage());
        return 0.0f;
      }
    }
    for (int s=1; s<numStates; s++)
      betas.set(J-1, s, 1.0f);
    for (int j=J-2; j>=1; j--) {
      //System.out.println("J="+j);
      for (int s = 0; s < numStates; s++) {
        float beta = 0.0f;
        float m1boost = 1.0f;
        float m1penalty = 0.0f;
        boolean use_m1 = false;
        if (m1a != null && m1a.isFAligned(j-1)) {
          float m1post = 0.0f;
          use_m1 = true;
          for (int i=0; i<l; i++)
            if (m1a.aligned(j-1, i))
              m1post = m1_post.getAlignmentPointPosterior(j-1, i+1);
          m1boost = (float)(Math.sqrt(m1post));
          m1penalty = 1.0f - m1boost;
        }
        for (int s_next = 0; s_next < numStates; s_next++) {
          //System.out.println("    s_next="+s_next + " b(j+1,s_next)="+ betas.get(j+1, s_next) + " * " +
          //    hmm.getTransitionProb(s, s_next) + " * " + hmm.getEmissionProb(s_next, j));
          float trans = getTransitionProb(s, s_next);
          if (use_m1) {
            if (s <= l && s > 0 && m1a.aligned(j-1, s-1))
              trans = m1boost;
            else
              trans *= m1penalty;
          }
          beta += betas.get(j+1, s_next) *
            trans *
            getEmissionProb(j+1, s_next);
        }
       
        beta /= anorms[j];
        //System.out.println("  s="+s+ "  b:"+beta);
        betas.set(j, s, beta);
      }
    }
   
    // PARTIAL COUNTS FOR EMMISSIONS (WORD TRANSLATION)
    float totalProb[] = new float[J];
    for (int j=1; j<J; j++) {
      float tp = 0.0f;
      for (int s = 0; s < numStates; s++) {
        tp += betas.get(j, s) * alphas.get(j, s);
      }
      // System.out.println("total prob(" + j + ")=" + tp);
      totalProb[j] = tp;
      for (int s = 0; s < numStates; s++) {
        // j=1 s=14
        int iplus1 = e_coords.get(j, s);
        if (iplus1 == -1) continue;
        float pc = betas.get(j, s) * alphas.get(j, s) / tp;
        if (pg != null) {
          int e = 0;
          if (s <= l)
            e = s;
          if (s != 0) {
            float p = pg.getAlignmentPointPosterior(j-1, e) + pc;
            pg.setAlignmentPointPosterior(j-1, e, p);
          }
        } else {
          try {
            addTranslationCount(iplus1, j-1, pc);
          } catch (Exception e) {
            throw new RuntimeException("J=" + J + ", numStates=" + numStates +": Failed to add (" +iplus1+","+(j-1)+") += " + pc + " s=" + s + " pp=" + pp + "\n E:\n"+ e_coords);
          }
        }
        //System.out.println("ec="+ec+" pc="+pc);
      }
    }
   
    // PARTIAL COUNTS FOR TRANSITIONS
    if (pg == null) {
      for (int j=1; j<J-1; j++) {
        for (int s_prev=0; s_prev < numStates; s_prev++) {
          for (int s=0; s < numStates; s++) {
            int tc = transition_coords.get(s_prev, s);
            if (tc == -1) continue;
            float m1boost = 1.0f;
            float m1penalty = 0.0f;
            boolean use_m1 = false;
            if (m1a != null && m1a.isFAligned(j-1)) {
              float m1post = 0.0f;
              use_m1 = true;
              for (int i=0; i<l; i++)
                if (m1a.aligned(j-1, i))
                  m1post = m1_post.getAlignmentPointPosterior(j-1, i+1);
              m1boost = (float)(Math.sqrt(m1post));
              m1penalty = 1.0f - m1boost;
            }
            float trans = getTransitionProb(s_prev, s);
            if (use_m1) {
              if (s <= l && s > 0 && m1a.aligned(j-1, s-1))
                trans = m1boost;
              else
                trans *= m1penalty;
            }
            // SKIPPING: REMOVE!!!
            if (use_m1) continue;
            float pc = alphas.get(j, s_prev)
              * trans
              * emission.get(j+1, s)
              / anorms[j+1]
              * betas.get(j+1, s)
              / totalProb[j+1];
            acounts.add(tc, (char)l, pc);
            //System.out.println("tc="+tc+"  pc="+pc);
          }
        }
      }
    }
   
    float tlp = 0.0f;
    for (float n : anorms)
      tlp += Math.log(n);
    return tlp;
    //System.out.println(acounts);
   
//    System.out.println(alphas + "\n" + betas);
  }

  @Override
  public AlignmentPosteriorGrid computeAlignmentPosteriors(PhrasePair pp) {
    AlignmentPosteriorGrid res = new AlignmentPosteriorGrid(pp);
    buildHMMTables(pp);
    baumWelch(pp, res);
    return res;
  }

  @Override
  public Alignment viterbiAlign(PhrasePair sentence,
      PerplexityReporter reporter) {
    this.buildHMMTables(sentence);
    Alignment res = new Alignment(sentence.getF().size(), sentence.getE().size());
    int J = sentence.getF().size() + 1;
    int numStates = getNumStates();
    viterbi.resize(J, getNumStates());
    backtrace.resize(J, getNumStates());
    viterbi.fill(Float.NEGATIVE_INFINITY);
    viterbi.set(0, 0, 0.0f);
    int lene = sentence.getE().getWords().length;
    Alignment m1a = null;
    if (m1_post != null)
      m1a = m1_post.alignPosteriorThreshold(THRESH);

    //System.out.println(emission);
    for (int j = 1; j < J; j++) {
      //System.out.println("J="+j);
      boolean valid = false;
      for (int s = 1; s < numStates; s++) {
        float best = Float.NEGATIVE_INFINITY;
        int best_s = -1;
        double emitLogProb = Math.log(emission.get(j, s));
        if (emitLogProb == Float.NEGATIVE_INFINITY) {
          //System.out.println("BAD STATE: " + j + " " + s);
          continue;
        }
        //System.out.println("j="+j + " s="+s+ "  ep"+emitLogProb);
        for (int s_prev = 0; s_prev < numStates; s_prev++) {
          float m1boost = 1.0f;
          float m1penalty = 0.0f;
          boolean use_m1 = false;
          if (m1a != null && m1a.isFAligned(j-1)) {
            float m1post = 0.0f;
            use_m1 = true;
            for (int i=0; i<lene; i++) {
              if (m1a.aligned(j-1, i))
                m1post = m1_post.getAlignmentPointPosterior(j-1, i+1);
            }
            m1boost = (float)Math.sqrt(m1post);
            m1penalty = 1.0f - m1boost;
          }
          float trans = getTransitionProb(s_prev, s);
          if (use_m1) {
            if (s <= l && s > 0 && m1a.aligned(j-1, s-1))
              trans = m1boost;
            else
              trans *= m1penalty;
          }
          float cur = (float)(viterbi.get(j - 1, s_prev) +
              Math.log(trans) +
              emitLogProb);
          //System.out.println(" s'="+s_prev + "  cur="+cur);
          if (cur > best) {
            best = cur;
            best_s = s_prev;
            //System.out.println("new best: " + s + " " + best_s);
          }
        }
        //System.out.println(" s_best="+best_s + "  cur="+best);
        viterbi.set(j, s, best);
        if (best != Float.NEGATIVE_INFINITY)
          valid = true;
        backtrace.set(j, s, best_s);
      }
      // if we don't know how to generate some column
      // create a uniform distribution over the states
      // and assume the previous state was the best
      if (!valid) {
        float best = Float.NEGATIVE_INFINITY;
        int bests = -1;
        for (int s = 1; s < numStates; s++) {
          if (viterbi.get(j-1, s) > best) {
            best = viterbi.get(j-1, s);
            bests = s;
          }
        }
        for (int s = 1; s < numStates; s++) {
          viterbi.set(j, s, 0.0f);
          backtrace.set(j, s, bests);
        }
      }
    }
    //System.out.println(viterbi);
    float best = Float.NEGATIVE_INFINITY;
    int best_s = -1;
    for (int s = 1; s < numStates; s++) {
      if (viterbi.get(J-1, s) > best) {
        best = viterbi.get(J-1,s);
        best_s = s;
      }
    }
    //System.out.println("vit: " + best + "j-1="+(J-1));
    reporter.addFactor(best, J - 1);
    //System.out.println(viterbi);
    int e = best_s;
    for (int f=J-1; f>0; f--) {
      if (e <= 0) {
        throw new ZeroProbabilityException("  Error f=" +f+" e="+e+
            "  sentence + \n" + viterbi + "\n" + emission + "\n" + transition + "\n" + backtrace);
      } else {
        if (viterbi.get(f, e) < 0.0) {
          // hack to avoid errors
          try {
            int af = f-1;
            int ae = e_words.get(f, e);
            if (ae >= 0)
              res.align(af, ae);
            //else
            //  System.err.println("ALIGN NULL TO " + af);
          } catch (RuntimeException ex) {
            throw new RuntimeException("Caught " + ex + "\nvit(f,e)="+viterbi.get(f,e)+"  size(f,e)=" + sentence.getF().size() +","+ sentence.getE().size() + " Error f=" +f+" e="+e+
                "  sentence + \n" + viterbi + "\n" + emission + "\n" + transition + "\n" + backtrace + "\n" + e_words);
          }
        }
        e = backtrace.get(f, e);
      }
    }
    return res;
  }
}
TOP

Related Classes of edu.umd.hooka.alignment.hmm.HMM

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.