Package cc.mallet.pipe

Source Code of cc.mallet.pipe.AddClassifierTokenPredictions$TokenClassifiers

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */


package cc.mallet.pipe;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.logging.Logger;

import cc.mallet.classify.BalancedWinnowTrainer;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.util.MalletLogger;


/**
* This pipe uses a Classifier to label each token (i.e., using 0-th order Markov assumption),
* then adds the predictions as features to each token.
*
* This pipe assumes the input Instance's data is of type FeatureVectorSequence
* (each an augmentable feature vector).
*
* Example usage:<pre>
*     1) Create and serialize a featurePipe that converts raw input to FeatureVectorSequences
*     2) Pipe input data through featurePipe, train a TokenClassifiers via cross validation, then serialize the classifiers
*     2) Pipe input data through featurePipe and this pipe (using the saved classifiers), and train a Transducer
*     4) Serialize the trained Transducer
* </pre>
* @author ghuang
*/
public class AddClassifierTokenPredictions extends Pipe implements Serializable
{
  private static Logger logger = MalletLogger.getLogger(AddClassifierTokenPredictions.class.getName());
 
  // Specify which predictions are to be added as features. 
  // E.g., { 1, 2 } = add labels of the top 2 highest-scoring predictions as features.
  int[] m_predRanks2add;
 
  // The trained token classifier
  TokenClassifiers m_tokenClassifiers;

  // Whether to treat each instance's feature values as binary
  boolean m_binary;

  // Whether the pipe is currently being used at production time
  // (i.e., not being used as pipeline for training a transducer)  
  boolean m_inProduction;

  // Augmented data alphabet that includes the class predictions
  Alphabet m_dataAlphabet;
 
 
  public AddClassifierTokenPredictions(InstanceList trainList)
  {
    this(trainList, null);
  }

 
  public AddClassifierTokenPredictions(InstanceList trainList, InstanceList testList)
  {
    this(new TokenClassifiers(convert(trainList, (Noop) trainList.getPipe())), new int[] { 1 }, true,
          convert(testList, (Noop) trainList.getPipe()));
  }
 
 
  public AddClassifierTokenPredictions(TokenClassifiers tokenClassifiers, int[] predRanks2add,
      boolean binary, InstanceList testList)
  {
    m_predRanks2add = predRanks2add;
    m_binary = binary;
    m_tokenClassifiers = tokenClassifiers;
    m_inProduction = false;
    m_dataAlphabet = (Alphabet) tokenClassifiers.getAlphabet().clone();
    Alphabet labelAlphabet = tokenClassifiers.getLabelAlphabet();
   
    // add the token prediction features to the alphabet
    for (int i = 0; i < m_predRanks2add.length; i++) {
      for (int j = 0; j < labelAlphabet.size(); j++) {
        String featName = "TOK_PRED=" + labelAlphabet.lookupObject(j).toString() + "_@_RANK_" + m_predRanks2add[i];
        m_dataAlphabet.lookupIndex(featName, true);
      }
    }
   
    // evaluate token classifier 
    if (testList != null) {
      Trial trial = new Trial(m_tokenClassifiers, testList);
      logger.info("Token classifier accuracy on test set = " + trial.getAccuracy());
    }
  }

 
  public void setInProduction(boolean inProduction) { m_inProduction = inProduction; }
  public boolean getInProduction() { return m_inProduction; }

  public static void setInProduction(Pipe p, boolean value)
  {
    if (p instanceof AddClassifierTokenPredictions)
      ((AddClassifierTokenPredictions) p).setInProduction(value);
    else if (p instanceof SerialPipes) {
      SerialPipes sp = (SerialPipes) p;
      for (int i = 0; i < sp.size(); i++)
        setInProduction(sp.getPipe(i), value);
    }
  }
 
  public Alphabet getDataAlphabet() { return m_dataAlphabet; }
 
