Package cc.mallet.grmm.types

Source Code of cc.mallet.grmm.types.LogTableFactor

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.types;


import java.util.Collection;
import java.util.Iterator;

import cc.mallet.grmm.util.Flops;
import cc.mallet.types.Matrix;
import cc.mallet.types.Matrixn;
import cc.mallet.types.SparseMatrixn;
import cc.mallet.util.Maths;

/**
* Created: Jan 4, 2006
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: LogTableFactor.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $
*/
public class LogTableFactor extends AbstractTableFactor {

  public LogTableFactor (AbstractTableFactor in)
  {
    super (in);
    probs = (Matrix) in.getLogValueMatrix ().cloneMatrix ();
  }

  public LogTableFactor (Variable var)
  {
    super (var);
  }

  public LogTableFactor (Variable[] allVars)
  {
    super (allVars);
  }

  public LogTableFactor (Collection allVars)
  {
    super (allVars);
  }

  // Create from
  //  Used by makeFromLogFactorValues
  private LogTableFactor (Variable[] vars, double[] logValues)
  {
    super (vars, logValues);
  }

  private LogTableFactor (Variable[] allVars, Matrix probsIn)
  {
    super (allVars, probsIn);
  }

  //**************************************************************************/

  public static LogTableFactor makeFromValues (Variable[] vars, double[] vals)
  {
    double[] vals2 = new double [vals.length];
    for (int i = 0; i < vals.length; i++) {
      vals2[i] = Math.log (vals[i]);
    }
    return makeFromLogValues (vars, vals2);
  }

  public static LogTableFactor makeFromLogValues (Variable[] vars, double[] vals)
  {
    return new LogTableFactor (vars, vals);
  }

  //**************************************************************************/

  void setAsIdentity ()
  {
    setAll (0.0);
  }

  public Factor duplicate ()
  {
    return new LogTableFactor (this);
  }

  protected AbstractTableFactor createBlankSubset (Variable[] vars)
  {
    return new LogTableFactor (vars);
  }

  public Factor normalize ()
  {
    double sum = logspaceOneNorm ();
    if (sum < -500)
      System.err.println ("Attempt to normalize all-0 factor "+this.dumpToString ());
   
    for (int i = 0; i < probs.numLocations (); i++) {
      double val = probs.valueAtLocation (i);
      probs.setValueAtLocation (i, val - sum);
    }
    return this;
  }

  private double logspaceOneNorm ()
  {
    double sum = Double.NEGATIVE_INFINITY; // That's 0 in log space
    for (int i = 0; i < probs.numLocations (); i++) {
      sum = Maths.sumLogProb (sum, probs.valueAtLocation (i));
    }
    Flops.sumLogProb (probs.numLocations ());
    return sum;
  }

  public double sum ()
  {
    Flops.exp ()// logspaceOneNorm counts rest
    return Math.exp (logspaceOneNorm ());
  }

  public double logsum ()
  {
    return logspaceOneNorm ();
  }

  /**
   * Does the conceptual equivalent of this *= pot.
   * Assumes that pot's variables are a subset of
   * this potential's.
   */
  protected void multiplyByInternal (DiscreteFactor ptl)
  {
    int[] projection = largeIdxToSmall (ptl);
    int numLocs = probs.numLocations ();
    for (int singleLoc = 0; singleLoc < numLocs; singleLoc++) {
      int smallIdx = projection[singleLoc];
      double prev = this.probs.valueAtLocation (singleLoc);
      double newVal = ptl.logValue (smallIdx);
      double product = prev + newVal;
      this.probs.setValueAtLocation (singleLoc, product);
    }
    Flops.increment (numLocs)// handle the pluses
  }

  // Does destructive divison on this, assuming this has all
// the variables in pot.
  protected void divideByInternal (DiscreteFactor ptl)
  {
    int[] projection = largeIdxToSmall (ptl);
    int numLocs = probs.numLocations ();
    for (int singleLoc = 0; singleLoc < numLocs; singleLoc++) {
      int smallIdx = projection[singleLoc];
      double prev = this.probs.valueAtLocation (singleLoc);
      double newVal = ptl.logValue (smallIdx);
      double product = prev - newVal;
      /* by convention, let -Inf + Inf (corresponds to 0/0) be -Inf */
      if (Double.isInfinite (newVal)) {
        product = Double.NEGATIVE_INFINITY;
      }
      this.probs.setValueAtLocation (singleLoc, product);
    }
    Flops.increment (numLocs)// handle the pluses
  }

