Package quickml.supervised.classifier.downsampling

Source Code of quickml.supervised.classifier.downsampling.DownsamplingClassifierBuilderTest$SamePredictionPredictiveModel

package quickml.supervised.classifier.downsampling;

import com.beust.jcommander.internal.Lists;
import junit.framework.Assert;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.testng.annotations.Test;
import quickml.collections.MapUtils;
import quickml.data.*;
import quickml.supervised.UpdatablePredictiveModelBuilder;
import quickml.supervised.classifier.AbstractClassifier;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.PredictiveModelWithDataBuilder;
import quickml.supervised.classifier.TreeBuilderTestUtils;
import quickml.supervised.classifier.decisionTree.Tree;
import quickml.supervised.classifier.decisionTree.TreeBuilder;
import quickml.supervised.classifier.decisionTree.scorers.SplitDiffScorer;
import quickml.supervised.classifier.randomForest.RandomForest;
import quickml.supervised.classifier.randomForest.RandomForestBuilder;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.mockito.Mockito.when;

/**
* Created by ian on 4/24/14.
*/
public class DownsamplingClassifierBuilderTest {
    @Test
    public void simpleTest() {
        final UpdatablePredictiveModelBuilder<AttributesMap , Classifier> updatablePredictiveModelBuilder = Mockito.mock(UpdatablePredictiveModelBuilder.class);
        when(updatablePredictiveModelBuilder.buildPredictiveModel(Mockito.any(Iterable.class))).thenAnswer(new Answer<Classifier>() {
            @Override
            public Classifier answer(final InvocationOnMock invocationOnMock) throws Throwable {
                Iterable<Instance<AttributesMap>> instances = (Iterable<Instance<AttributesMap>>) invocationOnMock.getArguments()[0];
                int total = 0, sum = 0;
                for (Instance<AttributesMap> instance : instances) {
                    total++;
                    if (instance.getLabel().equals(true)) {
                        sum++;
                    }
                }
                Classifier dumbPM = new SamePredictionPredictiveModel((double) sum / (double) total);
                return dumbPM;
            }
        });
        DownsamplingClassifierBuilder downsamplingClassifierBuilder = new DownsamplingClassifierBuilder(updatablePredictiveModelBuilder, 0.2);
        List<Instance<AttributesMap>> data = Lists.newArrayList();
        for (int x=0; x<10000; x++) {
            data.add(new InstanceImpl(AttributesMap.newHashMap(), (MapUtils.random.nextDouble() < 0.05)));
        }
        DownsamplingClassifier predictiveModel = downsamplingClassifierBuilder.buildPredictiveModel(data);
        AttributesMap  map = AttributesMap.newHashMap() ;
        map.put("true",Boolean.TRUE);
        final double correctedMinorityInstanceOccurance = predictiveModel.getProbability(map, Boolean.TRUE);
        double error = Math.abs(0.05 - correctedMinorityInstanceOccurance);
        Assert.assertTrue(String.format("Error should be < 0.1 but was %s (prob=%s, desired=0.05)", error, correctedMinorityInstanceOccurance), error < 0.01);
    }

    @Test
    public void simpleBmiTest() throws IOException, ClassNotFoundException {
        final TreeBuilder tb = new TreeBuilder(new SplitDiffScorer());
        final RandomForestBuilder urfb = new RandomForestBuilder(tb);
        final DownsamplingClassifierBuilder dpmb = new DownsamplingClassifierBuilder(urfb, 0.1);

        final List<Instance<AttributesMap>> instances = TreeBuilderTestUtils.getIntegerInstances(1000);
        final PredictiveModelWithDataBuilder<AttributesMap ,DownsamplingClassifier> wb = new PredictiveModelWithDataBuilder<>(dpmb);
        final long startTime = System.currentTimeMillis();
        final DownsamplingClassifier downsamplingClassifier = wb.buildPredictiveModel(instances);

        TreeBuilderTestUtils.serializeDeserialize(downsamplingClassifier);

        RandomForest randomForest = (RandomForest) downsamplingClassifier.wrappedClassifier;
        final List<Tree> trees = randomForest.trees;
        final int treeSize = trees.size();
        final int firstTreeNodeSize = trees.get(0).node.size();
        org.testng.Assert.assertTrue(treeSize < 400, "Forest size should be less than 400");
        org.testng.Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000, "Building this node should take far less than 20 seconds");

        final List<Instance<AttributesMap>> newInstances = TreeBuilderTestUtils.getIntegerInstances(1000);
        final DownsamplingClassifier downsamplingClassifier1 = wb.buildPredictiveModel(newInstances);
        final RandomForest newRandomForest = (RandomForest) downsamplingClassifier1.wrappedClassifier;
        org.testng.Assert.assertTrue(downsamplingClassifier == downsamplingClassifier1, "Expect same tree to be updated");
        org.testng.Assert.assertEquals(treeSize, newRandomForest.trees.size(), "Expected same number of trees");
        org.testng.Assert.assertEquals(firstTreeNodeSize, newRandomForest.trees.get(0).node.size(), "Expected same nodes");
    }

    private static class SamePredictionPredictiveModel extends AbstractClassifier {

        private static final long serialVersionUID = 8241616760952568181L;
        private final double prediction;

        public SamePredictionPredictiveModel(double prediction) {

            this.prediction = prediction;
        }

        @Override
        public void dump(final Appendable appendable) {
            throw new UnsupportedOperationException();
        }


        @Override
        public PredictionMap predict(AttributesMap attributes) {
            Map<Serializable, Double> map = new HashMap<>();
            for(Serializable value : attributes.values()) {
                map.put(value, prediction);
            }
            return new PredictionMap(map);
        }
    }
}
TOP

Related Classes of quickml.supervised.classifier.downsampling.DownsamplingClassifierBuilderTest$SamePredictionPredictiveModel

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.