Package cc.mallet.cluster

Source Code of cc.mallet.cluster.GreedyAgglomerative

package cc.mallet.cluster;

import java.util.logging.Logger;

import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor;
import cc.mallet.cluster.neighbor_evaluator.Neighbor;
import cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator;
import cc.mallet.cluster.util.ClusterUtils;
import cc.mallet.cluster.util.PairwiseMatrix;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletProgressMessageLogger;


/**
* Greedily merges Instances until convergence. New merges are scored
* using {@link NeighborEvaluator}.
*
* @author "Aron Culotta" <culotta@degas.cs.umass.edu>
* @version 1.0
* @since 1.0
* @see HillClimbingClusterer
*/
public class GreedyAgglomerative extends HillClimbingClusterer {

 
  private static final long serialVersionUID = 1L;

  private static Logger progressLogger =
    MalletProgressMessageLogger.getLogger(GreedyAgglomerative.class.getName()+"-pl");

  /**
   * Converged when merge score is below this value.
   */
  protected double stoppingThreshold;

  /**
   * True if should stop clustering.
   */
   protected boolean converged;

  /**
   * Cache for calls to {@link NeighborhoodEvaluator}. In some
   * experiments, reduced running time by nearly half.
   */
   protected PairwiseMatrix scoreCache;
 
  /**
   *
   * @param instancePipe Pipe for each underying {@link Instance}.
   * @param evaluator To score potential merges.
   * @param stoppingThreshold Clustering converges when the evaluator score is below this value.
   * @return
   */
  public GreedyAgglomerative (Pipe instancePipe,
                              NeighborEvaluator evaluator,
                              double stoppingThreshold) {
    super(instancePipe, evaluator);   
    this.stoppingThreshold = stoppingThreshold;
    this.converged = false;
  }

  /**
   *
   * @param instances
   * @return A singleton clustering (each Instance in its own cluster).
   */
  public Clustering initializeClustering (InstanceList instances) {
    reset();
    return ClusterUtils.createSingletonClustering(instances);
  }

  public boolean converged (Clustering clustering) {
    return converged;
  }

  /**
   * Reset convergence to false so a new round of clustering can begin.
   */
  public void reset () {
    converged = false;
    scoreCache = null;
    evaluator.reset();
  }
 
  /**
   * For each pair of clusters, calculate the score of the {@link Neighbor}
   * that would result from merging the two clusters. Choose the merge that
   * obtains the highest score. If no merge improves score, return original
   * Clustering
   *
   * @param clustering
   * @return
   */
  public Clustering improveClustering (Clustering clustering) {
    double bestScore = Double.NEGATIVE_INFINITY;
    int[] toMerge = new int[]{-1,-1};
    for (int i = 0; i < clustering.getNumClusters(); i++) {
      for (int j = i + 1; j < clustering.getNumClusters(); j++) {
        double score = getScore(clustering, i, j);
        if (score > bestScore) {
          bestScore = score;
          toMerge[0] = i;
          toMerge[1] = j;
        }       
      }
    }
   
    converged = (bestScore < stoppingThreshold);

    if (!(converged)) {
      progressLogger.info("Merging " + toMerge[0] + "(" + clustering.size(toMerge[0]) +
                          " nodes) and " + toMerge[1] + "(" + clustering.size(toMerge[1]) +
                          " nodes) [" + bestScore + "] numClusters=" +
                          clustering.getNumClusters());
      updateScoreMatrix(clustering, toMerge[0], toMerge[1]);
      clustering = ClusterUtils.mergeClusters(clustering, toMerge[0], toMerge[1]);
    } else {
      progressLogger.info("Converged with score " + bestScore);
    }
    return clustering;
  }
 
  /**
   *
   * @param clustering
   * @param i
   * @param j
   * @return The score for merging these two clusters.
   */
  protected double getScore (Clustering clustering, int i, int j) {
    if (scoreCache == null)
      scoreCache = new PairwiseMatrix(clustering.getNumInstances());

    int[] ci = clustering.getIndicesWithLabel(i);
    int[] cj = clustering.getIndicesWithLabel(j);
    if (scoreCache.get(ci[0], cj[0]) == 0.0) {
      double val = evaluator.evaluate(
        new AgglomerativeNeighbor(clustering,
                                  ClusterUtils.copyAndMergeClusters(clustering, i, j),
                                  ci, cj));
      for (int ni = 0; ni < ci.length; ni++)
        for (int nj = 0; nj < cj.length; nj++)
          scoreCache.set(ci[ni], cj[nj], val);
    }

    return scoreCache.get(ci[0], cj[0]);                           
  }

  /**
   * Resets the values of clusters that have been merged.
   * @param clustering
   * @param i
   * @param j
   */
  protected void updateScoreMatrix (Clustering clustering, int i, int j) {
    int size = clustering.getNumInstances();
    int[] ci = clustering.getIndicesWithLabel(i);
    for (int ni = 0; ni < ci.length; ni++) {
      for (int nj = 0; nj < size; nj++)
        if (ci[ni] != nj)
          scoreCache.set(ci[ni], nj, 0.0);
    }
    int[] cj = clustering.getIndicesWithLabel(j);
    for (int ni = 0; ni < cj.length; ni++) {
      for (int nj = 0; nj < size; nj++)
        if (cj[ni] != nj)
          scoreCache.set(cj[ni], nj, 0.0);
    }
  }
   
  public String toString () {
    return "class=" + this.getClass().getName() +
      "\nstoppingThreshold=" + stoppingThreshold +
      "\nneighborhoodEvaluator=[" + evaluator + "]";   
  }
}
TOP

Related Classes of cc.mallet.cluster.GreedyAgglomerative

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.