Package cc.mallet.topics

Source Code of cc.mallet.topics.LDAStream

/**
* Implement different Gibbs sampling based inference methods
*/
package cc.mallet.topics;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.zip.GZIPOutputStream;

import cc.mallet.types.FeatureCounter;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;

/**
* @author Limin Yao, David Mimno
*
*/
public class LDAStream extends LDAHyper {

  protected ArrayList<Topication> test; // the test instances and their topic assignments
  /**
   * @param numberOfTopics
   */
  public LDAStream(int numberOfTopics) {
    super(numberOfTopics);
    // TODO Auto-generated constructor stub
  }

  /**
   * @param numberOfTopics
   * @param alphaSum
   * @param beta
   */
  public LDAStream(int numberOfTopics, double alphaSum, double beta) {
    super(numberOfTopics, alphaSum, beta);
    // TODO Auto-generated constructor stub
  }

  /**
   * @param numberOfTopics
   * @param alphaSum
   * @param beta
   * @param random
   */
  public LDAStream(int numberOfTopics, double alphaSum, double beta,
      Randoms random) {
    super(numberOfTopics, alphaSum, beta, random);
    // TODO Auto-generated constructor stub
  }

  /**
   * @param topicAlphabet
   * @param alphaSum
   * @param beta
   * @param random
   */
  public LDAStream(LabelAlphabet topicAlphabet, double alphaSum, double beta,
      Randoms random) {
    super(topicAlphabet, alphaSum, beta, random);
    // TODO Auto-generated constructor stub
  }

  public ArrayList<Topication> getTest() { return test; }
 
  //first training a topic model on training data,
  //inference on test data, count typeTopicCounts
  // re-sampling on all data
  public void inferenceAll(int maxIteration){
    this.test = new ArrayList<Topication>()//initialize test
    //initial sampling on testdata
    ArrayList<LabelSequence> topicSequences = new ArrayList<LabelSequence>();
    for (Instance instance : testing) {
      LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
      if (false) {
        // This method not yet obeying its last "false" argument, and must be for this to work
        //sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
      } else {
        Randoms r = new Randoms();
        FeatureSequence fs = (FeatureSequence) instance.getData();
        int[] topics = topicSequence.getFeatures();
        for (int i = 0; i < topics.length; i++) {
          int type = fs.getIndexAtPosition(i);
          topics[i] = r.nextInt(numTopics);
          typeTopicCounts[type].adjustOrPutValue(topics[i], 1, 1);
            tokensPerTopic[topics[i]]++;
        }
      }
      topicSequences.add (topicSequence);
    }

    //construct test
    assert (testing.size() == topicSequences.size());
    for (int i = 0; i < testing.size(); i++) {
      Topication t = new Topication (testing.get(i), this, topicSequences.get(i));
      test.add (t);
    }

    long startTime = System.currentTimeMillis();
    //loop
    int iter = 0;
    for ( ; iter <= maxIteration; iter++) {
      if(iter%100==0)
      {
        System.out.print("Iteration: " + iter);
        System.out.println();
      }
      int numDocs = test.size(); // TODO
      for (int di = 0; di < numDocs; di++) {
        FeatureSequence tokenSequence = (FeatureSequence) test.get(di).instance.getData();
        LabelSequence topicSequence = test.get(di).topicSequence;
        sampleTopicsForOneTestDocAll (tokenSequence, topicSequence);
      }
    }

    long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
    long minutes = seconds / 60;  seconds %= 60;
    long hours = minutes / 60;  minutes %= 60;
    long days = hours / 24;  hours %= 24;
    System.out.print ("\nTotal inferencing time: ");
    if (days != 0) { System.out.print(days); System.out.print(" days "); }
    if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
    if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
    System.out.print(seconds); System.out.println(" seconds");
  }

