Package com.cloudera.oryx.ml.mllib.als

Source Code of com.cloudera.oryx.ml.mllib.als.ALSUpdateIT

/*
* Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/

package com.cloudera.oryx.ml.mllib.als;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.typesafe.config.Config;
import org.dmg.pmml.Extension;
import org.dmg.pmml.PMML;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.io.IOUtils;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.ml.MLUpdate;
import com.cloudera.oryx.common.pmml.PMMLUtils;

public final class ALSUpdateIT extends AbstractALSIT {

  private static final Logger log = LoggerFactory.getLogger(ALSUpdateIT.class);
  private static final ObjectMapper MAPPER = new ObjectMapper();

  private static final int DATA_TO_WRITE = 2000;
  private static final int WRITE_INTERVAL_MSEC = 10;
  private static final int GEN_INTERVAL_SEC = 10;
  private static final int BLOCK_INTERVAL_SEC = 1;
  private static final int FEATURES = 4;
  private static final double LAMBDA = 0.001;
  private static final int NUM_USERS_ITEMS = 1000;

  @Test
  public void testALS() throws Exception {
    Path tempDir = getTempDir();
    Path dataDir =  tempDir.resolve("data");
    Path modelDir = tempDir.resolve("model");

    Map<String,String> overlayConfig = new HashMap<>();
    overlayConfig.put("batch.update-class", ALSUpdate.class.getName());
    overlayConfig.put("batch.storage.data-dir",
                      "\"" + dataDir.toUri() + "\"");
    overlayConfig.put("batch.storage.model-dir",
                      "\"" + modelDir.toUri() + "\"");
    overlayConfig.put("batch.generation-interval-sec",
                      Integer.toString(GEN_INTERVAL_SEC));
    overlayConfig.put("batch.block-interval-sec",
                      Integer.toString(BLOCK_INTERVAL_SEC));
    overlayConfig.put("als.implicit", "false");
    overlayConfig.put("als.hyperparams.lambda", Double.toString(LAMBDA));
    overlayConfig.put("als.hyperparams.features", Integer.toString(FEATURES));
    overlayConfig.put("als.no-known-items", "false");
    Config config = ConfigUtils.overlayOn(overlayConfig, getConfig());

    startMessageQueue();

    List<Pair<String,String>> updates = startServerProduceConsumeQueues(
        config,
        new RandomALSDataGenerator(NUM_USERS_ITEMS, NUM_USERS_ITEMS, 1, 5),
        DATA_TO_WRITE,
        WRITE_INTERVAL_MSEC);

    List<Path> modelInstanceDirs = IOUtils.listFiles(modelDir, "*");
    log.info("Model instance dirs: {}", modelInstanceDirs);

    int generations = modelInstanceDirs.size();
    checkIntervals(generations, DATA_TO_WRITE, WRITE_INTERVAL_MSEC, GEN_INTERVAL_SEC);

    List<Collection<Integer>> userIDs = new ArrayList<>();
    userIDs.add(Collections.<Integer>emptySet());
    List<Collection<Integer>> productIDs = new ArrayList<>();
    productIDs.add(Collections.<Integer>emptySet());

    for (Path modelInstanceDir : modelInstanceDirs) {
      log.info("Testing model instance dir {}", modelInstanceDir);
      Path modelFile = modelInstanceDir.resolve(MLUpdate.MODEL_FILE_NAME);
      assertTrue("Model file should exist: " + modelFile, Files.exists(modelFile));
      assertTrue("Model file should not be empty: " + modelFile, Files.size(modelFile) > 0);
      Path xDir = modelInstanceDir.resolve("X");
      assertTrue(Files.exists(xDir));
      userIDs.add(checkFeatures(xDir, userIDs.get(userIDs.size() - 1)));
      Path yDir = modelInstanceDir.resolve("Y");
      assertTrue(Files.exists(yDir));
      productIDs.add(checkFeatures(yDir, productIDs.get(productIDs.size() - 1)));
    }
    userIDs.remove(0);
    productIDs.remove(0);

    Collection<Integer> expectedUsers = null;
    Collection<Integer> expectedProducts = null;
    Collection<Integer> seenUsers = null;
    Collection<Integer> seenProducts = null;
    Collection<Integer> lastModelUsers = null;
    Collection<Integer> lastModelProducts = null;
    int whichGeneration = -1;
    for (Pair<String,String> km : updates) {

      String type = km.getFirst();
      String value = km.getSecond();

      log.debug("{} = {}", type, value);

      boolean isModel = "MODEL".equals(type);
      boolean isUpdate = "UP".equals(type);
      assertTrue(isModel || isUpdate);

      if (isUpdate) {

        List<?> update = MAPPER.readValue(value, List.class);
        // First field is X or Y, depending on whether it's a user or item vector
        boolean isUser = "X".equals(update.get(0).toString());
        boolean isProduct = "Y".equals(update.get(0).toString());
        // Next is user/item ID
        Integer id = Integer.valueOf(update.get(1).toString());
        assertTrue(isUser || isProduct);
        if (isUser) {
          seenUsers.add(id);
        } else {
          seenProducts.add(id);
        }
        // Verify that feature vector are valid floats
        for (float f : MAPPER.convertValue(update.get(2), float[].class)) {
          assertTrue(!Float.isNaN(f) && !Float.isInfinite(f));
        }

        if (isUser) {
          // Only known-items for users exist now, not known users for items
          @SuppressWarnings("unchecked")
          Collection<String> knownUsersItems = (Collection<String>) update.get(3);
          assertFalse(knownUsersItems.isEmpty());
          for (String known : knownUsersItems) {
            int i = Integer.parseInt(known);
            assertTrue(i >= 0 && i < NUM_USERS_ITEMS);
          }
        }

      } else {

        PMML pmml = PMMLUtils.fromString(value);
        List<Extension> extensions = pmml.getExtensions();
        assertEquals(7, extensions.size());
        // Basic hyperparameters should match
        assertEquals(Integer.toString(FEATURES), PMMLUtils.getExtensionValue(pmml, "features"));
        assertEquals(Double.toString(LAMBDA),PMMLUtils.getExtensionValue(pmml, "lambda"));
        assertEquals("false", PMMLUtils.getExtensionValue(pmml, "implicit"));

        // See if users/item sets seen in updates match what was expected from output
        assertEquals(expectedUsers, seenUsers);
        assertEquals(expectedProducts, seenProducts);

        // Also check key sets reported in model
        assertEquals(expectedUsers, lastModelUsers);
        assertEquals(expectedProducts, lastModelProducts);

        // Update for next round
        whichGeneration++;
        expectedUsers = userIDs.get(whichGeneration);
        expectedProducts = productIDs.get(whichGeneration);
        seenUsers = new HashSet<>();
        seenProducts = new HashSet<>();
        lastModelUsers = parseIDsFromContent(PMMLUtils.getExtensionContent(pmml, "XIDs"));
        lastModelProducts = parseIDsFromContent(PMMLUtils.getExtensionContent(pmml, "YIDs"));

      }
    }

  }

  private static Collection<Integer> parseIDsFromContent(List<?> content) {
    List<String> values = PMMLUtils.parseArray(content);
    Collection<Integer> result = new HashSet<>(values.size());
    for (String s : values) {
      result.add(Integer.valueOf(s));
    }
    return result;
  }

  private static Collection<Integer> checkFeatures(Path path, Collection<Integer> previousIDs)
      throws IOException {
    Collection<Integer> seenIDs = new HashSet<>();
    for (Path file : IOUtils.listFiles(path, "part-*")) {
      for (String line : IOUtils.readLines(file)) {
        List<?> update = MAPPER.readValue(line, List.class);
        seenIDs.add(Integer.valueOf(update.get(0).toString()));
        assertEquals(FEATURES, MAPPER.convertValue(update.get(1), float[].class).length);
      }
    }
    assertFalse(seenIDs.isEmpty());
    assertTrue(seenIDs.containsAll(previousIDs));
    return seenIDs;
  }

}
TOP

Related Classes of com.cloudera.oryx.ml.mllib.als.ALSUpdateIT

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.