Package cc.mallet.grmm.util

Source Code of cc.mallet.grmm.util.Models

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.util;


import java.util.*;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.inference.JunctionTree;
import cc.mallet.grmm.inference.JunctionTreeInferencer;
import cc.mallet.grmm.types.*;

import gnu.trove.THashSet;

/**
* Static utilities that do useful things with factor graphs.
*
* Created: Sep 22, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: Models.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $
*/
public class Models {

  /**
   * Returns a new factor graph, the same as a given one, except that all the nodes in
   *  the given Assignment are clamped as evidence.
   * @param mdl Old model.  Will not be modified.
   * @param assn Evidence to add
   * @return A new factor graph.
   */
  public static FactorGraph addEvidence (FactorGraph mdl, Assignment assn)
  {
    return addEvidence (mdl, assn, null);
  }

  public static FactorGraph addEvidence (FactorGraph mdl, Assignment assn, Map toSlicedMap)
  {
    FactorGraph newMdl = new FactorGraph (mdl.numVariables ());
    addSlicedPotentials (mdl, newMdl, assn, toSlicedMap);
    return newMdl;
  }

  public static UndirectedModel addEvidence (UndirectedModel mdl, Assignment assn)
  {
    UndirectedModel newMdl = new UndirectedModel (mdl.numVariables ());
    addSlicedPotentials (mdl, newMdl, assn, null);
    return newMdl;
  }

  private static void addSlicedPotentials (FactorGraph fromMdl, FactorGraph toMdl, Assignment assn, Map toSlicedMap)
  {
    Set inputVars = new THashSet (Arrays.asList (assn.getVars ()));
    Set remainingVars = new THashSet (fromMdl.variablesSet ());
    remainingVars.removeAll (inputVars);
    for (Iterator it = fromMdl.factorsIterator (); it.hasNext ();) {
      Factor ptl = (Factor) it.next ();
      Set theseVars = new THashSet (ptl.varSet ());
      theseVars.retainAll (remainingVars);
      Factor slicedPtl = ptl.slice (assn);
      toMdl.addFactor (slicedPtl);
      if (toSlicedMap != null) {
        toSlicedMap.put (ptl, slicedPtl);
      }
    }
  }

  /**
   * Returns the highest-score Assignment in a model according to a given inferencer.
   * @param mdl Factor graph to use
   * @param inf Inferencer to use.  No need to call <tt>computeMarginals</tt> first.
   * @return An Assignment
   */
  public static Assignment bestAssignment (FactorGraph mdl, Inferencer inf)
  {
    inf.computeMarginals (mdl);
    int[] outcomes = new int [mdl.numVariables ()];
    for (int i = 0; i < outcomes.length; i++) {
      Variable var = mdl.get (i);
      int best = inf.lookupMarginal (var).argmax ();
      outcomes[i] = best;
    }
    return new Assignment (mdl, outcomes);
  }

  /**
   * Computes the exact entropy of a factor graph distribution using the junction tree algorithm.
   *  If the model is intractable, then this method won't return a number anytime soon.
   */
  public static double entropy (FactorGraph mdl)
  {
    JunctionTreeInferencer inf = new JunctionTreeInferencer ();
    inf.computeMarginals (mdl);
    JunctionTree jt = inf.lookupJunctionTree ();
    return jt.entropy ();
  }

  /**
   * Computes the KL divergence <tt>KL(mdl1||mdl2)</tt>.  Junction tree is used to compute the entropy.
   * <p>
   * TODO: This probably won't handle when the jnuction tree for MDL2 contains a clique that's not present in the
   *  junction tree for mdl1.  If so, this is a bug.
   *
   * @param mdl1
   * @param mdl2
   * @return KL(mdl1||mdl2)
   */
  public static double KL (FactorGraph mdl1, FactorGraph mdl2)
  {
    JunctionTreeInferencer inf1 = new JunctionTreeInferencer ();
    inf1.computeMarginals (mdl1);
    JunctionTree jt1 = inf1.lookupJunctionTree ();

    JunctionTreeInferencer inf2 = new JunctionTreeInferencer ();
    inf2.computeMarginals (mdl2);
    JunctionTree jt2 = inf2.lookupJunctionTree ();

    double entropy = jt1.entropy ();
    double energy = 0;

    for (Iterator it = jt2.clusterPotentials ().iterator(); it.hasNext();) {
      Factor marg2 = (Factor) it.next ();
      Factor marg1 = inf1.lookupMarginal (marg2.varSet ());
      for (AssignmentIterator assnIt = marg2.assignmentIterator (); assnIt.hasNext();) {
        energy += marg1.value (assnIt) * marg2.logValue (assnIt);
        assnIt.advance();
      }
    }
    for (Iterator it = jt2.sepsetPotentials ().iterator(); it.hasNext();) {
      Factor marg2 = (Factor) it.next ();
      Factor marg1 = inf1.lookupMarginal (marg2.varSet ());
      for (AssignmentIterator assnIt = marg2.assignmentIterator (); assnIt.hasNext();) {
        energy -= marg1.value (assnIt) * marg2.logValue (assnIt);
        assnIt.advance();
      }
    }

    return -entropy - energy;
  }

  public static void removeConstantFactors (FactorGraph sliced)
  {
    List factors = new ArrayList (sliced.factors ());
    for (Iterator it = factors.iterator (); it.hasNext();) {
      Factor factor = (Factor) it.next ();
      if (factor instanceof ConstantFactor) {
        sliced.divideBy (factor);
      }
    }
  }
}
TOP

Related Classes of cc.mallet.grmm.util.Models

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.