  //called by inferenceAll, using unseen words in testdata
  private void sampleTopicsForOneTestDocAll(FeatureSequence tokenSequence,
      LabelSequence topicSequence) {
    // TODO Auto-generated method stub
    int[] oneDocTopics = topicSequence.getFeatures();

    TIntIntHashMap currentTypeTopicCounts;
    int type, oldTopic, newTopic;
    double tw;
    double[] topicWeights = new double[numTopics];
    double topicWeightsSum;
    int docLength = tokenSequence.getLength();

    //    populate topic counts
    int[] localTopicCounts = new int[numTopics];
    for (int ti = 0; ti < numTopics; ti++){
      localTopicCounts[ti] = 0;
    }
    for (int position = 0; position < docLength; position++) {
      localTopicCounts[oneDocTopics[position]] ++;
    }

    // Iterate over the positions (words) in the document
    for (int si = 0; si < docLength; si++) {
      type = tokenSequence.getIndexAtPosition(si);
      oldTopic = oneDocTopics[si];

      // Remove this token from all counts
      localTopicCounts[oldTopic] --;

      currentTypeTopicCounts = typeTopicCounts[type];
      assert(currentTypeTopicCounts.get(oldTopic) >= 0);

      if (currentTypeTopicCounts.get(oldTopic) == 1) {
        currentTypeTopicCounts.remove(oldTopic);
      }
      else {
        currentTypeTopicCounts.adjustValue(oldTopic, -1);
      }
      tokensPerTopic[oldTopic]--;

      // Build a distribution over topics for this token
      Arrays.fill (topicWeights, 0.0);
      topicWeightsSum = 0;

      for (int ti = 0; ti < numTopics; ti++) {
        tw = ((currentTypeTopicCounts.get(ti) + beta) / (tokensPerTopic[ti] + betaSum))
              * ((localTopicCounts[ti] + alpha[ti])); // (/docLen-1+tAlpha); is constant across all topics
        topicWeightsSum += tw;
        topicWeights[ti] = tw;
      }
      // Sample a topic assignment from this distribution
      newTopic = random.nextDiscrete (topicWeights, topicWeightsSum);

      // Put that new topic into the counts
      oneDocTopics[si] = newTopic;
      currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
      localTopicCounts[newTopic] ++;
      tokensPerTopic[newTopic]++;
    }
  }
 
  //what do we have:
  //typeTopicCounts, tokensPerTopic, topic-sequence of training and test data
  public void estimateAll(int iteration) throws IOException {
    //re-Gibbs sampling on all data
    data.addAll(test);
    initializeHistogramsAndCachedValues();
    estimate(iteration);
  }

  //inference on testdata, one problem is how to deal with unseen words
  //unseen words is in the Alphabet, but typeTopicsCount entry is null
  //added by Limin Yao
  /**
   * @param maxIteration
   * @param
   */
  public void inference(int maxIteration){
    this.test = new ArrayList<Topication>()//initialize test
    //initial sampling on testdata
    ArrayList<LabelSequence> topicSequences = new ArrayList<LabelSequence>();
    for (Instance instance : testing) {
      LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
      if (false) {
        // This method not yet obeying its last "false" argument, and must be for this to work
        //sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
      } else {
        Randoms r = new Randoms();
        FeatureSequence fs = (FeatureSequence) instance.getData();
        int[] topics = topicSequence.getFeatures();
        for (int i = 0; i < topics.length; i++) {
          int type = fs.getIndexAtPosition(i);
          topics[i] = r.nextInt(numTopics);
        /*  if(typeTopicCounts[type].size() != 0) {
            topics[i] = r.nextInt(numTopics);
          } else {
            topics[i] = -1;  // for unseen words
          }*/
        }
      }
      topicSequences.add (topicSequence);
    }

    //construct test
    assert (testing.size() == topicSequences.size());
    for (int i = 0; i < testing.size(); i++) {
      Topication t = new Topication (testing.get(i), this, topicSequences.get(i));
      test.add (t);
      // Include sufficient statistics for this one doc
      // add count on new data to n[k][w] and n[k][*]
      // pay attention to unseen words
      FeatureSequence tokenSequence = (FeatureSequence) t.instance.getData();
      LabelSequence topicSequence = t.topicSequence;
      for (int pi = 0; pi < topicSequence.getLength(); pi++) {
        int topic = topicSequence.getIndexAtPosition(pi);
        int type = tokenSequence.getIndexAtPosition(pi);
        if(topic != -1) // type seen in training
        {
          typeTopicCounts[type].adjustOrPutValue(topic, 1, 1);
            tokensPerTopic[topic]++;
        }
      }
    }

    long startTime = System.currentTimeMillis();
    //loop
    int iter = 0;
    for ( ; iter <= maxIteration; iter++) {
      if(iter%100==0)
      {
        System.out.print("Iteration: " + iter);
        System.out.println();
      }
      int numDocs = test.size(); // TODO
      for (int di = 0; di < numDocs; di++) {
        FeatureSequence tokenSequence = (FeatureSequence) test.get(di).instance.getData();
        LabelSequence topicSequence = test.get(di).topicSequence;
        sampleTopicsForOneTestDoc (tokenSequence, topicSequence);
      }
    }

    long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
    long minutes = seconds / 60;  seconds %= 60;
    long hours = minutes / 60;  minutes %= 60;
    long days = hours / 24;  hours %= 24;
    System.out.print ("\nTotal inferencing time: ");
    if (days != 0) { System.out.print(days); System.out.print(" days "); }
    if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
    if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
    System.out.print(seconds); System.out.println(" seconds");
  }

