Package com.cloudera.oryx.common.pmml

Source Code of com.cloudera.oryx.common.pmml.PMMLUtilsTest

/*
* Copyright (c) 2014, Cloudera, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/

package com.cloudera.oryx.common.pmml;

import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.Model;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TreeModel;
import org.junit.Test;

import com.cloudera.oryx.common.OryxTest;

public final class PMMLUtilsTest extends OryxTest {

  public static PMML buildDummyModel() {
    Node node = new Node();
    node.setRecordCount(123.0);
    TreeModel treeModel = new TreeModel(null, node, MiningFunctionType.CLASSIFICATION);
    PMML pmml = PMMLUtils.buildSkeletonPMML();
    pmml.getModels().add(treeModel);
    return pmml;
  }

  @Test
  public void testSkeleton() {
    PMML pmml = PMMLUtils.buildSkeletonPMML();
    assertEquals("Oryx", pmml.getHeader().getApplication().getName());
    assertNotNull(pmml.getHeader().getTimestamp());
  }

  @Test
  public void testReadWrite() throws Exception {
    Path tempModelFile = Files.createTempFile(getTempDir(), "model", ".pmml.gz");
    PMML model = buildDummyModel();
    PMMLUtils.write(model, tempModelFile);
    assertTrue(Files.exists(tempModelFile));
    PMML model2 = PMMLUtils.read(tempModelFile);
    List<Model> models = model2.getModels();
    assertEquals(1, models.size());
    assertTrue(models.get(0) instanceof TreeModel);
    TreeModel treeModel = (TreeModel) models.get(0);
    assertEquals(123.0, treeModel.getNode().getRecordCount().doubleValue());
    assertEquals(MiningFunctionType.CLASSIFICATION, treeModel.getFunctionName());
  }

  @Test
  public void testExtensionValue() {
    PMML model = buildDummyModel();
    assertNull(PMMLUtils.getExtensionValue(model, "foo"));
    PMMLUtils.addExtension(model, "foo", "bar");
    assertEquals("bar", PMMLUtils.getExtensionValue(model, "foo"));
  }

  @Test
  public void testExtensionContent() {
    PMML model = buildDummyModel();
    assertNull(PMMLUtils.getExtensionContent(model, "foo"));
    PMMLUtils.addExtensionContent(model, "foo", Arrays.asList("bar", "baz"));
    assertEquals(Arrays.<Object>asList("bar", "baz"), PMMLUtils.getExtensionContent(model, "foo"));
    PMMLUtils.addExtensionContent(model, "foo", Collections.emptyList());
    assertEquals(Arrays.<Object>asList("bar", "baz"), PMMLUtils.getExtensionContent(model, "foo"));
  }

  @Test
  public void testParseArray() {
    assertEquals(Arrays.asList("foo", "bar", "baz"),
        PMMLUtils.parseArray(Collections.singletonList("foo bar baz")));
    assertTrue(PMMLUtils.parseArray(Collections.singletonList("")).isEmpty());
  }

  @Test
  public void testToString() throws Exception {
    PMML model = buildDummyModel();
    model.getHeader().setTimestamp(null);
    assertEquals("<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?>\n" +
                 "<PMML version=\"4.2.1\" xmlns=\"http://www.dmg.org/PMML-4_2\">\n" +
                 "    <Header>\n" +
                 "        <Application name=\"Oryx\"/>\n" +
                 "    </Header>\n" +
                 "    <TreeModel functionName=\"classification\">\n" +
                 "        <Node recordCount=\"123.0\"/>\n" +
                 "    </TreeModel>\n" +
                 "</PMML>\n",
                 PMMLUtils.toString(model));
  }

  @Test
  public void testFromString() throws Exception {
    PMML model = buildDummyModel();
    PMML model2 = PMMLUtils.fromString(PMMLUtils.toString(model));
    assertEquals(model.getHeader().getApplication().getName(),
                 model2.getHeader().getApplication().getName());
    assertEquals(model.getModels().get(0).getFunctionName(),
                 model2.getModels().get(0).getFunctionName());
  }

}
TOP

Related Classes of com.cloudera.oryx.common.pmml.PMMLUtilsTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.