Package cc.mallet.topics

Source Code of cc.mallet.topics.MarginalProbEstimator

/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

package cc.mallet.topics;

import java.util.Arrays;
import java.util.ArrayList;

import java.util.zip.*;

import java.io.*;
import java.text.NumberFormat;

import cc.mallet.types.*;
import cc.mallet.util.Randoms;

/**
* An implementation of topic model marginal probability estimators
*  presented in Wallach et al., "Evaluation Methods for Topic Models", ICML (2009)
*
* @author David Mimno
*/

public class MarginalProbEstimator implements Serializable {

  protected int numTopics; // Number of topics to be fit

  // These values are used to encode type/topic counts as
  //  count/topic pairs in a single int.
  protected int topicMask;
  protected int topicBits;

  protected double[] alpha;   // Dirichlet(alpha,alpha,...) is the distribution over topics
  protected double alphaSum;
  protected double beta;   // Prior on per-topic multinomial distribution over words
  protected double betaSum;
 
  protected double smoothingOnlyMass = 0.0;
  protected double[] cachedCoefficients;

  protected int[][] typeTopicCounts; // indexed by <feature index, topic index>
  protected int[] tokensPerTopic; // indexed by <topic index>

  protected Randoms random;
 
  public MarginalProbEstimator (int numTopics,
                  double[] alpha, double alphaSum,
                  double beta,
                  int[][] typeTopicCounts,
                  int[] tokensPerTopic) {

    this.numTopics = numTopics;

    if (Integer.bitCount(numTopics) == 1) {
      // exact power of 2
      topicMask = numTopics - 1;
      topicBits = Integer.bitCount(topicMask);
    }
    else {
      // otherwise add an extra bit
      topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
      topicBits = Integer.bitCount(topicMask);
    }

    this.typeTopicCounts = typeTopicCounts;
    this.tokensPerTopic = tokensPerTopic;
   
    this.alphaSum = alphaSum;
    this.alpha = alpha;
    this.beta = beta;
    this.betaSum = beta * typeTopicCounts.length;
    this.random = new Randoms();
   
    cachedCoefficients = new double[ numTopics ];

    // Initialize the smoothing-only sampling bucket
    smoothingOnlyMass = 0;
   
    // Initialize the cached coefficients, using only smoothing.
    //  These values will be selectively replaced in documents with
    //  non-zero counts in particular topics.
   
    for (int topic=0; topic < numTopics; topic++) {
      smoothingOnlyMass += alpha[topic] * beta / (tokensPerTopic[topic] + betaSum);
      cachedCoefficients[topic] =  alpha[topic] / (tokensPerTopic[topic] + betaSum);
    }
   
    System.err.println("Topic Evaluator: " + numTopics + " topics, " + topicBits + " topic bits, " +
               Integer.toBinaryString(topicMask) + " topic mask");

  }

  public int[] getTokensPerTopic() { return tokensPerTopic; }
  public int[][] getTypeTopicCounts() { return typeTopicCounts; }

  public double evaluateLeftToRight (InstanceList testing, int numParticles, boolean usingResampling,
                     PrintStream docProbabilityStream) {
    random = new Randoms();

    double logNumParticles = Math.log(numParticles);
    double totalLogLikelihood = 0;
    for (Instance instance : testing) {
     
      FeatureSequence tokenSequence = (FeatureSequence) instance.getData();

      double docLogLikelihood = 0;

      double[][] particleProbabilities = new double[ numParticles ][];
      for (int particle = 0; particle < numParticles; particle++) {
        particleProbabilities[particle] =
          leftToRight(tokenSequence, usingResampling);
      }

      for (int position = 0; position < particleProbabilities[0].length; position++) {
        double sum = 0;
        for (int particle = 0; particle < numParticles; particle++) {
          sum += particleProbabilities[particle][position];
        }

        if (sum > 0.0) {
          docLogLikelihood += Math.log(sum) - logNumParticles;
        }
      }

      if (docProbabilityStream != null) {
        docProbabilityStream.println(docLogLikelihood);
      }
      totalLogLikelihood += docLogLikelihood;
    }

    return totalLogLikelihood;
  }
 