  /**
   * Does the conceptual equivalent of this += pot.
   * Assumes that pot's variables are a subset of
   * this potential's.
   */
  protected void plusEqualsInternal (DiscreteFactor ptl)
  {
    int[] projection = largeIdxToSmall (ptl);
    int numLocs = probs.numLocations ();
    for (int singleLoc = 0; singleLoc < numLocs; singleLoc++) {
      int smallIdx = projection[singleLoc];
      double prev = this.probs.valueAtLocation (singleLoc);
      double newVal = ptl.logValue (smallIdx);
      double product = Maths.sumLogProb (prev, newVal);
      this.probs.setValueAtLocation (singleLoc, product);
    }
    Flops.sumLogProb (numLocs);
  }


  public double value (Assignment assn)
  {
    Flops.exp ();
    if (getNumVars () == 0) return 1.0;
    return Math.exp (rawValue (assn));
  }


  public double value (AssignmentIterator it)
  {
    Flops.exp ();
    return Math.exp (rawValue (it.indexOfCurrentAssn ()));
  }


  public double value (int idx)
  {
    Flops.exp ();
    return Math.exp (rawValue (idx));
  }

  public double logValue (AssignmentIterator it)
  {
    return rawValue (it.indexOfCurrentAssn ());
  }

  public double logValue (int idx)
  {
    return rawValue (idx);
  }

  public double logValue (Assignment assn)
  {
    return rawValue (assn);
  }

  protected Factor marginalizeInternal (AbstractTableFactor result)
  {

    result.setAll (Double.NEGATIVE_INFINITY);
    int[] projection = largeIdxToSmall (result);

    /* Add each element of the single array of the large potential
to the correct element in the small potential. */
    int numLocs = probs.numLocations ();
    for (int largeLoc = 0; largeLoc < numLocs; largeLoc++) {

      /* Convert a single-index from this distribution to
one for the smaller distribution */
      int smallIdx = projection[largeLoc];

      /* Whew! Now, add it in. */
      double oldValue = this.probs.valueAtLocation (largeLoc);
      double currentValue = result.probs.singleValue (smallIdx);
      result.probs.setValueAtLocation (smallIdx,
              Maths.sumLogProb (oldValue, currentValue));

    }
    Flops.sumLogProb (numLocs);

    return result;
  }

  protected double rawValue (Assignment assn)
  {
    int numVars = getNumVars ();
    int[] indices = new int[numVars];
    for (int i = 0; i < numVars; i++) {
      Variable var = getVariable (i);
      indices[i] = assn.get (var);
    }

    return rawValue (indices);
  }

  private double rawValue (int[] indices)
  {
    // handle non-occuring indices specially, for default value is -Inf in log space.
    int singleIdx = probs.singleIndex (indices);
    return rawValue (singleIdx);
  }

  protected double rawValue (int singleIdx)
  {
    int loc = probs.location (singleIdx);
    if (loc < 0) {
      return Double.NEGATIVE_INFINITY;
    } else {
      return probs.valueAtLocation (loc);
    }
  }

  public void exponentiate (double power)
  {
    Flops.increment (probs.numLocations ());
    probs.timesEquals (power);
  }

  /*
  protected AbstractTableFactor ensureOperandCompatible (AbstractTableFactor ptl)
  {
    if (!(ptl instanceof LogTableFactor)) {
      return new LogTableFactor(ptl);
    } else {
      return ptl;
    }
  }
  */

  public void setLogValue (Assignment assn, double logValue)
  {
    setRawValue (assn, logValue);
  }

  public void setLogValue (AssignmentIterator assnIt, double logValue)
  {
    setRawValue (assnIt, logValue);
  }

  public void setValue (AssignmentIterator assnIt, double value)
  {
    Flops.log ();
    setRawValue (assnIt, Math.log (value));
  }

  public void setLogValues (double[] vals)
  {
    for (int i = 0; i < vals.length; i++) {
      setRawValue (i, vals[i]);
    }
  }

  public void setValues (double[] vals)
  {
    Flops.log (vals.length);
    for (int i = 0; i < vals.length; i++) {
      setRawValue (i, Math.log (vals[i]));
    }
  }

  // v is *not* in log space
  public void timesEquals (double v)
  {
    timesEqualsLog (Math.log (v));
  }

  private void timesEqualsLog (double logV)
  {
    Flops.increment (probs.numLocations ());
    Matrix other = (Matrix) probs.cloneMatrix ();
    other.setAll (logV);
    probs.plusEquals (other);
  }