  private void sampleTopicsForOneTestDoc(FeatureSequence tokenSequence,
      LabelSequence topicSequence) {
    // TODO Auto-generated method stub
    int[] oneDocTopics = topicSequence.getFeatures();

    TIntIntHashMap currentTypeTopicCounts;
    int type, oldTopic, newTopic;
    double tw;
    double[] topicWeights = new double[numTopics];
    double topicWeightsSum;
    int docLength = tokenSequence.getLength();

    //    populate topic counts
    int[] localTopicCounts = new int[numTopics];
    for (int ti = 0; ti < numTopics; ti++){
      localTopicCounts[ti] = 0;
    }
    for (int position = 0; position < docLength; position++) {
      if(oneDocTopics[position] != -1) {
        localTopicCounts[oneDocTopics[position]] ++;
      }
    }

    // Iterate over the positions (words) in the document
    for (int si = 0; si < docLength; si++) {
      type = tokenSequence.getIndexAtPosition(si);
      oldTopic = oneDocTopics[si];
      if(oldTopic == -1) {
        continue;
      }

      // Remove this token from all counts
         localTopicCounts[oldTopic] --;
         currentTypeTopicCounts = typeTopicCounts[type];
      assert(currentTypeTopicCounts.get(oldTopic) >= 0);

      if (currentTypeTopicCounts.get(oldTopic) == 1) {
        currentTypeTopicCounts.remove(oldTopic);
      }
      else {
        currentTypeTopicCounts.adjustValue(oldTopic, -1);
      }
      tokensPerTopic[oldTopic]--;

      // Build a distribution over topics for this token
      Arrays.fill (topicWeights, 0.0);
      topicWeightsSum = 0;

      for (int ti = 0; ti < numTopics; ti++) {
        tw = ((currentTypeTopicCounts.get(ti) + beta) / (tokensPerTopic[ti] + betaSum))
              * ((localTopicCounts[ti] + alpha[ti])); // (/docLen-1+tAlpha); is constant across all topics
        topicWeightsSum += tw;
        topicWeights[ti] = tw;
      }
      // Sample a topic assignment from this distribution
      newTopic = random.nextDiscrete (topicWeights, topicWeightsSum);

      // Put that new topic into the counts
      oneDocTopics[si] = newTopic;
      currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
      localTopicCounts[newTopic] ++;
      tokensPerTopic[newTopic]++;
    }
  }

