Package cc.mrlda.polylda

Source Code of cc.mrlda.polylda.DocumentMapper

package cc.mrlda.polylda;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;

import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.lib.MultipleOutputs;

import cc.mrlda.InformedPrior;
import cc.mrlda.polylda.VariationalInference.ParameterCounter;

import com.google.common.base.Preconditions;

import edu.umd.cloud9.io.triple.TripleOfInts;
import edu.umd.cloud9.math.Gamma;
import edu.umd.cloud9.util.map.HMapII;
import edu.umd.cloud9.util.map.HMapIV;

public class DocumentMapper extends MapReduceBase implements
    Mapper<IntWritable, Document, TripleOfInts, DoubleWritable> {
  private long configurationTime = 0;
  private long trainingTime = 0;

  private static HMapIV<double[]>[] beta = null;

  private static double[] alpha = null;
  private static double likelihoodAlpha = 0;

  private static int numberOfTopics = 0;
  private static int numberOfLanguages = 0;
  private static int[] numberOfTerms = null;

  private static int maximumGammaIteration = Settings.MAXIMUM_LOCAL_ITERATION;

  private static boolean learning = Settings.LEARNING_MODE;
  private static boolean randomStartGamma = Settings.RANDOM_START_GAMMA;

  private TripleOfInts outputKey = new TripleOfInts();
  private DoubleWritable outputValue = new DoubleWritable();

  // boolean seededGamma = false;
  // boolean seededAlpha = false;

  private static MultipleOutputs multipleOutputs;
  private static OutputCollector<IntWritable, Document> outputDocument;

  private double[] tempBeta = null;

  private double[] tempGamma = null;
  private double[] updateGamma = null;

  private double[] phi = null;
  private HashMap<Integer, double[]>[] phiTable = null;

  private Iterator<Integer> itr = null;
  private HMapII hmap = null;

  public void configure(JobConf conf) {
    configurationTime = System.currentTimeMillis();

    numberOfTopics = conf.getInt(Settings.PROPERTY_PREFIX + "model.topics", 0);
    numberOfLanguages = conf.getInt(Settings.PROPERTY_PREFIX + "model.languages", 0);
    beta = new HMapIV[numberOfLanguages];

    numberOfTerms = new int[numberOfLanguages];
    // language index starts from 1
    for (int languageIndex = 0; languageIndex < numberOfLanguages; languageIndex++) {
      numberOfTerms[languageIndex] = conf.getInt(Settings.PROPERTY_PREFIX + "corpus.terms"
          + Settings.DOT + (languageIndex + 1), Integer.MAX_VALUE);
    }
    maximumGammaIteration = conf.getInt(Settings.PROPERTY_PREFIX
        + "model.mapper.converge.iteration", Settings.MAXIMUM_LOCAL_ITERATION);

    learning = conf.getBoolean(Settings.PROPERTY_PREFIX + "model.train", Settings.LEARNING_MODE);
    randomStartGamma = conf.getBoolean(Settings.PROPERTY_PREFIX + "model.random.start",
        Settings.RANDOM_START_GAMMA);

    multipleOutputs = new MultipleOutputs(conf);

    updateGamma = new double[numberOfTopics];
    phiTable = new HashMap[numberOfLanguages];
    for (int languageIndex = 0; languageIndex < numberOfLanguages; languageIndex++) {
      phiTable[languageIndex] = new HashMap<Integer, double[]>();
    }

    double alphaSum = 0;

    Path[] inputFiles;
    SequenceFile.Reader sequenceFileReader = null;
    try {
      inputFiles = DistributedCache.getLocalCacheFiles(conf);
      if (inputFiles != null) {
        for (Path path : inputFiles) {
          try {
            sequenceFileReader = new SequenceFile.Reader(FileSystem.getLocal(conf), path, conf);

            if (path.getName().startsWith(Settings.BETA)) {
              // Settings.BETA + Settings.LEFT_BRACKET + languageIndex + Settings.RIGHT_BRACKET
              String fileName = path.getName();
              int languageIndex = Integer.parseInt(fileName.substring(
                  fileName.indexOf(Settings.LANGUAGE_INDICATOR)
                      + Settings.LANGUAGE_INDICATOR.length(), fileName.indexOf(Settings.DASH)));

              // TODO: check whether seeded beta is valid, i.e., a true probability distribution
              // language index starts from 1
              Preconditions.checkArgument(beta[languageIndex - 1] == null,
                  "Beta matrix was initialized already...");
              beta[languageIndex - 1] = cc.mrlda.DocumentMapper.importBeta(sequenceFileReader,
                  numberOfTopics, numberOfTerms[languageIndex - 1]);

              // System.out.println(beta[languageIndex - 1] + " ");
            } else if (path.getName().startsWith(Settings.ALPHA)) {
              Preconditions.checkArgument(alpha == null, "Alpha vector was initialized already...");

              // TODO: check the validity of alpha
              alpha = cc.mrlda.VariationalInference.importAlpha(sequenceFileReader, numberOfTopics);
              double sumLnGammaAlpha = 0;
              for (double value : alpha) {
                sumLnGammaAlpha += Gamma.lngamma(value);
                alphaSum += value;
              }
              likelihoodAlpha = Gamma.lngamma(alphaSum) - sumLnGammaAlpha;
            } else if (path.getName().startsWith(InformedPrior.ETA)) {
              // beta = parseEta(sequenceFileReader, numberOfTopics);
              continue;
            } else {
              throw new IllegalArgumentException("Unexpected file in distributed cache: "
                  + path.getName());
            }
          } catch (IllegalArgumentException iae) {
            iae.printStackTrace();
          } catch (IOException ioe) {
            ioe.printStackTrace();
          } finally {
            IOUtils.closeStream(sequenceFileReader);
          }
        }
      }
    } catch (IOException ioe) {
      ioe.printStackTrace();
    }

    for (int languageIndex = 0; languageIndex < numberOfLanguages; languageIndex++) {
      if (beta[languageIndex] == null) {
        beta[languageIndex] = new HMapIV<double[]>();
      }
    }

    if (alpha == null) {
      alpha = new double[numberOfTopics];
      double alphaLnGammaSum = 0;
      for (int i = 0; i < numberOfTopics; i++) {
        alpha[i] = Math.random();
        alphaSum += alpha[i];
        alphaLnGammaSum += Gamma.lngamma(alpha[i]);
      }
      likelihoodAlpha = Gamma.lngamma(alphaSum) - alphaLnGammaSum;
    }

    // System.out.println("======================================================================");
    // System.out.println("Available processors (cores): "
    // + Runtime.getRuntime().availableProcessors());
    // long maxMemory = Runtime.getRuntime().maxMemory();
    // System.out.println("Maximum memory (bytes): "
    // + (maxMemory == Long.MAX_VALUE ? "no limit" : maxMemory));
    // System.out.println("Free memory (bytes): " + Runtime.getRuntime().freeMemory());
    // System.out.println("Total memory (bytes): " + Runtime.getRuntime().totalMemory());
    // System.out.println("======================================================================");

    configurationTime = System.currentTimeMillis() - configurationTime;
  }

  @SuppressWarnings("deprecation")
  public void map(IntWritable key, Document value,
      OutputCollector<TripleOfInts, DoubleWritable> output, Reporter reporter) throws IOException {
    reporter.incrCounter(ParameterCounter.CONFIG_TIME, configurationTime);
    reporter.incrCounter(ParameterCounter.TOTAL_DOC, 1);
    trainingTime = System.currentTimeMillis();

    double likelihoodPhi = 0;

    // initialize tempGamma for computing
    if (value.getGamma() != null && value.getNumberOfTopics() == numberOfTopics
        && !randomStartGamma) {
      // TODO: set up mechanisms to prevent starting from some irrelevant gamma value
      tempGamma = value.getGamma();
    } else {
      int totalNumberOfWords = 0;
      for (int i : value.getNumberOfWords()) {
        totalNumberOfWords += i;
      }

      tempGamma = new double[numberOfTopics];
      for (int i = 0; i < numberOfTopics; i++) {
        tempGamma[i] = alpha[i] + 1.0f * totalNumberOfWords / numberOfTopics;
      }
    }

    double[] phi = null;

    int gammaUpdateIterationCount = 1;
    do {
      likelihoodPhi = 0;

      for (int i = 0; i < numberOfTopics; i++) {
        tempGamma[i] = Gamma.digamma(tempGamma[i]);
        updateGamma[i] = Math.log(alpha[i]);
      }

      for (int languageIndex = 0; languageIndex < numberOfLanguages; languageIndex++) {
        hmap = value.getContent(languageIndex);
        if (hmap == null) {
          continue;
        }

        itr = value.getContent(languageIndex).keySet().iterator();
        while (itr.hasNext()) {
          int termID = itr.next();
          int termCounts = hmap.get(termID);

          if (phiTable[languageIndex].containsKey(termID)) {
            // reuse existing object
            phi = phiTable[languageIndex].get(termID);
          } else {
            phi = new double[numberOfTopics];
            phiTable[languageIndex].put(termID, phi);
          }

          // acquire the corresponding beta vector for this term
          tempBeta = cc.mrlda.DocumentMapper.retrieveBeta(numberOfTopics, beta[languageIndex],
              termID, numberOfTerms[languageIndex]);

          likelihoodPhi += cc.mrlda.DocumentMapper.updatePhi(numberOfTopics, termCounts, tempBeta,
              tempGamma, phi, updateGamma);
        }
      }

      // send out heart beat message
      if (Math.random() < 0.01) {
        reporter.incrCounter(ParameterCounter.DUMMY_COUNTER, 1);
      }

      for (int i = 0; i < numberOfTopics; i++) {
        tempGamma[i] = Math.exp(updateGamma[i]);
      }

      gammaUpdateIterationCount++;
    } while (gammaUpdateIterationCount < maximumGammaIteration);

    // compute the sum of gamma vector
    double sumGamma = 0;
    double likelihoodGamma = 0;
    for (int i = 0; i < numberOfTopics; i++) {
      sumGamma += tempGamma[i];
      likelihoodGamma += Gamma.lngamma(tempGamma[i]);
    }
    likelihoodGamma -= Gamma.lngamma(sumGamma);
    double documentLogLikelihood = likelihoodAlpha + likelihoodGamma + likelihoodPhi;
    reporter.incrCounter(ParameterCounter.LOG_LIKELIHOOD,
        (long) (-documentLogLikelihood * Settings.DEFAULT_COUNTER_SCALE));

    double digammaSumGamma = Gamma.digamma(sumGamma);

    if (learning) {
      for (int languageIndex = 0; languageIndex < numberOfLanguages; languageIndex++) {
        hmap = value.getContent(languageIndex);
        if (hmap == null) {
          continue;
        }

        // emit the phi counts for each languages
        itr = hmap.keySet().iterator();
        while (itr.hasNext()) {
          int termID = itr.next();
          phi = phiTable[languageIndex].get(termID);
          for (int i = 0; i < numberOfTopics; i++) {
            // emit phi values.
            outputKey.set(languageIndex + 1, i + 1, termID);
            outputValue.set(phi[i]);
            // System.out.println("phi\t" + outputKey + "\t" + outputValue);
            output.collect(outputKey, outputValue);
          }
        }
      }

      // emit the alpha sufficient statistics
      for (int i = 0; i < numberOfTopics; i++) {
        outputKey.set(0, i + 1, 0);
        outputValue.set(Gamma.digamma(updateGamma[i]) - digammaSumGamma);
        // System.out.println("alpha\t" + outputKey + "\t" + outputValue);
        output.collect(outputKey, outputValue);
      }
    }

    // output the embedded updated gamma together with document
    if (!learning || !randomStartGamma) {
      outputDocument = multipleOutputs.getCollector(Settings.GAMMA, Settings.GAMMA, reporter);
      value.setGamma(tempGamma);
      outputDocument.collect(key, value);
    }

    trainingTime = System.currentTimeMillis() - trainingTime;
    reporter.incrCounter(ParameterCounter.TRAINING_TIME, trainingTime);
  }

  public void close() throws IOException {
    multipleOutputs.close();
  }
}
TOP

Related Classes of cc.mrlda.polylda.DocumentMapper

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.