/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a>
@author Andrew McCallum <a href="mailto:mccallum@cs.umass.edu">mccallum@cs.umass.edu</a>
*/
package cc.mallet.fst;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Multinomial;
import cc.mallet.types.Sequence;
import cc.mallet.pipe.Pipe;
import cc.mallet.util.MalletLogger;
/** A Hidden Markov Model. */
public class HMM extends Transducer implements Serializable {
private static Logger logger = MalletLogger.getLogger(HMM.class.getName());
static final String LABEL_SEPARATOR = ",";
Alphabet inputAlphabet;
Alphabet outputAlphabet;
ArrayList<State> states = new ArrayList<State>();
ArrayList<State> initialStates = new ArrayList<State>();
HashMap<String, State> name2state = new HashMap<String, State>();
Multinomial.Estimator[] transitionEstimator;
Multinomial.Estimator[] emissionEstimator;
Multinomial.Estimator initialEstimator;
Multinomial[] transitionMultinomial;
Multinomial[] emissionMultinomial;
Multinomial initialMultinomial;
public HMM(Pipe inputPipe, Pipe outputPipe) {
this.inputPipe = inputPipe;
this.outputPipe = outputPipe;
this.inputAlphabet = inputPipe.getDataAlphabet();
this.outputAlphabet = inputPipe.getTargetAlphabet();
}
public HMM(Alphabet inputAlphabet, Alphabet outputAlphabet) {
inputAlphabet.stopGrowth();
logger.info("HMM input dictionary size = " + inputAlphabet.size());
this.inputAlphabet = inputAlphabet;
this.outputAlphabet = outputAlphabet;
}
public Alphabet getInputAlphabet() {
return inputAlphabet;
}
public Alphabet getOutputAlphabet() {
return outputAlphabet;
}
public void print() {
StringBuffer sb = new StringBuffer();
for (int i = 0; i < numStates(); i++) {
State s = (State) getState(i);
sb.append("STATE NAME=\"");
sb.append(s.name);
sb.append("\" (");
sb.append(s.destinations.length);
sb.append(" outgoing transitions)\n");
sb.append(" ");
sb.append("initialWeight= ");
sb.append(s.initialWeight);
sb.append('\n');
sb.append(" ");
sb.append("finalWeight= ");
sb.append(s.finalWeight);
sb.append('\n');
sb.append("Emission distribution:\n" + emissionMultinomial[i]
+ "\n\n");
sb.append("Transition distribution:\n"
+ transitionMultinomial[i].toString());
}
System.out.println(sb.toString());
}
public void addState(String name, double initialWeight, double finalWeight,
String[] destinationNames, String[] labelNames) {
assert (labelNames.length == destinationNames.length);
if (name2state.get(name) != null)
throw new IllegalArgumentException("State with name `" + name
+ "' already exists.");
State s = new State(name, states.size(), initialWeight, finalWeight,
destinationNames, labelNames, this);
s.print();
states.add(s);
if (initialWeight > IMPOSSIBLE_WEIGHT)
initialStates.add(s);
name2state.put(name, s);
}
/**
* Add a state with parameters equal zero, and labels on out-going arcs the
* same name as their destination state names.
*/
public void addState(String name, String[] destinationNames) {
this.addState(name, 0, 0, destinationNames, destinationNames);
}
/**
* Add a group of states that are fully connected with each other, with
* parameters equal zero, and labels on their out-going arcs the same name
* as their destination state names.
*/
public void addFullyConnectedStates(String[] stateNames) {
for (int i = 0; i < stateNames.length; i++)
addState(stateNames[i], stateNames);
}
public void addFullyConnectedStatesForLabels() {
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
labels[i] = (String) outputAlphabet.lookupObject(i);
}
addFullyConnectedStates(labels);
}
private boolean[][] labelConnectionsIn(InstanceList trainingSet) {
int numLabels = outputAlphabet.size();
boolean[][] connections = new boolean[numLabels][numLabels];
for (Instance instance : trainingSet) {
FeatureSequence output = (FeatureSequence) instance.getTarget();
for (int j = 1; j < output.size(); j++) {
int sourceIndex = outputAlphabet.lookupIndex(output.get(j - 1));
int destIndex = outputAlphabet.lookupIndex(output.get(j));
assert (sourceIndex >= 0 && destIndex >= 0);
connections[sourceIndex][destIndex] = true;
}
}
return connections;
}
/**
* Add states to create a first-order Markov model on labels, adding only
* those transitions the occur in the given trainingSet.
*/
public void addStatesForLabelsConnectedAsIn(InstanceList trainingSet) {
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn(trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
numDestinations++;
String[] destinationNames = new String[numDestinations];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
destinationNames[destinationIndex++] = (String) outputAlphabet
.lookupObject(j);
addState((String) outputAlphabet.lookupObject(i), destinationNames);
}
}
/**
* Add as many states as there are labels, but don't create separate weights
* for each source-destination pair of states. Instead have all the incoming
* transitions to a state share the same weights.
*/
public void addStatesForHalfLabelsConnectedAsIn(InstanceList trainingSet) {
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn(trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
numDestinations++;
String[] destinationNames = new String[numDestinations];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
destinationNames[destinationIndex++] = (String) outputAlphabet
.lookupObject(j);
addState((String) outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames);
}
}
/**
* Add as many states as there are labels, but don't create separate
* observational-test-weights for each source-destination pair of
* states---instead have all the incoming transitions to a state share the
* same observational-feature-test weights. However, do create separate
* default feature for each transition, (which acts as an HMM-style
* transition probability).
*/
public void addStatesForThreeQuarterLabelsConnectedAsIn(
InstanceList trainingSet) {
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn(trainingSet);
for (int i = 0; i < numLabels; i++) {
int numDestinations = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j])
numDestinations++;
String[] destinationNames = new String[numDestinations];
int destinationIndex = 0;
for (int j = 0; j < numLabels; j++)
if (connections[i][j]) {
String labelName = (String) outputAlphabet.lookupObject(j);
destinationNames[destinationIndex] = labelName;
// The "transition" weights will include only the default
// feature
// gsc: variable is never used
// String wn = (String)outputAlphabet.lookupObject(i) + "->"
// + (String)outputAlphabet.lookupObject(j);
destinationIndex++;
}
addState((String) outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames);
}
}
public void addFullyConnectedStatesForThreeQuarterLabels(
InstanceList trainingSet) {
int numLabels = outputAlphabet.size();
for (int i = 0; i < numLabels; i++) {
String[] destinationNames = new String[numLabels];
for (int j = 0; j < numLabels; j++) {
String labelName = (String) outputAlphabet.lookupObject(j);
destinationNames[j] = labelName;
}
addState((String) outputAlphabet.lookupObject(i), 0.0, 0.0,
destinationNames, destinationNames);
}
}
public void addFullyConnectedStatesForBiLabels() {
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
labels[i] = outputAlphabet.lookupObject(i).toString();
}
for (int i = 0; i < labels.length; i++) {
for (int j = 0; j < labels.length; j++) {
String[] destinationNames = new String[labels.length];
for (int k = 0; k < labels.length; k++)
destinationNames[k] = labels[j] + LABEL_SEPARATOR
+ labels[k];
addState(labels[i] + LABEL_SEPARATOR + labels[j], 0.0, 0.0,
destinationNames, labels);
}
}
}
/**
* Add states to create a second-order Markov model on labels, adding only
* those transitions the occur in the given trainingSet.
*/
public void addStatesForBiLabelsConnectedAsIn(InstanceList trainingSet) {
int numLabels = outputAlphabet.size();
boolean[][] connections = labelConnectionsIn(trainingSet);
for (int i = 0; i < numLabels; i++) {
for (int j = 0; j < numLabels; j++) {
if (!connections[i][j])
continue;
int numDestinations = 0;
for (int k = 0; k < numLabels; k++)
if (connections[j][k])
numDestinations++;
String[] destinationNames = new String[numDestinations];
String[] labels = new String[numDestinations];
int destinationIndex = 0;
for (int k = 0; k < numLabels; k++)
if (connections[j][k]) {
destinationNames[destinationIndex] = (String) outputAlphabet
.lookupObject(j)
+ LABEL_SEPARATOR
+ (String) outputAlphabet.lookupObject(k);
labels[destinationIndex] = (String) outputAlphabet
.lookupObject(k);
destinationIndex++;
}
addState((String) outputAlphabet.lookupObject(i)
+ LABEL_SEPARATOR
+ (String) outputAlphabet.lookupObject(j), 0.0, 0.0,
destinationNames, labels);
}
}
}
public void addFullyConnectedStatesForTriLabels() {
String[] labels = new String[outputAlphabet.size()];
// This is assuming the the entries in the outputAlphabet are Strings!
for (int i = 0; i < outputAlphabet.size(); i++) {
logger.info("HMM: outputAlphabet.lookup class = "
+ outputAlphabet.lookupObject(i).getClass().getName());
labels[i] = outputAlphabet.lookupObject(i).toString();
}
for (int i = 0; i < labels.length; i++) {
for (int j = 0; j < labels.length; j++) {
for (int k = 0; k < labels.length; k++) {
String[] destinationNames = new String[labels.length];
for (int l = 0; l < labels.length; l++)
destinationNames[l] = labels[j] + LABEL_SEPARATOR
+ labels[k] + LABEL_SEPARATOR + labels[l];
addState(labels[i] + LABEL_SEPARATOR + labels[j]
+ LABEL_SEPARATOR + labels[k], 0.0, 0.0,
destinationNames, labels);
}
}
}
}
public void addSelfTransitioningStateForAllLabels(String name) {
String[] labels = new String[outputAlphabet.size()];
String[] destinationNames = new String[outputAlphabet.size()];
for (int i = 0; i < outputAlphabet.size(); i++) {
labels[i] = outputAlphabet.lookupObject(i).toString();
destinationNames[i] = name;
}
addState(name, 0.0, 0.0, destinationNames, labels);
}
private String concatLabels(String[] labels) {
String sep = "";
StringBuffer buf = new StringBuffer();
for (int i = 0; i < labels.length; i++) {
buf.append(sep).append(labels[i]);
sep = LABEL_SEPARATOR;
}
return buf.toString();
}
private String nextKGram(String[] history, int k, String next) {
String sep = "";
StringBuffer buf = new StringBuffer();
int start = history.length + 1 - k;
for (int i = start; i < history.length; i++) {
buf.append(sep).append(history[i]);
sep = LABEL_SEPARATOR;
}
buf.append(sep).append(next);
return buf.toString();
}
private boolean allowedTransition(String prev, String curr, Pattern no,
Pattern yes) {
String pair = concatLabels(new String[] { prev, curr });
if (no != null && no.matcher(pair).matches())
return false;
if (yes != null && !yes.matcher(pair).matches())
return false;
return true;
}
private boolean allowedHistory(String[] history, Pattern no, Pattern yes) {
for (int i = 1; i < history.length; i++)
if (!allowedTransition(history[i - 1], history[i], no, yes))
return false;
return true;
}
/**
* Assumes that the HMM's output alphabet contains <code>String</code>s.
* Creates an order-<em>n</em> HMM with input predicates and output labels
* given by <code>trainingSet</code> and order, connectivity, and weights
* given by the remaining arguments.
*
* @param trainingSet
* the training instances
* @param orders
* an array of increasing non-negative numbers giving the orders
* of the features for this HMM. The largest number <em>n</em> is
* the Markov order of the HMM. States are <em>n</em>-tuples of
* output labels. Each of the other numbers <em>k</em> in
* <code>orders</code> represents a weight set shared by all
* destination states whose last (most recent) <em>k</em> labels
* agree. If <code>orders</code> is <code>null</code>, an order-0
* HMM is built.
* @param defaults
* If non-null, it must be the same length as <code>orders</code>
* , with <code>true</code> positions indicating that the weight
* set for the corresponding order contains only the weight for a
* default feature; otherwise, the weight set has weights for all
* features built from input predicates.
* @param start
* The label that represents the context of the start of a
* sequence. It may be also used for sequence labels.
* @param forbidden
* If non-null, specifies what pairs of successive labels are not
* allowed, both for constructing <em>n</em>order states or for
* transitions. A label pair (<em>u</em>,<em>v</em>) is not
* allowed if <em>u</em> + "," + <em>v</em> matches
* <code>forbidden</code>.
* @param allowed
* If non-null, specifies what pairs of successive labels are
* allowed, both for constructing <em>n</em>order states or for
* transitions. A label pair (<em>u</em>,<em>v</em>) is allowed
* only if <em>u</em> + "," + <em>v</em> matches
* <code>allowed</code>.
* @param fullyConnected
* Whether to include all allowed transitions, even those not
* occurring in <code>trainingSet</code>,
* @returns The name of the start state.
*
*/
public String addOrderNStates(InstanceList trainingSet, int[] orders,
boolean[] defaults, String start, Pattern forbidden,
Pattern allowed, boolean fullyConnected) {
boolean[][] connections = null;
if (!fullyConnected)
connections = labelConnectionsIn(trainingSet);
int order = -1;
if (defaults != null && defaults.length != orders.length)
throw new IllegalArgumentException(
"Defaults must be null or match orders");
if (orders == null)
order = 0;
else {
for (int i = 0; i < orders.length; i++) {
if (orders[i] <= order)
throw new IllegalArgumentException(
"Orders must be non-negative and in ascending order");
order = orders[i];
}
if (order < 0)
order = 0;
}
if (order > 0) {
int[] historyIndexes = new int[order];
String[] history = new String[order];
String label0 = (String) outputAlphabet.lookupObject(0);
for (int i = 0; i < order; i++)
history[i] = label0;
int numLabels = outputAlphabet.size();
while (historyIndexes[0] < numLabels) {
logger.info("Preparing " + concatLabels(history));
if (allowedHistory(history, forbidden, allowed)) {
String stateName = concatLabels(history);
int nt = 0;
String[] destNames = new String[numLabels];
String[] labelNames = new String[numLabels];
for (int nextIndex = 0; nextIndex < numLabels; nextIndex++) {
String next = (String) outputAlphabet
.lookupObject(nextIndex);
if (allowedTransition(history[order - 1], next,
forbidden, allowed)
&& (fullyConnected || connections[historyIndexes[order - 1]][nextIndex])) {
destNames[nt] = nextKGram(history, order, next);
labelNames[nt] = next;
nt++;
}
}
if (nt < numLabels) {
String[] newDestNames = new String[nt];
String[] newLabelNames = new String[nt];
for (int t = 0; t < nt; t++) {
newDestNames[t] = destNames[t];
newLabelNames[t] = labelNames[t];
}
destNames = newDestNames;
labelNames = newLabelNames;
}
addState(stateName, 0.0, 0.0, destNames, labelNames);
}
for (int o = order - 1; o >= 0; o--)
if (++historyIndexes[o] < numLabels) {
history[o] = (String) outputAlphabet
.lookupObject(historyIndexes[o]);
break;
} else if (o > 0) {
historyIndexes[o] = 0;
history[o] = label0;
}
}
for (int i = 0; i < order; i++)
history[i] = start;
return concatLabels(history);
}
String[] stateNames = new String[outputAlphabet.size()];
for (int s = 0; s < outputAlphabet.size(); s++)
stateNames[s] = (String) outputAlphabet.lookupObject(s);
for (int s = 0; s < outputAlphabet.size(); s++)
addState(stateNames[s], 0.0, 0.0, stateNames, stateNames);
return start;
}
public State getState(String name) {
return (State) name2state.get(name);
}
public int numStates() {
return states.size();
}
public Transducer.State getState(int index) {
return (Transducer.State) states.get(index);
}
public Iterator initialStateIterator() {
return initialStates.iterator();
}
public boolean isTrainable() {
return true;
}
private Alphabet getTransitionAlphabet() {
Alphabet transitionAlphabet = new Alphabet();
for (int i = 0; i < numStates(); i++)
transitionAlphabet.lookupIndex(getState(i).getName(), true);
return transitionAlphabet;
}
@Deprecated
public void reset() {
emissionEstimator = new Multinomial.LaplaceEstimator[numStates()];
transitionEstimator = new Multinomial.LaplaceEstimator[numStates()];
emissionMultinomial = new Multinomial[numStates()];
transitionMultinomial = new Multinomial[numStates()];
Alphabet transitionAlphabet = getTransitionAlphabet();
for (int i = 0; i < numStates(); i++) {
emissionEstimator[i] = new Multinomial.LaplaceEstimator(
inputAlphabet);
transitionEstimator[i] = new Multinomial.LaplaceEstimator(
transitionAlphabet);
emissionMultinomial[i] = new Multinomial(
getUniformArray(inputAlphabet.size()), inputAlphabet);
transitionMultinomial[i] = new Multinomial(
getUniformArray(transitionAlphabet.size()),
transitionAlphabet);
}
initialMultinomial = new Multinomial(getUniformArray(transitionAlphabet
.size()), transitionAlphabet);
initialEstimator = new Multinomial.LaplaceEstimator(transitionAlphabet);
}
/**
* Separate initialization of initial/transitions and emissions. All
* probabilities are proportional to (1+Uniform[0,1])^noise.
*
* @author kedarb
* @param random
* Random object (if null use uniform distribution)
* @param noise
* Noise exponent to use. If zero, then uniform distribution.
*/
public void initTransitions(Random random, double noise) {
Alphabet transitionAlphabet = getTransitionAlphabet();
initialMultinomial = new Multinomial(getRandomArray(transitionAlphabet
.size(), random, noise), transitionAlphabet);
initialEstimator = new Multinomial.LaplaceEstimator(transitionAlphabet);
transitionMultinomial = new Multinomial[numStates()];
transitionEstimator = new Multinomial.LaplaceEstimator[numStates()];
for (int i = 0; i < numStates(); i++) {
transitionMultinomial[i] = new Multinomial(getRandomArray(
transitionAlphabet.size(), random, noise),
transitionAlphabet);
transitionEstimator[i] = new Multinomial.LaplaceEstimator(
transitionAlphabet);
// set state's initial weight
State s = (State) getState(i);
s.setInitialWeight(initialMultinomial.logProbability(s.getName()));
}
}
public void initEmissions(Random random, double noise) {
emissionMultinomial = new Multinomial[numStates()];
emissionEstimator = new Multinomial.LaplaceEstimator[numStates()];
for (int i = 0; i < numStates(); i++) {
emissionMultinomial[i] = new Multinomial(getRandomArray(
inputAlphabet.size(), random, noise), inputAlphabet);
emissionEstimator[i] = new Multinomial.LaplaceEstimator(
inputAlphabet);
}
}
public void estimate() {
Alphabet transitionAlphabet = getTransitionAlphabet();
initialMultinomial = initialEstimator.estimate();
initialEstimator = new Multinomial.LaplaceEstimator(transitionAlphabet);
for (int i = 0; i < numStates(); i++) {
State s = (State) getState(i);
emissionMultinomial[i] = emissionEstimator[i].estimate();
transitionMultinomial[i] = transitionEstimator[i].estimate();
s.setInitialWeight(initialMultinomial.logProbability(s.getName()));
// reset estimators
emissionEstimator[i] = new Multinomial.LaplaceEstimator(
inputAlphabet);
transitionEstimator[i] = new Multinomial.LaplaceEstimator(
transitionAlphabet);
}
}
/**
* Trains a HMM without validation and evaluation.
*/
public boolean train(InstanceList ilist) {
return train(ilist, (InstanceList) null, (InstanceList) null);
}
/**
* Trains a HMM with <tt>evaluator</tt> set to null.
*/
public boolean train(InstanceList ilist, InstanceList validation,
InstanceList testing) {
return train(ilist, validation, testing, (TransducerEvaluator) null);
}
public boolean train(InstanceList ilist, InstanceList validation,
InstanceList testing, TransducerEvaluator eval) {
assert (ilist.size() > 0);
if (emissionEstimator == null) {
emissionEstimator = new Multinomial.LaplaceEstimator[numStates()];
transitionEstimator = new Multinomial.LaplaceEstimator[numStates()];
emissionMultinomial = new Multinomial[numStates()];
transitionMultinomial = new Multinomial[numStates()];
Alphabet transitionAlphabet = new Alphabet();
for (int i = 0; i < numStates(); i++)
transitionAlphabet.lookupIndex(((State) states.get(i))
.getName(), true);
for (int i = 0; i < numStates(); i++) {
emissionEstimator[i] = new Multinomial.LaplaceEstimator(
inputAlphabet);
transitionEstimator[i] = new Multinomial.LaplaceEstimator(
transitionAlphabet);
emissionMultinomial[i] = new Multinomial(
getUniformArray(inputAlphabet.size()), inputAlphabet);
transitionMultinomial[i] = new Multinomial(
getUniformArray(transitionAlphabet.size()),
transitionAlphabet);
}
initialEstimator = new Multinomial.LaplaceEstimator(
transitionAlphabet);
}
for (Instance instance : ilist) {
FeatureSequence input = (FeatureSequence) instance.getData();
FeatureSequence output = (FeatureSequence) instance.getTarget();
new SumLatticeDefault(this, input, output, new Incrementor());
}
initialMultinomial = initialEstimator.estimate();
for (int i = 0; i < numStates(); i++) {
emissionMultinomial[i] = emissionEstimator[i].estimate();
transitionMultinomial[i] = transitionEstimator[i].estimate();
getState(i).setInitialWeight(
initialMultinomial.logProbability(getState(i).getName()));
}
return true;
}
public class Incrementor implements Transducer.Incrementor {
public void incrementFinalState(Transducer.State s, double count) {
}
public void incrementInitialState(Transducer.State s, double count) {
initialEstimator.increment(s.getName(), count);
}
public void incrementTransition(Transducer.TransitionIterator ti,
double count) {
int inputFtr = (Integer) ti.getInput();
State src = (HMM.State) ((TransitionIterator) ti).getSourceState();
State dest = (HMM.State) ((TransitionIterator) ti)
.getDestinationState();
int index = ti.getIndex();
emissionEstimator[index].increment(inputFtr, count);
transitionEstimator[src.getIndex()]
.increment(dest.getName(), count);
}
}
public class WeightedIncrementor implements Transducer.Incrementor {
double weight = 1.0;
public WeightedIncrementor(double wt) {
this.weight = wt;
}
public void incrementFinalState(Transducer.State s, double count) {
}
public void incrementInitialState(Transducer.State s, double count) {
initialEstimator.increment(s.getName(), weight * count);
}
public void incrementTransition(Transducer.TransitionIterator ti,
double count) {
int inputFtr = (Integer) ti.getInput();
State src = (HMM.State) ((TransitionIterator) ti).getSourceState();
State dest = (HMM.State) ((TransitionIterator) ti)
.getDestinationState();
int index = ti.getIndex();
emissionEstimator[index].increment(inputFtr, weight * count);
transitionEstimator[src.getIndex()].increment(dest.getName(),
weight * count);
}
}
public void write(File f) {
try {
ObjectOutputStream oos = new ObjectOutputStream(
new FileOutputStream(f));
oos.writeObject(this);
oos.close();
} catch (IOException e) {
System.err.println("Exception writing file " + f + ": " + e);
}
}
private double[] getUniformArray(int size) {
double[] ret = new double[size];
for (int i = 0; i < size; i++)
// gsc: removing unnecessary cast from 'size'
ret[i] = 1.0 / size;
return ret;
}
// kedarb: p[i] = (1+random)^noise/sum
private double[] getRandomArray(int size, Random random, double noise) {
double[] ret = new double[size];
double sum = 0;
for (int i = 0; i < size; i++) {
ret[i] = random == null ? 1.0 : Math.pow(1.0 + random.nextDouble(),
noise);
sum += ret[i];
}
for (int i = 0; i < size; i++)
ret[i] /= sum;
return ret;
}
// Serialization
// For HMM class
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 1;
static final int NULL_INTEGER = -1;
/* Need to check for null pointers. */
/* Bug fix from Cheng-Ju Kuo cju.kuo@gmail.com */
private void writeObject(ObjectOutputStream out) throws IOException {
int i, size;
out.writeInt(CURRENT_SERIAL_VERSION);
out.writeObject(inputPipe);
out.writeObject(outputPipe);
out.writeObject(inputAlphabet);
out.writeObject(outputAlphabet);
size = states.size();
out.writeInt(size);
for (i = 0; i < size; i++)
out.writeObject(states.get(i));
size = initialStates.size();
out.writeInt(size);
for (i = 0; i < size; i++)
out.writeObject(initialStates.get(i));
out.writeObject(name2state);
if (emissionEstimator != null) {
size = emissionEstimator.length;
out.writeInt(size);
for (i = 0; i < size; i++)
out.writeObject(emissionEstimator[i]);
} else
out.writeInt(NULL_INTEGER);
if (emissionMultinomial != null) {
size = emissionMultinomial.length;
out.writeInt(size);
for (i = 0; i < size; i++)
out.writeObject(emissionMultinomial[i]);
} else
out.writeInt(NULL_INTEGER);
if (transitionEstimator != null) {
size = transitionEstimator.length;
out.writeInt(size);
for (i = 0; i < size; i++)
out.writeObject(transitionEstimator[i]);
} else
out.writeInt(NULL_INTEGER);
if (transitionMultinomial != null) {
size = transitionMultinomial.length;
out.writeInt(size);
for (i = 0; i < size; i++)
out.writeObject(transitionMultinomial[i]);
} else
out.writeInt(NULL_INTEGER);
}
/* Bug fix from Cheng-Ju Kuo cju.kuo@gmail.com */
private void readObject(ObjectInputStream in) throws IOException,
ClassNotFoundException {
int size, i;
int version = in.readInt();
inputPipe = (Pipe) in.readObject();
outputPipe = (Pipe) in.readObject();
inputAlphabet = (Alphabet) in.readObject();
outputAlphabet = (Alphabet) in.readObject();
size = in.readInt();
states = new ArrayList();
for (i = 0; i < size; i++) {
State s = (HMM.State) in.readObject();
states.add(s);
}
size = in.readInt();
initialStates = new ArrayList();
for (i = 0; i < size; i++) {
State s = (HMM.State) in.readObject();
initialStates.add(s);
}
name2state = (HashMap) in.readObject();
size = in.readInt();
if (size == NULL_INTEGER) {
emissionEstimator = null;
} else {
emissionEstimator = new Multinomial.Estimator[size];
for (i = 0; i < size; i++) {
emissionEstimator[i] = (Multinomial.Estimator) in.readObject();
}
}
size = in.readInt();
if (size == NULL_INTEGER) {
emissionMultinomial = null;
} else {
emissionMultinomial = new Multinomial[size];
for (i = 0; i < size; i++) {
emissionMultinomial[i] = (Multinomial) in.readObject();
}
}
size = in.readInt();
if (size == NULL_INTEGER) {
transitionEstimator = null;
} else {
transitionEstimator = new Multinomial.Estimator[size];
for (i = 0; i < size; i++) {
transitionEstimator[i] = (Multinomial.Estimator) in
.readObject();
}
}
size = in.readInt();
if (size == NULL_INTEGER) {
transitionMultinomial = null;
} else {
transitionMultinomial = new Multinomial[size];
for (i = 0; i < size; i++) {
transitionMultinomial[i] = (Multinomial) in.readObject();
}
}
}
public static class State extends Transducer.State implements Serializable {
// Parameters indexed by destination state, feature index
String name;
int index;
double initialWeight, finalWeight;
String[] destinationNames;
State[] destinations;
String[] labels;
HMM hmm;
// No arg constructor so serialization works
protected State() {
super();
}
protected State(String name, int index, double initialWeight,
double finalWeight, String[] destinationNames,
String[] labelNames, HMM hmm) {
super();
assert (destinationNames.length == labelNames.length);
this.name = name;
this.index = index;
this.initialWeight = initialWeight;
this.finalWeight = finalWeight;
this.destinationNames = new String[destinationNames.length];
this.destinations = new State[labelNames.length];
this.labels = new String[labelNames.length];
this.hmm = hmm;
for (int i = 0; i < labelNames.length; i++) {
// Make sure this label appears in our output Alphabet
hmm.outputAlphabet.lookupIndex(labelNames[i]);
this.destinationNames[i] = destinationNames[i];
this.labels[i] = labelNames[i];
}
}
public Transducer getTransducer() {
return hmm;
}
public double getFinalWeight() {
return finalWeight;
}
public double getInitialWeight() {
return initialWeight;
}
public void setFinalWeight(double c) {
finalWeight = c;
}
public void setInitialWeight(double c) {
initialWeight = c;
}
public void print() {
System.out.println("State #" + index + " \"" + name + "\"");
System.out.println("initialWeight=" + initialWeight
+ ", finalWeight=" + finalWeight);
System.out.println("#destinations=" + destinations.length);
for (int i = 0; i < destinations.length; i++)
System.out.println("-> " + destinationNames[i]);
}
public State getDestinationState(int index) {
State ret;
if ((ret = destinations[index]) == null) {
ret = destinations[index] = (State) hmm.name2state
.get(destinationNames[index]);
assert (ret != null) : index;
}
return ret;
}
public Transducer.TransitionIterator transitionIterator(
Sequence inputSequence, int inputPosition,
Sequence outputSequence, int outputPosition) {
if (inputPosition < 0 || outputPosition < 0)
throw new UnsupportedOperationException(
"Epsilon transitions not implemented.");
if (inputSequence == null)
throw new UnsupportedOperationException(
"HMMs are generative models; but this is not yet implemented.");
if (!(inputSequence instanceof FeatureSequence))
throw new UnsupportedOperationException(
"HMMs currently expect Instances to have FeatureSequence data");
return new TransitionIterator(this,
(FeatureSequence) inputSequence, inputPosition,
(outputSequence == null ? null : (String) outputSequence
.get(outputPosition)), hmm);
}
public String getName() {
return name;
}
public int getIndex() {
return index;
}
public void incrementInitialCount(double count) {
}
public void incrementFinalCount(double count) {
}
// Serialization
// For class State
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject(ObjectOutputStream out) throws IOException {
int i, size;
out.writeInt(CURRENT_SERIAL_VERSION);
out.writeObject(name);
out.writeInt(index);
size = (destinationNames == null) ? NULL_INTEGER
: destinationNames.length;
out.writeInt(size);
if (size != NULL_INTEGER) {
for (i = 0; i < size; i++) {
out.writeObject(destinationNames[i]);
}
}
size = (destinations == null) ? NULL_INTEGER : destinations.length;
out.writeInt(size);
if (size != NULL_INTEGER) {
for (i = 0; i < size; i++) {
out.writeObject(destinations[i]);
}
}
size = (labels == null) ? NULL_INTEGER : labels.length;
out.writeInt(size);
if (size != NULL_INTEGER) {
for (i = 0; i < size; i++)
out.writeObject(labels[i]);
}
out.writeObject(hmm);
}
private void readObject(ObjectInputStream in) throws IOException,
ClassNotFoundException {
int size, i;
int version = in.readInt();
name = (String) in.readObject();
index = in.readInt();
size = in.readInt();
if (size != NULL_INTEGER) {
destinationNames = new String[size];
for (i = 0; i < size; i++) {
destinationNames[i] = (String) in.readObject();
}
} else {
destinationNames = null;
}
size = in.readInt();
if (size != NULL_INTEGER) {
destinations = new State[size];
for (i = 0; i < size; i++) {
destinations[i] = (State) in.readObject();
}
} else {
destinations = null;
}
size = in.readInt();
if (size != NULL_INTEGER) {
labels = new String[size];
for (i = 0; i < size; i++)
labels[i] = (String) in.readObject();
// inputAlphabet = (Alphabet) in.readObject();
// outputAlphabet = (Alphabet) in.readObject();
} else {
labels = null;
}
hmm = (HMM) in.readObject();
}
}
protected static class TransitionIterator extends
Transducer.TransitionIterator implements Serializable {
State source;
int index, nextIndex, inputPos;
double[] weights; // -logProb
// Eventually change this because we will have a more space-efficient
// FeatureVectorSequence that cannot break out each FeatureVector
FeatureSequence inputSequence;
Integer inputFeature;
HMM hmm;
public TransitionIterator(State source, FeatureSequence inputSeq,
int inputPosition, String output, HMM hmm) {
this.source = source;
this.hmm = hmm;
this.inputSequence = inputSeq;
this.inputFeature = new Integer(inputSequence
.getIndexAtPosition(inputPosition));
this.inputPos = inputPosition;
this.weights = new double[source.destinations.length];
for (int transIndex = 0; transIndex < source.destinations.length; transIndex++) {
if (output == null || output.equals(source.labels[transIndex])) {
weights[transIndex] = 0;
// xxx should this be emission of the _next_ observation?
// double logEmissionProb =
// hmm.emissionMultinomial[source.getIndex()].logProbability
// (inputSeq.get (inputPosition));
int destIndex = source.getDestinationState(transIndex).getIndex();
double logEmissionProb = hmm.emissionMultinomial[destIndex]
.logProbability(inputSeq.get(inputPosition));
double logTransitionProb = hmm.transitionMultinomial[source
.getIndex()]
.logProbability(source.destinationNames[transIndex]);
// weight = logProbability
weights[transIndex] = (logEmissionProb + logTransitionProb);
assert (!Double.isNaN(weights[transIndex]));
} else
weights[transIndex] = IMPOSSIBLE_WEIGHT;
}
nextIndex = 0;
while (nextIndex < source.destinations.length
&& weights[nextIndex] == IMPOSSIBLE_WEIGHT)
nextIndex++;
}
public boolean hasNext() {
return nextIndex < source.destinations.length;
}
public Transducer.State nextState() {
assert (nextIndex < source.destinations.length);
index = nextIndex;
nextIndex++;
while (nextIndex < source.destinations.length
&& weights[nextIndex] == IMPOSSIBLE_WEIGHT)
nextIndex++;
return source.getDestinationState(index);
}
public int getIndex() {
return index;
}
/*
* Returns an Integer object containing the feature index of the symbol
* at this position in the input sequence.
*/
public Object getInput() {
return inputFeature;
}
// public int getInputPosition () { return inputPos; }
public Object getOutput() {
return source.labels[index];
}
public double getWeight() {
return weights[index];
}
public Transducer.State getSourceState() {
return source;
}
public Transducer.State getDestinationState() {
return source.getDestinationState(index);
}
// Serialization
// TransitionIterator
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private static final int NULL_INTEGER = -1;
private void writeObject(ObjectOutputStream out) throws IOException {
out.writeInt(CURRENT_SERIAL_VERSION);
out.writeObject(source);
out.writeInt(index);
out.writeInt(nextIndex);
out.writeInt(inputPos);
if (weights != null) {
out.writeInt(weights.length);
for (int i = 0; i < weights.length; i++) {
out.writeDouble(weights[i]);
}
} else {
out.writeInt(NULL_INTEGER);
}
out.writeObject(inputSequence);
out.writeObject(inputFeature);
out.writeObject(hmm);
}
private void readObject(ObjectInputStream in) throws IOException,
ClassNotFoundException {
int version = in.readInt();
source = (State) in.readObject();
index = in.readInt();
nextIndex = in.readInt();
inputPos = in.readInt();
int size = in.readInt();
if (size == NULL_INTEGER) {
weights = null;
} else {
weights = new double[size];
for (int i = 0; i < size; i++) {
weights[i] = in.readDouble();
}
}
inputSequence = (FeatureSequence) in.readObject();
inputFeature = (Integer) in.readObject();
hmm = (HMM) in.readObject();
}
}
}