Package cc.mallet.types

Examples of cc.mallet.types.InstanceList


      new SGML2TokenSequence (new CharSequenceLexer (CharSequenceLexer.LEX_NONWHITESPACE_CLASSES  ), "O"),
      new Target2LabelSequence (),
      new PrintInputAndTarget (),
    });

    InstanceList pred = new InstanceList (pipe);
    pred.addThruPipe (new ArrayIterator (predStrings));

    InstanceList targets = new InstanceList (pipe);
    targets.addThruPipe (new ArrayIterator (trueStrings));

    LabelAlphabet dict = (LabelAlphabet) pipe.getTargetAlphabet ();
    Extraction extraction = new Extraction (null, dict);

    for (int i = 0; i < pred.size(); i++) {
      Instance aPred = pred.get (i);
      Instance aTarget = targets.get (i);
      Tokenization input = (Tokenization) aPred.getData ();
      Sequence predSeq = (Sequence) aPred.getTarget ();
      Sequence targetSeq = (Sequence) aTarget.getTarget ();
      DocumentExtraction docextr = new DocumentExtraction ("TEST"+i, dict, input, predSeq, targetSeq, "O");
      extraction.addDocumentExtraction (docextr);
View Full Code Here


    // Print to a string
    ByteArrayOutputStream out = new ByteArrayOutputStream ();
    PrintStream oldOut = System.out;
    System.setOut (new PrintStream (out));

    InstanceList lst = new InstanceList (p);
    lst.addThruPipe (new ArrayIterator (new String[] { TestCRF.data[0],
                                               TestCRF.data[1], }));

    System.setOut (oldOut);
   
    assertEquals (spacePipeOutput, out.toString());
View Full Code Here

  {
    Pipe pipe = TestMEMM.makeSpacePredictionPipe ();
    String[] data0 = { TestCRF.data[0] };
    String[] data1 = { TestCRF.data[1] };

    InstanceList training = new InstanceList (pipe);
    training.addThruPipe (new ArrayIterator (data0));
    InstanceList testing = new InstanceList (pipe);
    testing.addThruPipe (new ArrayIterator (data1));

    CRF crf = new CRF (pipe, null);
    crf.addFullyConnectedStatesForLabels ();
    CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
    crft.trainIncremental (training);
View Full Code Here

    boolean converged = false;

    int iteration = 0;
    while (!converged) {
      // Make a new trainingSet that has some labels set
      InstanceList trainingSet2 = new InstanceList (trainingSet.getPipe());
      for (int ii = 0; ii < trainingSet.size(); ii++) {
        Instance inst = trainingSet.get(ii);
        if (inst.getLabeling() != null)
          trainingSet2.add(inst, 1.0);
        else {
          Instance inst2 = inst.shallowCopy();
          inst2.unLock();
          inst2.setLabeling(c.classify(inst).getLabeling());
          inst2.lock();
          trainingSet2.add(inst2, unlabeledDataWeight);
        }
      }
      c = (NaiveBayes) nbTrainer.newClassifierTrainer().train (trainingSet2);
      logLikelihood = c.dataLogLikelihood (trainingSet2);
      System.err.println ("Loglikelihood = "+logLikelihood);
View Full Code Here

    public GainRatio getGainRatio() { return m_gainRatio; }
    public Object getSplitFeature() { return m_dataDict.lookupObject(m_gainRatio.getMaxValuedIndex()); }
   
    public InstanceList getInstances()
    {
      InstanceList ret = new InstanceList(m_ilist.getPipe());
      for (int ii = 0; ii < m_instIndices.length; ii++)
        ret.add(m_ilist.get(m_instIndices[ii]));
      return ret;
    }
View Full Code Here

  }
   
  public void testEvaluators ()
  {
    Randoms random = new Randoms(1);
    InstanceList instances = new InstanceList(random, 100, 2).subList(0,10);
    System.err.println(instances.size() + " instances");
    Clustering clustering = generateClustering(instances);
    System.err.println("clustering=" + clustering);

    System.err.println("ClusterSampleIterator");
    NeighborIterator iter = new ClusterSampleIterator(clustering,
View Full Code Here

                    new int[] { 0, 1 },
                    false),
            new PrintInputAndTarget (),
    });

    InstanceList mtLst = new InstanceList (mtPipe);
    InstanceList noMtLst = new InstanceList (noMtPipe);

    mtLst.addThruPipe (new ArrayIterator (doc1));
    noMtLst.addThruPipe (new ArrayIterator (doc1));

    Instance mtInst = mtLst.get (0);
    Instance noMtInst = noMtLst.get (0);

    TokenSequence mtTs = (TokenSequence) mtInst.getData ();
    TokenSequence noMtTs = (TokenSequence) noMtInst.getData ();

    assertEquals (6, mtTs.size ());
View Full Code Here

                    true),
            new PrintInputAndTarget (),
    });

    Pipe mtPipe = (Pipe) TestSerializable.cloneViaSerialization (origPipe);
    InstanceList mtLst = new InstanceList (mtPipe);
    mtLst.addThruPipe (new ArrayIterator (doc1));
    Instance mtInst = mtLst.get (0);
    TokenSequence mtTs = (TokenSequence) mtInst.getData ();
    assertEquals (6, mtTs.size ());
    assertEquals (1.0, mtTs.get (3).getFeatureValue ("time"), 1e-15);
    assertEquals (1.0, mtTs.get (4).getFeatureValue ("time"), 1e-15);
  }
