/* Copyright (C) 2011 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.classify;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletProgressMessageLogger;
/**
* Training of MaxEnt models with labeled features using
* Generalized Expectation Criteria.
*
* Based on:
* "Learning from Labeled Features using Generalized Expectation Criteria"
* Gregory Druck, Gideon Mann, Andrew McCallum
* SIGIR 2008
*
* @author Gregory Druck <a href="mailto:gdruck@cs.umass.edu">gdruck@cs.umass.edu</a>
*/
/**
* @author gdruck
*
*/
public class MaxEntOptimizableByGE implements Optimizable.ByGradientValue {
private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntOptimizableByGE.class.getName()+"-pl");
protected boolean cacheStale = true;
protected int defaultFeatureIndex;
protected double temperature;
protected double objWeight;
protected double cachedValue;
protected double gaussianPriorVariance;
protected double[] cachedGradient;
protected double[] parameters;
protected InstanceList trainingList;
protected MaxEnt classifier;
protected ArrayList<MaxEntGEConstraint> constraints;
/**
* @param trainingList List with unlabeled training instances.
* @param constraints Feature expectation constraints.
* @param initClassifier Initial classifier.
*/
public MaxEntOptimizableByGE(InstanceList trainingList, ArrayList<MaxEntGEConstraint> constraints, MaxEnt initClassifier) {
temperature = 1.0;
objWeight = 1.0;
gaussianPriorVariance = 1.0;
this.trainingList = trainingList;
int numFeatures = trainingList.getDataAlphabet().size();
defaultFeatureIndex = numFeatures;
int numLabels = trainingList.getTargetAlphabet().size();
cachedGradient = new double[(numFeatures + 1) * numLabels];
cachedValue = 0;
if (initClassifier != null) {
this.parameters = initClassifier.parameters;
this.classifier = initClassifier;
}
else {
this.parameters = new double[(numFeatures + 1) * numLabels];
this.classifier = new MaxEnt(trainingList.getPipe(),parameters);
}
this.constraints = constraints;
for (MaxEntGEConstraint constraint : constraints) {
constraint.preProcess(trainingList);
}
}
/**
* Sets the variance for Gaussian prior or
* equivalently the inverse of the weight
* of the L2 regularization term.
*
* @param variance Gaussian prior variance.
*/
public void setGaussianPriorVariance(double variance) {
this.gaussianPriorVariance = variance;
}
/**
* Model probabilities are raised to the power 1/temperature and
* renormalized. As the temperature decreases, model probabilities
* approach 1 for the maximum probability class, and 0 for other classes.
*
* DEFAULT: 1
*
* @param temp Temperature.
*/
public void setTemperature(double temp) {
this.temperature = temp;
}
/**
* The weight of GE term in the objective function.
*
* @param weight GE term weight.
*/
public void setWeight(double weight) {
this.objWeight = weight;
}
public MaxEnt getClassifier() {
return classifier;
}
public double getValue() {
if (!cacheStale) {
return cachedValue;
}
if (objWeight == 0) {
return 0.0;
}
for (MaxEntGEConstraint constraint : constraints) {
constraint.zeroExpectations();
}
Arrays.fill(cachedGradient,0);
int numFeatures = trainingList.getDataAlphabet().size() + 1;
int numLabels = trainingList.getTargetAlphabet().size();
double[][] scores = new double[trainingList.size()][numLabels];
double[] constraintValue = new double[numLabels];
// pass 1: calculate model distribution
for (int ii = 0; ii < trainingList.size(); ii++) {
Instance instance = trainingList.get(ii);
double instanceWeight = trainingList.getInstanceWeight(instance);
// skip if labeled
if (instance.getTarget() != null) {
continue;
}
FeatureVector fv = (FeatureVector) instance.getData();
classifier.getClassificationScoresWithTemperature(instance, temperature, scores[ii]);
for (MaxEntGEConstraint constraint : constraints) {
constraint.computeExpectations(fv,scores[ii],instanceWeight);
}
}
// compute value
double value = 0;
for (MaxEntGEConstraint constraint : constraints) {
value += constraint.getValue();
}
value *= objWeight;
// pass 2: determine per example gradient
for (int ii = 0; ii < trainingList.size(); ii++) {
Instance instance = trainingList.get(ii);
// skip if labeled
if (instance.getTarget() != null) {
continue;
}
Arrays.fill(constraintValue,0);
double instanceExpectation = 0;
double instanceWeight = trainingList.getInstanceWeight(instance);
FeatureVector fv = (FeatureVector) instance.getData();
for (MaxEntGEConstraint constraint : constraints) {
constraint.preProcess(fv);
for (int label = 0; label < numLabels; label++) {
double val = constraint.getCompositeConstraintFeatureValue(fv, label);
constraintValue[label] += val;
instanceExpectation += val * scores[ii][label];
}
}
for (int label = 0; label < numLabels; label++) {
if (scores[ii][label] == 0) continue;
assert (!Double.isInfinite(scores[ii][label]));
double weight = objWeight * instanceWeight * scores[ii][label] * (constraintValue[label] - instanceExpectation) / temperature;
assert(!Double.isNaN(weight));
MatrixOps.rowPlusEquals(cachedGradient, numFeatures, label, fv, weight);
cachedGradient[numFeatures * label + defaultFeatureIndex] += weight;
}
}
cachedValue = value;
cacheStale = false;
double reg = getRegularization();
progressLogger.info ("Value (GE=" + value + " Gaussian prior= " + reg + ") = " + cachedValue);
return cachedValue;
}
protected double getRegularization() {
double regularization = 0;
for (int pi = 0; pi < parameters.length; pi++) {
double p = parameters[pi];
regularization -= p * p / (2 * gaussianPriorVariance);
cachedGradient[pi] -= p / gaussianPriorVariance;
}
cachedValue += regularization;
return regularization;
}
public void getValueGradient(double[] buffer) {
if (cacheStale) {
getValue();
}
assert(buffer.length == cachedGradient.length);
System.arraycopy (cachedGradient, 0, buffer, 0, buffer.length);
}
public int getNumParameters() {
return parameters.length;
}
public double getParameter(int index) {
return parameters[index];
}
public void getParameters(double[] buffer) {
assert(buffer.length == parameters.length);
System.arraycopy (parameters, 0, buffer, 0, buffer.length);
}
public void setParameter(int index, double value) {
cacheStale = true;
parameters[index] = value;
}
public void setParameters(double[] params) {
assert(params.length == parameters.length);
cacheStale = true;
System.arraycopy (params, 0, parameters, 0, parameters.length);
}
}