Package quickml.supervised.classifier.splitOnAttribute

Source Code of quickml.supervised.classifier.splitOnAttribute.SplitOnAttributeClassifierBuilderTest

package quickml.supervised.classifier.splitOnAttribute;


import org.testng.Assert;
import org.testng.annotations.Test;
import quickml.data.AttributesMap;
import quickml.data.Instance;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.PredictiveModelWithDataBuilder;
import quickml.supervised.classifier.TreeBuilderTestUtils;
import quickml.supervised.classifier.decisionTree.Tree;
import quickml.supervised.classifier.decisionTree.TreeBuilder;
import quickml.supervised.classifier.decisionTree.scorers.SplitDiffScorer;
import quickml.supervised.classifier.randomForest.RandomForest;
import quickml.supervised.classifier.randomForest.RandomForestBuilder;

import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Created by Chris on 5/14/2014.
*/
public class SplitOnAttributeClassifierBuilderTest {
    @Test
    public void simpleBmiTest() throws Exception {
        Set<String> whiteList = new HashSet<>();
        whiteList.add("weight");
        whiteList.add("height");
        final List<Instance<AttributesMap>> instances = TreeBuilderTestUtils.getInstances(10000);
        final TreeBuilder tb = new TreeBuilder(new SplitDiffScorer()).splitPredictiveModel("gender", whiteList);
        final RandomForestBuilder rfb = new RandomForestBuilder(tb);
        final SplitOnAttributeClassifierBuilder cpmb = new SplitOnAttributeClassifierBuilder("gender", rfb, 10, 0.1, whiteList, 1);
        final long startTime = System.currentTimeMillis();
        final SplitOnAttributeClassifier splitOnAttributeClassifier = cpmb.buildPredictiveModel(instances);
        final RandomForest randomForest = (RandomForest) splitOnAttributeClassifier.getDefaultPM();

        TreeBuilderTestUtils.serializeDeserialize(randomForest);
        final List<Tree> trees = randomForest.trees;
        final int treeSize = trees.size();
        Assert.assertTrue(treeSize < 400, "Forest size should be less than 400");
        Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000,"Building this node should take far less than 20 seconds");
    }

    @Test
    public void simpleBmiTestSplit() throws Exception {
        final List<Instance<AttributesMap>> instances = TreeBuilderTestUtils.getInstances(1000);
        final PredictiveModelWithDataBuilder<AttributesMap ,SplitOnAttributeClassifier> wb = getWrappedUpdatablePredictiveModelBuilder();
        wb.splitNodeThreshold(1);
        final long startTime = System.currentTimeMillis();
        final SplitOnAttributeClassifier splitOnAttributeClassifier = wb.buildPredictiveModel(instances);
        final RandomForest randomForest = (RandomForest) splitOnAttributeClassifier.getDefaultPM();

        TreeBuilderTestUtils.serializeDeserialize(randomForest);

        final List<Tree> trees = randomForest.trees;
        final int treeSize = trees.size();
        final int firstTreeNodeSize = trees.get(0).node.size();
        Assert.assertTrue(treeSize < 400, "Forest size should be less than 400");
        Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000,"Building this node should take far less than 20 seconds");

        final List<Instance<AttributesMap>> newInstances = TreeBuilderTestUtils.getInstances(1000);
        final SplitOnAttributeClassifier splitOnAttributeClassifier1 = wb.buildPredictiveModel(newInstances);
        final RandomForest newRandomForest = (RandomForest) splitOnAttributeClassifier1.getDefaultPM();
        Assert.assertTrue(splitOnAttributeClassifier == splitOnAttributeClassifier1, "Expect same tree to be updated");
        Assert.assertEquals(treeSize, newRandomForest.trees.size(), "Expected same number of trees");
        Assert.assertNotEquals(firstTreeNodeSize, newRandomForest.trees.get(0).node.size(), "Expected new nodes");
    }

    @Test
    public void simpleBmiTestNoSplit() throws Exception {
        final List<Instance<AttributesMap>> instances = TreeBuilderTestUtils.getInstances(1000);
        final PredictiveModelWithDataBuilder<AttributesMap ,SplitOnAttributeClassifier> wb = getWrappedUpdatablePredictiveModelBuilder();
        final long startTime = System.currentTimeMillis();
        final SplitOnAttributeClassifier splitOnAttributeClassifier = wb.buildPredictiveModel(instances);
        final RandomForest randomForest = (RandomForest) splitOnAttributeClassifier.getDefaultPM();

        TreeBuilderTestUtils.serializeDeserialize(randomForest);

        final List<Tree> trees = randomForest.trees;
        final int treeSize = trees.size();
        final int firstTreeNodeSize = trees.get(0).node.size();
        Assert.assertTrue(treeSize < 400, "Forest size should be less than 400");
        Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000,"Building this node should take far less than 20 seconds");

        final List<Instance<AttributesMap>> newInstances = TreeBuilderTestUtils.getInstances(1000);
        final SplitOnAttributeClassifier splitOnAttributeClassifier1 = wb.buildPredictiveModel(newInstances);
        final RandomForest newRandomForest = (RandomForest) splitOnAttributeClassifier1.getDefaultPM();
        Assert.assertTrue(splitOnAttributeClassifier == splitOnAttributeClassifier1, "Expect same tree to be updated");
        Assert.assertEquals(treeSize, newRandomForest.trees.size(), "Expected same number of trees");
        Assert.assertEquals(firstTreeNodeSize, newRandomForest.trees.get(0).node.size(), "Expected same nodes");
    }

    private PredictiveModelWithDataBuilder<AttributesMap ,SplitOnAttributeClassifier> getWrappedUpdatablePredictiveModelBuilder() {
        Set<String> whiteList = new HashSet<>();
        whiteList.add("weight");
        whiteList.add("height");
        final TreeBuilder tb = new TreeBuilder(new SplitDiffScorer()).splitPredictiveModel("gender", whiteList);
        final RandomForestBuilder urfb = new RandomForestBuilder(tb);
        final SplitOnAttributeClassifierBuilder ucpmb = new SplitOnAttributeClassifierBuilder("gender", urfb, 10, 0.1, whiteList, 1);
        return new PredictiveModelWithDataBuilder<>(ucpmb);
    }
}
TOP

Related Classes of quickml.supervised.classifier.splitOnAttribute.SplitOnAttributeClassifierBuilderTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.