Package org.encog.engine.opencl.kernels

Source Code of org.encog.engine.opencl.kernels.KernelNetworkCalc

/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/

package org.encog.engine.opencl.kernels;

import java.util.HashMap;
import java.util.Map;

import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.opencl.EncogCLDevice;
import org.encog.engine.opencl.EncogCLQueue;
import org.encog.engine.opencl.exceptions.OpenCLError;
import org.encog.engine.opencl.exceptions.OutOfOpenCLResources;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.ResourceLoader;
import org.jocl.CLException;
import org.jocl.cl_mem;

/**
* An OpenCL kernel that is designed to calculate the output of a neural
* network.
*/
public class KernelNetworkCalc extends EncogKernel {

  /**
   * The input count.
   */
  public static final int PARRAY_INPUT_COUNT = 0;

  /**
   * The output count.
   */
  public static final int PARRAY_OUTPUT_COUNT = 1;

  /**
   * The layer count.
   */
  public static final int PARRAY_LAYER_COUNT = 2;

  /**
   * Are we learning? 0=no, 1 =yes.
   */
  public static final int PARRAY_LEARN = 3;

  /**
   * What is the starting index to train at.
   */
  public static final int PARRAY_START = 4;

  /**
   * Items to train per call.
   */
  public static final int PARRAY_ITEMS_PER = 5;

  /**
   * Items to train per call.
   */
  public static final int PARRAY_ITERATIONS = 6;

  /**
   * A buffer to communicate weights to the kernel.
   */
  private cl_mem weightInArrayBuffer;

  /**
   * A buffer to hold the layer index.
   */
  private cl_mem layerIndexBuffer;

  /**
   * A buffer to hold the layer counts.
   */
  private cl_mem layerCountBuffer;

  /**
   * A buffer to hold the layer feed counts.
   */
  private cl_mem layerFeedCountBuffer;

  /**
   * A buffer to hold the weight indexes.
   */
  private cl_mem weightIndexBuffer;

  /**
   * The weight and bias array for the network.
   */
  private float[] weightInArray;

  /**
   * An array to hold the input to the neural network.
   */
  private float[] inputArray;

  /**
   * An array to hold the ideal values expected from the network.
   */
  private float[] idealArray;

  /**
   * The input buffer.
   */
  private cl_mem inputBuffer;

  /**
   * The layer output buffer.
   */
  private cl_mem layerOutputBuffer;

  /**
   * The ideal buffer.
   */
  private cl_mem idealBuffer;

  /**
   * The layer output.
   */
  private float[] layerOutput;

  /**
   * Holds parameters passed to the kernel.
   */
  private int[] paramArray;

  /**
   * A buffer to hold the parameters.
   */
  private cl_mem paramBuffer;

  /**
   * A buffer to hold the errors.
   */
  private cl_mem errorBuffer;

  /**
   * The network to train.
   */
  private FlatNetwork flat;

  /**
   * The training errors for this workload.
   */
  private float[] errors;

  /**
   * The training data to use.
   */
  private EngineIndexableSet training;

  /**
   * The device to train with.
   */
  private final EncogCLDevice device;

  /**
   * The length of the training data.
   */
  private int trainingLength;

  /**
   * Construct a kernel to train the network.
   *
   * @param device
   *            The OpenCL device to use.
   * @param flat
   *            The network to train.
   * @param training
   *            The training data.
   * @param tempDataSize
   *            How much temp data.
   */
  public KernelNetworkCalc(final EncogCLDevice device) {
    super(device, "org/encog/engine/resources/KernelNetCalc.txt",
        "NetworkCalc");

    this.device = device;
    this.paramArray = new int[10];
    this.paramBuffer = createArrayReadOnly(this.paramArray);

  }

