Package quickml

Source Code of quickml.Benchmarks

package quickml;

import com.google.common.collect.Lists;
import org.json.simple.JSONObject;
import org.json.simple.JSONValue;
import quickml.supervised.crossValidation.StationaryCrossValidator;
import quickml.supervised.crossValidation.crossValLossFunctions.ClassifierLogCVLossFunction;
import quickml.data.*;
import quickml.supervised.classifier.decisionTree.Scorer;
import quickml.supervised.classifier.decisionTree.TreeBuilder;
import quickml.supervised.classifier.decisionTree.scorers.MSEScorer;
import quickml.supervised.classifier.decisionTree.scorers.SplitDiffScorer;
import quickml.supervised.classifier.randomForest.RandomForestBuilder;

import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;

public class Benchmarks {

    /**
     * @param args
     */
    public static void main(final String[] args) throws Exception {
        List<Instance<AttributesMap>> diaInstances = loadDiabetesDataset();

        testWithInstances("diabetes", diaInstances);

        final List<Instance<AttributesMap>> moboInstances = loadMoboDataset();

        testWithInstances("mobo", moboInstances);


    }

    private static void testWithInstances(String dsName, final List<Instance<AttributesMap>> instances) {
        StationaryCrossValidator crossValidator = new StationaryCrossValidator(new ClassifierLogCVLossFunction());

        for (final Scorer scorer : Lists.newArrayList(new SplitDiffScorer(), new MSEScorer(MSEScorer.CrossValidationCorrection.FALSE), new MSEScorer(MSEScorer.CrossValidationCorrection.TRUE))) {
            final TreeBuilder singleTreeBuilder = new TreeBuilder(scorer).binaryClassification(true);
            System.out.println(dsName+", single-tree, "+scorer+", "+crossValidator.getCrossValidatedLoss(singleTreeBuilder, instances));

            TreeBuilder forestTreeBuilder = new TreeBuilder(scorer).ignoreAttributeAtNodeProbability(0.5).binaryClassification(true);
            RandomForestBuilder randomForestBuilder = new RandomForestBuilder(forestTreeBuilder).numTrees(100).executorThreadCount(8);
            System.out.println(dsName+", random-forest, "+scorer+", "+crossValidator.getCrossValidatedLoss(randomForestBuilder, instances));
        }
    }

    public static List<Instance<AttributesMap>> loadDiabetesDataset() throws IOException {
        final BufferedReader br = new BufferedReader(new InputStreamReader((new GZIPInputStream(Benchmarks.class.getResourceAsStream("diabetesDataset.txt.gz")))));
        final List<Instance<AttributesMap>> instances = Lists.newLinkedList();

        while (true) {
            String line = br.readLine();
            if (line == null) {
                break;
            }
            String[] splitLine = line.split("\\s");
            AttributesMap attributes = AttributesMap.newHashMap();
            for (int x=0; x<8; x++) {
                attributes.put("attr" + x, Double.parseDouble(splitLine[x]));
            }
            final Instance<AttributesMap> instance = new InstanceImpl(attributes, splitLine[8]);
            instances.add(instance);

        }

        return instances;
    }

    public static List<Instance<AttributesMap>> loadIrisDataset() throws IOException {
        final BufferedReader br = new BufferedReader(new InputStreamReader((new GZIPInputStream(Benchmarks.class.getResourceAsStream("iris.data.gz")))));
        final List<Instance<AttributesMap>> instances = Lists.newLinkedList();

        String[] headings = new String[] {"sepal-length", "sepal-width", "petal-length", "petal-width"};
        while (true) {
            String line = br.readLine();
            if (line == null) {
                break;
            }
            String[] splitLine = line.split(",");

            AttributesMap attributes = AttributesMap.newHashMap();
            for (int x=0; x<splitLine.length - 1; x++) {
                attributes.put(headings[x], splitLine[x]);
            }
            final Instance<AttributesMap> instance = new InstanceImpl(attributes, splitLine[splitLine.length-1]);
            instances.add(instance);

        }

        return instances;
    }

    public static List<Instance<AttributesMap>> loadMoboDataset() throws IOException {
        final BufferedReader br = new BufferedReader(new InputStreamReader((new GZIPInputStream(Benchmarks.class.getResourceAsStream("mobo1.json.gz")))));

        final List<Instance<AttributesMap>> instances = Lists.newLinkedList();

        int count = 0;
        while (true) {
            count++;
            final String line = br.readLine();
            if (line == null) {
                break;
            }
            final JSONObject jo = (JSONObject) JSONValue.parse(line);
            AttributesMap a = AttributesMap.newHashMap();
            a.putAll((JSONObject) jo.get("attributes"));
            String binaryClassification = ((String) jo.get("output")).equals("none") ? "none" : "notNone";
            Instance<AttributesMap> instance = new InstanceImpl(a,binaryClassification);
            instances.add(instance);
        }

        return instances;
    }
}
TOP

Related Classes of quickml.Benchmarks

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.