Package quickml.supervised.classifier.decisionTree

Source Code of quickml.supervised.classifier.decisionTree.TreeBuilderTest

package quickml.supervised.classifier.decisionTree;

import org.testng.Assert;
import org.testng.annotations.Test;
import org.testng.internal.annotations.Sets;
import quickml.collections.MapUtils;
import quickml.data.AttributesMap;
import quickml.data.Instance;
import quickml.data.InstanceImpl;
import quickml.data.PredictionMap;
import quickml.supervised.PredictiveModelWithDataBuilder;
import quickml.supervised.classifier.TreeBuilderTestUtils;
import quickml.supervised.classifier.decisionTree.scorers.SplitDiffScorer;
import quickml.supervised.classifier.decisionTree.tree.Node;

import java.io.Serializable;
import java.util.HashMap;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;


public class TreeBuilderTest {
  @Test
  public void simpleBmiTest() throws Exception {
        final List<Instance<AttributesMap>> instances = TreeBuilderTestUtils.getInstances(10000);
    final TreeBuilder tb = new TreeBuilder(new SplitDiffScorer());
    final long startTime = System.currentTimeMillis();
        final Tree tree = tb.buildPredictiveModel(instances);
    final Node node = tree.node;

        TreeBuilderTestUtils.serializeDeserialize(node);

        final int nodeSize = node.size();
    Assert.assertTrue(nodeSize < 400, "Tree size should be less than 400 nodes");
    Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000,"Building this node should take far less than 20 seconds");
  }



    @Test(enabled = false)
  public void multiScorerBmiTest() {
    final Set<Instance<AttributesMap>> instances = Sets.newHashSet();

    for (int x = 0; x < 10000; x++) {
      final double height = (4 * 12) + MapUtils.random.nextInt(3 * 12);
      final double weight = 120 + MapUtils.random.nextInt(110);
            AttributesMap  attributes = AttributesMap.newHashMap() ;
            attributes.put("weight", weight);
            attributes.put("height", height);
      final Instance<AttributesMap> instance = new InstanceImpl<>(attributes, TreeBuilderTestUtils.bmiHealthy(weight, height));
      instances.add(instance);
    }
    {
      final TreeBuilder tb = new TreeBuilder(new SplitDiffScorer());
      final Tree tree = tb.buildPredictiveModel(instances);
      System.out.println("SplitDiffScorer node size: " + tree.node.size());
    }
  }

    @Test
    public void simpleBmiTestSplit() throws Exception {
        final List<Instance<AttributesMap>> instances = TreeBuilderTestUtils.getInstances(10000);
        final PredictiveModelWithDataBuilder<AttributesMap ,Tree> wb = getWrappedUpdatablePredictiveModelBuilder();
        wb.splitNodeThreshold(1);
        final long startTime = System.currentTimeMillis();
        final Tree tree = wb.buildPredictiveModel(instances);

        TreeBuilderTestUtils.serializeDeserialize(tree.node);

        int nodeSize = tree.node.size();
        double nodeMeanDepth = tree.node.meanDepth();
        Assert.assertTrue(nodeSize < 400, "Tree size should be less than 400 nodes");
        Assert.assertTrue(nodeMeanDepth < 6, "Mean depth should be less than 6");
        Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000,"Building this node should take far less than 20 seconds");

        final List<Instance<AttributesMap>> newInstances = TreeBuilderTestUtils.getInstances(1000);
        final Tree newTree = wb.buildPredictiveModel(newInstances);
        Assert.assertTrue(tree == newTree, "Expect same tree to be updated");
        Assert.assertNotEquals(nodeSize, newTree.node.size(), "Expected new nodes");
        Assert.assertNotEquals(nodeMeanDepth, newTree.node.meanDepth(), "Expected new mean depth");

        nodeSize = newTree.node.size();
        nodeMeanDepth = newTree.node.meanDepth();
        wb.stripData(newTree);
        Assert.assertEquals(nodeSize, newTree.node.size(), "Expected same nodes");
        Assert.assertEquals(nodeMeanDepth, newTree.node.meanDepth(), "Expected same mean depth");
    }

    @Test
    public void simpleBmiTestNoSplit() throws Exception {
        final List<Instance<AttributesMap>> instances = TreeBuilderTestUtils.getInstances(10000);
        final PredictiveModelWithDataBuilder<AttributesMap ,Tree> wb = getWrappedUpdatablePredictiveModelBuilder();
        final long startTime = System.currentTimeMillis();
        final Tree tree = wb.buildPredictiveModel(instances);

        TreeBuilderTestUtils.serializeDeserialize(tree.node);

        int nodeSize = tree.node.size();
        double nodeMeanDepth = tree.node.meanDepth();
        Assert.assertTrue(nodeSize < 400, "Tree size should be less than 400 nodes");
        Assert.assertTrue(nodeMeanDepth < 6, "Mean depth should be less than 6");
        Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000,"Building this node should take far less than 20 seconds");

        final List<Instance<AttributesMap>> newInstances = TreeBuilderTestUtils.getInstances(10000);
        final Tree newTree = wb.buildPredictiveModel(newInstances);
        Assert.assertTrue(tree == newTree, "Expect same tree to be updated");
        Assert.assertEquals(nodeSize, newTree.node.size(), "Expected same nodes");
        Assert.assertNotEquals(nodeMeanDepth, newTree.node.meanDepth(), "Expected new mean depth");

        nodeSize = newTree.node.size();
        nodeMeanDepth = newTree.node.meanDepth();
        wb.stripData(newTree);
        Assert.assertEquals(nodeSize, newTree.node.size(), "Expected same nodes");
        Assert.assertEquals(nodeMeanDepth, newTree.node.meanDepth(), "Expected same mean depth");
    }

    @Test
    public void simpleBmiTestRebuild() throws Exception {
        final List<Instance<AttributesMap>> instances = TreeBuilderTestUtils.getInstances(10000);
        final PredictiveModelWithDataBuilder<AttributesMap ,Tree> wb = getWrappedUpdatablePredictiveModelBuilder();
        wb.rebuildThreshold(1);
        final long startTime = System.currentTimeMillis();
        final Tree tree = wb.buildPredictiveModel(instances);

        TreeBuilderTestUtils.serializeDeserialize(tree.node);

        int nodeSize = tree.node.size();
        double nodeMeanDepth = tree.node.meanDepth();
        Assert.assertTrue(nodeSize < 400, "Tree size should be less than 400 nodes");
        Assert.assertTrue(nodeMeanDepth < 6, "Mean depth should be less than 6");
        Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000,"Building this node should take far less than 20 seconds");

        final List<Instance<AttributesMap>> newInstances = TreeBuilderTestUtils.getInstances(10000);
        final Tree newTree = wb.buildPredictiveModel(newInstances);
        Assert.assertFalse(tree == newTree, "Expect new tree to be built");
    }

    private PredictiveModelWithDataBuilder<AttributesMap ,Tree> getWrappedUpdatablePredictiveModelBuilder() {
        final TreeBuilder tb = new TreeBuilder(new SplitDiffScorer());
        return new PredictiveModelWithDataBuilder<>(tb);
    }

    @Test
    public void twoDeterministicDecisionTreesAreEqual() throws IOException, ClassNotFoundException {

        final List<Instance<AttributesMap>> instancesTrain = TreeBuilderTestUtils.getInstances(10000);
        MapUtils.random.setSeed(1l);
        final Tree tree1 = (new TreeBuilder(new SplitDiffScorer())).buildPredictiveModel(instancesTrain);
        MapUtils.random.setSeed(1l);
        final Tree tree2 = (new TreeBuilder(new SplitDiffScorer())).buildPredictiveModel(instancesTrain);

        TreeBuilderTestUtils.serializeDeserialize(tree1.node);
        TreeBuilderTestUtils.serializeDeserialize(tree2.node);
        Assert.assertTrue(tree1.node.size() == tree2.node.size(), "Deterministic Decision Trees must have same number of nodes");

        final List<Instance<AttributesMap>> instancesTest = TreeBuilderTestUtils.getInstances(1000);
        for (Instance<AttributesMap> instance : instancesTest) {
           PredictionMap map1 = tree1.predict(instance.getAttributes());
           PredictionMap map2 = tree2.predict(instance.getAttributes());
            Assert.assertTrue(map1.equals(map2), "Deterministic Decision Trees must have equal classifications");
        }
    }

}
TOP

Related Classes of quickml.supervised.classifier.decisionTree.TreeBuilderTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.