Package org.apache.mahout.df.mapreduce.partial

Source Code of org.apache.mahout.df.mapreduce.partial.PartitionBugTest$MockTreeBuilder

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.df.mapreduce.partial;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Locale;
import java.util.Random;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.builder.TreeBuilder;
import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.data.Utils;
import org.apache.mahout.df.mapreduce.Builder;
import org.apache.mahout.df.node.Node;
import org.junit.Test;

public final class PartitionBugTest extends MahoutTestCase {
  static final int NUM_ATTRIBUTES = 40;

  static final int NUM_INSTANCES = 200;

  static final int NUM_TREES = 10;

  static final int NUM_MAPS = 5;
   
  /**
   * Make sure that the correct instance ids are being computed
   */
  @Test
  public void testProcessOutput() throws Exception {
    Random rng = RandomUtils.getRandom();
    //long seed = rng.nextLong();

    // create a dataset large enough to be split up
    String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES);
    double[][] source = Utils.randomDoubles(rng, descriptor, NUM_INSTANCES);

    // each instance label is its index in the dataset
    int labelId = Utils.findLabel(descriptor);
    for (int index = 0; index < NUM_INSTANCES; index++) {
      source[index][labelId] = index;
    }

    // store the data into a file
    String[] sData = Utils.double2String(source);
    Path dataPath = Utils.writeDataToTestFile(sData);
    Dataset dataset = DataLoader.generateDataset(descriptor, sData);
    Data data = DataLoader.loadData(dataset, sData);

    Configuration conf = new Configuration();
    Step0JobTest.setMaxSplitSize(conf, dataPath, NUM_MAPS);

    // prepare a custom TreeBuilder that will classify each
    // instance with its own label (in this case its index in the dataset)
    TreeBuilder treeBuilder = new MockTreeBuilder();
   
    // disable the second step because we can test without it
    // and we won't be able to serialize the MockNode
    PartialBuilder.setStep2(conf, false);
    long seed = 1L;
    Builder builder = new PartialSequentialBuilder(treeBuilder, dataPath, dataset, seed, conf);

    // remove the output path (its only used for testing)
    Path outputPath = builder.getOutputPath(conf);
    HadoopUtil.delete(conf, outputPath);

    builder.build(NUM_TREES, new MockCallback(data));
  }

  /**
   * Asserts that the instanceId are correct
   *
   */
  private static class MockCallback implements PredictionCallback {
    private final Data data;

    private MockCallback(Data data) {
      this.data = data;
    }

    @Override
    public void prediction(int treeId, int instanceId, int prediction) {
      // because of the bagging, prediction can be -1
      if (prediction == -1) {
        return;
      }

      assertEquals(String.format(Locale.ENGLISH, "treeId: %d, InstanceId: %d, Prediction: %d",
          treeId, instanceId, prediction), data.get(instanceId).getLabel(), prediction);
    }

  }

  /**
   * Custom Leaf node that returns for each instance its own label
   *
   */
  private static class MockLeaf extends Node {

    @Override
    public int classify(Instance instance) {
      return instance.getLabel();
    }

    @Override
    protected String getString() {
      return "[MockLeaf]";
    }

    @Override
    public long maxDepth() {
      return 0;
    }

    @Override
    protected Type getType() {
      return Type.MOCKLEAF;
    }

    @Override
    public long nbNodes() {
      return 0;
    }

    @Override
    protected void writeNode(DataOutput out) throws IOException {
    }

    @Override
    public void readFields(DataInput in) throws IOException {
    }

   
  }

  private static class MockTreeBuilder implements TreeBuilder {

    @Override
    public Node build(Random rng, Data data) {
      return new MockLeaf();
    }

  }
}
TOP

Related Classes of org.apache.mahout.df.mapreduce.partial.PartitionBugTest$MockTreeBuilder

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.