  //inference method 3, for each doc, for each iteration, for each word
  //compare against inference(that is method2): for each iter, for each doc, for each word
  public void inferenceOneByOne(int maxIteration){
    this.test = new ArrayList<Topication>()//initialize test
    //initial sampling on testdata
    ArrayList<LabelSequence> topicSequences = new ArrayList<LabelSequence>();
    for (Instance instance : testing) {
      LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
      if (false) {
        // This method not yet obeying its last "false" argument, and must be for this to work
        //sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
      } else {
        Randoms r = new Randoms();
        FeatureSequence fs = (FeatureSequence) instance.getData();
        int[] topics = topicSequence.getFeatures();
        for (int i = 0; i < topics.length; i++) {
          int type = fs.getIndexAtPosition(i);
          topics[i] = r.nextInt(numTopics);
          typeTopicCounts[type].adjustOrPutValue(topics[i], 1, 1);
          tokensPerTopic[topics[i]]++;
        /*  if(typeTopicCounts[type].size() != 0) {
            topics[i] = r.nextInt(numTopics);
            typeTopicCounts[type].adjustOrPutValue(topics[i], 1, 1);
            tokensPerTopic[topics[i]]++;
          } else {
            topics[i] = -1;  // for unseen words
          }*/
        }
      }
      topicSequences.add (topicSequence);
    }

    //construct test
    assert (testing.size() == topicSequences.size());
    for (int i = 0; i < testing.size(); i++) {
      Topication t = new Topication (testing.get(i), this, topicSequences.get(i));
      test.add (t);
    }

    long startTime = System.currentTimeMillis();
    //loop
    int iter = 0;
    int numDocs = test.size(); // TODO
    for (int di = 0; di < numDocs; di++) {
      iter = 0;
      FeatureSequence tokenSequence = (FeatureSequence) test.get(di).instance.getData();
      LabelSequence topicSequence = test.get(di).topicSequence;
      for( ; iter <= maxIteration; iter++) {
        sampleTopicsForOneTestDoc (tokenSequence, topicSequence);
      }
      if(di%100==0)
      {
        System.out.print("Docnum: " + di);
        System.out.println();
      }
    }

    long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
    long minutes = seconds / 60;  seconds %= 60;
    long hours = minutes / 60;  minutes %= 60;
    long days = hours / 24;  hours %= 24;
    System.out.print ("\nTotal inferencing time: ");
    if (days != 0) { System.out.print(days); System.out.print(" days "); }
    if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
    if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
    System.out.print(seconds); System.out.println(" seconds");
  }
 
  public void inferenceWithTheta(int maxIteration, InstanceList theta){
    this.test = new ArrayList<Topication>()//initialize test
    //initial sampling on testdata
    ArrayList<LabelSequence> topicSequences = new ArrayList<LabelSequence>();
    for (Instance instance : testing) {
      LabelSequence topicSequence = new LabelSequence(topicAlphabet, new int[instanceLength(instance)]);
      if (false) {
        // This method not yet obeying its last "false" argument, and must be for this to work
        //sampleTopicsForOneDoc((FeatureSequence)instance.getData(), topicSequence, false, false);
      } else {
        Randoms r = new Randoms();
        FeatureSequence fs = (FeatureSequence) instance.getData();
        int[] topics = topicSequence.getFeatures();
        for (int i = 0; i < topics.length; i++) {
          int type = fs.getIndexAtPosition(i);
          topics[i] = r.nextInt(numTopics);
        }
      }
      topicSequences.add (topicSequence);
    }

    //construct test
    assert (testing.size() == topicSequences.size());
    for (int i = 0; i < testing.size(); i++) {
      Topication t = new Topication (testing.get(i), this, topicSequences.get(i));
      test.add (t);
      // Include sufficient statistics for this one doc
      // add count on new data to n[k][w] and n[k][*]
      // pay attention to unseen words
      FeatureSequence tokenSequence = (FeatureSequence) t.instance.getData();
      LabelSequence topicSequence = t.topicSequence;
      for (int pi = 0; pi < topicSequence.getLength(); pi++) {
        int topic = topicSequence.getIndexAtPosition(pi);
        int type = tokenSequence.getIndexAtPosition(pi);
        if(topic != -1) // type seen in training
        {
          typeTopicCounts[type].adjustOrPutValue(topic, 1, 1);
            tokensPerTopic[topic]++;
        }
      }
    }

    long startTime = System.currentTimeMillis();
    //loop
    int iter = 0;
    for ( ; iter <= maxIteration; iter++) {
      if(iter%100==0)
      {
        System.out.print("Iteration: " + iter);
        System.out.println();
      }
      int numDocs = test.size(); // TODO
      for (int di = 0; di < numDocs; di++) {
        FeatureVector fvTheta = (FeatureVector) theta.get(di).getData();
        double[] topicDistribution = fvTheta.getValues();
        FeatureSequence tokenSequence = (FeatureSequence) test.get(di).instance.getData();
        LabelSequence topicSequence = test.get(di).topicSequence;
        sampleTopicsForOneDocWithTheta (tokenSequence, topicSequence, topicDistribution);
      }
    }

    long seconds = Math.round((System.currentTimeMillis() - startTime)/1000.0);
    long minutes = seconds / 60;  seconds %= 60;
    long hours = minutes / 60;  minutes %= 60;
    long days = hours / 24;  hours %= 24;
    System.out.print ("\nTotal inferencing time: ");
    if (days != 0) { System.out.print(days); System.out.print(" days "); }
    if (hours != 0) { System.out.print(hours); System.out.print(" hours "); }
    if (minutes != 0) { System.out.print(minutes); System.out.print(" minutes "); }
    System.out.print(seconds); System.out.println(" seconds");
  }

