Package cc.mallet.grmm.inference

Source Code of cc.mallet.grmm.inference.JunctionTreeInferencer

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.grmm.inference;

import org._3pq.jgrapht.GraphHelper;
import org._3pq.jgrapht.UndirectedGraph;
import org._3pq.jgrapht.alg.ConnectivityInspector;
import org._3pq.jgrapht.graph.SimpleGraph;
import org._3pq.jgrapht.graph.ListenableUndirectedGraph;
import org._3pq.jgrapht.traverse.BreadthFirstIterator;

import cc.mallet.grmm.types.*;
import cc.mallet.grmm.util.Graphs;
import cc.mallet.types.Alphabet;
import cc.mallet.util.MalletLogger;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.*;
import java.util.logging.Level;
import java.util.logging.Logger;


/**
* Does inference in general graphical models using
*  the Hugin junction tree algorithm.
*
* Created: Mon Nov 10 23:58:44 2003
*
* @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a>
* @version $Id: JunctionTreeInferencer.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
public class JunctionTreeInferencer extends AbstractInferencer {

  private static Logger logger = MalletLogger.getLogger(JunctionTreeInferencer.class.getName());
  private boolean inLogSpace;
  private JunctionTreePropagation propagator;

  public JunctionTreeInferencer()
  {
    this (JunctionTreePropagation.createSumProductInferencer ());
  } // JunctionTreeInferencer constructor

  public JunctionTreeInferencer (JunctionTreePropagation propagator)
  {
    this.propagator = propagator;
  }

  public static JunctionTreeInferencer createForMaxProduct ()
  {
    return new JunctionTreeInferencer (JunctionTreePropagation.createMaxProductInferencer ());
  }


  private boolean isAdjacent (UndirectedGraph g, Variable v1, Variable v2)
  {
    return g.getEdge (v1, v2) != null;
  }


  transient protected JunctionTree jtCurrent;
  transient private ArrayList cliques;


  /**
   * Returns the number of edges that would be added to a graph if a
   *  given vertex would be removed in the triangulation procedure.
   *  The return value is the number of edges in the elimination
   *  clique of V that are not already present.
   */
  private int newEdgesRequired(UndirectedGraph mdl, Variable v)
  {
    int rating = 0;

    for (Iterator it1 = neighborsIterator (mdl,v); it1.hasNext();) {
      Variable neighbor1 = (Variable) it1.next();
      Iterator it2 = neighborsIterator (mdl,v);
      while (it2.hasNext()) {
        Variable neighbor2 = (Variable) it2.next();
        if (neighbor1 != neighbor2) {
          if (!isAdjacent (mdl, neighbor1, neighbor2)) {
            rating++;
          }
        }
      }
    }

//    System.out.println(v+" = "+rating);

    return rating;
  }


  /**
   * Returns the weight of the clique that would be added to a graph if a
   *  given vertex would be removed in the triangulation procedure.
   *  The return value is the number of edges in the elimination
   *  clique of V that are not already present.
   */
  private int weightRequired (UndirectedGraph mdl, Variable v)
  {
    int rating = 1;

    for (Iterator it1 = neighborsIterator (mdl,v); it1.hasNext();) {
      Variable neighbor = (Variable) it1.next();
      rating *= neighbor.getNumOutcomes();
    }

//    System.out.println(v+" = "+rating);

    return rating;
  }


  private void connectNeighbors(UndirectedGraph mdl, Variable v)
  {
    for (Iterator it1 = neighborsIterator(mdl,v); it1.hasNext();) {
      Variable neighbor1 = (Variable) it1.next();
      Iterator it2 = neighborsIterator(mdl,v);
      while (it2.hasNext()) {
        Variable neighbor2 = (Variable) it2.next();
        if (neighbor1 != neighbor2) {
          if (!isAdjacent (mdl, neighbor1, neighbor2)) {
            try {
              mdl.addEdge(neighbor1, neighbor2);
            } catch (Exception e) {
              throw new RuntimeException(e);
            }
          }
        }
      }
    }
  }


  // xx should refactor into Collections.any (Coll, TObjectProc)
  /* Return true iff a clique in L strictly contains c. */
  private boolean findSuperClique(List l, VarSet c)
  {
    for (Iterator it = l.iterator(); it.hasNext();) {
      VarSet c2 = (VarSet) it.next();
      if (c2.containsAll(c)) {
        return true;
      }
    }
    return false;
  }


  // works like the obscure <=> operator in Perl.
  private static int cmp(int i1, int i2)
  {
    if (i1 < i2) {
      return -1;
    } else if (i1 > i2) {
      return 1;
    } else {
      return 0;
    }
  }

  public Variable pickVertexToRemove (UndirectedGraph mdl, ArrayList lst)
  {
    Iterator it = lst.iterator();
    Variable best = (Variable) it.next();
    int bestVal1 = newEdgesRequired (mdl, best);
    int bestVal2 = weightRequired (mdl, best);

    while (it.hasNext()) {
      Variable v = (Variable) it.next();
      int val = newEdgesRequired (mdl, v);
      if (val < bestVal1) {
        best = v;
        bestVal1 = val;
        bestVal2 = weightRequired (mdl, v);
      } else if (val == bestVal1) {
        int val2 = weightRequired (mdl, v);
        if (val2 < bestVal2) {
          best = v;
          bestVal1 = val;
          bestVal2 = val2;
        }
      }
    }

    return best;
  }


  /**
   * Adds edges to graph until it is triangulated.
   */
  private void triangulate(final UndirectedGraph mdl)
  {
    UndirectedGraph mdl2 = dupGraph (mdl);
    ArrayList vars = new ArrayList(mdl.vertexSet());
    Alphabet varMap = makeVertexMap(vars);
    cliques = new ArrayList();

    // debug
    if (logger.isLoggable (Level.FINER)) {
      logger.finer ("Triangulating model: "+mdl);
      String ret = "";
      for (int i = 0; i < vars.size(); i++) {
        Variable next = (Variable) vars.get(i);
        ret += next.toString() + "\n"; // " (" + mdl.getIndex(next) + ")\n  ";
      }
      logger.finer(ret);
    }

    while (!vars.isEmpty()) {
      Variable v = (Variable) pickVertexToRemove (mdl2, vars);
      logger.finer("Triangulating vertex " + v);

      VarSet varSet = new BitVarSet (v.getUniverse (), GraphHelper.neighborListOf (mdl2, v));
      varSet.add(v);
      if (!findSuperClique(cliques, varSet)) {
        cliques.add(varSet);
        if (logger.isLoggable (Level.FINER)) {
          logger.finer ("  Elim clique " + varSet + " size " + varSet.size () + " weight " + varSet.weight ());
        }
      }

      // must remove V from graph first, because adding the edges
//  will change the rating of other vertices

      connectNeighbors (mdl2, v);
      vars.remove(v);
      mdl2.removeVertex (v);
    }

    if (logger.isLoggable(Level.FINE)) {
      logger.fine("Triangulation done. Cliques are: ");
      int totSize = 0, totWeight = 0, maxSize = 0, maxWeight = 0;
      for (Iterator it = cliques.iterator(); it.hasNext();) {
        VarSet c = (VarSet) it.next();
        logger.finer(c.toString());
        totSize += c.size();
        maxSize = Math.max(c.size(), maxSize);
        totWeight += c.weight();
        maxWeight = Math.max(c.weight(), maxWeight);
      }
      double sz = cliques.size();
      logger.fine("Jt created " + sz + " cliques. Size: avg " + (totSize / sz)
                  + " max " + (maxSize) + " Weight: avg " + (totWeight / sz)
                  + " max " + (maxWeight));
    }
  }


  private Alphabet makeVertexMap(ArrayList vars)
  {
    Alphabet map = new Alphabet (vars.size (), Variable.class);
    map.lookupIndices(vars.toArray(), true);
    return map;
  }


  private static int sepsetSize(BitVarSet[] pair)
  {
    assert pair.length == 2;
    return pair[0].intersectionSize(pair[1]);
  }


  private static int sepsetCost(VarSet[] pair)
  {
    assert pair.length == 2;
    return pair[0].weight() + pair[1].weight();
  }


  // Given two pairs of cliques, returns -1 if the pair o1 should be
  // added to the tree first.  We add pairs that have the largest
  // mass (number of vertices in common) to ensure that the clique
  // tree satifies the running intersection property.
  private static Comparator sepsetChooser = new Comparator() {
    public int compare(Object o1, Object o2)
    {
      if (o1 == o2) return 0;
      BitVarSet[] pair1 = (BitVarSet[]) o1;
      BitVarSet[] pair2 = (BitVarSet[]) o2;
      int size1 = sepsetSize(pair1);
      int size2 = sepsetSize(pair2);
      int retval = -cmp(size1, size2);
      if (retval == 0) {
        // Break ties by adding the sepset with the
        //  smallest cost (sum of weights of connected clusters)
        int cost1 = sepsetCost(pair1);
        int cost2 = sepsetCost(pair2);
        retval = cmp(cost1, cost2);

        // Still a tie? Break arbitrarily but consistently.
        if (retval == 0) {
          retval = cmp (o1.hashCode (), o2.hashCode ());
        }
      }
      return retval;
    }
  };


  private JunctionTree graphToJt (UndirectedGraph g)
  {
    JunctionTree jt = new JunctionTree (g.vertexSet ().size ());
    Object root = g.vertexSet ().iterator ().next ();
    jt.add (root);

    for (Iterator it1 = new BreadthFirstIterator (g, root); it1.hasNext ();) {
      Object v1 = it1.next ();
      for (Iterator it2 = GraphHelper.neighborListOf (g, v1).iterator (); it2.hasNext ();) {
        Object v2 = it2.next ();
        if (jt.getParent (v1) != v2) {
          jt.addNode (v1, v2);
        }
      }
    }
    return jt;
  }


  private JunctionTree buildJtStructure()
  {
    TreeSet pq = new TreeSet(sepsetChooser);

    // Initialize pq with all possible edges...
    for (Iterator it = cliques.iterator(); it.hasNext();) {
      BitVarSet c1 = (BitVarSet) it.next();
      for (Iterator it2 = cliques.iterator(); it2.hasNext();) {
        BitVarSet c2 = (BitVarSet) it2.next();
        if (c1 == c2) break;
        pq.add(new BitVarSet[]{c1, c2});
      }
    }

    // ...and add the edges to jt that come to the top of the queue
    //  and don't cause a cycle.
    // xxx OK, this sucks.  openjgraph doesn't allow adding
    //  disconnected edges to a tree, so what we'll do is create a
    //  Graph frist, then convert it to a Tree.
    ListenableUndirectedGraph g = new ListenableUndirectedGraph (new SimpleGraph ());

    // first add every clique to the graph
    for (Iterator it = cliques.iterator(); it.hasNext();) {
      VarSet c = (VarSet) it.next();
      g.addVertex (c);
    }

    ConnectivityInspector inspector = new ConnectivityInspector (g);
    g.addGraphListener (inspector);
   
    // then add n - 1 edges
    int numCliques = cliques.size();
    int edgesAdded = 0;
    while (edgesAdded < numCliques - 1) {
      VarSet[] pair = (VarSet[]) pq.first();
      pq.remove(pair);

      if (!inspector.pathExists(pair[0], pair[1])) {
          g.addEdge(pair[0], pair[1]);
          edgesAdded++;
      }
    }

    JunctionTree jt = graphToJt(g);
    if (logger.isLoggable (Level.FINER)) {
      logger.finer ("  jt structure was " + jt);
    }
    return jt;
  }


  private void initJtCpts(FactorGraph mdl, JunctionTree jt)
  {
    for (Iterator it = jt.getVerticesIterator(); it.hasNext();) {
      VarSet c = (VarSet) it.next();
//      DiscreteFactor ptl = createBlankFactor (c);
//      jt.setCPF(c, ptl);
      jt.setCPF (c, new ConstantFactor (1.0));
    }

    for (Iterator it = mdl.factors ().iterator(); it.hasNext();) {
      Factor ptl = (Factor) it.next();
      VarSet parent = jt.findParentCluster(ptl.varSet());
      assert parent != null
              : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt;

      Factor cpf = jt.getCPF(parent);
      Factor newCpf = cpf.multiply(ptl);
      jt.setCPF (parent, newCpf);

      /* debug
         if (jt.isNaN()) {
           throw new RuntimeException ("Got a NaN");
         }
         */
    }
  }

  private AbstractTableFactor createBlankFactor (VarSet c)
  {
    if (inLogSpace) {
      return new LogTableFactor (c);
    } else {
      return new TableFactor (c);
    }
  }


  public void computeMarginals (FactorGraph mdl)
  {
    inLogSpace = mdl.getFactor (0) instanceof LogTableFactor;
    buildJunctionTree(mdl);
    propagator.computeMarginals(jtCurrent);
    totalMessagesSent += propagator.getTotalMessagesSent();
  }

  public void computeMarginals (JunctionTree jt)
  {
    inLogSpace = false; //??
    jtCurrent = jt;
    propagator.computeMarginals(jtCurrent);
    totalMessagesSent += propagator.getTotalMessagesSent();
  }



  /**
   * Constructs a junction tree from a given factor graph.  Does not perform BP in the resulting
   *  graph.  So this gives you the structure of a jnuction tree, but the factors don't correspond
   *  to the true marginals unless you call BP yourself.
   * @param mdl Factor graph to compute JT for.
   */
  public JunctionTree buildJunctionTree(FactorGraph mdl)
  {
    jtCurrent = (JunctionTree) mdl.getInferenceCache(JunctionTreeInferencer.class);
    if (jtCurrent != null) {
      jtCurrent.clearCPFs();
    } else {
      /* The graph g is the topology of the MRF that corresponds to the factor graph mdl.
       * Essentially, this means that we triangulate factor graphs by converting to an MRF first.
       * I could have chosen to trianglualte the FactorGraph directly, but I didn't for historical reasons
       *  (I already had a version of triangulate() for MRFs, not bipartite factor graphs.)
       * Note that the call to mdlToGraph() is perfectly valid for FactorGraphs that are also DirectedModels,
       *  and has the effect of moralizing in that case.  */
      UndirectedGraph g = Graphs.mdlToGraph (mdl);
      triangulate (g);
      jtCurrent = buildJtStructure();
      mdl.setInferenceCache(JunctionTreeInferencer.class, jtCurrent);
    }

    initJtCpts(mdl, jtCurrent);
    return jtCurrent;
  }

  private UndirectedGraph dupGraph (UndirectedGraph original)
  {
    UndirectedGraph copy = new SimpleGraph ();
    GraphHelper.addGraph (copy, original);
    return copy;
  }


  public Factor lookupMarginal(Variable var)
  {
    return propagator.lookupMarginal (jtCurrent, var);
  }


  public Factor lookupMarginal(VarSet varSet)
  {
    return propagator.lookupMarginal (jtCurrent, varSet);
  }


  public double lookupLogJoint(Assignment assn)
  {
    return jtCurrent.lookupLogJoint(assn);
  }


  public double dumpLogJoint(Assignment assn)
  {
    return jtCurrent.dumpLogJoint(assn);
  }

  /**
   * Returns the JunctionTree computed from the last call to
   *  {@link #computeMarginals}.  Caller must not modify return value.
   */
  public JunctionTree lookupJunctionTree ()
  {
    return jtCurrent;
  }

  private Iterator neighborsIterator (UndirectedGraph g, Variable v)
  {
    return GraphHelper.neighborListOf (g, v).iterator ();
  }

  public void dump ()
  {
    if (jtCurrent != null) {
      System.out.println("Current junction tree");
      jtCurrent.dump();
    } else {
      System.out.println("NO current junction tree");
    }
  }


  transient private int totalMessagesSent = 0;

  /**
   * Returns the total number of messages this inferencer has sent.
   */
  public int getTotalMessagesSent () { return totalMessagesSent; }


  // Serialization
  private static final long serialVersionUID = 1;

  // If seralization-incompatible changes are made to these classes,
  //  then smarts can be added to these methods for backward compatibility.
  private void writeObject (ObjectOutputStream out) throws IOException {
     out.defaultWriteObject ();
   }

  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
     in.defaultReadObject ();
  }

} // JunctionTreeInferencer
TOP

Related Classes of cc.mallet.grmm.inference.JunctionTreeInferencer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.