  /**
   * Calculate one iteration over the specified range.
   *
   * @param start
   *            The starting position to calculate for.
   * @param size
   *            The ending position to calculate for.
   * @param iterations
   *            The number of iterations to execute.
   * @param learn
   *            True, if we should learn.
   */
  public void calculate(final int start, final int size) {
    prepareKernel();

    this.paramArray[KernelNetworkCalc.PARRAY_START] = start;
    this.paramArray[KernelNetworkCalc.PARRAY_ITEMS_PER] = size;
    this.setGlobalWork(size);
    this.setLocalWork(64);

    EngineArray.arrayCopy(this.flat.getWeights(), this.weightInArray);

    setArg(0, this.paramBuffer);
    setArg(1, this.errorBuffer);
    setArg(2, this.layerIndexBuffer);
    setArg(3, this.layerCountBuffer);
    setArg(4, this.layerFeedCountBuffer);
    setArg(5, this.weightIndexBuffer);
    setArg(6, this.inputBuffer);
    setArg(7, this.idealBuffer);
    setArg(8, this.weightInArrayBuffer);
    setArg(9, this.layerOutputBuffer);

    try {
      final EncogCLQueue queue = this.device.getQueue();

      this.paramArray[4] = start;

      queue.array2Buffer(this.weightInArray, this.weightInArrayBuffer);
      queue.array2Buffer(this.paramArray, this.paramBuffer);

      // Execute the kernel
      queue.execute(this);
      queue.waitFinish();

      // Read the results
      queue.buffer2Array(this.errorBuffer, this.errors);
      queue.buffer2Array(this.layerOutputBuffer, this.layerOutput );

    } catch (final CLException e) {
      if (e.getMessage().equals("CL_OUT_OF_RESOURCES")) {
        throw new OutOfOpenCLResources(e);
      } else {
        throw new OpenCLError(e);
      }
    } catch (final Exception e) {
      throw new OpenCLError(e);
    }
  }

  /**
   * Compile the kernel.
   *
   * @param options
   *            The options.
   * @param profile
   *            The OpenCL training profile.
   * @param network
   *            The network to compile for.
   */
  public void compile(final FlatNetwork network) {

    final ActivationFunction activation = network.getActivationFunctions()[0];
    final StringBuilder source = new StringBuilder();

    source.append("#define ACTIVATION(x,slope)");
    source.append(activation.getOpenCLExpression(false));
    source.append("\r\n");

    source.append(ResourceLoader.loadString(getSourceName()));
    setCLSource(source.toString());
   
    final Map<String, String> options = new HashMap<String, String>();
    options.put("NEURON_COUNT", "" + network.getNeuronCount());
    options.put("WEIGHT_COUNT", "" + network.getWeights().length);

    compile(options);
  }

  /**
   * @return the errors
   */
  public float[] getErrors() {
    return this.errors;
  }

  /**
   * Release the kernel and all buffers.
   */
  @Override
  public void release() {
    super.release();

    if (this.errorBuffer != null) {
      releaseBuffer(this.errorBuffer);
      this.errorBuffer = null;
    }

    if (this.idealBuffer != null) {
      releaseBuffer(this.idealBuffer);
      this.idealBuffer = null;
    }

    if (this.inputBuffer != null) {
      releaseBuffer(this.inputBuffer);
      this.inputBuffer = null;
    }

    if (this.layerCountBuffer != null) {
      releaseBuffer(this.layerCountBuffer);
      this.layerCountBuffer = null;
    }

    if (this.layerFeedCountBuffer != null) {
      releaseBuffer(this.layerFeedCountBuffer);
      this.layerFeedCountBuffer = null;
    }

    if (this.layerIndexBuffer != null) {
      releaseBuffer(this.layerIndexBuffer);
      this.layerIndexBuffer = null;
    }

    if (this.paramBuffer != null) {
      releaseBuffer(this.paramBuffer);
      this.paramBuffer = null;
    }

    if (this.weightInArrayBuffer != null) {
      releaseBuffer(this.weightInArrayBuffer);
      this.weightInArrayBuffer = null;
    }

    if (this.weightIndexBuffer != null) {
      releaseBuffer(this.weightIndexBuffer);
      this.weightIndexBuffer = null;
    }
  }

  public FlatNetwork getFlat() {
    return flat;
  }

