Package cc.mallet.grmm.learning

Source Code of cc.mallet.grmm.learning.GenericAcrfTui

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning;
/**
*
* Created: Aug 23, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: GenericAcrfTui.java,v 1.1 2007/10/22 21:37:43 mccallum Exp $
*/

import bsh.EvalError;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.logging.Logger;
import java.util.regex.Pattern;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.optimize.Optimizer;
import cc.mallet.pipe.*;
import cc.mallet.pipe.iterator.LineGroupIterator;
import cc.mallet.pipe.iterator.PipeInputIterator;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Instance;
import cc.mallet.util.*;

public class GenericAcrfTui {

  private static CommandOption.File modelFile = new CommandOption.File
          (GenericAcrfTui.class, "model-file", "FILENAME", true, null, "Text file describing model structure.", null);

  private static CommandOption.File trainFile = new CommandOption.File
          (GenericAcrfTui.class, "training", "FILENAME", true, null, "File containing training data.", null);

  private static CommandOption.File testFile = new CommandOption.File
          (GenericAcrfTui.class, "testing", "FILENAME", true, null, "File containing testing data.", null);

  private static CommandOption.Integer numLabelsOption = new CommandOption.Integer
  (GenericAcrfTui.class, "num-labels", "INT", true, -1,
          "If supplied, number of labels on each line of input file." +
                  "  Otherwise, the token ---- must separate labels from features.", null);

  private static CommandOption.String inferencerOption = new CommandOption.String
          (GenericAcrfTui.class, "inferencer", "STRING", true, "TRP",
                  "Specification of inferencer.", null);

  private static CommandOption.String maxInferencerOption = new CommandOption.String
          (GenericAcrfTui.class, "max-inferencer", "STRING", true, "TRP.createForMaxProduct()",
                  "Specification of inferencer.", null);

  private static CommandOption.String evalOption = new CommandOption.String
          (GenericAcrfTui.class, "eval", "STRING", true, "LOG",
                  "Evaluator to use.  Java code grokking performed.", null);

  private static CommandOption.Boolean usePiecewiseTraining = new CommandOption.Boolean
            (GenericAcrfTui.class, "piecewise", "true|false", true, false,
                    "Whether to use piecewise training.", null);

  private  static CommandOption.Boolean usePwplTraining = new CommandOption.Boolean
            (GenericAcrfTui.class, "pwpl", "true|false", true, false,
                    "Whether to use pwpl training.", null);

  private  static CommandOption.Boolean usePlTraining = new CommandOption.Boolean
            (GenericAcrfTui.class, "pl", "true|false", true, false,
                    "Whether to use Besag pseudolikelihood.", null);

  static CommandOption.Boolean cacheUnrolledGraph = new CommandOption.Boolean
          (GenericAcrfTui.class, "cache-graphs", "true|false", true, false,
                  "Whether to use memory-intensive caching.", null);

  static CommandOption.Boolean useTokenText = new CommandOption.Boolean
          (GenericAcrfTui.class, "use-token-text", "true|false", true, false,
                  "Set this to true if first feature in every list is should be considered the text of the " +
                          "current token.  This is used for NLP-specific debugging and error analysis.", null);

  static CommandOption.Integer randomSeedOption = new CommandOption.Integer
  (GenericAcrfTui.class, "random-seed", "INTEGER", true, 0,
   "The random seed for randomly selecting a proportion of the instance list for training", null);


  private static BshInterpreter interpreter = setupInterpreter ();

    private static ACRFTrainer createTrainer ()
    {
      if (usePiecewiseTraining.value) {
        return new PiecewiseACRFTrainer();
      } else if (usePwplTraining.value) {
        return new PwplACRFTrainer();
      } else if (usePlTraining.value) {
        return new PseudolikelihoodACRFTrainer ();
      } else {
        return new DefaultAcrfTrainer ();
      }
    }

