Package org.encog.ml.hmm.train.bw

Source Code of org.encog.ml.hmm.train.bw.BaseBaumWelch

/*
* Encog(tm) Core v3.3 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2014 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.hmm.train.bw;

import java.util.Arrays;
import java.util.List;

import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ForwardBackwardCalculator;
import org.encog.ml.hmm.distributions.StateDistribution;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

/**
* This class provides the base implementation for Baum-Welch learning for
* HMM's. There are currently two implementations provided.
*
* TrainBaumWelch - Regular Baum Welch Learning.
*
* TrainBaumWelchScaled - Regular Baum Welch Learning, which can handle
* underflows in long sequences.
*
* L. E. Baum, T. Petrie, G. Soules, and N. Weiss,
* "A maximization technique occurring in the statistical analysis of probabilistic functions of Markov chains"
* , Ann. Math. Statist., vol. 41, no. 1, pp. 164-171, 1970.
*
* Hidden Markov Models and the Baum-Welch Algorithm, IEEE Information Theory
* Society Newsletter, Dec. 2003.
*
*/
public abstract class BaseBaumWelch implements MLTrain {
  private int iterations;
  private HiddenMarkovModel method;
  private final MLSequenceSet training;

  public BaseBaumWelch(final HiddenMarkovModel hmm,
      final MLSequenceSet training) {
    this.method = hmm;
    this.training = training;
  }

  @Override
  public void addStrategy(final Strategy strategy) {

  }

  @Override
  public boolean canContinue() {
    return false;
  }

  protected double[][] estimateGamma(final double[][][] xi,
      final ForwardBackwardCalculator fbc) {
    final double[][] gamma = new double[xi.length + 1][xi[0].length];

    for (int t = 0; t < (xi.length + 1); t++) {
      Arrays.fill(gamma[t], 0.);
    }

    for (int t = 0; t < xi.length; t++) {
      for (int i = 0; i < xi[0].length; i++) {
        for (int j = 0; j < xi[0].length; j++) {
          gamma[t][i] += xi[t][i][j];
        }
      }
    }

    for (int j = 0; j < xi[0].length; j++) {
      for (int i = 0; i < xi[0].length; i++) {
        gamma[xi.length][j] += xi[xi.length - 1][i][j];
      }
    }

    return gamma;
  }

  public abstract double[][][] estimateXi(MLDataSet sequence,
      ForwardBackwardCalculator fbc, HiddenMarkovModel hmm);

  @Override
  public void finishTraining() {

  }

  public abstract ForwardBackwardCalculator generateForwardBackwardCalculator(
      MLDataSet sequence, HiddenMarkovModel hmm);

  @Override
  public double getError() {
    return 0;
  }

  @Override
  public TrainingImplementationType getImplementationType() {
    return TrainingImplementationType.Iterative;
  }

  @Override
  public int getIteration() {
    return this.iterations;
  }

  @Override
  public MLMethod getMethod() {
    return this.method;
  }

  @Override
  public List<Strategy> getStrategies() {
    return null;
  }

  @Override
  public MLDataSet getTraining() {
    return this.training;
  }

  @Override
  public boolean isTrainingDone() {
    return false;
  }

  @Override
  public void iteration() {
    HiddenMarkovModel nhmm;
    try {
      nhmm = this.method.clone();
    } catch (final CloneNotSupportedException e) {
      throw new InternalError();
    }

    final double allGamma[][][] = new double[this.training
        .getSequenceCount()][][];
    final double aijNum[][] = new double[this.method.getStateCount()][this.method
        .getStateCount()];
    final double aijDen[] = new double[this.method.getStateCount()];

    Arrays.fill(aijDen, 0.0);
    for (int i = 0; i < this.method.getStateCount(); i++) {
      Arrays.fill(aijNum[i], 0.);
    }

    int g = 0;
    for (final MLDataSet obsSeq : this.training.getSequences()) {
      final ForwardBackwardCalculator fbc = generateForwardBackwardCalculator(
          obsSeq, this.method);

      final double xi[][][] = estimateXi(obsSeq, fbc, this.method);
      final double gamma[][] = allGamma[g++] = estimateGamma(xi, fbc);

      for (int i = 0; i < this.method.getStateCount(); i++) {
        for (int t = 0; t < (obsSeq.size() - 1); t++) {
          aijDen[i] += gamma[t][i];

          for (int j = 0; j < this.method.getStateCount(); j++) {
            aijNum[i][j] += xi[t][i][j];
          }
        }
      }
    }

    for (int i = 0; i < this.method.getStateCount(); i++) {
      if (aijDen[i] == 0.0) {
        for (int j = 0; j < this.method.getStateCount(); j++) {
          nhmm.setTransitionProbability(i, j,
              this.method.getTransitionProbability(i, j));
        }
      } else {
        for (int j = 0; j < this.method.getStateCount(); j++) {
          nhmm.setTransitionProbability(i, j, aijNum[i][j]
              / aijDen[i]);
        }
      }
    }

    /* compute pi */
    for (int i = 0; i < this.method.getStateCount(); i++) {
      nhmm.setPi(i, 0.);
    }

    for (int o = 0; o < this.training.getSequenceCount(); o++) {
      for (int i = 0; i < this.method.getStateCount(); i++) {
        nhmm.setPi(
            i,
            nhmm.getPi(i)
                + (allGamma[o][0][i] / this.training
                    .getSequenceCount()));
      }
    }

    /* compute pdfs */
    for (int i = 0; i < this.method.getStateCount(); i++) {

      final double[] weights = new double[this.training.size()];
      double sum = 0.;
      int j = 0;

      int o = 0;
      for (final MLDataSet obsSeq : this.training.getSequences()) {
        for (int t = 0; t < obsSeq.size(); t++, j++) {
          sum += weights[j] = allGamma[o][t][i];
        }
        o++;
      }

      for (j--; j >= 0; j--) {
        weights[j] /= sum;
      }

      final StateDistribution opdf = nhmm.getStateDistribution(i);
      opdf.fit(this.training, weights);
    }

    this.method = nhmm;
  }

  @Override
  public void iteration(final int count) {
    for (int i = 0; i < count; i++) {
      iteration();
    }
  }

  @Override
  public TrainingContinuation pause() {
    return null;
  }

  @Override
  public void resume(final TrainingContinuation state) {

  }

  @Override
  public void setError(final double error) {

  }

  @Override
  public void setIteration(final int iteration) {
    this.iterations = iteration;
  }
}
TOP

Related Classes of org.encog.ml.hmm.train.bw.BaseBaumWelch

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.