Package org.apache.mahout.clustering.lda

Source Code of org.apache.mahout.clustering.lda.LDAWordTopicMapper

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.clustering.lda;

import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/**
* Runs inference on the input documents (which are sparse vectors of word counts) and outputs the sufficient
* statistics for the word-topic assignments.
*/
public class LDAWordTopicMapper extends Mapper<WritableComparable<?>,VectorWritable,IntPairWritable,DoubleWritable> {
 
  private LDAState state;
  private LDAInference infer;
 
  @Override
  protected void map(WritableComparable<?> key,
                     VectorWritable wordCountsWritable,
                     Context context) throws IOException, InterruptedException {
    Vector wordCounts = wordCountsWritable.get();
    LDAInference.InferredDocument doc;
    try {
      doc = infer.infer(wordCounts);
    } catch (ArrayIndexOutOfBoundsException e1) {
      throw new IllegalStateException(
         "This is probably because the --numWords argument is set too small.  \n"
         + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
         + "\tlarger if some storage inefficiency can be tolerated.", e1);
    }
   
    double[] logTotals = new double[state.getNumTopics()];
    Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
   
    // Output sufficient statistics for each word. == pseudo-log counts.
    DoubleWritable v = new DoubleWritable();
    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
      Vector.Element e = iter.next();
      int w = e.index();
     
      for (int k = 0; k < state.getNumTopics(); ++k) {
        v.set(doc.phi(k, w) + Math.log(e.get()));
       
        IntPairWritable kw = new IntPairWritable(k, w);
       
        // output (topic, word)'s logProb contribution
        context.write(kw, v);
        logTotals[k] = LDAUtil.logSum(logTotals[k], v.get());
      }
    }

    // Output the totals for the statistics. This is to make
    // normalizing a lot easier.
    for (int k = 0; k < state.getNumTopics(); ++k) {
      IntPairWritable kw = new IntPairWritable(k, LDADriver.TOPIC_SUM_KEY);
      v.set(logTotals[k]);
      assert !Double.isNaN(v.get());
      context.write(kw, v);
    }
    IntPairWritable llk = new IntPairWritable(LDADriver.LOG_LIKELIHOOD_KEY, LDADriver.LOG_LIKELIHOOD_KEY);
    // Output log-likelihoods.
    v.set(doc.getLogLikelihood());
    context.write(llk, v);
  }
 
  public void configure(LDAState myState) {
    this.state = myState;
    this.infer = new LDAInference(state);
  }
 
  public void configure(Configuration job) {
    LDAState myState = LDADriver.createState(job);
    configure(myState);
  }
 
  @Override
  protected void setup(Context context) {
    configure(context.getConfiguration());
  }
 
}
TOP

Related Classes of org.apache.mahout.clustering.lda.LDAWordTopicMapper

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.