  //sampling with known theta, from maxent
  private void sampleTopicsForOneDocWithTheta(FeatureSequence tokenSequence,
      LabelSequence topicSequence, double[] topicDistribution) {
    // TODO Auto-generated method stub
    int[] oneDocTopics = topicSequence.getFeatures();

    TIntIntHashMap currentTypeTopicCounts;
    int type, oldTopic, newTopic;
    double tw;
    double[] topicWeights = new double[numTopics];
    double topicWeightsSum;
    int docLength = tokenSequence.getLength();
   
    // Iterate over the positions (words) in the document
    for (int si = 0; si < docLength; si++) {
      type = tokenSequence.getIndexAtPosition(si);
      oldTopic = oneDocTopics[si];
      if(oldTopic == -1) {
        continue;
      }

       currentTypeTopicCounts = typeTopicCounts[type];
      assert(currentTypeTopicCounts.get(oldTopic) >= 0);

      if (currentTypeTopicCounts.get(oldTopic) == 1) {
        currentTypeTopicCounts.remove(oldTopic);
      }
      else {
        currentTypeTopicCounts.adjustValue(oldTopic, -1);
      }
      tokensPerTopic[oldTopic]--;

      // Build a distribution over topics for this token
      Arrays.fill (topicWeights, 0.0);
      topicWeightsSum = 0;

      for (int ti = 0; ti < numTopics; ti++) {
        tw = ((currentTypeTopicCounts.get(ti) + beta) / (tokensPerTopic[ti] + betaSum))
              * topicDistribution[ti]; // (/docLen-1+tAlpha); is constant across all topics
        topicWeightsSum += tw;
        topicWeights[ti] = tw;
      }
      // Sample a topic assignment from this distribution
      newTopic = random.nextDiscrete (topicWeights, topicWeightsSum);

      // Put that new topic into the counts
      oneDocTopics[si] = newTopic;
      currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
      tokensPerTopic[newTopic]++;
    }
  }

  //print human readable doc-topic matrix, for further IR use
  public void printTheta(ArrayList<Topication> dataset, File f, double threshold, int max) throws IOException{
    PrintWriter pw = new PrintWriter(new FileWriter(f));
    int[] topicCounts = new int[ numTopics ];
    int docLen;
   
    for (int di = 0; di < dataset.size(); di++) {
      LabelSequence topicSequence = dataset.get(di).topicSequence;
      int[] currentDocTopics = topicSequence.getFeatures();
      docLen = currentDocTopics.length;
      for (int token=0; token < docLen; token++) {
        topicCounts[ currentDocTopics[token] ]++;
      }
      pw.println(dataset.get(di).instance.getName());
      // n(t|d)+alpha(t) / docLen + alphaSum
      for (int topic = 0; topic < numTopics; topic++) {
        double prob = (double) (topicCounts[topic]+alpha[topic]) / (docLen + alphaSum);
        pw.println("topic"+ topic + "\t" + prob);
      }

      pw.println();
      Arrays.fill(topicCounts, 0);
    }
    pw.close();
  }
 
  //print topic-word matrix, for further IR use
  public void printPhi(File f, double threshold) throws IOException{
    PrintWriter pw = new PrintWriter(new FileWriter(f));
    FeatureCounter[] wordCountsPerTopic = new FeatureCounter[numTopics];
    for (int ti = 0; ti < numTopics; ti++) {
      wordCountsPerTopic[ti] = new FeatureCounter(alphabet);
    }

    for (int fi = 0; fi < numTypes; fi++) {
      int[] topics = typeTopicCounts[fi].keys();
      for (int i = 0; i < topics.length; i++) {
        wordCountsPerTopic[topics[i]].increment(fi, typeTopicCounts[fi].get(topics[i]));
      }
    }
   
    for(int ti = 0; ti < numTopics; ti++){
      pw.println("Topic\t" + ti);
      FeatureCounter counter = wordCountsPerTopic[ti];
      FeatureVector fv = counter.toFeatureVector();
      for(int pos = 0; pos < fv.numLocations(); pos++){
        int fi = fv.indexAtLocation(pos);
        String word = (String) alphabet.lookupObject(fi);
        int count = (int) fv.valueAtLocation(pos);
        double prob;
        prob = (double) (count+beta)/(tokensPerTopic[ti] + betaSum);
        pw.println(word + "\t" + prob);
      }
      pw.println();
    }
    pw.close();
  }

