Package com.github.neuralnetworks.training.rbm

Source Code of com.github.neuralnetworks.training.rbm.AparapiCDTrainer

package com.github.neuralnetworks.training.rbm;

import com.github.neuralnetworks.architecture.types.RBM;
import com.github.neuralnetworks.calculation.RBMLayerCalculator;
import com.github.neuralnetworks.util.Constants;
import com.github.neuralnetworks.util.Environment;
import com.github.neuralnetworks.util.Properties;

/**
* Base class for Aparapi Contrastive Divergence
* Supports learning rate, momentum and weight decay
*/
public class AparapiCDTrainer extends CDTrainerBase {

    private static final long serialVersionUID = 1L;

    /**
     * weights update kernel for the connections between the visible and the hidden layer
     */
    private CDWeightUpdatesKernel weightUpdatesKernel;

    /**
     * weights update kernel for visible bias connections
     */
    private CDBiasUpdatesKernel visibleBiasUpdatesKernel;

    /**
     * weights update kernel for the hidden bias connections
     */
    private CDBiasUpdatesKernel hiddenBiasUpdatesKernel;

    public AparapiCDTrainer(Properties properties) {
  super(properties);
    }

    /* (non-Javadoc)
     * @see com.github.neuralnetworks.training.rbm.CDTrainerBase#updateWeights(com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix, com.github.neuralnetworks.architecture.Matrix)
     * before each update the kernel update parameters are refreshed
     */
    @Override
    protected void updateWeights() {
  RBM rbm = getNeuralNetwork();

  RBMLayerCalculator lc = getLayerCalculator();
  int mbs = lc.getPositivePhaseVisible().getDimensions()[lc.getPositivePhaseVisible().getDimensions().length - 1];

  if (weightUpdatesKernel == null || weightUpdatesKernel.getMiniBatchSize() != mbs) {
      weightUpdatesKernel = new CDWeightUpdatesKernel(lc.getPositivePhaseVisible(), lc.getPositivePhaseHidden(), lc.getNegativePhaseVisible(), lc.getNegativePhaseHidden(), rbm.getMainConnections().getWeights(), getLearningRate(), getMomentum(), getl1weightDecay(), getl2weightDecay());
  }
  Environment.getInstance().getExecutionStrategy().execute(weightUpdatesKernel, rbm.getMainConnections().getWeights().getRows());

  // update visible bias
  if (rbm.getVisibleBiasConnections() != null) {
      if (visibleBiasUpdatesKernel == null || visibleBiasUpdatesKernel.getMiniBatchSize() != mbs) {
    visibleBiasUpdatesKernel = new CDBiasUpdatesKernel(rbm.getVisibleBiasConnections().getWeights(), lc.getPositivePhaseVisible(), lc.getNegativePhaseVisible(), getLearningRate(), getMomentum());
      }

      Environment.getInstance().getExecutionStrategy().execute(visibleBiasUpdatesKernel, rbm.getVisibleBiasConnections().getWeights().getSize());
  }

  // update hidden bias
  if (rbm.getHiddenBiasConnections() != null) {
      if (hiddenBiasUpdatesKernel == null || hiddenBiasUpdatesKernel.getMiniBatchSize() != mbs) {
    hiddenBiasUpdatesKernel = new CDBiasUpdatesKernel(rbm.getHiddenBiasConnections().getWeights(), lc.getPositivePhaseHidden(), lc.getNegativePhaseHidden(), getLearningRate(), getMomentum());
      }

      Environment.getInstance().getExecutionStrategy().execute(hiddenBiasUpdatesKernel, rbm.getHiddenBiasConnections().getWeights().getSize());
  }
    }

    protected float getLearningRate() {
  return properties.getParameter(Constants.LEARNING_RATE);
    }

    protected float getMomentum() {
  return (float) (properties.getParameter(Constants.MOMENTUM) != null ? properties.getParameter(Constants.MOMENTUM) : 0f);
    }

    protected float getl1weightDecay() {
  return (float) (properties.getParameter(Constants.L1_WEIGHT_DECAY) != null ? properties.getParameter(Constants.L1_WEIGHT_DECAY) : 0f);
    }

    protected float getl2weightDecay() {
  return (float) (properties.getParameter(Constants.L2_WEIGHT_DECAY) != null ? properties.getParameter(Constants.L2_WEIGHT_DECAY) : 0f);
    }
}
TOP

Related Classes of com.github.neuralnetworks.training.rbm.AparapiCDTrainer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.