Package weka.classifiers.functions

Source Code of weka.classifiers.functions.LibLINEARDataTest

package weka.classifiers.functions;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import java.io.File;
import java.io.IOException;

import org.junit.Test;

import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ArffLoader;
import de.bwaldvogel.liblinear.SolverType;


public class LibLINEARDataTest {

    private final File TEST_RESOURCES = new File("src/test/resources");

    private int doCrossValidation(Instances insts, Classifier classifier, final int FOLDS) throws Exception {
        int errors = 0;

        for (int fold = 0; fold < FOLDS; fold++) {
            classifier.buildClassifier(insts.trainCV(FOLDS, fold));
            Instances testInstances = insts.testCV(FOLDS, fold);
            for (Instance testInstance : testInstances) {
                double result = classifier.classifyInstance(testInstance);
                double value = testInstance.classValue();
                if (value != result) {
                    errors++;
                }
            }
        }

        return errors;
    }

    @Test
    public void testWeather() throws Exception {
        Instances insts = loadInstancesFromARFF("weather.arff", "play");

        LibLINEAR liblinear = new LibLINEAR();
        for (SolverType solverType : new SolverType[] {SolverType.L2R_LR, SolverType.L1R_LR}) {
            liblinear.setSolverType(solverType);
            liblinear.setBias(-1);
            liblinear.setCost(1000); // explicitly high costs to reduce training error
            liblinear.setEps(1e-3);
            liblinear.setProbabilityEstimates(true);

            int errors = 0;
            liblinear.buildClassifier(insts);
            for (Instance instance : insts) {
                double[] d = liblinear.distributionForInstance(instance);
                assertEquals(2, d.length);
                int prediction = (int)liblinear.classifyInstance(instance);
                assertTrue(d[prediction] > d[1 - prediction]);
                if (prediction != (int)instance.classValue()) {
                    errors++;
                }
            }
            assertEquals(3, errors);
        }
    }

    @Test
    public void testIris() throws Exception {
        Instances insts = loadInstancesFromARFF("iris.arff", "class");

        LibLINEAR liblinear = new LibLINEAR();
        liblinear.setSolverType(SolverType.MCSVM_CS);
        liblinear.setBias(1);
        liblinear.setCost(1);
        liblinear.setEps(1e-2);
        liblinear.setProbabilityEstimates(false);

        int errors = doCrossValidation(insts, liblinear, 10);
        assertEquals(6, errors);
    }

    private Instances loadInstancesFromARFF(String filename, String className) throws IOException {
        File file = new File(TEST_RESOURCES, filename);
        assertTrue(file.exists());

        ArffLoader loader = new ArffLoader();
        loader.setFile(file);
        Instances insts = loader.getDataSet();
        Attribute classAttribute = insts.attribute(className);
        insts.setClass(classAttribute);
        return insts;
    }


    // tests for bug that was reported by Ondrej Dusek
    @Test
    public void testIrisNoVirginica() throws Exception {
        Instances insts = loadInstancesFromARFF("iris-novirginica.arff", "class");

        LibLINEAR liblinear = new LibLINEAR();
        liblinear.setSolverType(SolverType.L2R_LR);
        liblinear.setProbabilityEstimates(true);
        liblinear.buildClassifier(insts);

        for (Instance instance : insts) {
            liblinear.classifyInstance(instance);
            liblinear.distributionForInstance(instance);
        }

    }
}
TOP

Related Classes of weka.classifiers.functions.LibLINEARDataTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.