  protected void plusEqualsAtLocation (int loc, double v)
  {
    Flops.log (); Flops.sumLogProb (1);
    double oldVal = logValue (loc);
    setRawValue (loc, Maths.sumLogProb (oldVal, Math.log (v)));
  }

  public static LogTableFactor makeFromValues (Variable var, double[] vals2)
  {
    return makeFromValues (new Variable[]{var}, vals2);
  }

  public static LogTableFactor makeFromMatrix (Variable[] vars, SparseMatrixn values)
  {
    SparseMatrixn logValues = (SparseMatrixn) values.cloneMatrix ();
    for (int i = 0; i < logValues.numLocations (); i++) {
      logValues.setValueAtLocation (i, Math.log (logValues.valueAtLocation (i)));
    }
    Flops.log (logValues.numLocations ());
    return new LogTableFactor (vars, logValues);
  }

  public static LogTableFactor makeFromLogMatrix (Variable[] vars, Matrix values)
  {
    Matrix logValues = (Matrix) values.cloneMatrix ();
    return new LogTableFactor (vars, logValues);
  }

  public static LogTableFactor makeFromLogValues (Variable v, double[] vals)
  {
    return makeFromLogValues (new Variable[]{v}, vals);
  }

  public Matrix getValueMatrix ()
  {
    Matrix logProbs = (Matrix) probs.cloneMatrix ();
    for (int loc = 0; loc < probs.numLocations (); loc++) {
      logProbs.setValueAtLocation (loc, Math.exp (logProbs.valueAtLocation (loc)));
    }
    Flops.exp (probs.numLocations ());
    return logProbs;
  }

  public Matrix getLogValueMatrix ()
  {
    return probs;
  }

  public double valueAtLocation (int idx)
  {
    Flops.exp ();
    return Math.exp (probs.valueAtLocation (idx));
  }

  protected Factor slice_onevar (Variable var, Assignment observed)
  {
    Assignment assn = (Assignment) observed.duplicate ();
    double[] vals = new double [var.getNumOutcomes ()];
    for (int i = 0; i < var.getNumOutcomes (); i++) {
      assn.setValue (var, i);
      vals[i] = logValue (assn);
    }

    return LogTableFactor.makeFromLogValues (var, vals);
  }

  protected Factor slice_twovar (Variable v1, Variable v2, Assignment observed)
  {
    Assignment assn = (Assignment) observed.duplicate ();

    int N1 = v1.getNumOutcomes ();
    int N2 = v2.getNumOutcomes ();
    int[] szs = new int[]{N1, N2};

    double[] vals = new double [N1 * N2];
    for (int i = 0; i < N1; i++) {
      assn.setValue (v1, i);
      for (int j = 0; j < N2; j++) {
        assn.setValue (v2, j);
        int idx = Matrixn.singleIndex (szs, new int[]{i, j}); // Inefficient, but much less error prone
        vals[idx] = logValue (assn);
      }
    }

    return LogTableFactor.makeFromLogValues (new Variable[]{v1, v2}, vals);
  }

  protected Factor slice_general (Variable[] vars, Assignment observed)
  {
    VarSet toKeep = new HashVarSet (vars);
    toKeep.removeAll (observed.varSet ());
    double[] vals = new double [toKeep.weight ()];

    AssignmentIterator it = toKeep.assignmentIterator ();
    while (it.hasNext ()) {
      Assignment union = Assignment.union (observed, it.assignment ());
      vals[it.indexOfCurrentAssn ()] = logValue (union);
      it.advance ();
    }

    return LogTableFactor.makeFromLogValues (toKeep.toVariableArray (), vals);
  }

  public static LogTableFactor multiplyAll (Collection phis)
  {
    /* Get all the variables */
    VarSet vs = new HashVarSet ();
    for (Iterator it = phis.iterator (); it.hasNext ();) {
      Factor phi = (Factor) it.next ();
      vs.addAll (phi.varSet ());
    }

    /* define a new potential over the neighbors of NODE */
    LogTableFactor newCPF = new LogTableFactor (vs);
    for (Iterator it = phis.iterator (); it.hasNext ();) {
      Factor phi = (Factor) it.next ();
      newCPF.multiplyBy (phi);
    }

    return newCPF;
  }

  public AbstractTableFactor recenter ()
  {
//    return (AbstractTableFactor) normalize ();
    int loc = argmax ();
    double lval = probs.valueAtLocation(loc)// should be refactored
    timesEqualsLog (-lval);
    return this;
  }
}
TOP

Related Classes of cc.mallet.grmm.types.LogTableFactor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.