  protected double[] leftToRight (FeatureSequence tokenSequence, boolean usingResampling) {

    int[] oneDocTopics = new int[tokenSequence.getLength()];
    double[] wordProbabilities = new double[tokenSequence.getLength()];

    int[] currentTypeTopicCounts;
    int type, oldTopic, newTopic;
    double topicWeightsSum;
    int docLength = tokenSequence.getLength();

    // Keep track of the number of tokens we've examined, not
    //  including out-of-vocabulary words
    int tokensSoFar = 0;

    int[] localTopicCounts = new int[numTopics];
    int[] localTopicIndex = new int[numTopics];

    // Build an array that densely lists the topics that
    //  have non-zero counts.
    int denseIndex = 0;

    // Record the total number of non-zero topics
    int nonZeroTopics = denseIndex;

    //    Initialize the topic count/beta sampling bucket
    double topicBetaMass = 0.0;
    double topicTermMass = 0.0;

    double[] topicTermScores = new double[numTopics];
    int[] topicTermIndices;
    int[] topicTermValues;
    int i;
    double score;

    double logLikelihood = 0;

    // All counts are now zero, we are starting completely fresh.

    //  Iterate over the positions (words) in the document
    for (int limit = 0; limit < docLength; limit++) {
     
      // Record the marginal probability of the token
      //  at the current limit, summed over all topics.

      if (usingResampling) {

        // Iterate up to the current limit
        for (int position = 0; position < limit; position++) {

          type = tokenSequence.getIndexAtPosition(position);
          oldTopic = oneDocTopics[position];

          // Check for out-of-vocabulary words
          if (type >= typeTopicCounts.length ||
            typeTopicCounts[type] == null) {
            continue;
          }

          currentTypeTopicCounts = typeTopicCounts[type];
       
          //  Remove this token from all counts.
       
          // Remove this topic's contribution to the
          //  normalizing constants.
          // Note that we are using clamped estimates of P(w|t),
          //  so we are NOT changing smoothingOnlyMass.
          topicBetaMass -= beta * localTopicCounts[oldTopic] /
            (tokensPerTopic[oldTopic] + betaSum);
       
          // Decrement the local doc/topic counts
       
          localTopicCounts[oldTopic]--;
       
          // Maintain the dense index, if we are deleting
          //  the old topic
          if (localTopicCounts[oldTopic] == 0) {
         
            // First get to the dense location associated with
            //  the old topic.
         
            denseIndex = 0;
         
            // We know it's in there somewhere, so we don't
            //  need bounds checking.
            while (localTopicIndex[denseIndex] != oldTopic) {
              denseIndex++;
            }
         
            // shift all remaining dense indices to the left.
            while (denseIndex < nonZeroTopics) {
              if (denseIndex < localTopicIndex.length - 1) {
                localTopicIndex[denseIndex] =
                  localTopicIndex[denseIndex + 1];
              }
              denseIndex++;
            }
         
            nonZeroTopics --;
          }

          // Add the old topic's contribution back into the
          //  normalizing constants.
          topicBetaMass += beta * localTopicCounts[oldTopic] /
            (tokensPerTopic[oldTopic] + betaSum);

          // Reset the cached coefficient for this topic
          cachedCoefficients[oldTopic] =
            (alpha[oldTopic] + localTopicCounts[oldTopic]) /
            (tokensPerTopic[oldTopic] + betaSum);
       

          // Now go over the type/topic counts, calculating the score
          //  for each topic.
       
          int index = 0;
          int currentTopic, currentValue;
       
          boolean alreadyDecremented = false;
       
          topicTermMass = 0.0;
       
          while (index < currentTypeTopicCounts.length &&
               currentTypeTopicCounts[index] > 0) {
            currentTopic = currentTypeTopicCounts[index] & topicMask;
            currentValue = currentTypeTopicCounts[index] >> topicBits;
         
            score =
              cachedCoefficients[currentTopic] * currentValue;
            topicTermMass += score;
            topicTermScores[index] = score;
         
            index++;
          }
     
          double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
          double origSample = sample;
       
          //  Make sure it actually gets set
          newTopic = -1;
       
          if (sample < topicTermMass) {
         
            i = -1;
            while (sample > 0) {
              i++;
              sample -= topicTermScores[i];
            }
         
            newTopic = currentTypeTopicCounts[i] & topicMask;
          }
          else {
            sample -= topicTermMass;
         
            if (sample < topicBetaMass) {
              //betaTopicCount++;
           
              sample /= beta;
           
              for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
                int topic = localTopicIndex[denseIndex];
             
                sample -= localTopicCounts[topic] /
                  (tokensPerTopic[topic] + betaSum);
             
                if (sample <= 0.0) {
                  newTopic = topic;
                  break;
                }
              }
           
            }
            else {
              //smoothingOnlyCount++;
           
              sample -= topicBetaMass;
           
              sample /= beta;
           
              newTopic = 0;
              sample -= alpha[newTopic] /
                (tokensPerTopic[newTopic] + betaSum);
           
              while (sample > 0.0) {
                newTopic++;
                sample -= alpha[newTopic] /
                  (tokensPerTopic[newTopic] + betaSum);
              }
           
            }
         
          }
       
          if (newTopic == -1) {
            System.err.println("sampling error: "+ origSample + " " + sample + " " + smoothingOnlyMass + " " +
                       topicBetaMass + " " + topicTermMass);
            newTopic = numTopics-1; // TODO is this appropriate
            //throw new IllegalStateException ("WorkerRunnable: New topic not sampled.");
          }
          //assert(newTopic != -1);
       
          //      Put that new topic into the counts
          oneDocTopics[position] = newTopic;
       
          topicBetaMass -= beta * localTopicCounts[newTopic] /
            (tokensPerTopic[newTopic] + betaSum);
       
          localTopicCounts[newTopic]++;
       
          // If this is a new topic for this document,
          //  add the topic to the dense index.
          if (localTopicCounts[newTopic] == 1) {
         
            // First find the point where we
            //  should insert the new topic by going to
            //  the end (which is the only reason we're keeping
            //  track of the number of non-zero
            //  topics) and working backwards
         
            denseIndex = nonZeroTopics;
         
            while (denseIndex > 0 &&
                 localTopicIndex[denseIndex - 1] > newTopic) {
           
              localTopicIndex[denseIndex] =
                localTopicIndex[denseIndex - 1];
              denseIndex--;
            }
         
            localTopicIndex[denseIndex] = newTopic;
            nonZeroTopics++;
          }
       
          //  update the coefficients for the non-zero topics
          cachedCoefficients[newTopic] =
            (alpha[newTopic] + localTopicCounts[newTopic]) /
            (tokensPerTopic[newTopic] + betaSum);
       
          topicBetaMass += beta * localTopicCounts[newTopic] /
            (tokensPerTopic[newTopic] + betaSum);
       
        }
      }
     
      // We've just resampled all tokens UP TO the current limit,
      //  now sample the token AT the current limit.
     
      type = tokenSequence.getIndexAtPosition(limit);       

      // Check for out-of-vocabulary words
      if (type >= typeTopicCounts.length ||
        typeTopicCounts[type] == null) {
        continue;
      }

      currentTypeTopicCounts = typeTopicCounts[type];

      int index = 0;
      int currentTopic, currentValue;
     
      topicTermMass = 0.0;
     
      while (index < currentTypeTopicCounts.length &&
           currentTypeTopicCounts[index] > 0) {
        currentTopic = currentTypeTopicCounts[index] & topicMask;
        currentValue = currentTypeTopicCounts[index] >> topicBits;
       
        score =
          cachedCoefficients[currentTopic] * currentValue;
        topicTermMass += score;
        topicTermScores[index] = score;
       
        //System.out.println("  " + currentTopic + " = " + currentValue);

        index++;
      }
     
      /* // Debugging, to make sure we're getting the right probabilities
         for (int topic = 0; topic < numTopics; topic++) {
         index = 0;
         int displayCount = 0;
       
         while (index < currentTypeTopicCounts.length &&
         currentTypeTopicCounts[index] > 0) {
         currentTopic = currentTypeTopicCounts[index] & topicMask;
         currentValue = currentTypeTopicCounts[index] >> topicBits;

         if (currentTopic == topic) {
         displayCount = currentValue;
         break;
         }

         index++;
         }
       
         System.out.print(topic + "\t");
         System.out.print("(" + localTopicCounts[topic] + " + " + alpha[topic] + ") / " +
         "(" + alphaSum + " + " + tokensSoFar + ") * ");

         System.out.println("(" + displayCount + " + " + beta + ") / " +
         "(" + tokensPerTopic[topic] + " + " + betaSum + ") =" +
         ((displayCount + beta) / (tokensPerTopic[topic] + betaSum)));


         }
      */
     
      double sample = random.nextUniform() * (smoothingOnlyMass + topicBetaMass + topicTermMass);
      double origSample = sample;
       
      // Note that we've been absorbing (alphaSum + docLength) into
      //  the normalizing constant. The true marginal probability needs
      //  this term, so we stick it back in.
      wordProbabilities[limit] +=
        (smoothingOnlyMass + topicBetaMass + topicTermMass) /
        (alphaSum + tokensSoFar);

      //System.out.println("normalizer: " + alphaSum + " + " + tokensSoFar);
      tokensSoFar++;

      //  Make sure it actually gets set
      newTopic = -1;
       
      if (sample < topicTermMass) {
         
        i = -1;
        while (sample > 0) {
          i++;
          sample -= topicTermScores[i];
        }
         
        newTopic = currentTypeTopicCounts[i] & topicMask;
      }
      else {
        sample -= topicTermMass;
         
        if (sample < topicBetaMass) {
          //betaTopicCount++;
           
          sample /= beta;
           
          for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
            int topic = localTopicIndex[denseIndex];
             
            sample -= localTopicCounts[topic] /
              (tokensPerTopic[topic] + betaSum);
             
            if (sample <= 0.0) {
              newTopic = topic;
              break;
            }
          }
           
        }
        else {
          //smoothingOnlyCount++;
           
          sample -= topicBetaMass;
           
          sample /= beta;
           
          newTopic = 0;
          sample -= alpha[newTopic] /
            (tokensPerTopic[newTopic] + betaSum);
           
          while (sample > 0.0) {
            newTopic++;
            sample -= alpha[newTopic] /
              (tokensPerTopic[newTopic] + betaSum);
          }
           
        }
         
      }
       
