/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.types;
import java.util.*;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.util.Flops;
import cc.mallet.types.*;
import cc.mallet.util.*;
import gnu.trove.TIntArrayList;
import gnu.trove.TDoubleArrayList;
/**
* A static utility class containing utility methods for dealing with factors,
* especially TableFactor objects.
*
* Created: Mar 17, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: Factors.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class Factors {
public static CPT normalizeAsCpt (AbstractTableFactor ptl, Variable var)
{
double[] sums = new double [ptl.numLocations ()];
Arrays.fill (sums, Double.NEGATIVE_INFINITY);
// Compute normalization factor for each neighbor assignment
VarSet neighbors = new HashVarSet (ptl.varSet ());
neighbors.remove (var);
for (AssignmentIterator it = ptl.assignmentIterator (); it.hasNext (); it.advance ()) {
Assignment assn = it.assignment ();
Assignment nbrAssn = (Assignment) assn.marginalizeOut (var);
int idx = nbrAssn.singleIndex ();
// sums[idx] += ptl.phi (assn);
sums[idx] = Maths.sumLogProb (ptl.logValue (assn), sums[idx]);
}
// ...and then normalize potential
for (AssignmentIterator it = ptl.assignmentIterator (); it.hasNext (); it.advance ()) {
Assignment assn = it.assignment ();
double oldVal = ptl.logValue (assn);
// double oldVal = ptl.phi (assn);
Assignment nbrAssn = (Assignment) assn.marginalizeOut (var);
double logZ = sums[nbrAssn.singleIndex ()];
// ptl.setPhi (assn, oldVal / logZ);
if (Double.isInfinite (oldVal) && Double.isInfinite (logZ)) {
// 0/0 = 0
ptl.setLogValue (assn, Double.NEGATIVE_INFINITY);
} else {
ptl.setLogValue (assn, oldVal - logZ);
}
}
return new CPT (ptl, var);
}
public static Factor average (Factor ptl1, Factor ptl2, double weight)
{
// complete hack
TableFactor mptl1 = (TableFactor) ptl1;
TableFactor mptl2 = (TableFactor) ptl2;
return TableFactor.hackyMixture (mptl1, mptl2, weight);
}
public static double oneDistance (Factor bel1, Factor bel2)
{
Set vs1 = bel1.varSet ();
Set vs2 = bel2.varSet ();
if (!vs1.equals (vs2)) {
throw new IllegalArgumentException ("Attempt to take distancebetween mismatching potentials "+bel1+" and "+bel2);
}
double dist = 0;
for (AssignmentIterator it = bel1.assignmentIterator (); it.hasNext ();) {
Assignment assn = it.assignment ();
dist += Math.abs (bel1.value (assn) - bel2.value (assn));
it.advance ();
}
return dist;
}
public static TableFactor retainMass (DiscreteFactor ptl, double alpha)
{
int[] idxs = new int [ptl.numLocations ()];
double[] vals = new double [ptl.numLocations ()];
for (int i = 0; i < idxs.length; i++) {
idxs[i] = ptl.indexAtLocation (i);
vals[i] = ptl.logValue (i);
}
RankedFeatureVector rfv = new RankedFeatureVector (new Alphabet(), idxs, vals);
TIntArrayList idxList = new TIntArrayList ();
TDoubleArrayList valList = new TDoubleArrayList ();
double mass = Double.NEGATIVE_INFINITY;
double logAlpha = Math.log (alpha);
for (int rank = 0; rank < rfv.numLocations (); rank++) {
int idx = rfv.getIndexAtRank (rank);
double val = rfv.value (idx);
mass = Maths.sumLogProb (mass, val);
idxList.add (idx);
valList.add (val);
if (mass > logAlpha) {
break;
}
}
int[] szs = computeSizes (ptl);
SparseMatrixn m = new SparseMatrixn (szs, idxList.toNativeArray (), valList.toNativeArray ());
TableFactor result = new TableFactor (computeVars (ptl));
result.setValues (m);
return result;
}
public static int[] computeSizes (Factor result)
{
int nv = result.varSet ().size();
int[] szs = new int [nv];
for (int i = 0; i < nv; i++) {
Variable var = result.getVariable (i);
szs[i] = var.getNumOutcomes ();
}
return szs;
}
public static Variable[] computeVars (Factor result)
{
int nv = result.varSet ().size();
Variable[] vars = new Variable [nv];
for (int i = 0; i < nv; i++) {
Variable var = result.getVariable (i);
vars[i] = var;
}
return vars;
}
/**
* Given a joint distribution over two variables, returns their mutual information.
* @param factor A joint distribution. Must be normalized, and over exactly two variables.
* @return The mutual inforamiton
*/
public static double mutualInformation (Factor factor)
{
VarSet vs = factor.varSet ();
if (vs.size() != 2) throw new IllegalArgumentException ("Factor must have size 2");
Factor marg1 = factor.marginalize (vs.get (0));
Factor marg2 = factor.marginalize (vs.get (1));
double result = 0;
for (Iterator it = factor.assignmentIterator (); it.hasNext(); ) {
Assignment assn = (Assignment) it.next ();
result += (factor.value (assn)) * (factor.logValue (assn) - marg1.logValue (assn) - marg2.logValue (assn));
}
return result;
}
public static double KL (AbstractTableFactor f1, AbstractTableFactor f2)
{
double result = 0;
// assumes same var set
for (int loc = 0; loc < f1.numLocations (); loc++) {
double val1 = f1.valueAtLocation (loc);
double val2 = f2.value (f1.indexAtLocation (loc));
if (val1 > 1e-5) {
result += val1 * Math.log (val1 / val2);
}
}
return result;
}
/**
* Returns a new Factor <tt>F = alpha * f1 + (1 - alpha) * f2</tt>.
*/
public static Factor mix (AbstractTableFactor f1, AbstractTableFactor f2, double alpha)
{
return AbstractTableFactor.hackyMixture (f1, f2, alpha);
}
public static double euclideanDistance (AbstractTableFactor f1, AbstractTableFactor f2)
{
double result = 0;
// assumes same var set
for (int loc = 0; loc < f1.numLocations (); loc++) {
double val1 = f1.valueAtLocation (loc);
double val2 = f2.value (f1.indexAtLocation (loc));
result += (val1 - val2) * (val1 - val2);
}
return Math.sqrt (result);
}
public static double l1Distance (AbstractTableFactor f1, AbstractTableFactor f2)
{
double result = 0;
// assumes same var set
for (int loc = 0; loc < f1.numLocations (); loc++) {
double val1 = f1.valueAtLocation (loc);
double val2 = f2.value (f1.indexAtLocation (loc));
result += Math.abs (val1 - val2);
}
return result;
}
/**
* Adapter that allows an Inferencer to be treated as if it were a factor.
* @param inf An inferencer on which computeMarginals() has been called.
* @return A factor
*/
public static Factor asFactor (final Inferencer inf)
{
return new SkeletonFactor () {
public double value (Assignment assn)
{
Factor factor = inf.lookupMarginal (assn.varSet ());
return factor.value (assn);
}
public Factor marginalize (Variable vars[])
{
return inf.lookupMarginal (new HashVarSet (vars));
}
public Factor marginalize (Collection vars)
{
return inf.lookupMarginal (new HashVarSet (vars));
}
public Factor marginalize (Variable var)
{
return inf.lookupMarginal (new HashVarSet (new Variable[] { var }));
}
public Factor marginalizeOut (Variable var)
{
throw new UnsupportedOperationException ();
}
public Factor marginalizeOut (VarSet varset)
{
throw new UnsupportedOperationException ();
}
public VarSet varSet ()
{
throw new UnsupportedOperationException ();
}
};
}
public static Variable[] discreteVarsOf (Factor fg)
{
List vars = new ArrayList ();
VarSet vs = fg.varSet ();
for (int vi = 0; vi < vs.size (); vi++) {
Variable var = vs.get (vi);
if (!var.isContinuous ()) {
vars.add (var);
}
}
return (Variable[]) vars.toArray (new Variable [vars.size ()]);
}
public static Variable[] continuousVarsOf (Factor fg)
{
List vars = new ArrayList ();
VarSet vs = fg.varSet ();
for (int vi = 0; vi < vs.size (); vi++) {
Variable var = vs.get (vi);
if (var.isContinuous ()) {
vars.add (var);
}
}
return (Variable[]) vars.toArray (new Variable [vars.size ()]);
}
public static double corr (Factor factor)
{
if (factor.varSet ().size() != 2)
throw new IllegalArgumentException ("corr() only works on Factors of size 2, tried "+factor);
Variable v0 = factor.varSet ().get (0);
Variable v1 = factor.varSet ().get (1);
double eXY = 0.0;
for (AssignmentIterator it = factor.assignmentIterator (); it.hasNext();) {
Assignment assn = (Assignment) it.next ();
int val0 = assn.get (v0);
int val1 = assn.get (v1);
eXY += factor.value (assn) * val0 * val1;
}
double eX = mean (factor.marginalize (v0));
double eY = mean (factor.marginalize (v1));
return eXY - eX * eY;
}
private static double mean (Factor factor)
{
if (factor.varSet ().size() != 1)
throw new IllegalArgumentException ("mean() only works on Factors of size 1, tried "+factor);
Variable v0 = factor.varSet ().get (0);
double mean = 0.0;
for (AssignmentIterator it = factor.assignmentIterator (); it.hasNext();) {
Assignment assn = (Assignment) it.next ();
int val0 = assn.get (v0);
mean += factor.value (assn) * val0;
}
return mean;
}
public static Factor multiplyAll (Collection factors)
{
Factor first = (Factor) factors.iterator ().next ();
if (factors.size() == 1) {
return first.duplicate ();
}
/* Get all the variables */
VarSet vs = new HashVarSet ();
for (Iterator it = factors.iterator (); it.hasNext ();) {
Factor phi = (Factor) it.next ();
vs.addAll (phi.varSet ());
}
/* define a new potential over the neighbors of NODE */
Factor result = first.duplicate ();
for (Iterator it = factors.iterator (); it.hasNext ();) {
Factor phi = (Factor) it.next ();
result.multiplyBy (phi);
}
return result;
}
public static double distLinf (AbstractTableFactor f1, AbstractTableFactor f2)
{
// double sum1 = f1.logsum ();
// double sum2 = f2.logsum ();
Matrix m1 = f1.getLogValueMatrix ();
Matrix m2 = f2.getLogValueMatrix ();
return matrixDistLinf (m1, m2);
}
public static double distValueLinf (AbstractTableFactor f1, AbstractTableFactor f2)
{
// double sum1 = f1.logsum ();
// double sum2 = f2.logsum ();
Matrix m1 = f1.getValueMatrix ();
Matrix m2 = f2.getValueMatrix ();
return matrixDistLinf (m1, m2);
}
private static double matrixDistLinf (Matrix m1, Matrix m2)
{
double max = 0;
int nl1 = m1.singleSize ();
int nl2 = m2.singleSize ();
if (nl1 != nl2) return Double.POSITIVE_INFINITY;
for (int l = 0; l < nl1; l++) {
double val1 = m1.singleValue (l);
double val2 = m2.singleValue (l);
double diff = (val1 > val2) ? val1 - val2 : val2 - val1;
max = (diff > max) ? diff : max;
}
return max;
}
/** Implements the error range measure from Ihler et al. */
public static double logErrorRange (AbstractTableFactor f1, AbstractTableFactor f2)
{
double error_min = Double.MAX_VALUE;
double error_max = 0;
Matrix m1 = f1.getLogValueMatrix ();
Matrix m2 = f2.getLogValueMatrix ();
int nl1 = m1.singleSize ();
int nl2 = m2.singleSize ();
if (nl1 != nl2) return Double.POSITIVE_INFINITY;
for (int l = 0; l < nl1; l++) {
double val1 = m1.singleValue (l);
double val2 = m2.singleValue (l);
double diff = (val1 > val2) ? val1 - val2 : val2 - val1;
error_max = (diff > error_max) ? diff : error_max;
error_min = (diff < error_min) ? diff : error_min;
}
return error_max - error_min;
}
}