Package cc.mallet.fst

Source Code of cc.mallet.fst.HMMTrainerByLikelihood

package cc.mallet.fst;

import java.util.logging.Logger;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;

public class HMMTrainerByLikelihood extends TransducerTrainer {
  private static Logger logger = MalletLogger
      .getLogger(HMMTrainerByLikelihood.class.getName());

  HMM hmm;
  InstanceList trainingSet, unlabeledSet;
  int iterationCount = 0;
  boolean converged = false;

  public HMMTrainerByLikelihood(HMM hmm) {
    this.hmm = hmm;
  }

  @Override
  public Transducer getTransducer() {
    return hmm;
  }

  @Override
  public int getIteration() {
    return iterationCount;
  }

  @Override
  public boolean isFinishedTraining() {
    return converged;
  }

  @Override
  public boolean train(InstanceList trainingSet, int numIterations) {
    return train(trainingSet, null, numIterations);
  }

  public boolean train(InstanceList trainingSet, InstanceList unlabeledSet,
      int numIterations) {
    if (hmm.emissionEstimator == null)
      hmm.reset();

    converged = false;
    double threshold = 0.001;
    double logLikelihood = Double.NEGATIVE_INFINITY, prevLogLikelihood;
    for (int iter = 0; iter < numIterations; iter++) {
      prevLogLikelihood = logLikelihood;
      logLikelihood = 0;
      for (Instance inst : trainingSet) {
        FeatureSequence input = (FeatureSequence) inst.getData();
        FeatureSequence output = (FeatureSequence) inst.getTarget();
        double obsLikelihood = new SumLatticeDefault(hmm, input,
            output, hmm.new Incrementor()).getTotalWeight();
        logLikelihood += obsLikelihood;
      }
      logger.info("getValue() (observed log-likelihood) = "
          + logLikelihood);

      if (unlabeledSet != null) {
        int numEx = 0;
        for (Instance inst : unlabeledSet) {
          numEx++;
          if (numEx % 100 == 0) {
            System.err.print(numEx + ". ");
            System.err.flush();
          }
          FeatureSequence input = (FeatureSequence) inst.getData();
          double hiddenLikelihood = new SumLatticeDefault(hmm, input,
              null, hmm.new Incrementor()).getTotalWeight();
          logLikelihood += hiddenLikelihood;
        }
        System.err.println();
      }
      logger.info("getValue() (log-likelihood) = " + logLikelihood);

      hmm.estimate();
      iterationCount++;
      logger.info("HMM finished one iteration of maximizer, i=" + iter);

      runEvaluators();

      if (Math.abs(logLikelihood - prevLogLikelihood) < threshold) {
        converged = true;
        logger.info("HMM training has converged, i=" + iter);
        break;
      }
    }

    return converged;
  }
}
TOP

Related Classes of cc.mallet.fst.HMMTrainerByLikelihood

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.