Package org.apache.mahout.utils

Source Code of org.apache.mahout.utils.SplitInputJob

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.utils;

import java.io.IOException;
import java.io.Serializable;
import java.util.Random;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;

/**
* Class which implements a map reduce version of SplitInput.
* This class takes a SequenceFile input, e.g. a set of training data
* for a learning algorithm, downsamples it, applies a random
* permutation and splits it into test and training sets
*/
public final class SplitInputJob {

  private static final String DOWNSAMPLING_FACTOR = "SplitInputJob.downsamplingFactor";
  private static final String RANDOM_SELECTION_PCT = "SplitInputJob.randomSelectionPct";
  private static final String TRAINING_TAG = "training";
  private static final String TEST_TAG = "test";

  private SplitInputJob() {}

  /**
   * Run job to downsample, randomly permute and split data into test and
   * training sets. This job takes a SequenceFile as input and outputs two
   * SequenceFiles test-r-00000 and training-r-00000 which contain the test and
   * training sets respectively
   *
   * @param initialConf
   * @param inputPath
   *          path to input data SequenceFile
   * @param outputPath
   *          path for output data SequenceFiles
   * @param keepPct
   *          percentage of key value pairs in input to keep. The rest are
   *          discarded
   * @param randomSelectionPercent
   *          percentage of key value pairs to allocate to test set. Remainder
   *          are allocated to training set
   */
  @SuppressWarnings("rawtypes")
  public static void run(Configuration initialConf, Path inputPath,
      Path outputPath, int keepPct, float randomSelectionPercent)
    throws IOException, ClassNotFoundException, InterruptedException {

    int downsamplingFactor = (int) (100.0 / keepPct);
    initialConf.setInt(DOWNSAMPLING_FACTOR, downsamplingFactor);
    initialConf.setFloat(RANDOM_SELECTION_PCT, randomSelectionPercent);

    // Determine class of keys and values
    FileSystem fs = FileSystem.get(initialConf);

    SequenceFileDirIterator<? extends WritableComparable, Writable> iterator =
        new SequenceFileDirIterator<WritableComparable, Writable>(inputPath,
            PathType.LIST, PathFilters.partFilter(), null, false, fs.getConf());
    Class<? extends WritableComparable> keyClass;
    Class<? extends Writable> valueClass;
    if (iterator.hasNext()) {
      Pair<? extends WritableComparable, Writable> pair = iterator.next();
      keyClass = pair.getFirst().getClass();
      valueClass = pair.getSecond().getClass();
    } else {
      throw new IllegalStateException("Couldn't determine class of the input values");
    }

    Job job = new Job(new Configuration(initialConf));

    MultipleOutputs.addNamedOutput(job, TRAINING_TAG, SequenceFileOutputFormat.class, keyClass, valueClass);
    MultipleOutputs.addNamedOutput(job, TEST_TAG, SequenceFileOutputFormat.class, keyClass, valueClass);
    job.setJarByClass(SplitInputJob.class);
    FileInputFormat.addInputPath(job, inputPath);
    FileOutputFormat.setOutputPath(job, outputPath);
    job.setNumReduceTasks(1);
    job.setInputFormatClass(SequenceFileInputFormat.class);
    job.setOutputFormatClass(SequenceFileOutputFormat.class);
    job.setMapperClass(SplitInputMapper.class);
    job.setReducerClass(SplitInputReducer.class);
    job.setSortComparatorClass(SplitInputComparator.class);
    job.setOutputKeyClass(keyClass);
    job.setOutputValueClass(valueClass);
    job.submit();
    boolean succeeded = job.waitForCompletion(true);
    if (!succeeded) {
      throw new IllegalStateException("Job failed!");
    }
  }

  /** Mapper which downsamples the input by downsamplingFactor */
  public static class SplitInputMapper extends
      Mapper<WritableComparable<?>, Writable, WritableComparable<?>, Writable> {

    private int downsamplingFactor;

    @Override
    public void setup(Context ctx) {
      downsamplingFactor = ctx.getConfiguration().getInt(DOWNSAMPLING_FACTOR, 1);
    }

    /** Only run map() for one out of every downsampleFactor inputs */
    @Override
    public void run(Context context) throws IOException, InterruptedException {
      setup(context);
      int i = 0;
      while (context.nextKeyValue()) {
        if (i % downsamplingFactor == 0) {
          map(context.getCurrentKey(), context.getCurrentValue(), context);
        }
        i++;
      }
      cleanup(context);
    }

  }

  /** Reducer which uses MultipleOutputs to randomly allocate key value pairs between test and training outputs */
  public static class SplitInputReducer extends
      Reducer<WritableComparable<?>, Writable, WritableComparable<?>, Writable> {

    private MultipleOutputs multipleOutputs;
    private final Random rnd = RandomUtils.getRandom();
    private float randomSelectionPercent;

    @Override
    protected void setup(Context ctx) throws IOException {
      randomSelectionPercent = ctx.getConfiguration().getFloat(RANDOM_SELECTION_PCT, 0);
      multipleOutputs = new MultipleOutputs(ctx);
    }

    /**
     * Randomly allocate key value pairs between test and training sets.
     * randomSelectionPercent of the pairs will go to the test set.
     */
    @Override
    protected void reduce(WritableComparable<?> key, Iterable<Writable> values,
        Context context) throws IOException, InterruptedException {
      for (Writable value : values) {
        if (rnd.nextInt(100) < randomSelectionPercent) {
          multipleOutputs.write(TEST_TAG, key, value);
        } else {
          multipleOutputs.write(TRAINING_TAG, key, value);
        }
      }

    }

    @Override
    protected void cleanup(Context context) throws IOException {
      try {
        multipleOutputs.close();
      } catch (InterruptedException e) {
        throw new IOException(e);
      }
    }

  }

  /** Randomly permute key value pairs */
  public static class SplitInputComparator extends WritableComparator implements Serializable {

    private final Random rnd = RandomUtils.getRandom();

    protected SplitInputComparator() {
      super(WritableComparable.class);
    }

    @Override
    public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
      if (rnd.nextBoolean()) {
        return 1;
      } else {
        return -1;
      }
    }
  }

}
TOP

Related Classes of org.apache.mahout.utils.SplitInputJob

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.