Package com.cloudera.oryx.ml.mllib.als

Source Code of com.cloudera.oryx.ml.mllib.als.HyperParamTuningIT

/*
* Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/

package com.cloudera.oryx.ml.mllib.als;

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.typesafe.config.Config;
import org.dmg.pmml.Extension;
import org.dmg.pmml.PMML;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.cloudera.oryx.common.io.IOUtils;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.ml.MLUpdate;
import com.cloudera.oryx.common.pmml.PMMLUtils;

public final class HyperParamTuningIT extends AbstractALSIT {

  private static final Logger log = LoggerFactory.getLogger(HyperParamTuningIT.class);

  private static final int DATA_TO_WRITE = 2000;
  private static final int WRITE_INTERVAL_MSEC = 10;
  private static final int GEN_INTERVAL_SEC = 30;
  private static final int BLOCK_INTERVAL_SEC = 1;
  private static final int TEST_FEATURES = 7;
  private static final int TEST_ELEMENTS = 1000;

  @Test
  public void testALS() throws Exception {
    Path tempDir = getTempDir();
    Path dataDir =  tempDir.resolve("data");
    Path modelDir = tempDir.resolve("model");

    Map<String,String> overlayConfig = new HashMap<>();
    overlayConfig.put("batch.update-class", ALSUpdate.class.getName());
    overlayConfig.put("batch.storage.data-dir",
                      "\"" + dataDir.toUri() + "\"");
    overlayConfig.put("batch.storage.model-dir",
                      "\"" + modelDir.toUri() + "\"");
    overlayConfig.put("batch.generation-interval-sec",
                      Integer.toString(GEN_INTERVAL_SEC));
    overlayConfig.put("batch.block-interval-sec",
                      Integer.toString(BLOCK_INTERVAL_SEC));
    // Choose pairs of values where the best is predictable
    overlayConfig.put("als.implicit", "true");
    overlayConfig.put("als.hyperparams.features", "[1," + TEST_FEATURES + "]");
    overlayConfig.put("als.hyperparams.lambda", "0.001");
    overlayConfig.put("als.hyperparams.alpha", "1.0");
    overlayConfig.put("als.no-known-items", "false");
    overlayConfig.put("ml.eval.candidates", "2");
    overlayConfig.put("ml.eval.parallelism", "2");
    Config config = ConfigUtils.overlayOn(overlayConfig, getConfig());

    startMessageQueue();

    startServerProduceConsumeQueues(config,
                                    new FeaturesALSDataGenerator(TEST_ELEMENTS,
                                                                 TEST_ELEMENTS,
                                                                 TEST_FEATURES),
                                    DATA_TO_WRITE,
                                    WRITE_INTERVAL_MSEC);

    List<Path> modelInstanceDirs = IOUtils.listFiles(modelDir, "*");
    log.info("Model instance dirs: {}", modelInstanceDirs);
    assertFalse("No models?", modelInstanceDirs.isEmpty());

    checkIntervals(modelInstanceDirs.size(), DATA_TO_WRITE, WRITE_INTERVAL_MSEC, GEN_INTERVAL_SEC);

    Path modelFile = modelInstanceDirs.get(0).resolve(MLUpdate.MODEL_FILE_NAME);
    assertTrue("No such model file: " + modelFile, Files.exists(modelFile));

    PMML pmml = PMMLUtils.read(modelFile);
    List<Extension> extensions = pmml.getExtensions();
    assertEquals(8, extensions.size());
    assertNotNull(PMMLUtils.getExtensionValue(pmml, "X"));
    assertNotNull(PMMLUtils.getExtensionValue(pmml, "Y"));
    assertTrue(Boolean.parseBoolean(PMMLUtils.getExtensionValue(pmml, "implicit")));
    assertEquals(0.001, Double.parseDouble(PMMLUtils.getExtensionValue(pmml, "lambda")));
    assertEquals(1.0, Double.parseDouble(PMMLUtils.getExtensionValue(pmml, "alpha")));
    assertEquals(TEST_FEATURES, Integer.parseInt(PMMLUtils.getExtensionValue(pmml, "features")));
  }

}
TOP

Related Classes of com.cloudera.oryx.ml.mllib.als.HyperParamTuningIT

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.