View Full Code Here

    // Prune clusters based on size.
    if (minClusterSize.value > 1) {
      for (int i = 0; i < clusterings.size(); i++) {
        Clustering clustering = clusterings.get(i);
        InstanceList oldInstances = clustering.getInstances();
        Alphabet alph = oldInstances.getDataAlphabet();
        LabelAlphabet lalph = (LabelAlphabet) oldInstances.getTargetAlphabet();
        if (alph == null) alph = new Alphabet();
        if (lalph == null) lalph = new LabelAlphabet();
        Pipe noop = new Noop(alph, lalph);
        InstanceList newInstances = new InstanceList(noop);
        for (int j = 0; j < oldInstances.size(); j++) {
          int label = clustering.getLabel(j);
          Instance instance = oldInstances.get(j);
          if (clustering.size(label) >= minClusterSize.value)
            newInstances.add(noop.pipe(new Instance(instance.getData(), lalph.lookupLabel(new Integer(label)), instance.getName(), instance.getSource())));
        }
        clusterings.set(i, createSmallerClustering(newInstances));
      }
      if (outputPrefixFile.value != null) {
        try {
          ObjectOutputStream oos =
            new ObjectOutputStream(new FileOutputStream(outputPrefixFile.value));
          oos.writeObject(clusterings);
          oos.close();
        } catch (Exception e) {
          logger.warning("Exception writing clustering to file " + outputPrefixFile.value                        + " " + e);
          e.printStackTrace();
        }
      }
    }
   
   
    // Split into training/testing
    if (trainingProportion.value > 0) {
      if (clusterings.size() > 1)
        throw new IllegalArgumentException("Expect one clustering to do train/test split, not " + clusterings.size());
      Clustering clustering = clusterings.get(0);
      int targetTrainSize = (int)(trainingProportion.value * clustering.getNumInstances());
      TIntHashSet clustersSampled = new TIntHashSet();
      Randoms random = new Randoms(123);
      LabelAlphabet lalph = new LabelAlphabet();
      InstanceList trainingInstances = new InstanceList(new Noop(null, lalph));
      while (trainingInstances.size() < targetTrainSize) {
        int cluster = random.nextInt(clustering.getNumClusters());
        if (!clustersSampled.contains(cluster)) {
          clustersSampled.add(cluster);
          InstanceList instances = clustering.getCluster(cluster);
          for (int i = 0; i < instances.size(); i++) {
            Instance inst = instances.get(i);
            trainingInstances.add(new Instance(inst.getData(), lalph.lookupLabel(new Integer(cluster)), inst.getName(), inst.getSource()));
          }
        }
      }
      trainingInstances.shuffle(random);
      Clustering trainingClustering = createSmallerClustering(trainingInstances);
     
      InstanceList testingInstances = new InstanceList(null, lalph);
      for (int i = 0; i < clustering.getNumClusters(); i++) {
        if (!clustersSampled.contains(i)) {
          InstanceList instances = clustering.getCluster(i);
          for (int j = 0; j < instances.size(); j++) {
            Instance inst = instances.get(j);
            testingInstances.add(new Instance(inst.getData(), lalph.lookupLabel(new Integer(i)), inst.getName(), inst.getSource()));
          }         
        }
      }
      testingInstances.shuffle(random);
View Full Code Here

  CommandOption.Double(Vectors2FeatureConstraints.class, "majority-prob", "DOUBLE",
      false, 0.9, "Probability for majority labels when using heuristic target estimation.", null);

  public static void main(String[] args) {
    CommandOption.process(Vectors2FeatureConstraints.class, args);
    InstanceList list = InstanceList.load(vectorsFile.value)
   
    // Here we will assume that we use all labeled data available. 
    ArrayList<Integer> features = null;
    HashMap<Integer,ArrayList<Integer>> featuresAndLabels = null;

    // if a features file was specified, then load features from the file
    if (featuresFile.wasInvoked()) {
      if (fileContainsLabels(featuresFile.value)) {
        // better error message from dfrankow@gmail.com
        if (targets.value.equals("oracle")) {
          throw new RuntimeException("with --targets oracle, features file must be unlabeled");
        }
        featuresAndLabels = readFeaturesAndLabelsFromFile(featuresFile.value, list.getDataAlphabet(), list.getTargetAlphabet());
      }
      else {
        features = readFeaturesFromFile(featuresFile.value, list.getDataAlphabet());       
      }
    }
   
    // otherwise select features using specified method
    else {
      if (featureSelection.value.equals("infogain")) {
        features = FeatureConstraintUtil.selectFeaturesByInfoGain(list,numConstraints.value);
      }
      else if (featureSelection.value.equals("lda")) {
        try {
          ObjectInputStream ois = new ObjectInputStream(new FileInputStream(ldaFile.value));
          ParallelTopicModel lda = (ParallelTopicModel)ois.readObject();
          features = FeatureConstraintUtil.selectTopLDAFeatures(numConstraints.value, lda, list.getDataAlphabet());
        }
        catch (Exception e) {
          e.printStackTrace();
        }
      }
      else {
        throw new RuntimeException("Unsupported value for feature selection: " + featureSelection.value);
      }
    }
   
    // If the target method is oracle, then we do not need feature "labels".
    HashMap<Integer,double[]> constraints = null;
   
    if (targets.value.equals("none")) {
      constraints = new HashMap<Integer,double[]>();
      for (int fi : features) {    
        constraints.put(fi, null);
      }
    }
    else if (targets.value.equals("oracle")) {
      constraints = FeatureConstraintUtil.setTargetsUsingData(list, features);
    }
    else {
      // For other methods, we need to get feature labels, as
      // long as they haven't been already loaded from disk.
      if (featuresAndLabels == null) {
        featuresAndLabels = FeatureConstraintUtil.labelFeatures(list,features);
       
        for (int fi : featuresAndLabels.keySet()) {
          logger.info(list.getDataAlphabet().lookupObject(fi) + ":  ");
          for (int li : featuresAndLabels.get(fi)) {
            logger.info(list.getTargetAlphabet().lookupObject(li) + " ");
          }
        }
       
      }
      if (targets.value.equals("heuristic")) {
        constraints = FeatureConstraintUtil.setTargetsUsingHeuristic(featuresAndLabels,list.getTargetAlphabet().size(),majorityProb.value);
      }
      else if (targets.value.equals("voted")) {
        constraints = FeatureConstraintUtil.setTargetsUsingFeatureVoting(featuresAndLabels,list);
      }
      else {
        throw new RuntimeException("Unsupported value for targets: " + targets.value);
      }
    }
    writeConstraints(constraints,constraintsFile.value,list.getDataAlphabet(),list.getTargetAlphabet())
  }
View Full Code Here

TOP

Related Classes of cc.mallet.types.InstanceList

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.