Package cc.mrlda

Source Code of cc.mrlda.TermReducer

package cc.mrlda;

import java.io.IOException;
import java.util.Iterator;
import java.util.Set;

import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.lib.MultipleOutputs;

import cc.mrlda.VariationalInference.ParameterCounter;

import com.google.common.base.Preconditions;

import edu.umd.cloud9.io.map.HMapIDW;
import edu.umd.cloud9.io.pair.PairOfIntFloat;
import edu.umd.cloud9.io.pair.PairOfInts;
import edu.umd.cloud9.math.Gamma;
import edu.umd.cloud9.math.LogMath;
import edu.umd.cloud9.util.map.HMapIV;

public class TermReducer extends MapReduceBase implements
    Reducer<PairOfInts, DoubleWritable, IntWritable, DoubleWritable> {
  // boolean approximateBeta = false;
  // boolean truncateBeta = false;
  // int truncationSize = 10000;
  // TreeMap<Double, Integer> treeMap = new TreeMap<Double, Integer>();
  // Iterator<Entry<Double, Integer>> itr = null;

  private static HMapIV<Set<Integer>> lambdaMap = null;

  private static boolean learning = Settings.LEARNING_MODE;
  // private static int numberOfTerms = 0;

  private int topicIndex = 0;
  private double logNormalizeFactor = 0;

  private MultipleOutputs multipleOutputs;
  private OutputCollector<PairOfIntFloat, HMapIDW> outputBeta;
  // private OutputCollector<PairOfIntFloat, ProbDist> outputBeta;
  // private OutputCollector<PairOfIntFloat, HashMap> outputBeta;
  // private OutputCollector<PairOfIntFloat, BloomMap> outputBeta;

  private IntWritable intWritable = new IntWritable();
  private DoubleWritable doubleWritable = new DoubleWritable();

  private PairOfIntFloat outputKey = new PairOfIntFloat();

  // private HashMap outputValue = null;
  // private BloomMap outputValue = null;
  // private ProbDist outputValue = null;
  private HMapIDW outputValue = null;

  public void configure(JobConf conf) {
    multipleOutputs = new MultipleOutputs(conf);

    learning = conf.getBoolean(Settings.PROPERTY_PREFIX + "model.train", Settings.LEARNING_MODE);

    // truncateBeta = conf.getBoolean(Settings.PROPERTY_PREFIX + "model.truncate.beta", false);

    // outputValue = new HashMap();
    outputValue = new HMapIDW();

    // approximateBeta = conf.getBoolean(Settings.PROPERTY_PREFIX + "model.truncate.beta", false);
    // if (!approximateBeta) {
    // outputValue = new HashMap();
    // } else {
    // outputValue = new BloomMap(numberOfTerms * 100, 3, Hash.JENKINS_HASH);
    // }

    boolean informedPrior = conf.getBoolean(Settings.PROPERTY_PREFIX + "model.informed.prior",
        false);

    Path[] inputFiles;
    SequenceFile.Reader sequenceFileReader = null;

    try {
      inputFiles = DistributedCache.getLocalCacheFiles(conf);

      if (inputFiles != null) {
        for (Path path : inputFiles) {
          try {
            sequenceFileReader = new SequenceFile.Reader(FileSystem.getLocal(conf), path, conf);

            if (path.getName().startsWith(Settings.ALPHA)) {
              continue;
            } else if (path.getName().startsWith(Settings.BETA)) {
              continue;
            } else if (path.getName().startsWith(InformedPrior.ETA)) {
              Preconditions.checkArgument(lambdaMap == null,
                  "Lambda matrix was initialized already...");
              lambdaMap = InformedPrior.importEta(sequenceFileReader);
            } else {
              throw new IllegalArgumentException("Unexpected file in distributed cache: "
                  + path.getName());
            }
          } catch (IllegalArgumentException iae) {
            iae.printStackTrace();
          } catch (IOException ioe) {
            ioe.printStackTrace();
          } finally {
            IOUtils.closeStream(sequenceFileReader);
          }
        }
      }
    } catch (IOException ioe) {
      ioe.printStackTrace();
    }

    Preconditions.checkArgument(informedPrior == (lambdaMap != null),
        "Fail to initialize informed prior...");

    // System.out.println("======================================================================");
    // System.out.println("Available processors (cores): " +
    // Runtime.getRuntime().availableProcessors());
    // long maxMemory = Runtime.getRuntime().maxMemory();
    // System.out.println("Maximum memory (bytes): " + (maxMemory == Long.MAX_VALUE ? "no limit" :
    // maxMemory));
    // System.out.println("Free memory (bytes): " + Runtime.getRuntime().freeMemory());
    // System.out.println("Total memory (bytes): " + Runtime.getRuntime().totalMemory());
    // System.out.println("======================================================================");
  }

  public void reduce(PairOfInts key, Iterator<DoubleWritable> values,
      OutputCollector<IntWritable, DoubleWritable> output, Reporter reporter) throws IOException {
    // if this value is the sufficient statistics for alpha updating
    if (key.getLeftElement() == 0) {
      double sum = values.next().get();
      while (values.hasNext()) {
        sum += values.next().get();
      }

      Preconditions.checkArgument(key.getRightElement() > 0,
          "Unexpected sequence order for alpha sufficient statistics: " + key.toString());

      intWritable.set(key.getRightElement());
      doubleWritable.set(sum);
      output.collect(intWritable, doubleWritable);

      return;
    }

    // I would be very surprised to get here...
    Preconditions.checkArgument(learning, "Invalid key from Mapper");
    reporter.incrCounter(ParameterCounter.TOTAL_TERMS, 1);

    double logPhiValue = values.next().get();
    while (values.hasNext()) {
      logPhiValue = LogMath.add(logPhiValue, values.next().get());
    }

    if (lambdaMap != null) {
      logPhiValue = LogMath.add(
          InformedPrior.getLogEta(key.getRightElement(), lambdaMap.get(topicIndex)), logPhiValue);
    } else {
      logPhiValue = LogMath.add(Settings.DEFAULT_LOG_ETA, logPhiValue);
    }

    if (topicIndex != key.getLeftElement()) {
      if (topicIndex == 0) {
        outputBeta = multipleOutputs.getCollector(Settings.BETA, Settings.BETA, reporter);
      } else {
        outputKey.set(topicIndex, (float) Gamma.digamma(Math.exp(logNormalizeFactor)));

        // if (truncateBeta) {
        // itr = treeMap.entrySet().iterator();
        // Entry<Double, Integer> temp = null;
        // outputValue.clear();
        // while (itr.hasNext()) {
        // temp = itr.next();
        // outputValue.put(temp.getValue(), temp.getKey());
        // }
        // }

        outputBeta.collect(outputKey, outputValue);
      }

      topicIndex = key.getLeftElement();
      logNormalizeFactor = logPhiValue;
      // if (truncateBeta) {
      // treeMap.clear();
      // treeMap.put(phiValue, key.getRightElement());
      // } else {
      outputValue.clear();
      outputValue.put(key.getRightElement(), Gamma.digamma(Math.exp(logPhiValue)));
      // }
    } else {
      // if (truncateBeta) {
      // if (treeMap.size() >= truncationSize) {
      // if (treeMap.firstKey() < phiValue) {
      // normalizeFactor = Math.log(Math.exp(normalizeFactor) - Math.exp(treeMap.firstKey()));
      // treeMap.remove(treeMap.firstKey());
      //
      // treeMap.put(phiValue, key.getRightElement());
      // normalizeFactor = LogMath.add(normalizeFactor, phiValue);
      // }
      // } else {
      // treeMap.put(phiValue, key.getRightElement());
      // normalizeFactor = LogMath.add(normalizeFactor, phiValue);
      // }
      // } else {
      logNormalizeFactor = LogMath.add(logNormalizeFactor, logPhiValue);
      outputValue.put(key.getRightElement(), Gamma.digamma(Math.exp(logPhiValue)));
      // }
    }
  }

  public void close() throws IOException {
    // if (truncateBeta) {
    // if (!treeMap.isEmpty()) {
    // outputKey.set(topicIndex, (float) normalizeFactor);
    // itr = treeMap.entrySet().iterator();
    // Entry<Double, Integer> temp = null;
    // outputValue.clear();
    // while (itr.hasNext()) {
    // temp = itr.next();
    // outputValue.put(temp.getValue(), temp.getKey().floatValue());
    // }
    // outputBeta.collect(outputKey, outputValue);
    // }
    // } else {
    if (!outputValue.isEmpty()) {
      outputKey.set(topicIndex, (float) Gamma.digamma(Math.exp(logNormalizeFactor)));
      outputBeta.collect(outputKey, outputValue);
    }
    // }
    multipleOutputs.close();
  }
}
TOP

Related Classes of cc.mrlda.TermReducer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.