Package cc.mallet.fst

Source Code of cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood

package cc.mallet.fst;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

import java.util.logging.Logger;

import cc.mallet.optimize.Optimizable;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;

import cc.mallet.util.MalletLogger;


/**
* Implements label likelihood gradient computations for batches of data, can be
* easily parallelized. <p>
*
* The gradient computations are the same as that of
* <tt>CRFOptimizableByLabelLikelihood</tt>. <p>
*
* *Note*: Expectations corresponding to each batch of data can be computed in
* parallel. During gradient computation, the prior and the constraints are
* incorporated into the expectations of the last batch (see
* <tt>getBatchValue, getBatchValueGradient</tt>).
*
* *Note*: This implementation ignores instances with infinite weights (see
* <tt>getExpectationValue</tt>).
*
* @author Gaurav Chandalia
*/
public class CRFOptimizableByBatchLabelLikelihood implements Optimizable.ByCombiningBatchGradient, Serializable {
  private static Logger logger = MalletLogger.getLogger(CRFOptimizableByBatchLabelLikelihood.class.getName());

  static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
  static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2;
  static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0;

  protected CRF crf;
  protected InstanceList trainingSet;

  // number of batches of training set
  protected int numBatches;

  // batch specific expectations
  protected List<CRF.Factors> expectations;
  // constraints over whole training set
  protected CRF.Factors constraints;

  // value and gradient for each batch, to avoid sharing
  protected double[] cachedValue;
  protected List<double[]> cachedGradient;

  boolean usingHyperbolicPrior = false;
  double gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
  double hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
  double hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;

  public CRFOptimizableByBatchLabelLikelihood(CRF crf, InstanceList ilist, int numBatches) {
    // set up
    this.crf = crf;
    this.trainingSet = ilist;
    this.numBatches = numBatches;

    cachedValue = new double[this.numBatches];
    cachedGradient = new ArrayList<double[]>(this.numBatches);
    expectations = new ArrayList<CRF.Factors>(this.numBatches);
    int numFactors = crf.parameters.getNumFactors();
    for (int i = 0; i < this.numBatches; ++i) {
      cachedGradient.add(new double[numFactors]);
      expectations.add(new CRF.Factors(crf.parameters));
    }
    constraints = new CRF.Factors(crf.parameters);

    gatherConstraints(ilist);
  }

  /**
   * Set the constraints by running forward-backward with the <i>output label
   * sequence provided</i>, thus restricting it to only those paths that agree with
   * the label sequence.
   */
  protected void gatherConstraints(InstanceList ilist) {
    logger.info("Gathering constraints...");
    assert (constraints.structureMatches(crf.parameters));
    constraints.zero();

    for (Instance instance : ilist) {
      FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
      FeatureSequence output = (FeatureSequence) instance.getTarget();
      double instanceWeight = ilist.getInstanceWeight(instance);
      Transducer.Incrementor incrementor =
        instanceWeight == 1.0 ? constraints.new Incrementor()
      : constraints.new WeightedIncrementor(instanceWeight);
        new SumLatticeDefault (this.crf, input, output, incrementor);
    }
    constraints.assertNotNaNOrInfinite();
  }

  /**
   * Computes log probability of a batch of training data, fill in corresponding
   * expectations as well
   */
  protected double getExpectationValue(int batchIndex, int[] batchAssignments) {
    // Reset expectations to zero before we fill them again
    CRF.Factors batchExpectations = expectations.get(batchIndex);
    batchExpectations.zero();

    // count the number of instances that have infinite weight
    int numInfLabeledWeight = 0;
    int numInfUnlabeledWeight = 0;
    int numInfWeight = 0;

    double value = 0;
    double unlabeledWeight, labeledWeight, weight;
    for (int ii = batchAssignments[0]; ii < batchAssignments[1]; ii++) {
      Instance instance = trainingSet.get(ii);
      double instanceWeight = trainingSet.getInstanceWeight(instance);
      FeatureVectorSequence input = (FeatureVectorSequence) instance.getData();
      FeatureSequence output = (FeatureSequence) instance.getTarget();

      labeledWeight = new SumLatticeDefault (this.crf, input, output, null).getTotalWeight();
      if (Double.isInfinite (labeledWeight)) {
        ++numInfLabeledWeight;
      }

      Transducer.Incrementor incrementor = instanceWeight == 1.0 ? batchExpectations.new Incrementor()
        : batchExpectations.new WeightedIncrementor (instanceWeight);
      unlabeledWeight = new SumLatticeDefault (this.crf, input, null, incrementor).getTotalWeight();
      if (Double.isInfinite (unlabeledWeight)) {
        ++numInfUnlabeledWeight;
      }

      // weight is log(conditional probability correct label sequence)
      weight = labeledWeight - unlabeledWeight;
      if (Double.isInfinite(weight)) {
        ++numInfWeight;
      } else {
        // Weights are log probabilities, and we want to return a log probability
        value += weight * instanceWeight;
      }
    }
    batchExpectations.assertNotNaNOrInfinite();

    if (numInfLabeledWeight > 0 || numInfUnlabeledWeight > 0 || numInfWeight > 0) {
      logger.warning("Batch: " + batchIndex + ", Number of instances with:\n" +
          "\t -infinite labeled weight: " + numInfLabeledWeight + "\n" +
          "\t -infinite unlabeled weight: " + numInfUnlabeledWeight + "\n" +
          "\t -infinite weight: " + numInfWeight);
    }
   
    return value;
  }