  public static void main (String[] args) throws IOException, EvalError
  {
    doProcessOptions (GenericAcrfTui.class, args);
    Timing timing = new Timing ();

    GenericAcrfData2TokenSequence basePipe;
    if (!numLabelsOption.wasInvoked ()) {
      basePipe = new GenericAcrfData2TokenSequence ();
    } else {
      basePipe = new GenericAcrfData2TokenSequence (numLabelsOption.value);
    }

    basePipe.setFeaturesIncludeToken(useTokenText.value);
    basePipe.setIncludeTokenText(useTokenText.value);
   
    Pipe pipe = new SerialPipes (new Pipe[] {
        basePipe,
        new TokenSequence2FeatureVectorSequence (true, true),
    });

    Iterator<Instance> trainSource = new LineGroupIterator (new FileReader (trainFile.value), Pattern.compile ("^\\s*$"), true);
    Iterator<Instance> testSource;
    if (testFile.wasInvoked ()) {
      testSource = new LineGroupIterator (new FileReader (testFile.value), Pattern.compile ("^\\s*$"), true);
    } else {
      testSource = null;
    }

    InstanceList training = new InstanceList (pipe);
    training.addThruPipe (trainSource);
    InstanceList testing = new InstanceList (pipe);
    testing.addThruPipe (testSource);

    ACRF.Template[] tmpls = parseModelFile (modelFile.value);
    ACRFEvaluator eval = createEvaluator (evalOption.value);

    Inferencer inf = createInferencer (inferencerOption.value);
    Inferencer maxInf = createInferencer (maxInferencerOption.value);

    ACRF acrf = new ACRF (pipe, tmpls);
    acrf.setInferencer (inf);
    acrf.setViterbiInferencer (maxInf);

    ACRFTrainer trainer = createTrainer();
    System.err.println ("ACRF Trainer = "+trainer);
    trainer.train (acrf, training, null, testing, eval, 9999);
    timing.tick ("Training");

    FileUtils.writeGzippedObject (new File ("acrf.ser.gz"), acrf);
    timing.tick ("Serializing");

    System.err.println ("Total time (ms) = " + timing.elapsedTime ());
  }

  private static BshInterpreter setupInterpreter ()
  {
    BshInterpreter interpreter = CommandOption.getInterpreter ();
    try {
      interpreter.eval ("import cc.mallet.base.extract.*");
      interpreter.eval ("import cc.mallet.grmm.inference.*");
      interpreter.eval ("import cc.mallet.grmm.learning.*");
      interpreter.eval ("import cc.mallet.grmm.learning.templates.*");
    } catch (EvalError e) {
      throw new RuntimeException (e);
    }

    return interpreter;
  }

  public static ACRFEvaluator createEvaluator (String spec) throws EvalError
  {
    if (spec.indexOf ('(') >= 0) {
      // assume it's Java code, and don't screw with it.
      return (ACRFEvaluator) interpreter.eval (spec);
    } else {
      LinkedList toks = new LinkedList (Arrays.asList (spec.split ("\\s+")));
      return createEvaluator (toks);
    }
  }

  private static ACRFEvaluator createEvaluator (LinkedList toks)
  {
    String type = (String) toks.removeFirst ();

    if (type.equalsIgnoreCase ("SEGMENT")) {
      int slice = Integer.parseInt ((String) toks.removeFirst ());
      if (toks.size() % 2 != 0)
        throw new RuntimeException ("Error in --eval "+evalOption.value+": Every start tag must have a continue.");
      int numTags = toks.size () / 2;
      String[] startTags = new String [numTags];
      String[] continueTags = new String [numTags];

      for (int i = 0; i < numTags; i++) {
        startTags[i] = (String) toks.removeFirst ();
        continueTags[i] = (String) toks.removeFirst ();
      }

      return new MultiSegmentationEvaluatorACRF (startTags, continueTags, slice);

    } else if (type.equalsIgnoreCase ("LOG")) {
      return new DefaultAcrfTrainer.LogEvaluator ();

    } else if (type.equalsIgnoreCase ("SERIAL")) {
      List evals = new ArrayList ();
      while (!toks.isEmpty ()) {
        evals.add (createEvaluator (toks));
      }
      return new AcrfSerialEvaluator (evals);

    } else {
      throw new RuntimeException ("Error in --eval "+evalOption.value+": illegal evaluator "+type);
    }
  }

  private static Inferencer createInferencer (String spec) throws EvalError
  {
    String cmd;
    if (spec.indexOf ('(') >= 0) {
      // assume it's Java code, and don't screw with it.
      cmd = spec;
    } else {
      cmd = "new "+spec+"()";
    }

    // Return whatever the Java code says to
    Object inf = interpreter.eval (cmd);

    if (inf instanceof Inferencer)
      return (Inferencer) inf;

    else throw new RuntimeException ("Don't know what to do with inferencer "+inf);
  }


  public static void doProcessOptions (Class childClass, String[] args)
  {
    CommandOption.List options = new CommandOption.List ("", new CommandOption[0]);
    options.add (childClass);
    options.process (args);
    options.logOptions (Logger.getLogger (""));
  }

  private static ACRF.Template[] parseModelFile (File mdlFile) throws IOException, EvalError
  {
    BufferedReader in = new BufferedReader (new FileReader (mdlFile));

    List tmpls = new ArrayList ();
    String line = in.readLine ();
    while (line != null) {
      Object tmpl = interpreter.eval (line);
      if (!(tmpl instanceof ACRF.Template)) {
        throw new RuntimeException ("Error in "+mdlFile+" line "+in.toString ()+":\n  Object "+tmpl+" not a template");
      }
      tmpls.add (tmpl);
      line = in.readLine ();
    }

    return (ACRF.Template[]) tmpls.toArray (new ACRF.Template [0]);
  }

}
TOP

Related Classes of cc.mallet.grmm.learning.GenericAcrfTui

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.