Package com.github.neuralnetworks.training.backpropagation

Source Code of com.github.neuralnetworks.training.backpropagation.AparapiBackpropagationConv2D

package com.github.neuralnetworks.training.backpropagation;

import java.util.List;

import com.github.neuralnetworks.architecture.Connections;
import com.github.neuralnetworks.architecture.Conv2DConnection;
import com.github.neuralnetworks.architecture.Layer;
import com.github.neuralnetworks.calculation.memory.ValuesProvider;
import com.github.neuralnetworks.calculation.neuronfunctions.AparapiConv2D;
import com.github.neuralnetworks.tensor.Tensor;
import com.github.neuralnetworks.tensor.TensorFactory;
import com.github.neuralnetworks.util.Util;

/**
* BackPropagation base function for convolutional layers
*/
public class AparapiBackpropagationConv2D extends AparapiConv2D implements BackPropagationConnectionCalculator {

    private static final long serialVersionUID = -345286029645674230L;

    /**
     * Activation of the output layer from the feedforward phase
     */
    protected float[] ffActivation;
    protected final int activationStartIndex;
    protected final int activationFeatureMapRowsDistance;
    protected final int activationFeatureMapColumnsDistance;

    /**
     * weight updates and momentum
     */
    protected final Tensor weightUpdatesTensor;
    protected final float[] weightUpdates;
    protected final float[] weightUpdatesMomentum;

    /**
     * BP parameters
     */
    protected float learningRate;
    protected float momentum;
    protected float l1weightDecay;
    protected float l2weightDecay;

    /**
     * activations from the feedforward phase
     */
    protected ValuesProvider activations;

    public AparapiBackpropagationConv2D(Conv2DConnection c, ValuesProvider valuesProvider, ValuesProvider activations, Tensor weightUpdates, Layer targetLayer) {
  super(c, valuesProvider, targetLayer);

  if (c.getWeights().getSize() != weightUpdates.getSize()) {
      throw new IllegalArgumentException("weights and weightUpdates must have the same size");
  }

  Tensor t = TensorFactory.tensor(targetLayer, c, activations);
  this.ffActivation = t.getElements();
  this.activationStartIndex = t.getStartIndex();
  this.activationFeatureMapRowsDistance = t.getDimensionElementsDistance(1);
  this.activationFeatureMapColumnsDistance = t.getDimensionElementsDistance(2);

  this.weightUpdatesTensor = weightUpdates;
  this.weightUpdates = weightUpdates.getElements();
  this.weightUpdatesMomentum = new float[weightUpdates.getSize()];
    }

    @Override
    public void calculate(List<Connections> connections, ValuesProvider valuesProvider, Layer targetLayer) {
  Conv2DConnection c = null;

  for (Connections con : connections) {
      if (con instanceof Conv2DConnection) {
    c = (Conv2DConnection) con;
      }
  }

  if (c != null) {
      weightUpdatesTensor.forEach(i -> weightUpdates[i] = 0);

      // currently works only as a feedforward (including bp)
      if (targetLayer == c.getOutputLayer()) {
    super.calculate(c, valuesProvider, targetLayer);
      } else {
    super.calculate(c, valuesProvider, Util.getOppositeLayer(c, targetLayer));
      }

      updateWeights();
  }
    }

    @Override
    protected void conv(int weightsStartId, int inputStartId, int outputStartId) {
  float activationDerivative = 0;
  int activationStartId = activationStartIndex + ((getGlobalId() % outputFeatureMapLength) / outputColumns) * activationFeatureMapRowsDistance * stride + (getGlobalId() % outputColumns) * activationFeatureMapColumnsDistance * stride;

  for (int i = 0; i < miniBatchSize; i++) {
      activationDerivative = activationFunctionDerivative(output[outputStartId + i * outputMiniBatchDistance]);
      output[outputStartId + i * outputMiniBatchDistance] = activationDerivative;

      for (int j = 0; j < featureMapWeights; j++) {
    weightUpdates[weightsStartId + j] += activationDerivative * ffActivation[activationStartId + featureMapOffsets[i * featureMapWeights + j]];
    input[inputStartId + featureMapOffsets[i * featureMapWeights + j]] += activationDerivative * weights[weightsStartId + j];
      }
  }
    }

    /**
     * Weight updates after the backpropagation
     */
    protected void updateWeights() {
  float weightUpdate = 0;
  for (int i = weightsStartIndex, j = 0, size = weightUpdatesTensor.getSize(); j < size; j++, i++) {
      weightUpdate = learningRate * weightUpdates[i] + momentum * weightUpdatesMomentum[j] - l1weightDecay * Math.abs(weights[i]) - l2weightDecay * weights[i] * weights[i] / 2;
      weights[i] += weightUpdate;
      weightUpdatesMomentum[j] = weightUpdates[i];
      weightUpdates[i] = weightUpdate;
  }
    }

    /**
     * Derivative of the FF activation function
     *
     * @param value
     * @return
     */
    protected float activationFunctionDerivative(float value) {
  return value;
    }

    @Override
    public float getLearningRate() {
        return learningRate;
    }

    @Override
    public void setLearningRate(float learningRate) {
        this.learningRate = learningRate;
    }

    @Override
    public float getMomentum() {
        return momentum;
    }

    @Override
    public void setMomentum(float momentum) {
        this.momentum = momentum;
    }

    @Override
    public float getL1weightDecay() {
        return l1weightDecay;
    }

    @Override
    public void setL1weightDecay(float weightDecay) {
  this.l1weightDecay = weightDecay;
    }
   
    @Override
    public float getL2weightDecay() {
  return l2weightDecay;
    }
   
    @Override
    public void setL2weightDecay(float weightDecay) {
  this.l2weightDecay = weightDecay;
    }

    @Override
    public ValuesProvider getActivations() {
        return activations;
    }

    @Override
    public void setActivations(ValuesProvider activations) {
        this.activations = activations;
    }
}
TOP

Related Classes of com.github.neuralnetworks.training.backpropagation.AparapiBackpropagationConv2D

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.