Package cc.mallet.types

Examples of cc.mallet.types.InstanceList


    public void split (FeatureSelection fs)
    {
      if (ilist == null)
        throw new IllegalStateException ("Frozen.  Cannot split.");
      InstanceList ilist0 = new InstanceList (ilist.getPipe());
      InstanceList ilist1 = new InstanceList (ilist.getPipe());
      for (int i = 0; i < ilist.size(); i++) {
        Instance instance = ilist.get(i);
        FeatureVector fv = (FeatureVector) instance.getData ();
        // xxx What test should this be?  What to do with negative values?
          // Whatever is decided here should also go in InfoGain.calcInfoGains()
        if (fv.value (featureIndex) != 0) {
          //System.out.println ("list1 add "+instance.getUri()+" weight="+ilist.getInstanceWeight(i));
          ilist1.add (instance, ilist.getInstanceWeight(i));
        } else {
          //System.out.println ("list0 add "+instance.getUri()+" weight="+ilist.getInstanceWeight(i));
          ilist0.add (instance, ilist.getInstanceWeight(i));
        }
      }
      logger.info("child0="+ilist0.size()+" child1="+ilist1.size());
      child0 = new Node (ilist0, this, fs);
      child1 = new Node (ilist1, this, fs);
    }
View Full Code Here


    for (MaxEntPRConstraint constraint : constraints) {
      BitSet bitset = constraint.preProcess(data);
      instancesWithConstraints.or(bitset);
    }
   
    InstanceList unlabeled = data.cloneEmpty();
    for (int ii = 0; ii < data.size(); ii++) {
      if (instancesWithConstraints.get(ii)) {
        boolean noLabel = data.get(ii).getTarget() == null;
        if (noLabel) {
          data.get(ii).unLock();
          data.get(ii).setTarget(new NullLabel((LabelAlphabet)data.getTargetAlphabet()));
        }
        unlabeled.add(data.get(ii));
      }
    }

    int numFeatures = unlabeled.getDataAlphabet().size();
   
    // setup model
    int numParameters = (numFeatures + 1) * unlabeled.getTargetAlphabet().size();
    if (p == null) {
      p = new MaxEnt(unlabeled.getPipe(),new double[numParameters]);
    }

    // setup aux model
    q = new PRAuxClassifier(unlabeled.getPipe(),constraints);
   
    double oldValue = -Double.MAX_VALUE;
    for (numIterations = 0; numIterations < maxIterations; numIterations++) {

      double[][] base = optimizeQ(unlabeled,p,numIterations==0);
View Full Code Here

    return p;
  }
 
  private double optimizePAndComputeValue(InstanceList data, PRAuxClassifier q, double[][] base, double pGPV) {
   
    InstanceList dataLabeled = data.cloneEmpty();
   
    double entropy = 0;
   
    int numLabels = data.getTargetAlphabet().size();
    for (int ii = 0; ii < data.size(); ii++) {
      double[] scores = new double[numLabels];
      q.getClassificationScores(data.get(ii), scores);
      for (int li = 0; li < numLabels; li++) {
        if (base != null && base[ii][li] == 0) {
          scores[li] = Double.NEGATIVE_INFINITY;
        }
        else if (base != null) {
          double logP = Math.log(base[ii][li]);
          scores[li] += logP; 
        }
      }
      MatrixOps.expNormalize(scores);
  
      entropy += Maths.getEntropy(scores);

      LabelVector lv = new LabelVector((LabelAlphabet)data.getTargetAlphabet(), scores);
      Instance instance = new Instance(data.get(ii).getData(),lv,null,null);
      dataLabeled.add(instance);
    }
   
    // train supervised
    MaxEntOptimizableByLabelDistribution opt = new  MaxEntOptimizableByLabelDistribution(dataLabeled,p);
    opt.setGaussianPriorVariance(pGPV);
View Full Code Here

    private Pair<Double, Sequence<?>> applyCRF(String testingdata) {
        Sequence<?> input = null;
        Sequence<?> output = null;
        Double conf;

        InstanceList testSequence = null;
        crf_pipe.setTargetProcessing(true);
        testSequence = new InstanceList(crf_pipe);
        testSequence.addThruPipe(new LineGroupIterator(new StringReader(
                testingdata), Pattern.compile("^\\s*$"), true));

        if (testSequence.size() < 1) {
            return new Pair<Double, Sequence<?>>(-1.0, null);
        }

        Instance inst = testSequence.get(0);
        input = (Sequence<?>) inst.getData();

        output = crf.transduce(input);
        conf = crf_estimator.estimateConfidenceFor(inst, startTags, inTags);

View Full Code Here

    CommandOption.process (TopicTrain.class, args);

    LDAStream lda = null;

    if (inputFile.value != null) {
      InstanceList instances = InstanceList.load (new File(inputFile.value));
      System.out.println ("Training Data loaded.");
      lda=new LDAStream(numTopics.value, alpha.value, beta.value);
      lda.addInstances(instances);
    }
    if(testFile.value != null) {
      InstanceList testing = InstanceList.load(new File(testFile.value));
      lda.setTestingInstances(testing);
   
   
     lda.setTopicDisplay(showTopicsInterval.value, topWords.value);
View Full Code Here

   public void runCrossValidation(int nfolds, int iterations){
    
     InstanceList.CrossValidationIterator iter = instances.crossValidationIterator(nfolds);
     double[] prepResult = new  double[nfolds];
     InstanceList[] splitedList;
     InstanceList trainList, testList;
    
     int fold=0;
     while(iter.hasNext()){
       splitedList = iter.next();
       trainList = splitedList[0];
View Full Code Here

    //pipeList.add(new TokenSequenceNGrams(new int[] {2} ));
       
    //convert to feature
    pipeList.add( new TokenSequence2FeatureSequence() );

    InstanceList instances = new InstanceList (new SerialPipes(pipeList));
    InstanceList testInstances = new InstanceList (instances.getPipe());
       
    Reader insfileReader = new InputStreamReader(new FileInputStream(new File(inputFileName)), "UTF-8");
    Reader testfileReader = new InputStreamReader(new FileInputStream(new File(testFileName)), "UTF-8");
       
    instances.addThruPipe(new CsvIterator (insfileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
                             3, 2, 1)); // data, label, name fields
    testInstances.addThruPipe(new CsvIterator (testfileReader, Pattern.compile("^(\\S*)[\\s,]*(\\S*)[\\s,]*(.*)$"),
           3, 2, 1)); // data, label, name fields
   
    //setup HDP parameters(alpha, beta, gamma, initialTopics)
    HDP hdp = new HDP(1.0, 0.1, 1.0, 10);
    hdp.initialize(instances);
View Full Code Here

    CommandOption.setSummary (Vectors2Topics.class,
                  "A tool for estimating, saving and printing diagnostics for topic models, such as LDA.");
    CommandOption.process (Vectors2Topics.class, args);

    if (usePAM.value) {
      InstanceList ilist = InstanceList.load (new File(inputFile.value));
      System.out.println ("Data loaded.");
      if (inputModelFilename.value != null)
        throw new IllegalArgumentException ("--input-model not supported with --use-pam.");
      PAM4L pam = new PAM4L(pamNumSupertopics.value, pamNumSubtopics.value);
      pam.estimate (ilist, numIterations.value, /*optimizeModelInterval*/50,
              showTopicsInterval.value,
              outputModelInterval.value, outputModelFilename.value,
              randomSeed.value == 0 ? new Randoms() : new Randoms(randomSeed.value));
      pam.printTopWords(topWords.value, true);
      if (stateFile.value != null)
        pam.printState (new File(stateFile.value));
      if (docTopicsFile.value != null) {
        PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
        pam.printDocumentTopics (out, docTopicsThreshold.value, docTopicsMax.value);
        out.close();
      }

     
      if (outputModelFilename.value != null) {
        assert (pam != null);
        try {
          ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream (outputModelFilename.value));
          oos.writeObject (pam);
          oos.close();
        } catch (Exception e) {
          e.printStackTrace();
          throw new IllegalArgumentException ("Couldn't write topic model to filename "+outputModelFilename.value);
        }
      }
     

    }
   
    else if (useNgrams.value) {
      InstanceList ilist = InstanceList.load (new File(inputFile.value));
      System.out.println ("Data loaded.");
      if (inputModelFilename.value != null)
        throw new IllegalArgumentException ("--input-model not supported with --use-ngrams.");
      TopicalNGrams tng = new TopicalNGrams(numTopics.value,
                          alpha.value,
                          beta.value,
                          gamma.value,
                          delta.value,
                          delta1.value,
                          delta2.value);
      tng.estimate (ilist, numIterations.value, showTopicsInterval.value,
              outputModelInterval.value, outputModelFilename.value,
              randomSeed.value == 0 ? new Randoms() : new Randoms(randomSeed.value));
      tng.printTopWords(topWords.value, true);
      if (stateFile.value != null)
        tng.printState (new File(stateFile.value));
      if (docTopicsFile.value != null) {
        PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
        tng.printDocumentTopics (out, docTopicsThreshold.value, docTopicsMax.value);
        out.close();
      }

      if (outputModelFilename.value != null) {
        assert (tng != null);
        try {
          ObjectOutputStream oos = new ObjectOutputStream (new FileOutputStream (outputModelFilename.value));
          oos.writeObject (tng);
          oos.close();
        } catch (Exception e) {
          e.printStackTrace();
          throw new IllegalArgumentException ("Couldn't write topic model to filename "+outputModelFilename.value);
        }
      }
     
    }
    else if (languageInputFiles.value != null) {
      // Start a new polylingual topic model
     
      PolylingualTopicModel topicModel = null;

      InstanceList[] training = new InstanceList[ languageInputFiles.value.length ];
      for (int i=0; i < training.length; i++) {
        training[i] = InstanceList.load(new File(languageInputFiles.value[i]));
        if (training[i] != null) { System.out.println(i + " is not null"); }
        else { System.out.println(i + " is null"); }
      }

      System.out.println ("Data loaded.");
     
      // For historical reasons we currently only support FeatureSequence data,
      //  not the FeatureVector, which is the default for the input functions.
      //  Provide a warning to avoid ClassCastExceptions.
      if (training[0].size() > 0 &&
        training[0].get(0) != null) {
        Object data = training[0].get(0).getData();
        if (! (data instanceof FeatureSequence)) {
          System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
          System.exit(1);
        }
      }
     
      topicModel = new PolylingualTopicModel (numTopics.value, alpha.value);
      if (randomSeed.value != 0) {
        topicModel.setRandomSeed(randomSeed.value);
      }
     
      topicModel.addInstances(training);

      topicModel.setTopicDisplay(showTopicsInterval.value, topWords.value);

            topicModel.setNumIterations(numIterations.value);
            topicModel.setOptimizeInterval(optimizeInterval.value);
            topicModel.setBurninPeriod(optimizeBurnIn.value);

            if (outputStateInterval.value != 0) {
                topicModel.setSaveState(outputStateInterval.value, stateFile.value);
            }

            if (outputModelInterval.value != 0) {
                topicModel.setModelOutput(outputModelInterval.value, outputModelFilename.value);
            }

      topicModel.estimate();

      if (topicKeysFile.value != null) {
        topicModel.printTopWords(new File(topicKeysFile.value), topWords.value, false);
      }

      if (stateFile.value != null) {
        topicModel.printState (new File(stateFile.value));
      }

      if (docTopicsFile.value != null) {
        PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value))));
        topicModel.printDocumentTopics(out, docTopicsThreshold.value, docTopicsMax.value);
        out.close();
      }

      if (outputModelFilename.value != null) {
        assert (topicModel != null);
        try {

          ObjectOutputStream oos =
            new ObjectOutputStream (new FileOutputStream (outputModelFilename.value));
          oos.writeObject (topicModel);
          oos.close();

        } catch (Exception e) {
          e.printStackTrace();
          throw new IllegalArgumentException ("Couldn't write topic model to filename "+outputModelFilename.value);
        }
      }

    }
    else {

      // Start a new LDA topic model
     
      ParallelTopicModel topicModel = null;

      if (inputModelFilename.value != null) {
       
        try {
          topicModel = ParallelTopicModel.read(new File(inputModelFilename.value));
        } catch (Exception e) {
          System.err.println("Unable to restore saved topic model " +
                     inputModelFilename.value + ": " + e);
          System.exit(1);
        }
        /*
        // Loading new data is optional if we are restoring a saved state.
        if (inputFile.value != null) {
          InstanceList instances = InstanceList.load (new File(inputFile.value));
          System.out.println ("Data loaded.");
          lda.addInstances(instances);
        }
        */
      }
      else {
        InstanceList training = InstanceList.load (new File(inputFile.value));
        System.out.println ("Data loaded.");

        if (training.size() > 0 &&
          training.get(0) != null) {
          Object data = training.get(0).getData();
          if (! (data instanceof FeatureSequence)) {
            System.err.println("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
            System.exit(1);
          }
        }
View Full Code Here

                 inputModelFilename.value + ": " + e);
        System.exit(1);
      }
    }
    else {
      InstanceList training = InstanceList.load (new File(inputFile.value));
      logger.info("Data loaded.");
     
      if (training.size() > 0 &&
        training.get(0) != null) {
        Object data = training.get(0).getData();
        if (! (data instanceof FeatureSequence)) {
          logger.warning("Topic modeling currently only supports feature sequences: use --keep-sequence option when importing data.");
          System.exit(1);
        }
      }
View Full Code Here

   * @exception Exception if an error occurs
   */
  public static void main (String[] args) throws Exception
  {
    Reader trainingFile = null, testFile = null;
    InstanceList trainingData = null, testData = null;
    int numEvaluations = 0;
    int iterationsBetweenEvals = 16;
    int restArgs = commandOptions.processOptions(args);
    if (restArgs == args.length)
    {
      commandOptions.printUsage(true);
      throw new IllegalArgumentException("Missing data file(s)");
    }
    if (trainOption.value)
    {
      trainingFile = new FileReader(new File(args[restArgs]));
      if (testOption.value != null && restArgs < args.length - 1)
        testFile = new FileReader(new File(args[restArgs+1]));
    } else
      testFile = new FileReader(new File(args[restArgs]));

    Pipe p = null;
    CRF crf = null;
    TransducerEvaluator eval = null;
    if (continueTrainingOption.value || !trainOption.value) {
      if (modelOption.value == null)
      {
        commandOptions.printUsage(true);
        throw new IllegalArgumentException("Missing model file option");
      }
      ObjectInputStream s =
        new ObjectInputStream(new FileInputStream(modelOption.value));
      crf = (CRF) s.readObject();
      s.close();
      p = crf.getInputPipe();
    }
    else {
      p = new SimpleTaggerSentence2FeatureVectorSequence();
      p.getTargetAlphabet().lookupIndex(defaultOption.value);
    }


    if (trainOption.value)
    {
      p.setTargetProcessing(true);
      trainingData = new InstanceList(p);
      trainingData.addThruPipe(
          new LineGroupIterator(trainingFile,
            Pattern.compile("^\\s*$"), true));
      logger.info
        ("Number of features in training data: "+p.getDataAlphabet().size());
      if (testOption.value != null)
      {
        if (testFile != null)
        {
          testData = new InstanceList(p);
          testData.addThruPipe(
              new LineGroupIterator(testFile,
                Pattern.compile("^\\s*$"), true));
        } else
        {
          Random r = new Random (randomSeedOption.value);
          InstanceList[] trainingLists =
            trainingData.split(
                r, new double[] {trainingFractionOption.value,
                  1-trainingFractionOption.value});
          trainingData = trainingLists[0];
          testData = trainingLists[1];
        }
      }
    } else if (testOption.value != null)
    {
      p.setTargetProcessing(true);
      testData = new InstanceList(p);
      testData.addThruPipe(
          new LineGroupIterator(testFile,
            Pattern.compile("^\\s*$"), true));
    } else
    {
      p.setTargetProcessing(false);
      testData = new InstanceList(p);
      testData.addThruPipe(
          new LineGroupIterator(testFile,
            Pattern.compile("^\\s*$"), true));
    }
    logger.info ("Number of predicates: "+p.getDataAlphabet().size());
View Full Code Here

TOP

Related Classes of cc.mallet.types.InstanceList

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.