Package quickml.supervised.crossValidation

Source Code of quickml.supervised.crossValidation.OutOfTimeCrossValidatorTests

package quickml.supervised.crossValidation;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.AttributesMap;
import quickml.data.PredictionMap;
import quickml.supervised.PredictiveModel;
import quickml.supervised.PredictiveModelWithDataBuilder;
import quickml.supervised.crossValidation.crossValLossFunctions.ClassifierLogCVLossFunction;
import quickml.supervised.crossValidation.crossValLossFunctions.WeightedAUCCrossValLossFunction;
import quickml.supervised.crossValidation.dateTimeExtractors.TestDateTimeExtractor;
import quickml.data.Instance;
import quickml.experiments.TrainingDataGenerator2;
import quickml.supervised.classifier.decisionTree.TreeBuilder;
import quickml.supervised.classifier.randomForest.RandomForest;
import quickml.supervised.classifier.randomForest.RandomForestBuilder;

import java.io.Serializable;
import java.util.List;
import java.util.Map;

/**
* Created by alexanderhawk on 5/17/14.
*/
public class OutOfTimeCrossValidatorTests {
    private static final Logger logger = LoggerFactory.getLogger(OutOfTimeCrossValidator.class);
    List<Instance<AttributesMap>> trainingData;

    @Before
    public void setUp() throws Exception {
        int numTraniningExamples = 40100;
        String bidRequestAttributes[] = {"seller_id", "user_id", "users_favorite_beer_id", "favorite_soccer_team_id", "user_iq"};
        TrainingDataGenerator2 trainingDataGenerator = new TrainingDataGenerator2(numTraniningExamples, .005, bidRequestAttributes);
        trainingData = trainingDataGenerator.createTrainingData();
        int millisInMinute = 60000;
        int instanceNumber = 0;
        for (Instance<AttributesMap> instance : trainingData) {
            instance.getAttributes().put("currentTimeMillis", millisInMinute * instanceNumber);
            instanceNumber++;
        }
    }

    @Test
    public void testLossBetween0And1() {
        //  List<Instance<AttributesMap>> trainingData = setUp();
        logger.info("trainingDataSize " + trainingData.size());
        PredictiveModelWithDataBuilder<AttributesMap, ? extends PredictiveModel<AttributesMap, PredictionMap>> predictiveModelWithDataBuilder = getPredictiveModelWithDataBuilder(5, 5);

        ClassifierOutOfTimeCrossValidator crossValidator = new ClassifierOutOfTimeCrossValidator(new ClassifierLogCVLossFunction(), 0.25, 30, new TestDateTimeExtractor()); //number of validation time slices
        double totalLoss = crossValidator.getCrossValidatedLoss(predictiveModelWithDataBuilder, trainingData);
        Assert.assertTrue(totalLoss > 0 && totalLoss <= 1.0);
        logger.info("\n\nAUCLoss\n");
        crossValidator = new ClassifierOutOfTimeCrossValidator(new WeightedAUCCrossValLossFunction(1.0), 0.25, 30, new TestDateTimeExtractor()); //number of validation time slices
        totalLoss = crossValidator.getCrossValidatedLoss(predictiveModelWithDataBuilder, trainingData);
        logger.info("totalLoss from AUc "+ totalLoss);
        Assert.assertTrue(totalLoss > 0 && totalLoss <= 1.0);
//
        // crossValidator = new OutOfTimeCrossValidator(new ClassifierRMSECrossValLossFunction(), 0.25, 30, new TestDateTimeExtractor()); //number of validation time slices
        // totalLoss = crossValidator.getCrossValidatedLoss(randomForestBuilder, trainingData);
        // Assert.assertTrue(totalLoss > 0 && totalLoss <=1.0);
    }


    private static RandomForest getRandomForest(List<Instance<AttributesMap>> trainingData, int maxDepth, int numTrees) {
        TreeBuilder treeBuilder = new TreeBuilder().maxDepth(maxDepth).ignoreAttributeAtNodeProbability(.7);
        RandomForestBuilder randomForestBuilder = new RandomForestBuilder(treeBuilder).numTrees(numTrees);
        return randomForestBuilder.buildPredictiveModel(trainingData);
    }

    private static PredictiveModelWithDataBuilder<AttributesMap, ? extends PredictiveModel<AttributesMap, PredictionMap>> getPredictiveModelWithDataBuilder(int maxDepth, int numTrees) {
        TreeBuilder treeBuilder = new TreeBuilder().maxDepth(maxDepth).ignoreAttributeAtNodeProbability(.7);
        RandomForestBuilder randomForestBuilder = new RandomForestBuilder(treeBuilder).numTrees(numTrees);
        PredictiveModelWithDataBuilder<AttributesMap, ? extends PredictiveModel<AttributesMap, PredictionMap>> builder = new PredictiveModelWithDataBuilder<>(randomForestBuilder);//.rebuildThreshold(4).splitNodeThreshold(2);
        return builder;
    }
}
TOP

Related Classes of quickml.supervised.crossValidation.OutOfTimeCrossValidatorTests

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.