Package cc.mallet.fst

Source Code of cc.mallet.fst.CRFTrainerByL1LabelLikelihood

package cc.mallet.fst;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

import cc.mallet.optimize.Optimizer;
import cc.mallet.optimize.OrthantWiseLimitedMemoryBFGS;
import cc.mallet.types.InstanceList;

/**
* CRF trainer that implements L1-regularization.
*
* @author Kedar Bellare
*/
public class CRFTrainerByL1LabelLikelihood extends CRFTrainerByLabelLikelihood {
  static final double SPARSE_PRIOR = 0.0;

  double l1Weight = SPARSE_PRIOR;

  public CRFTrainerByL1LabelLikelihood(CRF crf) {
    this(crf, SPARSE_PRIOR);
  }

  /**
   * Constructor for CRF trainer.
   *
   * @param crf
   *            CRF to train.
   * @param l1Weight
   *            Weight of L1 term in objective (l1Weight*|w|). Higher L1
   *            weight means sparser solutions.
   */
  public CRFTrainerByL1LabelLikelihood(CRF crf, double l1Weight) {
    super(crf);
    this.l1Weight = l1Weight;
  }

  public void setL1RegularizationWeight(double l1Weight) {
    this.l1Weight = l1Weight;
  }

  public Optimizer getOptimizer(InstanceList trainingSet) {
    getOptimizableCRF(trainingSet);
    if (opt == null || ocrf != opt.getOptimizable())
      opt = new OrthantWiseLimitedMemoryBFGS(ocrf, l1Weight);
    return opt;
  }

  // Serialization
  private static final long serialVersionUID = 1L;

  private static final int CURRENT_SERIAL_VERSION = 0;

  private void writeObject(ObjectOutputStream out) throws IOException {
    out.writeInt(CURRENT_SERIAL_VERSION);
    out.writeDouble(l1Weight);
  }

  private void readObject(ObjectInputStream in) throws IOException {
    in.readInt(); // version
    l1Weight = in.readDouble();
  }
}
TOP

Related Classes of cc.mallet.fst.CRFTrainerByL1LabelLikelihood

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.