Package cc.mallet.grmm.inference

Source Code of cc.mallet.grmm.inference.JunctionTree$Sepset

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.grmm.inference;

import java.util.Set;

import java.util.HashSet;
import java.util.Iterator;
import java.util.Collection;
import java.util.List;
import java.util.Arrays;

import cc.mallet.grmm.types.*;

import gnu.trove.TIntObjectHashMap;
import gnu.trove.THashSet;
import gnu.trove.TIntObjectIterator;


/**
* Datastructure for a junction tree.
*
* Created: Tue Sep 30 10:30:25 2003
*
* @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a>
* @version $Id: JunctionTree.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
public class JunctionTree extends Tree {

  private int numNodes;

  private static class Sepset {

    Sepset(Set s, Factor p)
    {
      set = s;
      ptl = p;
    }


    Set set;
    Factor ptl;
  }

  private TIntObjectHashMap sepsets;
  private Factor[] cpfs;


  public JunctionTree(int size)
  {
    super();

    numNodes = size;
    sepsets = new TIntObjectHashMap();
    cpfs = new Factor[size];
  } // JunctionTree constructor


  public void addNode (Object parent1, Object child1)
  {
    super.addNode(parent1, child1);
    VarSet parent = (VarSet) parent1;
    VarSet child = (VarSet) child1;
    Set sepset = parent.intersection(child);
    int id1 = lookupIndex(parent);
    int id2 = lookupIndex(child);
    putSepset(id1, id2, new Sepset (sepset, newSepsetPtl (sepset)));
  }

  private Factor newSepsetPtl (Set sepset)
  {
    if (sepset.isEmpty ()) {
      // use identity factor
      return ConstantFactor.makeIdentityFactor ();
    } else {
      return new TableFactor (sepset);
    }
  }

  private int hashIdxIdx(int id1, int id2)
  {
    assert (id1 < 65536) && (id2 < 65536);

    int id;
    if (id1 < id2) {
      id = (id1 << 16) | id2;
    } else {
      id = (id2 << 16) | id1;
    }
    return id;
  }


  private void putSepset(int id1, int id2, Sepset sepset)
  {
    int id = hashIdxIdx(id1, id2);
    sepsets.put(id, sepset);
  }


  private Sepset getSepset(int id1, int id2)
  {
    int id = hashIdxIdx(id1, id2);
    return (Sepset) sepsets.get(id);
  }

  //  CPF accessors

  public Factor getCPF(VarSet c)
  {
    return cpfs[lookupIndex(c)];
  }


  public void setCPF(VarSet c, Factor pot)
  {
    cpfs[lookupIndex(c)] = pot;
  }


  void clearCPFs()
  {
    for (int i = 0; i < cpfs.length; i++) {
      cpfs[i] = new TableFactor ((VarSet) lookupVertex (i));
    }

    TIntObjectIterator it = sepsets.iterator();
    while (it.hasNext()) {
      it.advance();
      Sepset sepset = (Sepset) it.value();
      sepset.ptl = newSepsetPtl (sepset.set);
    }

  }


  public Set sepsetPotentials()
  {
    THashSet set = new THashSet();
    TIntObjectIterator it = sepsets.iterator();
    while (it.hasNext()) {
      it.advance();
      Factor ptl = ((Sepset) it.value()).ptl;
      set.add(ptl);
    }

    return set;
  }


  void setSepsetPot(Factor pot, VarSet v1, VarSet v2)
  {
    int id1 = lookupIndex(v1);
    int id2 = lookupIndex(v2);
    getSepset(id1, id2).ptl = pot;
  }


  public Factor getSepsetPot(VarSet v1, VarSet v2)
  {
    int id1 = lookupIndex(v1);
    int id2 = lookupIndex(v2);
    return getSepset(id1, id2).ptl;
  }

  /**
   * Returns a collection of all the potentials of cliques in the junction tree.
   *  (i.e., these are the terms in the numerator of the jounction tre theorem).
   * @see #sepsetPotentials()
   */
  public Collection clusterPotentials ()
  {
    HashSet h = new HashSet();
    for (int i = 0; i < cpfs.length; i++) {
      if (cpfs[i] != null) {
        h.add(cpfs[i]);
      }
    }
    return h;
  }


  public Set getSepset(VarSet v1, VarSet v2)
  {
    int id1 = lookupIndex(v1);
    int id2 = lookupIndex(v2);
    return getSepset(id1, id2).set;
  }


  public Factor lookupMarginal(Variable var)
  {
    VarSet c = findParentCluster(var);
    Factor pot = getCPF(c);
    return pot.marginalize(var);
  }


  public double lookupLogJoint(Assignment assn)
  {
    double accum = 0;
    for (int i = 0; i < cpfs.length; i++) {
      if (cpfs[i] != null) {
        double phi = cpfs[i].logValue (assn);
        accum += phi;
      }
    }

    TIntObjectIterator it = sepsets.iterator();
    while (it.hasNext()) {
      it.advance();
      Factor ptl = ((Sepset) it.value()).ptl;
      double phi = ptl.logValue (assn);
      accum -= phi;
    }

    return accum;
  }


  /** Returns a cluster in the tree that contains var. */
  public VarSet findParentCluster(Variable var)
  {
    int best = Integer.MAX_VALUE;
    VarSet retval = null;
    // xxx Inefficient
    for (Iterator it = getVerticesIterator(); it.hasNext();) {
      VarSet c = (VarSet) it.next();
      if (c.contains(var) && c.weight() < best) {
        retval = c;
        best = c.weight();
      }
    }
    return retval;
  }


  /**
   * Returns a cluster in the tree that contains all the vars in a
   *   collection.
   */
  public VarSet findParentCluster(Collection vars)
  {
    int best = Integer.MAX_VALUE;
    VarSet retval = null;
    // xxx Inefficient
    for (Iterator it = getVerticesIterator(); it.hasNext();) {
      VarSet c = (VarSet) it.next();
      if (c.containsAll(vars) && c.weight() < best) {
        retval = c;
        best = c.weight();
      }
    }
    return retval;
  }


  /** Returns a cluster in the tree that contains exactly the given
   *   variables, or null if no such cluster exists. */
  public VarSet findCluster(Variable[] vars)
  {
    List l = Arrays.asList(vars);
    for (Iterator it = getVerticesIterator(); it.hasNext();) {
      VarSet c2 = (VarSet) it.next();
      if (c2.containsAll(l) && l.containsAll(c2))
        return c2;
    }
    return null;
  }


  /** Normalizes all potentials in the tree, both node and sepset. */
  public void normalizeAll()
  {
    int n = cpfs.length;
    for (int i = 0; i < n; i++) {
      if (cpfs[i] != null) {
        cpfs[i].normalize();
      }
    }

    TIntObjectIterator it = sepsets.iterator();
    while (it.hasNext()) {
      it.advance();
      Factor ptl = ((Sepset) it.value()).ptl;
      ptl.normalize();
    }
  }


  int getId(VarSet c)
  {
    return lookupIndex(c);
  }

