Package cc.mallet.grmm.inference.gbp

Source Code of cc.mallet.grmm.inference.gbp.Kikuchi4SquareRegionGenerator

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.inference.gbp;


import java.util.Iterator;

import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.UndirectedGrid;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.ArrayUtils;

/**
* Created: May 31, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: Kikuchi4SquareRegionGenerator.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $
*/
public class Kikuchi4SquareRegionGenerator implements RegionGraphGenerator {

  public RegionGraph constructRegionGraph (FactorGraph mdl)
  {
    if (mdl instanceof UndirectedGrid) {
      RegionGraph rg = new RegionGraph ();

      UndirectedGrid grid = (UndirectedGrid) mdl;

      // First set up regions for all
      for (int x = 0; x < grid.getWidth () - 1; x++) {
        for (int y = 0; y < grid.getHeight () - 1; y++) {
          Variable[] vars = new Variable[] {
           grid.get (x, y),
           grid.get (x+1, y),
           grid.get (x+1, y+1),
           grid.get (x, y+1), };

          Factor[] edges = new Factor[] {
            mdl.factorOf (vars[0], vars[1]),
           mdl.factorOf (vars[1], vars[2]),
           mdl.factorOf (vars[2], vars[3]),
           mdl.factorOf (vars[0], vars[3]), };

          // Create region for 4-clique
          Region fourSquare = new Region (vars, edges);

          // Create 1-clique region
          for (int i = 0; i < 4; i++) {
            Variable var = vars[i];
            Factor ptl = mdl.factorOf (var);
            if (ptl != null) {
              fourSquare.factors.add (ptl);
            }
          }

          // Finally create edge regions, and connect to everyone else
          for (int i = 0; i < 4; i++) {
            Factor edgePtl = edges[i];
            Region edgeRgn = rg.findRegion (edgePtl, true);
            rg.add (fourSquare, edgeRgn);

            Variable v1 = (Variable) edgeRgn.vars.get (0);
            Region nodeRgn = createVarRegion (rg, mdl, v1);
            edgeRgn.factors.addAll (nodeRgn.factors);
            rg.add (edgeRgn, nodeRgn);

            Variable v2 = (Variable) edgeRgn.vars.get (1);
            nodeRgn = createVarRegion (rg, mdl, v2);
            edgeRgn.factors.addAll (nodeRgn.factors);

            rg.add (edgeRgn, nodeRgn);
          }
        }
      }

      rg.computeInferenceCaches ();

      return rg;

    } else {
      throw new UnsupportedOperationException ("Kikuchi4SquareRegionGenerator requires that you use UndirectedGrid.");
    }
  }

  private Region createVarRegion (RegionGraph rg, FactorGraph mdl, Variable v1)
  {
    Factor ptl = mdl.factorOf (v1);
    if (ptl == null) {
      return rg.findRegion (v1, true);
    } else {
      return rg.findRegion (ptl, true);
    }
  }

  private void checkAllSingles (RegionGraph rg, Region[] nodeRegions)
  {
    for (Iterator it = rg.iterator (); it.hasNext ();) {
      Region region = (Region) it.next ();
      if (region.vars.size() == 1) {
        if (ArrayUtils.indexOf (nodeRegions, region) < 0) {
          throw new IllegalStateException ("huh?");
        }
      }
    }
  }

  private void checkTooManyDoubles (RegionGraph rg, FactorGraph mdl)
  {
    int nv = mdl.factors ().size ();
    int doubles = 0;
    for (Iterator it = rg.iterator (); it.hasNext ();) {
      Region region = (Region) it.next ();
      if (region.vars.size() == 2)
        doubles++;
    }

    if (doubles > nv) {
      throw new IllegalStateException ("huh? ");
    }
  }

}
TOP

Related Classes of cc.mallet.grmm.inference.gbp.Kikuchi4SquareRegionGenerator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.