Package com.github.neuralnetworks.samples.cifar

Source Code of com.github.neuralnetworks.samples.cifar.CIFARInputProvider$CIFAR100TestingInputProvider

package com.github.neuralnetworks.samples.cifar;

import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.ArrayList;
import java.util.List;

import com.github.neuralnetworks.input.ImageInputProvider;
import com.github.neuralnetworks.util.Util;

/**
* Input provider for the CIFAR-10 and CIFAR-100 datasets. Requires location of the CIFAR images
* files (not included in the library). Do not use this class directly, but use the subclasses instead
* Experimental
*/
public abstract class CIFARInputProvider extends ImageInputProvider {

    private static final long serialVersionUID = 1L;

    protected RandomAccessFile files[];
    protected int labelSize;
    protected int inputSize;
    protected byte[] nextInputRaw;
    protected float[] nextTarget;
    private List<Integer> elementsOrder;

    private CIFARInputProvider() {
  super();
  this.elementsOrder = new ArrayList<>();
  this.nextInputRaw = new byte[3072];
    }

    @Override
    public int getInputSize() {
  return inputSize;
    }

    @Override
    public float[] getNextTarget() {
  return nextTarget;
    }

    @Override
    public float[] getNextInput() {
  // if no transformations are required and the data is grouped by color
  // channel the code can be optimized
  if (!requireAugmentation() && getProperties().getGroupByChannel()) {
      if (nextInput == null) {
    nextInput = new float[3072];
      }

      float scaleColors = getProperties().getScaleColors() ? 255 : 1;
      for (int i = 0; i < nextInput.length; i++) {
    nextInput[i] = (nextInputRaw[i] & 0xFF) / scaleColors;
      }

      return nextInput;
  }

  return super.getNextInput();
    }

    @Override
    protected BufferedImage getNextImage() {
  BufferedImage image = new BufferedImage(32, 32, BufferedImage.TYPE_3BYTE_BGR);
  byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();

  for (int i = 0; i < 1024; i++) {
      pixels[i * 3] = nextInputRaw[1024 * 2 + i];
      pixels[i * 3 + 1] = nextInputRaw[1024 + i];
      pixels[i * 3 + 2] = nextInputRaw[i];
  }

  return image;
    }

    @Override
    public void beforeSample() {
  if (elementsOrder.size() == 0) {
      resetOrder();
  }

  int currentEl = elementsOrder.remove(getProperties().getUseRandomOrder() ? getProperties().getRandom().nextInt(elementsOrder.size()) : 0);
  int id = currentEl % (getInputSize() / files.length);

  RandomAccessFile f = files[currentEl / (getInputSize() / files.length)];

  try {
      f.seek(id * (3072 + labelSize));
      if (labelSize > 1) {
    f.readUnsignedByte();
      }

      Util.fillArray(nextTarget, 0);
      nextTarget[f.readUnsignedByte()] = 1;

      f.readFully(nextInputRaw);
  } catch (IOException e) {
      e.printStackTrace();
  }
    }

    @Override
    public void reset() {
  super.reset();
  resetOrder();
    }

    public void resetOrder() {
  elementsOrder = new ArrayList<Integer>(getInputSize());
  for (int i = 0; i < getInputSize(); i++) {
      elementsOrder.add(i);
  }
    }

    public static class CIFAR10TrainingInputProvider extends CIFARInputProvider {

  private static final long serialVersionUID = 1L;

  /**
   * @param directory - the folder where the CIFAR files are located
   */
  public CIFAR10TrainingInputProvider(String directory) {
      super();

      this.labelSize = 1;
      this.inputSize = 50000;
      this.nextTarget = new float[10];
      this.files = new RandomAccessFile[5];

      try {
    if (!directory.endsWith(File.separator)) {
        directory += File.separator;
    }

    files[0] = new RandomAccessFile(directory + "data_batch_1.bin", "r");
    files[1] = new RandomAccessFile(directory + "data_batch_2.bin", "r");
    files[2] = new RandomAccessFile(directory + "data_batch_3.bin", "r");
    files[3] = new RandomAccessFile(directory + "data_batch_4.bin", "r");
    files[4] = new RandomAccessFile(directory + "data_batch_5.bin", "r");
      } catch (FileNotFoundException e) {
    e.printStackTrace();
      }
  }
    }

    public static class CIFAR10TestingInputProvider extends CIFARInputProvider {

  private static final long serialVersionUID = 1L;

  /**
   * @param directory - the folder where the CIFAR files are located
   */
  public CIFAR10TestingInputProvider(String directory) {
      super();

      this.labelSize = 1;
      this.inputSize = 10000;
      this.nextTarget = new float[10];
      this.files = new RandomAccessFile[1];

      try {
    if (!directory.endsWith(File.separator)) {
        directory += File.separator;
    }

    files[0] = new RandomAccessFile(directory + "test_batch.bin", "r");
      } catch (FileNotFoundException e) {
    e.printStackTrace();
      }
  }
    }

    public static class CIFAR100TrainingInputProvider extends CIFARInputProvider {

  private static final long serialVersionUID = 1L;

  /**
   * @param directory - the folder where the CIFAR files are located
   */
  public CIFAR100TrainingInputProvider(String directory) {
      super();

      this.labelSize = 2;
      this.inputSize = 50000;
      this.nextTarget = new float[100];
      this.files = new RandomAccessFile[5];

      try {
    if (!directory.endsWith(File.separator)) {
        directory += File.separator;
    }

    files[0] = new RandomAccessFile(directory + "data_batch_1.bin", "r");
    files[1] = new RandomAccessFile(directory + "data_batch_2.bin", "r");
    files[2] = new RandomAccessFile(directory + "data_batch_3.bin", "r");
    files[3] = new RandomAccessFile(directory + "data_batch_4.bin", "r");
    files[4] = new RandomAccessFile(directory + "data_batch_5.bin", "r");
      } catch (FileNotFoundException e) {
    e.printStackTrace();
      }
  }
    }

    public static class CIFAR100TestingInputProvider extends CIFARInputProvider {
 
  private static final long serialVersionUID = 1L;

  /**
   * @param directory - the folder where the CIFAR files are located
   */
  public CIFAR100TestingInputProvider(String directory) {
      super();

      this.labelSize = 2;
      this.inputSize = 10000;
      this.nextTarget = new float[100];
      this.files = new RandomAccessFile[1];

      try {
    if (!directory.endsWith(File.separator)) {
        directory += File.separator;
    }

    files[0] = new RandomAccessFile(directory + "test_batch.bin", "r");
      } catch (FileNotFoundException e) {
    e.printStackTrace();
      }
  }
    }
}
TOP

Related Classes of com.github.neuralnetworks.samples.cifar.CIFARInputProvider$CIFAR100TestingInputProvider

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.