// Debugging functions

  public void dump ()
  {
    int n = cpfs.length;
    // This will cause OpenJGraph to print all our nodes and edges
    System.out.println(this);
    // Now lets print all the cpfs
    System.out.println("Vertex CPFs");
    for (int i = 0; i < n; i++) {
      if (cpfs[i] != null) {
        System.out.println("CPF "+i+" "+cpfs[i].dumpToString ());
      }
    }

    // And the sepset potentials
    System.out.println("sepset CPFs");
    TIntObjectIterator it = sepsets.iterator();
    while (it.hasNext()) {
      it.advance();
      Factor ptl = ((Sepset) it.value()).ptl;
      System.out.println(ptl.dumpToString ());
    }
    System.out.println ("/End JT");
  }

  public double dumpLogJoint (Assignment assn)
  {
    double accum = 0;
    for (int i = 0; i < cpfs.length; i++) {
      if (cpfs[i] != null) {
        double phi = cpfs[i].logValue (assn);
        System.out.println ("CPF "+i+" accum = "+accum);
      }
    }

    TIntObjectIterator it = sepsets.iterator();
    while (it.hasNext()) {
      it.advance();
      Factor ptl = ((Sepset) it.value()).ptl;
      double phi = ptl.logValue (assn);
      System.out.println("Sepset "+ptl.varSet()+" accum "+accum);
    }

    return accum;
  }

  public boolean isNaN()
  {
    int n = cpfs.length;
    for (int i = 0; i < n; i++)
      if (cpfs[i].isNaN()) return true;

    // And the sepset potentials
    TIntObjectIterator it = sepsets.iterator();
    while (it.hasNext()) {
      it.advance();
      Factor ptl = ((Sepset) it.value()).ptl;
      if (ptl.isNaN()) return true;
    }

    return false;
  }

  public double entropy ()
  {
    double entropy = 0;
    for (Iterator it = clusterPotentials ().iterator (); it.hasNext ();) {
      Factor ptl = (Factor) it.next ();
      entropy += ptl.entropy ();
    }
    for (Iterator it = sepsetPotentials ().iterator (); it.hasNext ();) {
      Factor ptl = (Factor) it.next ();
      entropy -= ptl.entropy ();
    }
    return entropy;
  }



// Implementation of edu.umass.cs.mallet.users.casutton.graphical.Compactible

  public void decompact()
  {
    cpfs = new Factor[numNodes];
    clearCPFs();
  }


  public void compact()
  {
    cpfs = null;
  }

} // JunctionTree
TOP

Related Classes of cc.mallet.grmm.inference.JunctionTree$Sepset

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.