Package com.github.neuralnetworks.samples.test

Source Code of com.github.neuralnetworks.samples.test.CifarTest

package com.github.neuralnetworks.samples.test;

import java.util.Random;

import org.junit.Test;

import com.amd.aparapi.Kernel.EXECUTION_MODE;
import com.github.neuralnetworks.architecture.NeuralNetworkImpl;
import com.github.neuralnetworks.architecture.types.NNFactory;
import com.github.neuralnetworks.input.MultipleNeuronsOutputError;
import com.github.neuralnetworks.input.ScalingInputFunction;
import com.github.neuralnetworks.samples.cifar.CIFARInputProvider.CIFAR10TestingInputProvider;
import com.github.neuralnetworks.samples.cifar.CIFARInputProvider.CIFAR10TrainingInputProvider;
import com.github.neuralnetworks.training.TrainerFactory;
import com.github.neuralnetworks.training.backpropagation.BackPropagationTrainer;
import com.github.neuralnetworks.training.events.LogTrainingListener;
import com.github.neuralnetworks.training.random.NNRandomInitializer;
import com.github.neuralnetworks.training.random.RandomInitializerImpl;
import com.github.neuralnetworks.util.Environment;

/**
* CIFAR test
*/
public class CifarTest {

    /**
     * NOT TESTED
     */
    @Test
    public void testSigmoidBP() {
  Environment.getInstance().setUseDataSharedMemory(false);
  Environment.getInstance().setUseWeightsSharedMemory(false);
  NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[] { 3072, 10 }, true);

  CIFAR10TrainingInputProvider trainInputProvider = new CIFAR10TrainingInputProvider("cifar-10-batches-bin"); // specify your own path
  trainInputProvider.getProperties().setGroupByChannel(true);
  trainInputProvider.getProperties().setScaleColors(true);
  trainInputProvider.addInputModifier(new ScalingInputFunction(255));

  // specify your own path
  CIFAR10TestingInputProvider testInputProvider = new CIFAR10TestingInputProvider("cifar-10-batches-bin"); // specify your own path
  testInputProvider.getProperties().setGroupByChannel(true);
  testInputProvider.getProperties().setScaleColors(true);
  testInputProvider.addInputModifier(new ScalingInputFunction(255));

  BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(mlp, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new RandomInitializerImpl(new Random(), -0.01f, 0.01f)), 0.02f, 0.5f, 0f, 0f, 0f, 1, 1000, 1);

  bpt.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), false, true));

  Environment.getInstance().setExecutionMode(EXECUTION_MODE.CPU);

  bpt.train();
  bpt.test();
    }
}
TOP

Related Classes of com.github.neuralnetworks.samples.test.CifarTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.