      if (newTopic == -1) {
        System.err.println("sampling error: "+ origSample + " " +
                   sample + " " + smoothingOnlyMass + " " +
                   topicBetaMass + " " + topicTermMass);
        newTopic = numTopics-1; // TODO is this appropriate
      }
       
      // Put that new topic into the counts
      oneDocTopics[limit] = newTopic;
       
      topicBetaMass -= beta * localTopicCounts[newTopic] /
        (tokensPerTopic[newTopic] + betaSum);
       
      localTopicCounts[newTopic]++;
       
      // If this is a new topic for this document,
      //  add the topic to the dense index.
      if (localTopicCounts[newTopic] == 1) {
         
        // First find the point where we
        //  should insert the new topic by going to
        //  the end (which is the only reason we're keeping
        //  track of the number of non-zero
        //  topics) and working backwards
         
        denseIndex = nonZeroTopics;
         
        while (denseIndex > 0 &&
             localTopicIndex[denseIndex - 1] > newTopic) {
           
          localTopicIndex[denseIndex] =
            localTopicIndex[denseIndex - 1];
          denseIndex--;
        }
         
        localTopicIndex[denseIndex] = newTopic;
        nonZeroTopics++;
      }
     
      //  update the coefficients for the non-zero topics
      cachedCoefficients[newTopic] =
        (alpha[newTopic] + localTopicCounts[newTopic]) /
        (tokensPerTopic[newTopic] + betaSum);
       
