/*
* Encog(tm) Core v2.5 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2010 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.engine.opencl.kernels;
import java.util.HashMap;
import java.util.Map;
import org.encog.engine.data.BasicEngineData;
import org.encog.engine.data.EngineData;
import org.encog.engine.data.EngineIndexableSet;
import org.encog.engine.network.activation.ActivationFunction;
import org.encog.engine.network.flat.FlatNetwork;
import org.encog.engine.opencl.EncogCLDevice;
import org.encog.engine.opencl.EncogCLQueue;
import org.encog.engine.opencl.exceptions.OpenCLError;
import org.encog.engine.opencl.exceptions.OutOfOpenCLResources;
import org.encog.engine.util.EngineArray;
import org.encog.engine.util.ErrorCalculation;
import org.encog.engine.util.ResourceLoader;
import org.jocl.CLException;
import org.jocl.cl_mem;
/**
* An OpenCL kernel that is designed to calculate the output of a neural
* network.
*/
public class KernelNetworkCalc extends EncogKernel {
/**
* The input count.
*/
public static final int PARRAY_INPUT_COUNT = 0;
/**
* The output count.
*/
public static final int PARRAY_OUTPUT_COUNT = 1;
/**
* The layer count.
*/
public static final int PARRAY_LAYER_COUNT = 2;
/**
* Are we learning? 0=no, 1 =yes.
*/
public static final int PARRAY_LEARN = 3;
/**
* What is the starting index to train at.
*/
public static final int PARRAY_START = 4;
/**
* Items to train per call.
*/
public static final int PARRAY_ITEMS_PER = 5;
/**
* Items to train per call.
*/
public static final int PARRAY_ITERATIONS = 6;
/**
* A buffer to communicate weights to the kernel.
*/
private cl_mem weightInArrayBuffer;
/**
* A buffer to hold the layer index.
*/
private cl_mem layerIndexBuffer;
/**
* A buffer to hold the layer counts.
*/
private cl_mem layerCountBuffer;
/**
* A buffer to hold the layer feed counts.
*/
private cl_mem layerFeedCountBuffer;
/**
* A buffer to hold the weight indexes.
*/
private cl_mem weightIndexBuffer;
/**
* The weight and bias array for the network.
*/
private float[] weightInArray;
/**
* An array to hold the input to the neural network.
*/
private float[] inputArray;
/**
* An array to hold the ideal values expected from the network.
*/
private float[] idealArray;
/**
* The input buffer.
*/
private cl_mem inputBuffer;
/**
* The layer output buffer.
*/
private cl_mem layerOutputBuffer;
/**
* The ideal buffer.
*/
private cl_mem idealBuffer;
/**
* The layer output.
*/
private float[] layerOutput;
/**
* Holds parameters passed to the kernel.
*/
private int[] paramArray;
/**
* A buffer to hold the parameters.
*/
private cl_mem paramBuffer;
/**
* A buffer to hold the errors.
*/
private cl_mem errorBuffer;
/**
* The network to train.
*/
private FlatNetwork flat;
/**
* The training errors for this workload.
*/
private float[] errors;
/**
* The training data to use.
*/
private EngineIndexableSet training;
/**
* The device to train with.
*/
private final EncogCLDevice device;
/**
* The length of the training data.
*/
private int trainingLength;
/**
* Construct a kernel to train the network.
*
* @param device
* The OpenCL device to use.
* @param flat
* The network to train.
* @param training
* The training data.
* @param tempDataSize
* How much temp data.
*/
public KernelNetworkCalc(final EncogCLDevice device) {
super(device, "org/encog/engine/resources/KernelNetCalc.txt",
"NetworkCalc");
this.device = device;
this.paramArray = new int[10];
this.paramBuffer = createArrayReadOnly(this.paramArray);
}
/**
* Calculate one iteration over the specified range.
*
* @param start
* The starting position to calculate for.
* @param size
* The ending position to calculate for.
* @param iterations
* The number of iterations to execute.
* @param learn
* True, if we should learn.
*/
public void calculate(final int start, final int size) {
prepareKernel();
this.paramArray[KernelNetworkCalc.PARRAY_START] = start;
this.paramArray[KernelNetworkCalc.PARRAY_ITEMS_PER] = size;
this.setGlobalWork(size);
this.setLocalWork(64);
EngineArray.arrayCopy(this.flat.getWeights(), this.weightInArray);
setArg(0, this.paramBuffer);
setArg(1, this.errorBuffer);
setArg(2, this.layerIndexBuffer);
setArg(3, this.layerCountBuffer);
setArg(4, this.layerFeedCountBuffer);
setArg(5, this.weightIndexBuffer);
setArg(6, this.inputBuffer);
setArg(7, this.idealBuffer);
setArg(8, this.weightInArrayBuffer);
setArg(9, this.layerOutputBuffer);
try {
final EncogCLQueue queue = this.device.getQueue();
this.paramArray[4] = start;
queue.array2Buffer(this.weightInArray, this.weightInArrayBuffer);
queue.array2Buffer(this.paramArray, this.paramBuffer);
// Execute the kernel
queue.execute(this);
queue.waitFinish();
// Read the results
queue.buffer2Array(this.errorBuffer, this.errors);
queue.buffer2Array(this.layerOutputBuffer, this.layerOutput );
} catch (final CLException e) {
if (e.getMessage().equals("CL_OUT_OF_RESOURCES")) {
throw new OutOfOpenCLResources(e);
} else {
throw new OpenCLError(e);
}
} catch (final Exception e) {
throw new OpenCLError(e);
}
}
/**
* Compile the kernel.
*
* @param options
* The options.
* @param profile
* The OpenCL training profile.
* @param network
* The network to compile for.
*/
public void compile(final FlatNetwork network) {
final ActivationFunction activation = network.getActivationFunctions()[0];
final StringBuilder source = new StringBuilder();
source.append("#define ACTIVATION(x,slope)");
source.append(activation.getOpenCLExpression(false));
source.append("\r\n");
source.append(ResourceLoader.loadString(getSourceName()));
setCLSource(source.toString());
final Map<String, String> options = new HashMap<String, String>();
options.put("NEURON_COUNT", "" + network.getNeuronCount());
options.put("WEIGHT_COUNT", "" + network.getWeights().length);
compile(options);
}
/**
* @return the errors
*/
public float[] getErrors() {
return this.errors;
}
/**
* Release the kernel and all buffers.
*/
@Override
public void release() {
super.release();
if (this.errorBuffer != null) {
releaseBuffer(this.errorBuffer);
this.errorBuffer = null;
}
if (this.idealBuffer != null) {
releaseBuffer(this.idealBuffer);
this.idealBuffer = null;
}
if (this.inputBuffer != null) {
releaseBuffer(this.inputBuffer);
this.inputBuffer = null;
}
if (this.layerCountBuffer != null) {
releaseBuffer(this.layerCountBuffer);
this.layerCountBuffer = null;
}
if (this.layerFeedCountBuffer != null) {
releaseBuffer(this.layerFeedCountBuffer);
this.layerFeedCountBuffer = null;
}
if (this.layerIndexBuffer != null) {
releaseBuffer(this.layerIndexBuffer);
this.layerIndexBuffer = null;
}
if (this.paramBuffer != null) {
releaseBuffer(this.paramBuffer);
this.paramBuffer = null;
}
if (this.weightInArrayBuffer != null) {
releaseBuffer(this.weightInArrayBuffer);
this.weightInArrayBuffer = null;
}
if (this.weightIndexBuffer != null) {
releaseBuffer(this.weightIndexBuffer);
this.weightIndexBuffer = null;
}
}
public FlatNetwork getFlat() {
return flat;
}
public void setFlat(FlatNetwork flat) {
this.flat = flat;
this.weightInArray = new float[flat.getWeights().length];
final int inputSize = flat.getInputCount();
final int idealSize = flat.getOutputCount();
this.paramArray[0] = this.flat.getInputCount();
this.paramArray[1] = this.flat.getOutputCount();
this.paramArray[2] = this.flat.getLayerCounts().length;
if (this.layerCountBuffer != null) {
releaseBuffer(this.layerCountBuffer);
this.layerCountBuffer = null;
}
if (this.layerFeedCountBuffer != null) {
releaseBuffer(this.layerFeedCountBuffer);
this.layerFeedCountBuffer = null;
}
if (this.layerIndexBuffer != null) {
releaseBuffer(this.layerIndexBuffer);
this.layerIndexBuffer = null;
}
if (this.weightInArrayBuffer != null) {
releaseBuffer(this.weightInArrayBuffer);
this.weightInArrayBuffer = null;
}
if (this.weightIndexBuffer != null) {
releaseBuffer(this.weightIndexBuffer);
this.weightIndexBuffer = null;
}
this.layerIndexBuffer = createArrayReadOnly(this.flat.getLayerIndex());
this.layerCountBuffer = createArrayReadOnly(this.flat.getLayerCounts());
this.layerFeedCountBuffer = createArrayReadOnly(this.flat
.getLayerFeedCounts());
this.weightInArrayBuffer = createArrayReadOnly(this.weightInArray);
this.weightIndexBuffer = createArrayReadOnly(this.flat.getWeightIndex());
allocateCommon();
compile(flat);
}
private void allocateCommon() {
if (this.training != null && this.flat != null) {
if (this.layerOutputBuffer != null) {
releaseBuffer(this.layerOutputBuffer);
this.layerOutputBuffer = null;
}
this.layerOutput = new float[this.flat.getLayerOutput().length*this.trainingLength];
this.layerOutputBuffer = this
.createFloatArrayWriteOnly(this.layerOutput.length);
}
}
public EngineIndexableSet getTraining() {
return training;
}
public void setTraining(EngineIndexableSet training) {
this.training = training;
this.trainingLength = (int) this.training.getRecordCount();
final EngineData pair = BasicEngineData.createPair(
flat.getInputCount(), flat.getOutputCount());
this.inputArray = new float[training.getInputSize() * this.trainingLength];
this.idealArray = new float[training.getIdealSize() * this.trainingLength];
int inputIndex = 0;
int idealIndex = 0;
for (int i = 0; i < this.trainingLength; i++) {
training.getRecord(i, pair);
for (int col = 0; col < flat.getInputCount(); col++) {
this.inputArray[inputIndex++] = (float) pair.getInputArray()[col];
}
for (int col = 0; col < flat.getOutputCount(); col++) {
this.idealArray[idealIndex++] = (float) pair.getIdealArray()[col];
}
}
final int errorSize = (int) training.getRecordCount();
this.errors = new float[errorSize];
if (this.errorBuffer != null) {
releaseBuffer(this.errorBuffer);
this.errorBuffer = null;
}
if (this.idealBuffer != null) {
releaseBuffer(this.idealBuffer);
this.idealBuffer = null;
}
if (this.inputBuffer != null) {
releaseBuffer(this.inputBuffer);
this.inputBuffer = null;
}
this.errorBuffer = createFloatArrayWriteOnly(errorSize);
this.inputBuffer = createArrayReadOnly(this.inputArray);
this.idealBuffer = createArrayReadOnly(this.idealArray);
allocateCommon();
}
/**
* @return The error from the last evaluation.
*/
public double getError() {
ErrorCalculation ec = new ErrorCalculation();
double result = 0;
for (int i = 0; i < this.errors.length; i++) {
result += this.errors[i];
}
return result/(this.errors.length*this.flat.getOutputCount());
}
}