  /**
   * Add the token classifier's predictions as features to the instance.
   * This method assumes the input instance contains FeatureVectorSequence as data 
   */
  public Instance pipe(Instance carrier)
  {
    FeatureVectorSequence fvs = (FeatureVectorSequence) carrier.getData();
    InstanceList ilist = convert(carrier, (Noop) m_tokenClassifiers.getInstancePipe());
    assert (fvs.size() == ilist.size());

    // For passing instances to the token classifier, each instance's data alphabet needs to
    // match that used by the token classifier at training time.  For the resulting piped
    // instance, each instance's data alphabet needs to contain token classifier's prediction
    // as features
    FeatureVector[] fva = new FeatureVector[fvs.size()];

    for (int i = 0; i < ilist.size(); i++) {
      Instance inst = ilist.get(i);
      Classification c = m_tokenClassifiers.classify(inst, ! m_inProduction);
      LabelVector lv = c.getLabelVector();
      AugmentableFeatureVector afv1 = (AugmentableFeatureVector) inst.getData();
      int[] indices = afv1.getIndices();
      AugmentableFeatureVector afv2 = new AugmentableFeatureVector(m_dataAlphabet,
          indices, afv1.getValues(), indices.length + m_predRanks2add.length);

      for (int j = 0; j < m_predRanks2add.length; j++) {
        Label label = lv.getLabelAtRank(m_predRanks2add[j]);
        int idx = m_dataAlphabet.lookupIndex("TOK_PRED=" + label.toString() + "_@_RANK_" + m_predRanks2add[j]);

        assert(idx >= 0);
        afv2.add(idx, 1);
      }
      fva[i] = afv2;
    }

    carrier.setData(new FeatureVectorSequence(fva));
    return carrier;
  }


  /**
   * Converts each instance containing a FeatureVectorSequence to multiple instances,
   * each containing an AugmentableFeatureVector as data. 
   * 
   * @param ilist Instances with FeatureVectorSequence as data field
   * @param alphabetsPipe a Noop pipe containing the data and target alphabets for the resulting InstanceList
   * @return an InstanceList where each Instance contains one Token's AugmentableFeatureVector as data
   */
  public static InstanceList convert(InstanceList ilist, Noop alphabetsPipe)
  {
    if (ilist == null) return null;
   
    // This monstrosity is necessary b/c Classifiers obtain the data/target alphabets via pipes
    InstanceList ret = new InstanceList(alphabetsPipe);

    for (Instance inst : ilist)
      ret.add(inst);
    //for (int i = 0; i < ilist.size(); i++) ret.add(convert(ilist.get(i), alphabetsPipe));

    return ret;
  }


  /**
   *
   * @param inst input instance, with FeatureVectorSequence as data.
   * @param alphabetsPipe a Noop pipe containing the data and target alphabets for
   * the resulting InstanceList and AugmentableFeatureVectors
   * @return list of instances, each with one AugmentableFeatureVector as data
   */
  public static InstanceList convert(Instance inst, Noop alphabetsPipe)
  {
    InstanceList ret = new InstanceList(alphabetsPipe);
    Object obj = inst.getData();
    assert(obj instanceof FeatureVectorSequence);

    FeatureVectorSequence fvs = (FeatureVectorSequence) obj;
    LabelSequence ls = (LabelSequence) inst.getTarget();
    assert(fvs.size() == ls.size());

    Object instName = (inst.getName() == null ? "NONAME" : inst.getName());
   
    for (int j = 0; j < fvs.size(); j++) {
      FeatureVector fv = fvs.getFeatureVector(j);
      int[] indices = fv.getIndices();
      FeatureVector data = new AugmentableFeatureVector (alphabetsPipe.getDataAlphabet(),
          indices, fv.getValues(), indices.length);
      Labeling target = ls.getLabelAtPosition(j);
      String name = instName.toString() + "_@_POS_" + (j + 1);
      Object source = inst.getSource();
      Instance toAdd = alphabetsPipe.pipe(new Instance(data, target, name, source));

      ret.add(toAdd);
    }

    return ret;
  }


  // Serialization
  private static final long serialVersionUID = 1;

  /**
   * This inner class represents the trained token classifiers.
   * @author ghuang
   */
  public static class TokenClassifiers extends Classifier implements Serializable
  {
    // number of folds in cross-validation training
    int m_numCV;

    // random seed to split training data for cross-validation
    int m_randSeed;
   
    // trainer for token classifier
    ClassifierTrainer m_trainer;
   
    // token classifier trained on the entirety of the training set
    Classifier m_tokenClassifier;
   
    // table storing instance name -->  out-of-fold classifier
    // Used to prevent overfitting to the token classifier's predictions
    HashMap m_table;
   

