Package cc.mallet.classify

Source Code of cc.mallet.classify.MaxEntOptimizableByGE

/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.classify;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;

import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletProgressMessageLogger;

/**
* Training of MaxEnt models with labeled features using
* Generalized Expectation Criteria.
*
* Based on:
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum
* SIGIR 2008
*
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*/

/**
* @author gdruck
*
*/
public class MaxEntOptimizableByGE implements Optimizable.ByGradientValue {
 
  private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntOptimizableByGE.class.getName()+"-pl");
 
  protected boolean cacheStale = true;
  protected int defaultFeatureIndex;
  protected double temperature;
  protected double objWeight;
  protected double cachedValue;
  protected double gaussianPriorVariance;
  protected double[] cachedGradient;
  protected double[] parameters;
  protected InstanceList trainingList;
  protected MaxEnt classifier;
  protected ArrayList<MaxEntGEConstraint> constraints;
 
  /**
   * @param trainingList List with unlabeled training instances.
   * @param constraints Feature expectation constraints.
   * @param initClassifier Initial classifier.
   */
  public MaxEntOptimizableByGE(InstanceList trainingList, ArrayList<MaxEntGEConstraint> constraints, MaxEnt initClassifier) {
    temperature = 1.0;
    objWeight = 1.0;
    gaussianPriorVariance = 1.0;
    this.trainingList = trainingList;
   
    int numFeatures = trainingList.getDataAlphabet().size();
    defaultFeatureIndex = numFeatures;
    int numLabels = trainingList.getTargetAlphabet().size();

    cachedGradient = new double[(numFeatures + 1) * numLabels];
    cachedValue = 0;
      
    if (initClassifier != null) {
      this.parameters = initClassifier.parameters;
      this.classifier = initClassifier;
    }
    else {
      this.parameters = new double[(numFeatures + 1) * numLabels];
      this.classifier = new MaxEnt(trainingList.getPipe(),parameters);
    }
   
     this.constraints = constraints;
    
     for (MaxEntGEConstraint constraint : constraints) {
       constraint.preProcess(trainingList);
     }
  }
 
  /**
   * Sets the variance for Gaussian prior or
   * equivalently the inverse of the weight
   * of the L2 regularization term.
   *
   * @param variance Gaussian prior variance.
   */
  public void setGaussianPriorVariance(double variance) {
    this.gaussianPriorVariance = variance;
  }
 
 
  /**
   * Model probabilities are raised to the power 1/temperature and
   * renormalized. As the temperature decreases, model probabilities
   * approach 1 for the maximum probability class, and 0 for other classes.
   *
   * DEFAULT: 1 
   *
   * @param temp Temperature.
   */
  public void setTemperature(double temp) {
    this.temperature = temp;
  }
 
  /**
   * The weight of GE term in the objective function.
   *
   * @param weight GE term weight.
   */
  public void setWeight(double weight) {
    this.objWeight = weight;
  }
 
  public MaxEnt getClassifier() {
    return classifier;
  }

  public double getValue() {
    if (!cacheStale) {
      return cachedValue;
    }
   
    if (objWeight == 0) {
      return 0.0;
    }
   
    for (MaxEntGEConstraint constraint : constraints) {
      constraint.zeroExpectations();
    }
   
    Arrays.fill(cachedGradient,0);

    int numFeatures = trainingList.getDataAlphabet().size() + 1;
    int numLabels = trainingList.getTargetAlphabet().size();    

    double[][] scores = new double[trainingList.size()][numLabels];
    double[] constraintValue = new double[numLabels];
   
    // pass 1: calculate model distribution
    for (int ii = 0; ii < trainingList.size(); ii++) {
      Instance instance = trainingList.get(ii);
      double instanceWeight = trainingList.getInstanceWeight(instance);
     
      // skip if labeled
      if (instance.getTarget() != null) {
        continue;
      }
     
      FeatureVector fv = (FeatureVector) instance.getData();
      classifier.getClassificationScoresWithTemperature(instance, temperature, scores[ii]);
      for (MaxEntGEConstraint constraint : constraints) {
        constraint.computeExpectations(fv,scores[ii],instanceWeight);
      }
    }
   
    // compute value
    double value = 0;
    for (MaxEntGEConstraint constraint : constraints) {
      value += constraint.getValue();
    }
    value *= objWeight;

    // pass 2: determine per example gradient
    for (int ii = 0; ii < trainingList.size(); ii++) {
      Instance instance = trainingList.get(ii);
     
      // skip if labeled
      if (instance.getTarget() != null) {
        continue;
      }
     
      Arrays.fill(constraintValue,0);
      double instanceExpectation = 0;
      double instanceWeight = trainingList.getInstanceWeight(instance);
      FeatureVector fv = (FeatureVector) instance.getData();

      for (MaxEntGEConstraint constraint : constraints) {
        constraint.preProcess(fv);
        for (int label = 0; label < numLabels; label++) {
          double val = constraint.getCompositeConstraintFeatureValue(fv, label);
          constraintValue[label] += val;
          instanceExpectation += val * scores[ii][label];
        }
      }

      for (int label = 0; label < numLabels; label++) {
        if (scores[ii][label] == 0) continue;
        assert (!Double.isInfinite(scores[ii][label]));
        double weight = objWeight * instanceWeight * scores[ii][label] * (constraintValue[label] - instanceExpectation) / temperature;
        assert(!Double.isNaN(weight));
        MatrixOps.rowPlusEquals(cachedGradient, numFeatures, label, fv, weight);
        cachedGradient[numFeatures * label + defaultFeatureIndex] += weight;
     
    }

    cachedValue = value;
    cacheStale = false;
   
    double reg = getRegularization();
    progressLogger.info ("Value (GE=" + value + " Gaussian prior= " + reg + ") = " + cachedValue);
   
    return cachedValue;
  }

  protected double getRegularization() {
    double regularization = 0;
    for (int pi = 0; pi < parameters.length; pi++) {
      double p = parameters[pi];
      regularization -= p * p / (2 * gaussianPriorVariance);
      cachedGradient[pi] -= p / gaussianPriorVariance;
    }
    cachedValue += regularization;
    return regularization;
  }
 
  public void getValueGradient(double[] buffer) {
    if (cacheStale) {
      getValue()
    }
    assert(buffer.length == cachedGradient.length);
    System.arraycopy (cachedGradient, 0, buffer, 0, buffer.length);
  }

  public int getNumParameters() {
    return parameters.length;
  }

  public double getParameter(int index) {
    return parameters[index];
  }

  public void getParameters(double[] buffer) {
    assert(buffer.length == parameters.length);
    System.arraycopy (parameters, 0, buffer, 0, buffer.length);
  }

  public void setParameter(int index, double value) {
    cacheStale = true;
    parameters[index] = value;
  }

  public void setParameters(double[] params) {
    assert(params.length == parameters.length);
    cacheStale = true;
    System.arraycopy (params, 0, parameters, 0, parameters.length);
  }
}
TOP

Related Classes of cc.mallet.classify.MaxEntOptimizableByGE

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.