      topicBetaMass += beta * localTopicCounts[newTopic] /
        (tokensPerTopic[newTopic] + betaSum);
     
      //System.out.println(type + "\t" + newTopic + "\t" + logLikelihood);
     
    }

    //  Clean up our mess: reset the coefficients to values with only
    //  smoothing. The next doc will update its own non-zero topics...

    for (denseIndex = 0; denseIndex < nonZeroTopics; denseIndex++) {
      int topic = localTopicIndex[denseIndex];

      cachedCoefficients[topic] =
        alpha[topic] / (tokensPerTopic[topic] + betaSum);
    }

    return wordProbabilities;

  }

  private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    private void writeObject (ObjectOutputStream out) throws IOException {
        out.writeInt (CURRENT_SERIAL_VERSION);

        out.writeInt(numTopics);

        out.writeInt(topicMask);
        out.writeInt(topicBits);

        out.writeObject(alpha);
        out.writeDouble(alphaSum);
        out.writeDouble(beta);
    out.writeDouble(betaSum);

    out.writeObject(typeTopicCounts);
    out.writeObject(tokensPerTopic);

    out.writeObject(random);

        out.writeDouble(smoothingOnlyMass);
    out.writeObject(cachedCoefficients);
    }

  private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {

        int version = in.readInt ();

        numTopics = in.readInt();

        topicMask = in.readInt();
        topicBits = in.readInt();

        alpha = (double[]) in.readObject();
    alphaSum = in.readDouble();
        beta = in.readDouble();
        betaSum = in.readDouble();

        typeTopicCounts = (int[][]) in.readObject();
        tokensPerTopic = (int[]) in.readObject();

        random = (Randoms) in.readObject();

        smoothingOnlyMass = in.readDouble();
        cachedCoefficients = (double[]) in.readObject();
    }

    public static MarginalProbEstimator read (File f) throws Exception {

        MarginalProbEstimator estimator = null;

        ObjectInputStream ois = new ObjectInputStream (new FileInputStream(f));
        estimator = (MarginalProbEstimator) ois.readObject();
        ois.close();

        return estimator;
    }


}
TOP

Related Classes of cc.mallet.topics.MarginalProbEstimator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.