/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.types;
import java.util.Collection;
import java.util.Iterator;
import cc.mallet.grmm.util.Flops;
import cc.mallet.types.Matrix;
import cc.mallet.types.Matrixn;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Maths;
/**
* Created: Jan 4, 2006
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: LogTableFactor.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class LogTableFactor extends AbstractTableFactor {
public LogTableFactor (AbstractTableFactor in)
{
super (in);
probs = (Matrix) in.getLogValueMatrix ().cloneMatrix ();
}
public LogTableFactor (Variable var)
{
super (var);
}
public LogTableFactor (Variable[] allVars)
{
super (allVars);
}
public LogTableFactor (Collection allVars)
{
super (allVars);
}
// Create from
// Used by makeFromLogFactorValues
private LogTableFactor (Variable[] vars, double[] logValues)
{
super (vars, logValues);
}
private LogTableFactor (Variable[] allVars, Matrix probsIn)
{
super (allVars, probsIn);
}
//**************************************************************************/
public static LogTableFactor makeFromValues (Variable[] vars, double[] vals)
{
double[] vals2 = new double [vals.length];
for (int i = 0; i < vals.length; i++) {
vals2[i] = Math.log (vals[i]);
}
return makeFromLogValues (vars, vals2);
}
public static LogTableFactor makeFromLogValues (Variable[] vars, double[] vals)
{
return new LogTableFactor (vars, vals);
}
//**************************************************************************/
void setAsIdentity ()
{
setAll (0.0);
}
public Factor duplicate ()
{
return new LogTableFactor (this);
}
protected AbstractTableFactor createBlankSubset (Variable[] vars)
{
return new LogTableFactor (vars);
}
public Factor normalize ()
{
double sum = logspaceOneNorm ();
if (sum < -500)
System.err.println ("Attempt to normalize all-0 factor "+this.dumpToString ());
for (int i = 0; i < probs.numLocations (); i++) {
double val = probs.valueAtLocation (i);
probs.setValueAtLocation (i, val - sum);
}
return this;
}
private double logspaceOneNorm ()
{
double sum = Double.NEGATIVE_INFINITY; // That's 0 in log space
for (int i = 0; i < probs.numLocations (); i++) {
sum = Maths.sumLogProb (sum, probs.valueAtLocation (i));
}
Flops.sumLogProb (probs.numLocations ());
return sum;
}
public double sum ()
{
Flops.exp (); // logspaceOneNorm counts rest
return Math.exp (logspaceOneNorm ());
}
public double logsum ()
{
return logspaceOneNorm ();
}
/**
* Does the conceptual equivalent of this *= pot.
* Assumes that pot's variables are a subset of
* this potential's.
*/
protected void multiplyByInternal (DiscreteFactor ptl)
{
int[] projection = largeIdxToSmall (ptl);
int numLocs = probs.numLocations ();
for (int singleLoc = 0; singleLoc < numLocs; singleLoc++) {
int smallIdx = projection[singleLoc];
double prev = this.probs.valueAtLocation (singleLoc);
double newVal = ptl.logValue (smallIdx);
double product = prev + newVal;
this.probs.setValueAtLocation (singleLoc, product);
}
Flops.increment (numLocs); // handle the pluses
}
// Does destructive divison on this, assuming this has all
// the variables in pot.
protected void divideByInternal (DiscreteFactor ptl)
{
int[] projection = largeIdxToSmall (ptl);
int numLocs = probs.numLocations ();
for (int singleLoc = 0; singleLoc < numLocs; singleLoc++) {
int smallIdx = projection[singleLoc];
double prev = this.probs.valueAtLocation (singleLoc);
double newVal = ptl.logValue (smallIdx);
double product = prev - newVal;
/* by convention, let -Inf + Inf (corresponds to 0/0) be -Inf */
if (Double.isInfinite (newVal)) {
product = Double.NEGATIVE_INFINITY;
}
this.probs.setValueAtLocation (singleLoc, product);
}
Flops.increment (numLocs); // handle the pluses
}
/**
* Does the conceptual equivalent of this += pot.
* Assumes that pot's variables are a subset of
* this potential's.
*/
protected void plusEqualsInternal (DiscreteFactor ptl)
{
int[] projection = largeIdxToSmall (ptl);
int numLocs = probs.numLocations ();
for (int singleLoc = 0; singleLoc < numLocs; singleLoc++) {
int smallIdx = projection[singleLoc];
double prev = this.probs.valueAtLocation (singleLoc);
double newVal = ptl.logValue (smallIdx);
double product = Maths.sumLogProb (prev, newVal);
this.probs.setValueAtLocation (singleLoc, product);
}
Flops.sumLogProb (numLocs);
}
public double value (Assignment assn)
{
Flops.exp ();
if (getNumVars () == 0) return 1.0;
return Math.exp (rawValue (assn));
}
public double value (AssignmentIterator it)
{
Flops.exp ();
return Math.exp (rawValue (it.indexOfCurrentAssn ()));
}
public double value (int idx)
{
Flops.exp ();
return Math.exp (rawValue (idx));
}
public double logValue (AssignmentIterator it)
{
return rawValue (it.indexOfCurrentAssn ());
}
public double logValue (int idx)
{
return rawValue (idx);
}
public double logValue (Assignment assn)
{
return rawValue (assn);
}
protected Factor marginalizeInternal (AbstractTableFactor result)
{
result.setAll (Double.NEGATIVE_INFINITY);
int[] projection = largeIdxToSmall (result);
/* Add each element of the single array of the large potential
to the correct element in the small potential. */
int numLocs = probs.numLocations ();
for (int largeLoc = 0; largeLoc < numLocs; largeLoc++) {
/* Convert a single-index from this distribution to
one for the smaller distribution */
int smallIdx = projection[largeLoc];
/* Whew! Now, add it in. */
double oldValue = this.probs.valueAtLocation (largeLoc);
double currentValue = result.probs.singleValue (smallIdx);
result.probs.setValueAtLocation (smallIdx,
Maths.sumLogProb (oldValue, currentValue));
}
Flops.sumLogProb (numLocs);
return result;
}
protected double rawValue (Assignment assn)
{
int numVars = getNumVars ();
int[] indices = new int[numVars];
for (int i = 0; i < numVars; i++) {
Variable var = getVariable (i);
indices[i] = assn.get (var);
}
return rawValue (indices);
}
private double rawValue (int[] indices)
{
// handle non-occuring indices specially, for default value is -Inf in log space.
int singleIdx = probs.singleIndex (indices);
return rawValue (singleIdx);
}
protected double rawValue (int singleIdx)
{
int loc = probs.location (singleIdx);
if (loc < 0) {
return Double.NEGATIVE_INFINITY;
} else {
return probs.valueAtLocation (loc);
}
}
public void exponentiate (double power)
{
Flops.increment (probs.numLocations ());
probs.timesEquals (power);
}
/*
protected AbstractTableFactor ensureOperandCompatible (AbstractTableFactor ptl)
{
if (!(ptl instanceof LogTableFactor)) {
return new LogTableFactor(ptl);
} else {
return ptl;
}
}
*/
public void setLogValue (Assignment assn, double logValue)
{
setRawValue (assn, logValue);
}
public void setLogValue (AssignmentIterator assnIt, double logValue)
{
setRawValue (assnIt, logValue);
}
public void setValue (AssignmentIterator assnIt, double value)
{
Flops.log ();
setRawValue (assnIt, Math.log (value));
}
public void setLogValues (double[] vals)
{
for (int i = 0; i < vals.length; i++) {
setRawValue (i, vals[i]);
}
}
public void setValues (double[] vals)
{
Flops.log (vals.length);
for (int i = 0; i < vals.length; i++) {
setRawValue (i, Math.log (vals[i]));
}
}
// v is *not* in log space
public void timesEquals (double v)
{
timesEqualsLog (Math.log (v));
}
private void timesEqualsLog (double logV)
{
Flops.increment (probs.numLocations ());
Matrix other = (Matrix) probs.cloneMatrix ();
other.setAll (logV);
probs.plusEquals (other);
}
protected void plusEqualsAtLocation (int loc, double v)
{
Flops.log (); Flops.sumLogProb (1);
double oldVal = logValue (loc);
setRawValue (loc, Maths.sumLogProb (oldVal, Math.log (v)));
}
public static LogTableFactor makeFromValues (Variable var, double[] vals2)
{
return makeFromValues (new Variable[]{var}, vals2);
}
public static LogTableFactor makeFromMatrix (Variable[] vars, SparseMatrixn values)
{
SparseMatrixn logValues = (SparseMatrixn) values.cloneMatrix ();
for (int i = 0; i < logValues.numLocations (); i++) {
logValues.setValueAtLocation (i, Math.log (logValues.valueAtLocation (i)));
}
Flops.log (logValues.numLocations ());
return new LogTableFactor (vars, logValues);
}
public static LogTableFactor makeFromLogMatrix (Variable[] vars, Matrix values)
{
Matrix logValues = (Matrix) values.cloneMatrix ();
return new LogTableFactor (vars, logValues);
}
public static LogTableFactor makeFromLogValues (Variable v, double[] vals)
{
return makeFromLogValues (new Variable[]{v}, vals);
}
public Matrix getValueMatrix ()
{
Matrix logProbs = (Matrix) probs.cloneMatrix ();
for (int loc = 0; loc < probs.numLocations (); loc++) {
logProbs.setValueAtLocation (loc, Math.exp (logProbs.valueAtLocation (loc)));
}
Flops.exp (probs.numLocations ());
return logProbs;
}
public Matrix getLogValueMatrix ()
{
return probs;
}
public double valueAtLocation (int idx)
{
Flops.exp ();
return Math.exp (probs.valueAtLocation (idx));
}
protected Factor slice_onevar (Variable var, Assignment observed)
{
Assignment assn = (Assignment) observed.duplicate ();
double[] vals = new double [var.getNumOutcomes ()];
for (int i = 0; i < var.getNumOutcomes (); i++) {
assn.setValue (var, i);
vals[i] = logValue (assn);
}
return LogTableFactor.makeFromLogValues (var, vals);
}
protected Factor slice_twovar (Variable v1, Variable v2, Assignment observed)
{
Assignment assn = (Assignment) observed.duplicate ();
int N1 = v1.getNumOutcomes ();
int N2 = v2.getNumOutcomes ();
int[] szs = new int[]{N1, N2};
double[] vals = new double [N1 * N2];
for (int i = 0; i < N1; i++) {
assn.setValue (v1, i);
for (int j = 0; j < N2; j++) {
assn.setValue (v2, j);
int idx = Matrixn.singleIndex (szs, new int[]{i, j}); // Inefficient, but much less error prone
vals[idx] = logValue (assn);
}
}
return LogTableFactor.makeFromLogValues (new Variable[]{v1, v2}, vals);
}
protected Factor slice_general (Variable[] vars, Assignment observed)
{
VarSet toKeep = new HashVarSet (vars);
toKeep.removeAll (observed.varSet ());
double[] vals = new double [toKeep.weight ()];
AssignmentIterator it = toKeep.assignmentIterator ();
while (it.hasNext ()) {
Assignment union = Assignment.union (observed, it.assignment ());
vals[it.indexOfCurrentAssn ()] = logValue (union);
it.advance ();
}
return LogTableFactor.makeFromLogValues (toKeep.toVariableArray (), vals);
}
public static LogTableFactor multiplyAll (Collection phis)
{
/* Get all the variables */
VarSet vs = new HashVarSet ();
for (Iterator it = phis.iterator (); it.hasNext ();) {
Factor phi = (Factor) it.next ();
vs.addAll (phi.varSet ());
}
/* define a new potential over the neighbors of NODE */
LogTableFactor newCPF = new LogTableFactor (vs);
for (Iterator it = phis.iterator (); it.hasNext ();) {
Factor phi = (Factor) it.next ();
newCPF.multiplyBy (phi);
}
return newCPF;
}
public AbstractTableFactor recenter ()
{
// return (AbstractTableFactor) normalize ();
int loc = argmax ();
double lval = probs.valueAtLocation(loc); // should be refactored
timesEqualsLog (-lval);
return this;
}
}