  public void printDocumentTopics (ArrayList<Topication> dataset, File f) throws IOException {
    printDocumentTopics (dataset, new PrintWriter (new FileWriter (f) ) );
  }

  public void printDocumentTopics (ArrayList<Topication> dataset, PrintWriter pw) {
    printDocumentTopics (dataset, pw, 0.0, -1);
  }

  /**
   *  @param pw          A print writer
   *  @param threshold   Only print topics with proportion greater than this number
   *  @param max         Print no more than this many topics
   */
  public void printDocumentTopics (ArrayList<Topication> dataset, PrintWriter pw, double threshold, int max)  {
    pw.print ("#doc source topic proportion ...\n");
    int docLen;
    int[] topicCounts = new int[ numTopics ];

    IDSorter[] sortedTopics = new IDSorter[ numTopics ];
    for (int topic = 0; topic < numTopics; topic++) {
      // Initialize the sorters with dummy values
      sortedTopics[topic] = new IDSorter(topic, topic);
    }

    if (max < 0 || max > numTopics) {
      max = numTopics;
    }

    for (int di = 0; di < dataset.size(); di++) {
      LabelSequence topicSequence = dataset.get(di).topicSequence;
      int[] currentDocTopics = topicSequence.getFeatures();

      pw.print (di); pw.print (' ');

      if (dataset.get(di).instance.getSource() != null) {
        pw.print (dataset.get(di).instance.getSource());
      }
      else {
        pw.print ("null-source");
      }

      pw.print (' ');
      docLen = currentDocTopics.length;

      // Count up the tokens
      int realDocLen = 0;
      for (int token=0; token < docLen; token++) {
        if(currentDocTopics[token] != -1) {
          topicCounts[ currentDocTopics[token] ]++;
          realDocLen ++;
        }
      }
      assert(realDocLen == docLen);
            alphaSum=0.0;
      for(int topic=0; topic < numTopics; topic++){
        alphaSum+=alpha[topic];
      }
     
      // And normalize and smooth by Dirichlet prior alpha
      for (int topic = 0; topic < numTopics; topic++) {
        sortedTopics[topic].set(topic, (double) (topicCounts[topic]+alpha[topic]) / (docLen + alphaSum));
      }
    
      Arrays.sort(sortedTopics);

      for (int i = 0; i < max; i++) {
        if (sortedTopics[i].getWeight() < threshold) { break; }

        pw.print (sortedTopics[i].getID() + " " +
              sortedTopics[i].getWeight() + " ");
      }
      pw.print (" \n");

      Arrays.fill(topicCounts, 0);
    }
        pw.close();
  }
 
  public void printState (ArrayList<Topication> dataset, File f) throws IOException {
    PrintStream out =
      new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
    printState(dataset, out);
    out.close();
  }

  public void printState (ArrayList<Topication> dataset, PrintStream out) {

    out.println ("#doc source pos typeindex type topic");

    for (int di = 0; di < dataset.size(); di++) {
      FeatureSequence tokenSequence =  (FeatureSequence) dataset.get(di).instance.getData();
      LabelSequence topicSequence =  dataset.get(di).topicSequence;

      String source = "NA";
      if (dataset.get(di).instance.getSource() != null) {
        source = dataset.get(di).instance.getSource().toString();
      }

      for (int pi = 0; pi < topicSequence.getLength(); pi++) {
        int type = tokenSequence.getIndexAtPosition(pi);
        int topic = topicSequence.getIndexAtPosition(pi);
        out.print(di); out.print(' ');
        out.print(source); out.print(' ');
        out.print(pi); out.print(' ');
        out.print(type); out.print(' ');
        out.print(alphabet.lookupObject(type)); out.print(' ');
        out.print(topic); out.println();
      }
    }
  }

}
TOP

Related Classes of cc.mallet.topics.LDAStream

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.