/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/**
@author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a>
*/
package cc.mallet.pipe.iterator;
import java.util.ArrayList;
import java.util.Iterator;
import java.net.URI;
import java.io.*;
import cc.mallet.fst.*;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.*;
/**
Iterates over {@link Segment}s extracted by a {@link Transducer}
for some {@link InstanceList}.
*/
public class SegmentIterator implements Iterator<Instance>
{
Iterator subIterator;
ArrayList segments;
/**
NOTE!: Assumes that <code>segmentStartTags[i]</code> corresponds
to <code>segmentContinueTags[i]</code>.
@param model model to segment input sequences
@param ilist list of instances to be segmented
@param segmentStartTags array of tags indicating the start of a segment
@param segmentContinueTags array of tags indicating the continuation of a segment
*/
public SegmentIterator (Transducer model, InstanceList ilist, Object[] segmentStartTags,
Object[] segmentContinueTags) {
setSubIterator (model, ilist, segmentStartTags, segmentContinueTags);
}
/**
Iterates over {@link Segment}s for only one {@link Instance}.
*/
public SegmentIterator (Transducer model, Instance instance, Object[] segmentStartTags,
Object[] segmentContinueTags) {
InstanceList ilist = new InstanceList (new Noop (instance.getDataAlphabet(), instance.getTargetAlphabet()));
ilist.add (instance);
setSubIterator (model, ilist, segmentStartTags, segmentContinueTags);
}
/**
Useful when no {@link Transduce} is specified. A list of
sequences specifies the output.
@param ilist InstanceList containing sequence.
@param segmentStartTags array of tags indicating the start of a segment
@param segmentContinueTags array of tags indicating the continuation of a segment
@param predictions list of {@link Sequence}s that are the
predicted output of some {@link Transducer}
*/
public SegmentIterator (InstanceList ilist, Object[] startTags, Object[] inTags,
ArrayList predictions) {
setSubIterator (ilist, startTags, inTags, predictions);
}
/**
Iterate over segments in one instance.
@param ilist InstanceList containing sequence.
@param segmentStartTags array of tags indicating the start of a segment
@param segmentContinueTags array of tags indicating the continuation of a segment
@param predictions list of {@link Sequence}s that are the
predicted output of some {@link Transducer}
*/
public SegmentIterator (Instance instance, Object[] startTags, Object[] inTags, Sequence prediction) {
InstanceList ilist = new InstanceList (new Noop (instance.getDataAlphabet(), instance.getTargetAlphabet()));
ilist.add (instance);
ArrayList predictions = new ArrayList();
predictions.add (prediction);
setSubIterator (ilist, startTags, inTags, predictions);
}
/**
Iterate over segments in one labeled sequence
*/
public SegmentIterator (Sequence input, Sequence predicted, Sequence truth,
Object[] startTags, Object[] inTags) {
segments = new ArrayList ();
if (input.size() != truth.size () || predicted.size () != truth.size ())
throw new IllegalStateException ("sequence lengths not equal. input: " + input.size ()
+ " true: " + truth.size ()
+ " predicted: " + predicted.size ());
// find predicted segments
for (int n=0; n < predicted.size (); n++) {
for (int s=0; s < startTags.length; s++) {
if (startTags[s].equals (predicted.get (n))) { // found start tag
int j=n+1;
while (j < predicted.size() && inTags[s].equals (predicted.get (j))) // find end tag
j++;
segments.add (new Segment (input, predicted, truth, n, j-1,
startTags[s], inTags[s]));
}
}
}
this.subIterator = segments.iterator();
}
private void setSubIterator (InstanceList ilist, Object[] startTags, Object[] inTags,
ArrayList predictions) {
segments = new ArrayList (); // stores predicted <code>Segment</code>s
Iterator iter = ilist.iterator ();
for (int i=0; i < ilist.size(); i++) {
Instance instance = (Instance) ilist.get (i);
Sequence input = (Sequence) instance.getData ();
Sequence trueOutput = (Sequence) instance.getTarget ();
Sequence predOutput = (Sequence) predictions.get (i);
if (input.size() != trueOutput.size () || predOutput.size () != trueOutput.size ())
throw new IllegalStateException ("sequence lengths not equal. input: " + input.size ()
+ " true: " + trueOutput.size ()
+ " predicted: " + predOutput.size ());
// find predicted segments
for (int n=0; n < predOutput.size (); n++) {
for (int s=0; s < startTags.length; s++) {
if (startTags[s].equals (predOutput.get (n))) { // found start tag
int j=n+1;
while (j < predOutput.size() &&
inTags[s].equals (predOutput.get (j))) // find end tag
j++;
segments.add (new Segment (input, predOutput, trueOutput, n, j-1,
startTags[s], inTags[s]));
}
}
}
}
this.subIterator = segments.iterator ();
}
private void setSubIterator (Transducer model, InstanceList ilist, Object[] segmentStartTags, Object[] segmentContinueTags) {
segments = new ArrayList (); // stores predicted <code>Segment</code>s
Iterator iter = ilist.iterator ();
while (iter.hasNext ()) {
Instance instance = (Instance) iter.next ();
Sequence input = (Sequence) instance.getData ();
Sequence trueOutput = (Sequence) instance.getTarget ();
Sequence predOutput = new MaxLatticeDefault (model, input).bestOutputSequence();
if (input.size() != trueOutput.size () || predOutput.size () != trueOutput.size ())
throw new IllegalStateException ("sequence lengths not equal. input: " + input.size ()
+ " true: " + trueOutput.size ()
+ " predicted: " + predOutput.size ());
// find predicted segments
for (int i=0; i < predOutput.size (); i++) {
for (int s=0; s < segmentStartTags.length; s++) {
if (segmentStartTags[s].equals (predOutput.get (i))) { // found start tag
int j=i+1;
while (j < predOutput.size() &&
segmentContinueTags[s].equals (predOutput.get (j))) // find end tag
j++;
segments.add (new Segment (input, predOutput, trueOutput, i, j-1,
segmentStartTags[s], segmentContinueTags[s]));
}
}
}
}
this.subIterator = segments.iterator ();
}
// The PipeInputIterator interface
public Instance next ()
{
Segment nextSegment = (Segment) subIterator.next();
return new Instance (nextSegment, nextSegment.getTruth (), null, null);
}
public Segment nextSegment () {
return (Segment) subIterator.next ();
}
public boolean hasNext () { return subIterator.hasNext(); }
public ArrayList toArrayList () { return this.segments; }
public void remove () {
throw new IllegalStateException ("This Iterator<Instance> does not support remove().");
}
}