Package cc.mallet.grmm.learning

Source Code of cc.mallet.grmm.learning.GenericAcrfData2TokenSequence

/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */
package cc.mallet.grmm.learning;


import java.util.ArrayList;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.ObjectInputStream;

import cc.mallet.extract.StringSpan;
import cc.mallet.extract.StringTokenization;
import cc.mallet.grmm.util.LabelsAssignment;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.*;

/**
* Generic pipe that takes a linegroup of the form:
* <pre>
*  LABEL1 LABEL2 ... LABELk word feature1 feature2 ... featuren
* </pre>
* and converts it into an input FeatureVectorSequence and target LabelsSequence.
* <p>
* If the number of labels at each sequence position could vary, then use this format instead:
<pre>
*  LABEL1 LABEL2 ... LABELk ---- word feature1 feature2 ... featuren
</pre>
* The four dashes ---- must be there to separate the features from the labels.
* Whitespace is ignored.
* The difference between this pipe and {@link edu.umass.cs.iesl.casutton.experiments.dcrf.GenericDcrfPipe} is that this pipe
*  allows for a different number of labels at each sequence position.
* <p>
* Explicitly specifying which word is the token allows the use of the HTML output from
*  the extract package.
*
* Created: Aug 22, 2005
*
* @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A>
* @version $Id: GenericAcrfData2TokenSequence.java,v 1.1 2007/10/22 21:37:43 mccallum Exp $
*/
public class GenericAcrfData2TokenSequence extends Pipe {

  private ArrayList labelDicts = new ArrayList ();
  private int numLabels = -1;
  private boolean includeTokenText = true;
  private String textFeaturePrefix = "WORD=";
  private boolean featuresIncludeToken = true;
  private boolean labelsAtEnd = false;

  public GenericAcrfData2TokenSequence ()
  {
    super (new Alphabet(), new LabelAlphabet());
  }

  public GenericAcrfData2TokenSequence (int numLabels)
  {
    super (new Alphabet(), new LabelAlphabet());
    this.numLabels = numLabels;
  }

  public void setIncludeTokenText (boolean includeTokenText)
  {
    this.includeTokenText = includeTokenText;
  }

  /**
   * If true, then the first feature in the list is considered to be the token's text.
   * If false, then no feature is designated as the token text.
   * @param featuresIncludeToken
   */
  public void setFeaturesIncludeToken (boolean featuresIncludeToken)
  {
    this.featuresIncludeToken = featuresIncludeToken;
  }

  public boolean getFeaturesIncludeToken ()
  {
    return featuresIncludeToken;
  }

  public void setTextFeaturePrefix (String textFeaturePrefix)
  {
    this.textFeaturePrefix = textFeaturePrefix;
  }

  public LabelAlphabet getLabelAlphabet (int lvl)
  {
    return (LabelAlphabet) labelDicts.get (lvl);
  }

  public int numLevels ()
  {
    return labelDicts.size();
  }

  public Instance pipe (Instance carrier)
  {
    String input;
    if (carrier.getData () instanceof CharSequence) {
      input = String.valueOf(carrier.getData ());
    } else {
      throw new ClassCastException("Needed a String; got "+carrier.getData());
    }
   
    String[] lines = input.split ("\n");

    StringSpan[] spans = new StringSpan[lines.length];
    Labels[] lbls = new Labels[lines.length];
    StringBuffer buf = new StringBuffer ();

    Alphabet dict = getDataAlphabet ();

    for (int i = 0; i < lines.length; i++) {
      String line = lines[i];
      String[] toks = line.split ("\\s+");

      int j = 0;
      ArrayList thisLabels = new ArrayList ();
      if (!labelsAtEnd) {
        while (!isLabelSeparator (toks, j)) {
          thisLabels.add (labelForTok (toks[j], j));
          j++;
        }
        if ((j < toks.length) && toks[j].equals ("----")) j++;
        lbls[i] = new Labels ((Label[]) thisLabels.toArray (new Label[thisLabels.size ()]));
      }

      int maxFeatureIdx = (labelsAtEnd) ? toks.length - numLabels : toks.length;

      String text = "*???*";
      if (featuresIncludeToken) {
        if (j < maxFeatureIdx) {
          text = toks [j++];
        }
      }

      int start = buf.length ();
      buf.append (text);
      int end = buf.length ();
      buf.append (" ");

      StringSpan span = new StringSpan (buf, start, end);

      while (j < maxFeatureIdx) {
        span.setFeatureValue (toks[j].intern (), 1.0);
        j++;
      }

      if (includeTokenText) {
        span.setFeatureValue ((textFeaturePrefix+text).intern(), 1.0);
      }

      if (labelsAtEnd) {
        int firstLblIdx = j;
        while (j < toks.length) {
          thisLabels.add (labelForTok (toks[j], j - firstLblIdx));
          j++;
        }
        lbls[i] = new Labels ((Label[]) thisLabels.toArray (new Label[thisLabels.size ()]));
      }

      spans[i] = span;

    }

    StringTokenization tokenization = new StringTokenization (buf);
    tokenization.addAll (spans);
    carrier.setData (tokenization);

    carrier.setTarget (new LabelsAssignment (new LabelsSequence (lbls)));
    return carrier;
  }

  private Label labelForTok (String tok, int lvl)
  {
    while (labelDicts.size() <= lvl) {
      labelDicts.add (new LabelAlphabet ());
    }
    LabelAlphabet dict = (LabelAlphabet) labelDicts.get (lvl);
    return dict.lookupLabel (tok);
  }

  private boolean isLabelSeparator (String[] toks, int j)
  {
    if (numLabels > 0) {
      // if fixed numLabels, just return whether we have enough.
      return j >= numLabels;
    } else {
      // otherwise, use the dynamic labels separator
      return toks[j].equals ("----");
    }
  }

  // Serialization garbage

  // version 1.0 == returned a feature vector sequence
  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 2;

  private void writeObject (ObjectOutputStream out) throws IOException
  {
    out.defaultWriteObject ();
    out.writeInt (CURRENT_SERIAL_VERSION);
  }


  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
  {
    in.defaultReadObject ();
    int version = in.readInt ();
    if (version <= 1) {
      featuresIncludeToken = true;
    }
  }

  public boolean isLabelsAtEnd ()
  {
    return labelsAtEnd;
  }

  public void setLabelsAtEnd (boolean labelsAtEnd)
  {
    this.labelsAtEnd = labelsAtEnd;
  }
}
TOP

Related Classes of cc.mallet.grmm.learning.GenericAcrfData2TokenSequence

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.