    /**
     * Train a token classifier using the given Instances with 5-fold cross validation
     * @param trainList training instances
     */
    public TokenClassifiers(InstanceList trainList)
    {
      this(trainList, 0, 5);
    }
   
   
    public TokenClassifiers(InstanceList trainList, int randSeed, int numCV)
    {
//      this(new AdaBoostM2Trainer(new DecisionTreeTrainer(2), 10), trainList, randSeed, numCV);
//      this(new NaiveBayesTrainer(), trainList, randSeed, numCV);
      this(new BalancedWinnowTrainer(), trainList, randSeed, numCV);
//      this(new SVMTrainer(), trainList, randSeed, numCV);
    }
   
   
    public TokenClassifiers(ClassifierTrainer trainer, InstanceList trainList, int randSeed, int numCV)
    {
      super(trainList.getPipe());

      m_trainer = trainer;
      m_randSeed = randSeed;
      m_numCV = numCV;
      m_table = new HashMap();

      doTraining(trainList);
    }


    // train the token classifier
    private void doTraining(InstanceList trainList)
    {
      // train a classifier on the entire training set
      logger.info("Training token classifier on entire data set (size=" + trainList.size() + ")...");
      m_tokenClassifier = m_trainer.train(trainList);

      Trial t = new Trial(m_tokenClassifier, trainList);
      logger.info("Training set accuracy = " + t.getAccuracy());
     
      if (m_numCV == 0)
        return;

      // train classifiers using cross validation
      InstanceList.CrossValidationIterator cvIter = trainList.new CrossValidationIterator(m_numCV, m_randSeed);
      int f = 1;

      while (cvIter.hasNext()) {
        f++;
        InstanceList[] fold = cvIter.nextSplit();

        logger.info("Training token classifier on cv fold " + f + " / " + m_numCV + " (size=" + fold[0].size() + ")...");
       
        Classifier foldClassifier = m_trainer.train(fold[0]);
        Trial t1 = new Trial(foldClassifier, fold[0]);
        Trial t2 = new Trial(foldClassifier, fold[1]);

        logger.info("Within-fold accuracy = " + t1.getAccuracy());
        logger.info("Out-of-fold accuracy = " + t2.getAccuracy());

        /*for (int x = 0; x < t2.size(); x++) {
          logger.info("xxx pred:" + t2.getClassification(x).getLabeling().getBestLabel() + " true:" + t2.getClassification(x).getInstance().getLabeling());
        }*/
       
        for (int i = 0; i < fold[1].size(); i++) {
          Instance inst = fold[1].get(i);
          m_table.put(inst.getName(), foldClassifier);
        }
      }
    }


    public Classification classify(Instance instance)
    {
      return classify(instance, false);
    }


    /**
     *
     * @param instance the instance to classify
     * @param useOutOfFold whether to check the instance name and use the out-of-fold classifier
     * if the instance name matches one in the training data
     * @return the token classifier's output
     */
    public Classification classify(Instance instance, boolean useOutOfFold)
    {
      Object instName = instance.getName();
     
      if (! useOutOfFold || ! m_table.containsKey(instName))
        return m_tokenClassifier.classify(instance);
     
      Classifier classifier = (Classifier) m_table.get(instName);

      return classifier.classify(instance);
    }

    // serialization
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;
   
    private void writeObject(ObjectOutputStream out) throws IOException
    {
      out.writeInt(CURRENT_SERIAL_VERSION);
      out.writeObject(getInstancePipe());
      out.writeInt(m_numCV);
      out.writeInt(m_randSeed);
      out.writeObject(m_table);
      out.writeObject(m_tokenClassifier);
      out.writeObject(m_trainer);
    }
   
    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
      int version = in.readInt();
      if (version != CURRENT_SERIAL_VERSION)
        throw new ClassNotFoundException("Mismatched TokenClassifiers versions: wanted " +
            CURRENT_SERIAL_VERSION + ", got " +
            version);
      instancePipe = (Pipe) in.readObject();
      m_numCV = in.readInt();
      m_randSeed = in.readInt();
      m_table = (HashMap) in.readObject();
      m_tokenClassifier = (Classifier) in.readObject();
      m_trainer = (ClassifierTrainer) in.readObject();
    }
  }
}
TOP

Related Classes of cc.mallet.pipe.AddClassifierTokenPredictions$TokenClassifiers

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.