Package tv.floe.metronome.deeplearning.datasets

Source Code of tv.floe.metronome.deeplearning.datasets.DeepLearningTest

package tv.floe.metronome.deeplearning.datasets;

import static org.junit.Assert.*;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import tv.floe.metronome.deeplearning.datasets.fetchers.MnistFetcher;
import tv.floe.metronome.math.ArrayUtils;
import tv.floe.metronome.math.MathUtils;
import tv.floe.metronome.math.MatrixUtils;
import tv.floe.metronome.types.Pair;


public abstract class DeepLearningTest {

  private static Logger log = LoggerFactory.getLogger(DeepLearningTest.class);
/*
  public static Pair<Matrix,Matrix> getIris() throws IOException {
    Pair<Matrix,Matrix> pair = IrisUtils.loadIris();
    return pair;
  }
  public static Pair<Matrix,Matrix> getIris(int num) throws IOException {
    Pair<Matrix,Matrix> pair = IrisUtils.loadIris(num);
    return pair;
  }
  */
 
 
 
  /**
   * LFW Dataset: pick first num faces
   * @param num
   * @return
   * @throws Exception
   */
/*  public static Pair<Matrix,Matrix> getFaces(int num) throws Exception {
    LFWLoader loader = new LFWLoader();
    loader.getIfNotExists();
    return loader.getAllImagesAsMatrix(num);
  }
  */
 
  /**
   * LFW Dataset: pick all faces
   * @param num
   * @return
   * @throws Exception
   */
/*  public static Pair<Matrix,Matrix> getFacesMatrix() throws Exception {
    LFWLoader loader = new LFWLoader();
    loader.getIfNotExists();
    return loader.getAllImagesAsMatrix();
  }
  */
 
 
  /**
   * LFW Dataset: pick first num faces
   * @param num
   * @return
   * @throws Exception
   */
/*  public static List<Pair<Matrix,Matrix>> getFirstFaces(int num) throws Exception {
    LFWLoader loader = new LFWLoader();
    loader.getIfNotExists();
    return loader.getFirst(num);
  }
  */
 
  /**
   * LFW Dataset: pick all faces
   * @param num
   * @return
   * @throws Exception
   */
/*  public List<Pair<Matrix,Matrix>> getFaces() throws Exception {
    LFWLoader loader = new LFWLoader();
    loader.getIfNotExists();
    return loader.getImagesAsList();
  }
  */
 
 
  /**
   * Gets an mnist example as an input, label pair.
   * Keep in mind the return matrix for out come is a 1x1 matrix.
   * If you need multiple labels, remember to convert to a zeros
   * with 1 as index of the label for the output training vector.
   * @param example the example to get
   * @return the image,label pair
   * @throws IOException
   */
  public static Pair<Matrix,Matrix> getMnistExample(int example) throws IOException {
    File ensureExists = new File("/tmp/MNIST");
    if(!ensureExists.exists())
      new MnistFetcher().downloadAndUntar();

    MnistManager man = new MnistManager("/tmp/MNIST/" + MnistFetcher.trainingFilesFilename_unzipped,"/tmp/MNIST/" + MnistFetcher.trainingFileLabelsFilename_unzipped);
    man.setCurrent(example);
    int[] imageExample = ArrayUtils.flatten(man.readImage());
    //return new Pair<Matrix,Matrix>(MatrixUtils.toMatrix(imageExample).transpose(),MatrixUtils.toOutcomeVector(man.readLabel(),10));
    return new Pair<Matrix,Matrix>(MatrixUtils.toMatrix(imageExample),MatrixUtils.toOutcomeVector(man.readLabel(),10));
  }



  /**
   * Gets an mnist example as an input, label pair.
   * Keep in mind the return matrix for out come is a 1x1 matrix.
   * If you need multiple labels, remember to convert to a zeros
   * with 1 as index of the label for the output training vector.
   * @param example the example to get
   * @param batchSize the batch size of examples to get
   * @return the image,label pair
   * @throws IOException
   */
  public List<Pair<Matrix,Matrix>> getMnistExampleBatches(int batchSize,int numBatches) throws IOException {
    File ensureExists = new File("/tmp/MNIST");
    List<Pair<Matrix,Matrix>> ret = new ArrayList<Pair<Matrix, Matrix>>();
    if(!ensureExists.exists())
      new MnistFetcher().downloadAndUntar();
    MnistManager man = new MnistManager("/tmp/MNIST/" + MnistFetcher.trainingFilesFilename_unzipped,"/tmp/MNIST/" + MnistFetcher.trainingFileLabelsFilename_unzipped);

    int[][] image = man.readImage();
    int[] imageExample = ArrayUtils.flatten(image);

    for(int batch = 0; batch < numBatches; batch++) {
      double[][] examples = new double[batchSize][imageExample.length];
      int[][] outcomeMatrix = new int[batchSize][10];

      for(int i = 1 + batch; i < batchSize + 1 + batch; i++) {
        //1 based indices
        man.setCurrent(i);
        double[] currExample = ArrayUtils.flatten(ArrayUtils.toDouble(man.readImage()));
        examples[i - 1 - batch] = currExample;
        outcomeMatrix[i - 1 - batch] = ArrayUtils.toOutcomeArray(man.readLabel(), 10);
      }
      //Matrix training = new Matrix(examples);
      Matrix training = new DenseMatrix( batchSize, imageExample.length );
      training.assign(examples);
     
      ret.add(new Pair<Matrix, Matrix>(training,MatrixUtils.toMatrix(outcomeMatrix)));
    }

    return ret;
  }

  /**
   * Gets an mnist example as an input, label pair.
   * Keep in mind the return matrix for out come is a 1x1 matrix.
   * If you need multiple labels, remember to convert to a zeros
   * with 1 as index of the label for the output training vector.
   * @param example the example to get
   * @param batchSize the batch size of examples to get
   * @return the image,label pair
   * @throws IOException
   */
  public static Pair<Matrix,Matrix> getMnistExampleBatch(int batchSize) throws IOException {
    File ensureExists = new File("/tmp/MNIST");
    if(!ensureExists.exists() || !new File("/tmp/MNIST/" + MnistFetcher.trainingFilesFilename_unzipped).exists() || !new File("/tmp/MNIST/" + MnistFetcher.trainingFileLabelsFilename_unzipped).exists())
      new MnistFetcher().downloadAndUntar();
    MnistManager man = new MnistManager("/tmp/MNIST/" + MnistFetcher.trainingFilesFilename_unzipped,"/tmp/MNIST/" + MnistFetcher.trainingFileLabelsFilename_unzipped);

    int[][] image = man.readImage();
    int[] imageExample = ArrayUtils.flatten(image);
    double[][] examples = new double[batchSize][imageExample.length];
    int[][] outcomeMatrix = new int[batchSize][10];
    for(int i = 1; i < batchSize + 1; i++) {
      //1 based indices
      man.setCurrent(i);
      double[] currExample = ArrayUtils.flatten(ArrayUtils.toDouble(man.readImage()));
      for(int j = 0; j < currExample.length; j++)
        currExample[j] = MathUtils.normalize(currExample[j], 0, 255);
      examples[i - 1] = currExample;
      outcomeMatrix[i - 1] = ArrayUtils.toOutcomeArray(man.readLabel(), 10);
    }
    Matrix training = new DenseMatrix( batchSize, imageExample.length ); //Matrix(examples);
    training.assign(examples);
   
    return new Pair<Matrix,Matrix>(training,MatrixUtils.toMatrix(outcomeMatrix));

  }


}

TOP

Related Classes of tv.floe.metronome.deeplearning.datasets.DeepLearningTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.