  /**
   * Returns the log probability of a batch of training sequence labels and the prior over
   * parameters, if last batch then incorporate the prior on parameters as well.
   */
  public double getBatchValue(int batchIndex, int[] batchAssignments) {
    assert(batchIndex < this.numBatches) : "Incorrect batch index: " + batchIndex + ", range(0, " +
    this.numBatches + ")";
    assert(batchAssignments.length == 2 && batchAssignments[0] <= batchAssignments[1])
      : "Invalid batch assignments: " + Arrays.toString(batchAssignments);

    // Get the value of all the true labels for current batch, also filling in expectations
    double value = getExpectationValue(batchIndex, batchAssignments);

    if (batchIndex == numBatches-1) {
      if (usingHyperbolicPrior) // Hyperbolic prior
        value += crf.parameters.hyberbolicPrior(hyperbolicPriorSlope, hyperbolicPriorSharpness);
      else // Gaussian prior
        value += crf.parameters.gaussianPrior(gaussianPriorVariance);
    }
    assert(!(Double.isNaN(value) || Double.isInfinite(value)))
      : "Label likelihood is NaN/Infinite, batchIndex: " + batchIndex + "batchAssignments: " + Arrays.toString(batchAssignments);
    // update cache
    cachedValue[batchIndex] = value;
   
    return value;
  }

  public void getBatchValueGradient(double[] buffer, int batchIndex, int[] batchAssignments) {
    assert(batchIndex < this.numBatches) : "Incorrect batch index: " + batchIndex + ", range(0, " +
    this.numBatches + ")";
    assert(batchAssignments.length == 2 && batchAssignments[0] <= batchAssignments[1])
      : "Invalid batch assignments: " + Arrays.toString(batchAssignments);

    CRF.Factors batchExpectations = expectations.get(batchIndex);

    if (batchIndex == numBatches-1) {
      // crf parameters' check has to be done only once, infinite values are allowed
      crf.parameters.assertNotNaN();

      // factor the constraints and the prior into the expectations of last batch
      // Gradient = (constraints - expectations + prior) = -(expectations - constraints - prior)
      // The minus sign is factored in combineGradients method after all gradients are computed
      batchExpectations.plusEquals(constraints, -1.0);
      if (usingHyperbolicPrior)
        batchExpectations.plusEqualsHyperbolicPriorGradient(crf.parameters, -hyperbolicPriorSlope, hyperbolicPriorSharpness);
      else
        batchExpectations.plusEqualsGaussianPriorGradient(crf.parameters, -gaussianPriorVariance);
      batchExpectations.assertNotNaNOrInfinite();
    }

    double[] gradient = cachedGradient.get(batchIndex);
    // set the cached gradient
    batchExpectations.getParameters(gradient);
    System.arraycopy(gradient, 0, buffer, 0, gradient.length);
  }

  /**
   * Adds gradients from all batches. <p>
   * <b>Note:</b> assumes buffer is already initialized.
   */
  public void combineGradients(Collection<double[]> batchGradients, double[] buffer) {
    assert(buffer.length == crf.parameters.getNumFactors())
      : "Incorrect buffer length: " + buffer.length + ", expected: " + crf.parameters.getNumFactors();

    Arrays.fill(buffer, 0);
    for (double[] gradient : batchGradients) {
      MatrixOps.plusEquals(buffer, gradient);
    }
    // -(...) from getBatchValueGradient
    MatrixOps.timesEquals(buffer, -1.0);
  }

  public int getNumBatches() { return numBatches; }

  public void setUseHyperbolicPrior (boolean f) { usingHyperbolicPrior = f; }
  public void setHyperbolicPriorSlope (double p) { hyperbolicPriorSlope = p; }
  public void setHyperbolicPriorSharpness (double p) { hyperbolicPriorSharpness = p; }
  public double getUseHyperbolicPriorSlope () { return hyperbolicPriorSlope; }
  public double getUseHyperbolicPriorSharpness () { return hyperbolicPriorSharpness; }
  public void setGaussianPriorVariance (double p) { gaussianPriorVariance = p; }
  public double getGaussianPriorVariance () { return gaussianPriorVariance; }
  public int getNumParameters () {return crf.parameters.getNumFactors();}

  public void getParameters (double[] buffer) {
    crf.parameters.getParameters(buffer);
  }

  public double getParameter (int index) {
    return crf.parameters.getParameter(index);
  }

  public void setParameters (double [] buff) {
    crf.parameters.setParameters(buff);
    crf.weightsValueChanged();
  }

  public void setParameter (int index, double value) {
    crf.parameters.setParameter(index, value);
    crf.weightsValueChanged();
  }

  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 0;

  private void writeObject (ObjectOutputStream out) throws IOException {
    out.writeInt (CURRENT_SERIAL_VERSION);
    out.writeObject(trainingSet);
    out.writeObject(crf);
    out.writeInt(numBatches);
    out.writeObject(cachedValue);
    for (double[] gradient : cachedGradient)
      out.writeObject(gradient);
  }

  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
    in.readInt ();
    trainingSet = (InstanceList) in.readObject();
    crf = (CRF)in.readObject();
    numBatches = in.readInt();
    cachedValue = (double[]) in.readObject();
    cachedGradient = new ArrayList<double[]>(numBatches);
    for (int i = 0; i < numBatches; ++i)
      cachedGradient.set(i, (double[]) in.readObject());
  }

  public static class Factory {
    public Optimizable.ByCombiningBatchGradient newCRFOptimizable (CRF crf, InstanceList trainingData, int numBatches) {
      return new CRFOptimizableByBatchLabelLikelihood (crf, trainingData, numBatches);
    }
  }
}
TOP

Related Classes of cc.mallet.fst.CRFOptimizableByBatchLabelLikelihood

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.