  public void setFlat(FlatNetwork flat) {
    this.flat = flat;

    this.weightInArray = new float[flat.getWeights().length];

    final int inputSize = flat.getInputCount();
    final int idealSize = flat.getOutputCount();

    this.paramArray[0] = this.flat.getInputCount();
    this.paramArray[1] = this.flat.getOutputCount();
    this.paramArray[2] = this.flat.getLayerCounts().length;

    if (this.layerCountBuffer != null) {
      releaseBuffer(this.layerCountBuffer);
      this.layerCountBuffer = null;
    }

    if (this.layerFeedCountBuffer != null) {
      releaseBuffer(this.layerFeedCountBuffer);
      this.layerFeedCountBuffer = null;
    }

    if (this.layerIndexBuffer != null) {
      releaseBuffer(this.layerIndexBuffer);
      this.layerIndexBuffer = null;
    }

    if (this.weightInArrayBuffer != null) {
      releaseBuffer(this.weightInArrayBuffer);
      this.weightInArrayBuffer = null;
    }

    if (this.weightIndexBuffer != null) {
      releaseBuffer(this.weightIndexBuffer);
      this.weightIndexBuffer = null;
    }

    this.layerIndexBuffer = createArrayReadOnly(this.flat.getLayerIndex());
    this.layerCountBuffer = createArrayReadOnly(this.flat.getLayerCounts());
    this.layerFeedCountBuffer = createArrayReadOnly(this.flat
        .getLayerFeedCounts());
    this.weightInArrayBuffer = createArrayReadOnly(this.weightInArray);
    this.weightIndexBuffer = createArrayReadOnly(this.flat.getWeightIndex());
    allocateCommon();
    compile(flat);
  }

  private void allocateCommon() {
    if (this.training != null && this.flat != null) {

      if (this.layerOutputBuffer != null) {
        releaseBuffer(this.layerOutputBuffer);
        this.layerOutputBuffer = null;
      }

      this.layerOutput = new float[this.flat.getLayerOutput().length*this.trainingLength];
      this.layerOutputBuffer = this
          .createFloatArrayWriteOnly(this.layerOutput.length);
    }
  }

  public EngineIndexableSet getTraining() {
    return training;
  }

  public void setTraining(EngineIndexableSet training) {
    this.training = training;
    this.trainingLength = (int) this.training.getRecordCount();

    final EngineData pair = BasicEngineData.createPair(
        flat.getInputCount(), flat.getOutputCount());

    this.inputArray = new float[training.getInputSize() * this.trainingLength];
    this.idealArray = new float[training.getIdealSize() * this.trainingLength];
   
    int inputIndex = 0;
    int idealIndex = 0;

    for (int i = 0; i < this.trainingLength; i++) {
      training.getRecord(i, pair);
      for (int col = 0; col < flat.getInputCount(); col++) {
        this.inputArray[inputIndex++] = (float) pair.getInputArray()[col];
      }

      for (int col = 0; col < flat.getOutputCount(); col++) {
        this.idealArray[idealIndex++] = (float) pair.getIdealArray()[col];
      }
    }

    final int errorSize = (int) training.getRecordCount();
    this.errors = new float[errorSize];

    if (this.errorBuffer != null) {
      releaseBuffer(this.errorBuffer);
      this.errorBuffer = null;
    }

    if (this.idealBuffer != null) {
      releaseBuffer(this.idealBuffer);
      this.idealBuffer = null;
    }

    if (this.inputBuffer != null) {
      releaseBuffer(this.inputBuffer);
      this.inputBuffer = null;
    }

    this.errorBuffer = createFloatArrayWriteOnly(errorSize);
    this.inputBuffer = createArrayReadOnly(this.inputArray);
    this.idealBuffer = createArrayReadOnly(this.idealArray);

    allocateCommon();
  }

  /**
   * @return The error from the last evaluation.
   */
  public double getError() {
    ErrorCalculation ec = new ErrorCalculation();
    double result = 0;
    for (int i = 0; i < this.errors.length; i++) {
      result += this.errors[i];
    }
    return result/(this.errors.length*this.flat.getOutputCount());
  }

}
 
TOP

Related Classes of org.encog.engine.opencl.kernels.KernelNetworkCalc

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.