Package com.cloudera.oryx.ml

Source Code of com.cloudera.oryx.ml.SimpleMLUpdateIT

/*
* Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/

package com.cloudera.oryx.ml;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.typesafe.config.Config;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.cloudera.oryx.common.random.RandomManager;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.lambda.AbstractBatchIT;

/**
* Tests {@link MLUpdate}.
*/
public final class SimpleMLUpdateIT extends AbstractBatchIT {

  private static final Logger log = LoggerFactory.getLogger(SimpleMLUpdateIT.class);

  private static final int DATA_TO_WRITE = 2400;
  private static final int WRITE_INTERVAL_MSEC = 25;
  private static final int GEN_INTERVAL_SEC = 20;
  private static final int BLOCK_INTERVAL_SEC = 1;
  private static final double TEST_FRACTION = 0.2;

  @Test
  public void testMLUpdate() throws Exception {
    Path tempDir = getTempDir();
    Path dataDir = tempDir.resolve("data");
    Path modelDir = tempDir.resolve("model");
    Map<String,String> overlayConfig = new HashMap<>();
    overlayConfig.put("batch.update-class", MockMLUpdate.class.getName());
    overlayConfig.put("batch.storage.data-dir",
                      "\"" + dataDir.toUri() + "\"");
    overlayConfig.put("batch.storage.model-dir",
                      "\"" + modelDir.toUri() + "\"");
    overlayConfig.put("batch.generation-interval-sec",
                      Integer.toString(GEN_INTERVAL_SEC));
    overlayConfig.put("batch.block-interval-sec",
                      Integer.toString(BLOCK_INTERVAL_SEC));
    overlayConfig.put("ml.eval.test-fraction", Double.toString(TEST_FRACTION));
    Config config = ConfigUtils.overlayOn(overlayConfig, getConfig());

    startMessageQueue();

    List<Integer> trainCounts = new ArrayList<>();
    List<Integer> testCounts = new ArrayList<>();

    MockMLUpdate.setCountHolders(trainCounts, testCounts);

    startServerProduceConsumeQueues(config, DATA_TO_WRITE, WRITE_INTERVAL_MSEC);

    // If lists are unequal at this point, there must have been an empty test set
    // which yielded no call to evaluate(). Fill in the blank
    while (trainCounts.size() > testCounts.size()) {
      testCounts.add(0);
    }

    log.info("trainCounts = {}", trainCounts);
    log.info("testCounts = {}", testCounts);

    checkOutputData(dataDir, DATA_TO_WRITE);
    checkIntervals(trainCounts.size(), DATA_TO_WRITE, WRITE_INTERVAL_MSEC, GEN_INTERVAL_SEC);

    assertEquals(testCounts.size(), trainCounts.size());

    RandomGenerator random = RandomManager.getRandom();
    int lastTotalTrainCount = 0;
    int lastTestCount = 0;
    for (int i = 0; i < testCounts.size(); i++) {
      int totalTrainCount = trainCounts.get(i);
      int testCount = testCounts.get(i);
      int newTrainInGen = totalTrainCount - (lastTotalTrainCount + lastTestCount);
      if (newTrainInGen == 0) {
        continue;
      }
      lastTotalTrainCount = totalTrainCount;
      lastTestCount = testCount;
      int totalNew = testCount + newTrainInGen;

      IntegerDistribution dist = new BinomialDistribution(random, totalNew, TEST_FRACTION);
      double probability;
      if (testCount < dist.getNumericalMean()) {
        probability = dist.cumulativeProbability(testCount);
      } else {
        probability = 1.0 - dist.cumulativeProbability(testCount);
      }
      log.info("Probability of observing {} as {} sample of {}: {}",
               testCount, TEST_FRACTION, totalNew, probability);
      assertTrue(probability >= 0.001);
    }

  }

}
TOP

Related Classes of com.cloudera.oryx.ml.SimpleMLUpdateIT

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.