Package etc.aloe.data

Source Code of etc.aloe.data.CrossValidationSplitTest

/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.

* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.

* You should have received a copy of the GNU General Public License
* along with ALOE.  If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.data;

import etc.aloe.TestLabelable;
import etc.aloe.processes.CrossValidationSplit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import static org.junit.Assert.*;

/**
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class CrossValidationSplitTest {

    private List<TestLabelable> originalItems;

    public CrossValidationSplitTest() {
    }

    @BeforeClass
    public static void setUpClass() {
    }

    @AfterClass
    public static void tearDownClass() {
    }

    @Before
    public void setUp() {
        TestLabelable[] itemsArray = new TestLabelable[]{
            new TestLabelable("one", true, true),
            new TestLabelable("two", false, true),
            new TestLabelable("three", false, true),
            new TestLabelable("four", true, true),
            new TestLabelable("five", false, true),
            new TestLabelable("six", true, true),
            new TestLabelable("seven", true, true),
            new TestLabelable("eight", true, true),
            new TestLabelable("nine", false, true),
            new TestLabelable("ten", false, true),
            new TestLabelable("eleven", false, true),
            new TestLabelable("twelve", true, true)
        };

        this.originalItems = Arrays.asList(itemsArray);
    }

    @After
    public void tearDown() {
    }

    /**
     * Test of getTrainingForFold method, of class CrossValidationSplit.
     * Tests the middle fold.
     */
    @Test
    public void testGetTrainingForMiddleFold() {
        System.out.println("getTrainingForFold (middle)");

        int numFolds = 3;
        int foldIndex = 1;

        List<TestLabelable> expectedTraining = new ArrayList<TestLabelable>();
        expectedTraining.addAll(originalItems.subList(0, 4));
        expectedTraining.addAll(originalItems.subList(8, 12));

        CrossValidationSplit instance = new CrossValidationSplit();
        List<TestLabelable> training = instance.getTrainingForFold(originalItems, foldIndex, numFolds);

        assertArrayEquals("Correct training set for middle fold", expectedTraining.toArray(), training.toArray());
    }

    /**
     * Test of getTestingForFold method, of class CrossValidationSplit.
     * Tests the middle fold.
     */
    @Test
    public void testGetTestingForMiddleFold() {
        System.out.println("getTestingForFold (middle)");
        int numFolds = 3;
        int foldIndex = 1;

        List<TestLabelable> expectedTesting = new ArrayList<TestLabelable>();
        expectedTesting.addAll(originalItems.subList(4, 8));

        CrossValidationSplit instance = new CrossValidationSplit();

        List<TestLabelable> testing = instance.getTestingForFold(originalItems, foldIndex, numFolds);

        assertArrayEquals("Correct test set for middle fold", expectedTesting.toArray(), testing.toArray());
    }

    /**
     * Test of getTrainingForFold method, of class CrossValidationSplit.
     * Tests the first fold.
     */
    @Test
    public void testGetTrainingForFirstFold() {
        System.out.println("getTrainingForFold (first)");

        int numFolds = 3;
        int foldIndex = 0;

        List<TestLabelable> expectedTraining = new ArrayList<TestLabelable>();
        expectedTraining.addAll(originalItems.subList(4, 12));

        CrossValidationSplit instance = new CrossValidationSplit();
        List<TestLabelable> training = instance.getTrainingForFold(originalItems, foldIndex, numFolds);

        assertArrayEquals("Correct training set for first fold", expectedTraining.toArray(), training.toArray());
    }

    /**
     * Test of getTestingForFold method, of class CrossValidationSplit.
     * Tests the first fold.
     */
    @Test
    public void testGetTestingForFirstFold() {
        System.out.println("getTestingForFold (first)");

        int numFolds = 3;
        int foldIndex = 0;

        List<TestLabelable> expectedTesting = new ArrayList<TestLabelable>();
        expectedTesting.addAll(originalItems.subList(0, 4));

        CrossValidationSplit instance = new CrossValidationSplit();

        List<TestLabelable> testing = instance.getTestingForFold(originalItems, foldIndex, numFolds);

        assertArrayEquals("Correct test set for first fold", expectedTesting.toArray(), testing.toArray());
    }

    /**
     * Test of getTrainingForFold method, of class CrossValidationSplit.
     * Tests the last fold.
     */
    @Test
    public void testGetTrainingForLastFold() {
        System.out.println("getTrainingForFold (last)");

        int numFolds = 3;
        int foldIndex = 2;

        List<TestLabelable> expectedTraining = new ArrayList<TestLabelable>();
        expectedTraining.addAll(originalItems.subList(0, 8));

        CrossValidationSplit instance = new CrossValidationSplit();
        List<TestLabelable> training = instance.getTrainingForFold(originalItems, foldIndex, numFolds);

        assertArrayEquals("Correct training set for last fold", expectedTraining.toArray(), training.toArray());
    }

    /**
     * Test of getTestingForFold method, of class CrossValidationSplit.
     * Tests the middle fold.
     */
    @Test
    public void testGetTestingForLastFold() {
        System.out.println("getTestingForFold (last)");

        int numFolds = 3;
        int foldIndex = 2;

        List<TestLabelable> expectedTesting = new ArrayList<TestLabelable>();
        expectedTesting.addAll(originalItems.subList(8, 12));

        CrossValidationSplit instance = new CrossValidationSplit();

        List<TestLabelable> testing = instance.getTestingForFold(originalItems, foldIndex, numFolds);

        assertArrayEquals("Correct test set for last fold", expectedTesting.toArray(), testing.toArray());
    }
}
TOP

Related Classes of etc.aloe.data.CrossValidationSplitTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.