Package com.github.neuralnetworks.calculation

Source Code of com.github.neuralnetworks.calculation.RBMLayerCalculator

package com.github.neuralnetworks.calculation;

import java.util.HashSet;
import java.util.Set;

import com.github.neuralnetworks.architecture.Layer;
import com.github.neuralnetworks.architecture.types.RBM;
import com.github.neuralnetworks.util.Matrix;

/**
* Implementation of LayerCalculatorImpl for RBMs
* Contains some helper methods like calculateVisibleLayer and calculateHiddenLayer and also takes gibbs sampling into account
*/
public class RBMLayerCalculator extends LayerCalculatorImpl {

    private static final long serialVersionUID = -7524966192939615856L;

    private Set<Layer> calculatedLayers;
    private ValuesProvider results;
    private ValuesProvider intermediateResults;

    public RBMLayerCalculator() {
  super();
  calculatedLayers = new HashSet<>();
  results = new ValuesProvider();
  intermediateResults = new ValuesProvider();
    }

    public void gibbsSampling(RBM rbm, Matrix posPhaseVisible, Matrix posPhaseHidden, Matrix negPhaseVisible, Matrix negPhaseHidden, int samplingCount, boolean resetNetwork, boolean useIntermediateResults) {
  Matrix hidden, visible;
  calculateHiddenLayer(rbm, posPhaseVisible, posPhaseHidden);

  if (resetNetwork) {
      System.arraycopy(posPhaseHidden.getElements(), 0, negPhaseHidden.getElements(), 0, negPhaseHidden.getElements().length);
  }

  // Gibbs sampling
  for (int i = 1; i <= samplingCount; i++) {
      hidden = getLayerResult(rbm.getHiddenLayer(), negPhaseHidden, useIntermediateResults);
      calculateVisibleLayer(rbm, negPhaseVisible, hidden);

      visible = getLayerResult(rbm.getVisibleLayer(), negPhaseVisible, useIntermediateResults);
      calculateHiddenLayer(rbm, visible, negPhaseHidden);
  }
    }

    public void calculateVisibleLayer(RBM rbm, Matrix visibleLayerResults, Matrix hiddenLayerResults) {
  Layer visibleLayer = rbm.getVisibleLayer();
  Layer hiddenLayer = rbm.getHiddenLayer();

  calculatedLayers.clear();
  calculatedLayers.add(hiddenLayer);

  results.addValues(visibleLayer, visibleLayerResults);
  results.addValues(hiddenLayer, hiddenLayerResults);

  super.calculate(rbm, visibleLayer, calculatedLayers, results);
    }

    public void calculateHiddenLayer(RBM rbm, Matrix visibleLayerResults, Matrix hiddenLayerResults) {
  Layer visibleLayer = rbm.getVisibleLayer();
  Layer hiddenLayer = rbm.getHiddenLayer();

  calculatedLayers.clear();
  calculatedLayers.add(visibleLayer);

  results.addValues(visibleLayer, visibleLayerResults);
  results.addValues(hiddenLayer, hiddenLayerResults);

  super.calculate(rbm, hiddenLayer, calculatedLayers, results);
    }

    private Matrix getLayerResult(Layer layer, Matrix realResult, boolean useIntermediateResults) {
  Matrix result = realResult;
  if (useIntermediateResults) {
      intermediateResults.setColumns(realResult.getColumns());
      result = intermediateResults.getValues(layer, realResult.getRows());
      if (result == null) {
    intermediateResults.addValues(layer, result = new Matrix(realResult));
      }

      System.arraycopy(realResult.getElements(), 0, result.getElements(), 0, result.getElements().length);
  }

  return result;
    }
}
TOP

Related Classes of com.github.neuralnetworks.calculation.RBMLayerCalculator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.