Package tv.floe.metronome.deeplearning.neuralnetwork.dbn

Source Code of tv.floe.metronome.deeplearning.neuralnetwork.dbn.TestDeepBeliefNetwork

package tv.floe.metronome.deeplearning.neuralnetwork.dbn;

import static org.junit.Assert.*;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;

import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.junit.Test;


import tv.floe.metronome.deeplearning.dbn.DeepBeliefNetwork;
import tv.floe.metronome.deeplearning.math.transforms.MultiplyScalar;
import tv.floe.metronome.deeplearning.neuralnetwork.core.LogisticRegression;
import tv.floe.metronome.eval.Evaluation;
import tv.floe.metronome.math.MatrixUtils;
import tv.floe.metronome.types.Pair;


public class TestDeepBeliefNetwork {


  double[][] x_toy = new double[][]
  {
      {1,1,1,0,0,0},
      {1,0,1,0,0,0},
      {1,1,1,0,0,0},
      {0,0,1,1,1,0},
      {0,0,1,1,0,0},
      {0,0,0,1,1,1}
     
  };
 
  double[][] y_toy = new double[][]
  {
      {1, 0},
      {1, 0},
      {1, 0},
      {0, 1},
      {0, 1},
      {0, 1}
  };
 
  Matrix x_toy_Matrix = new DenseMatrix(x_toy);
  Matrix y_toy_Matrix = new DenseMatrix(y_toy);

  double[][] x_toy_Test = new double[][] {
      {1, 1, 1, 0, 0, 0},
      {0, 0, 0, 1, 1, 1}
      //{1, 1, 1, 1, 1, 0}
  };
 
  Matrix xTestMatrix = new DenseMatrix(x_toy_Test);
   
 
 
  public static Matrix gti(Matrix m, double val) {
   
    Matrix res = m.like();
   
    for ( int r = 0; r < m.numRows(); r++ ) {
     
      for ( int c = 0; c < m.numCols(); c++ ) {
       
        if (m.get(r, c) > val) {
          res.set(r, c, 1.0);
        } else {
          res.set(r, c, 0.0);
        }
       
      }
     
     
    }
   
    return res;
   
  }
 
  public static Matrix matrixRand( int rows, int cols ) {
   
    RandomGenerator rnd = new MersenneTwister(1234);
    //rnd.nextDouble();
   
    Matrix res = new DenseMatrix( rows, cols );
   
    for ( int r = 0; r < res.numRows(); r++ ) {
     
      for ( int c = 0; c < res.numCols(); c++ ) {
       
        res.set(r,  c, rnd.nextDouble());
        //System.out.println( "next: " + rnd.nextDouble() );
       
      }
     
    }
   
    return res;
   
  }

  double[][] xor_input = new double[][]
  {
      {0,0},
      {0,1},
      {1,0},
      {1,1}
     
  };
 
  double[][] xor_labels = new double[][]
  {
      {1, 0},
      {0, 1},
      {0, 1},
      {1, 0}
  };
   
  Matrix x_xor_Matrix = new DenseMatrix(xor_input);
  Matrix y_xor_Matrix = new DenseMatrix(xor_labels);

 
 
  @Test
  public void testXor() {






    RandomGenerator rng = new MersenneTwister(123);

    double preTrainLr = 0.0001;
    int preTrainEpochs = 1000;
    int k = 1;
    int nIns = 2,nOuts = 2;
    int[] hiddenLayerSizes = new int[] { 2 };
    double fineTuneLr = 0.0001;
    int fineTuneEpochs = 1000;
/*
    DBN dbn = new DBN.Builder()
    .transformWeightsAt(0, new MultiplyScalar(1000))
    .transformWeightsAt(1, new MultiplyScalar(100))

    .hiddenLayerSizes(hiddenLayerSizes).numberOfInputs(nIns).renderWeights(0)
    .useRegularization(false).withMomentum(0).withDist(new NormalDistribution(0,0.001))
    .numberOfOutPuts(nOuts).withRng(rng).build();
*/
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork(nIns, hiddenLayerSizes, nOuts, hiddenLayerSizes.length, rng ); //, Matrix input, Matrix labels);
//    dbn.addWeightTransform(0, new MultiplyScalar(100));
//    dbn.addWeightTransform(1, new MultiplyScalar(10));

   
   
    dbn.preTrain(x_xor_Matrix,k, preTrainLr, preTrainEpochs);
   
   
   
    dbn.finetune(y_xor_Matrix,fineTuneLr, fineTuneEpochs);







    Matrix predict = dbn.predict( x_xor_Matrix );
    System.out.println("--- Predictions XOR ----");
    MatrixUtils.debug_print(predict);
    //log.info(predict.toString());

    Evaluation eval = new Evaluation();
    eval.eval( y_xor_Matrix, predict );
    //log.info(eval.stats());
    System.out.println( eval.stats() );

    System.out.println( "Total Correct: " + eval.correctScores() + " out of " + x_xor_Matrix.numRows() );
 
 
  }
 

 
  @Test
  public void testToyInput() {






    RandomGenerator rng = new MersenneTwister(123);

    double preTrainLr = 0.01;
    int preTrainEpochs = 1000;
    int k = 1;
   
    int nIns = x_toy_Matrix.columnSize();
    int nOuts = y_toy_Matrix.columnSize();
   
    int[] hiddenLayerSizes = new int[] { 10, 8 };
    double fineTuneLr = 0.01;
    int fineTuneEpochs = 1000;
   
    MatrixUtils.debug_print(x_toy_Matrix);
   
/*
    DBN dbn = new DBN.Builder()
    .transformWeightsAt(0, new MultiplyScalar(1000))
    .transformWeightsAt(1, new MultiplyScalar(100))

    .hiddenLayerSizes(hiddenLayerSizes).numberOfInputs(nIns).renderWeights(0)
    .useRegularization(false).withMomentum(0).withDist(new NormalDistribution(0,0.001))
    .numberOfOutPuts(nOuts).withRng(rng).build();
*/
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork(nIns, hiddenLayerSizes, nOuts, hiddenLayerSizes.length, rng ); //, Matrix input, Matrix labels);
//    dbn.addWeightTransform(0, new MultiplyScalar(100));
//    dbn.addWeightTransform(1, new MultiplyScalar(10));

   
   
    dbn.preTrain(x_toy_Matrix,k, preTrainLr, preTrainEpochs);
   
   
   
    dbn.finetune(y_toy_Matrix,fineTuneLr, fineTuneEpochs);







    Matrix predict = dbn.predict( x_toy_Matrix );
    System.out.println("--- Predictions Toy Matrix ----");
    MatrixUtils.debug_print(predict);
    //log.info(predict.toString());

    Evaluation eval = new Evaluation();
    eval.eval( y_toy_Matrix, predict );
    //log.info(eval.stats());
    System.out.println( eval.stats() );

    System.out.println( "Total Correct: " + eval.correctScores() + " out of " + x_toy_Matrix.numRows() );
 
 
 
 
 
  @Test
  public void testSerdeMechanics() throws FileNotFoundException {
   
    String tmpFilename = "/tmp/DBNSerdeTest.model";
   
    /*
    int n = 10;
    Pair<Matrix, Matrix> d = xorDataStatic(); //xorData(n);
    Matrix x = d.getFirst();
    Matrix y = d.getSecond();
*/




    RandomGenerator rng = new MersenneTwister(123);

    double preTrainLr = 0.001;
    int preTrainEpochs = 100;
    int k = 1;
    int nIns = 2,nOuts = 2;
    int[] hiddenLayerSizes = new int[] {2,2,2};
    double fineTuneLr = 0.001;
    int fineTuneEpochs = 100;
/*
    DBN dbn = new DBN.Builder()
    .transformWeightsAt(0, new MultiplyScalar(1000))
    .transformWeightsAt(1, new MultiplyScalar(100))

    .hiddenLayerSizes(hiddenLayerSizes).numberOfInputs(nIns).renderWeights(0)
    .useRegularization(false).withMomentum(0).withDist(new NormalDistribution(0,0.001))
    .numberOfOutPuts(nOuts).withRng(rng).build();
*/
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork(nIns, hiddenLayerSizes, nOuts, hiddenLayerSizes.length, rng ); //, Matrix input, Matrix labels);

   
   
    dbn.preTrain(x_xor_Matrix,k, preTrainLr, preTrainEpochs);
    dbn.finetune(y_xor_Matrix,fineTuneLr, fineTuneEpochs);

   
   
   
    // save / write the model
   
    FileOutputStream oFileOutStream = new FileOutputStream( tmpFilename, false);
    dbn.write( oFileOutStream );
   
    // read / load the model
    FileInputStream oFileInputStream = new FileInputStream( tmpFilename );
   
   
    DeepBeliefNetwork dbn_deserialize = new DeepBeliefNetwork(1, hiddenLayerSizes, 1, hiddenLayerSizes.length, rng ); //, Matrix input, Matrix labels);
    dbn_deserialize.load(oFileInputStream);
   
   
   
   
    assertEquals( dbn.inputNeuronCount, dbn_deserialize.inputNeuronCount );
    assertEquals( dbn.l2, dbn_deserialize.l2, 0.0 );

    // check hidden layers
    for ( int i = 0; i < dbn.hiddenLayers.length; i++ ) {
     
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.hiddenLayers[i].biasTerms, dbn_deserialize.hiddenLayers[i].biasTerms ) );
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.hiddenLayers[i].connectionWeights, dbn_deserialize.hiddenLayers[i].connectionWeights ) );

    }
   
    // check RBMs
    for ( int i = 0; i < dbn.preTrainingLayers.length; i++ ) {
     
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.preTrainingLayers[i].getConnectionWeights(), dbn_deserialize.preTrainingLayers[i].getConnectionWeights() ) );
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.preTrainingLayers[i].getHiddenBias(), dbn_deserialize.preTrainingLayers[i].getHiddenBias() ) );
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.preTrainingLayers[i].getVisibleBias(), dbn_deserialize.preTrainingLayers[i].getVisibleBias() ) );

    }
   
  }   

 
  @Test
  public void testParameterAvgSerdeMechanics() throws FileNotFoundException {
   
    String tmpFilename = "/tmp/DBNParameterAvgSerdeTest.model";
   
   



    RandomGenerator rng = new MersenneTwister(123);

    double preTrainLr = 0.001;
    int preTrainEpochs = 100;
    int k = 1;
    int nIns = 2,nOuts = 2;
    int[] hiddenLayerSizes = new int[] {2,2,2};
    double fineTuneLr = 0.001;
    int fineTuneEpochs = 100;
/*
    DBN dbn = new DBN.Builder()
    .transformWeightsAt(0, new MultiplyScalar(1000))
    .transformWeightsAt(1, new MultiplyScalar(100))

    .hiddenLayerSizes(hiddenLayerSizes).numberOfInputs(nIns).renderWeights(0)
    .useRegularization(false).withMomentum(0).withDist(new NormalDistribution(0,0.001))
    .numberOfOutPuts(nOuts).withRng(rng).build();
*/
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork(nIns, hiddenLayerSizes, nOuts, hiddenLayerSizes.length, rng ); //, Matrix input, Matrix labels);

   
   
    dbn.preTrain( x_xor_Matrix,k, preTrainLr, preTrainEpochs );
    dbn.finetune( y_xor_Matrix,fineTuneLr, fineTuneEpochs );

   
   
   
    // save / write the model
   
    FileOutputStream oFileOutStream = new FileOutputStream( tmpFilename, false);
    //dbn.serializeParameters( oFileOutStream );
    dbn.write(oFileOutStream);
   
    // read / load the model
    FileInputStream oFileInputStream = new FileInputStream( tmpFilename );
   
   
    DeepBeliefNetwork dbn_deserialize = new DeepBeliefNetwork(1, hiddenLayerSizes, 1, hiddenLayerSizes.length, rng ); //, Matrix input, Matrix labels);
    //dbn_deserialize.loadParameterValues( oFileInputStream );
    dbn_deserialize.load(oFileInputStream);
   
    int[] hiddenLayerSizesTmp = new int[] {1};
   
    // now setup a DBN based on a clone operation via initBasedOn()
    DeepBeliefNetwork dbn_merge_load = new DeepBeliefNetwork(1, hiddenLayerSizesTmp, 1, hiddenLayerSizesTmp.length, null); //1, , 1, hiddenLayerSizes.length, rng );
    // this will clear the weights
    dbn_merge_load.initBasedOn(dbn_deserialize);
   
    DeepBeliefNetwork dbn_merge_load_2 = new DeepBeliefNetwork(1, hiddenLayerSizesTmp, 1, hiddenLayerSizesTmp.length, null); //1, , 1, hiddenLayerSizes.length, rng );
    // this will clear the weights
    dbn_merge_load_2.initBasedOn( dbn );
   
    //assertEquals( dbn.inputNeuronCount, dbn_deserialize.inputNeuronCount );
    //assertEquals( dbn.l2, dbn_deserialize.l2, 0.0 );

    // check logistic layer
   
    assertEquals( true, MatrixUtils.elementwiseSame(dbn.logisticRegressionLayer.connectionWeights, dbn_deserialize.logisticRegressionLayer.connectionWeights ) );
   
    // check sizes here
    //assertEquals( true, MatrixUtils.elementwiseSame(dbn.logisticRegressionLayer.connectionWeights, dbn_merge_load.logisticRegressionLayer.connectionWeights ) );
    assertEquals( dbn.logisticRegressionLayer.connectionWeights.numRows(), dbn_merge_load.logisticRegressionLayer.connectionWeights.numRows() );
    assertEquals( dbn.logisticRegressionLayer.connectionWeights.numCols(), dbn_merge_load.logisticRegressionLayer.connectionWeights.numCols() );
   
    assertEquals( true, MatrixUtils.elementwiseSame(dbn.logisticRegressionLayer.biasTerms, dbn_deserialize.logisticRegressionLayer.biasTerms ) );
   
    //check sizes
    //assertEquals( true, MatrixUtils.elementwiseSame(dbn.logisticRegressionLayer.biasTerms, dbn_merge_load.logisticRegressionLayer.biasTerms ) );
    assertEquals( dbn.logisticRegressionLayer.biasTerms.numRows(), dbn_merge_load.logisticRegressionLayer.biasTerms.numRows() );
    assertEquals( dbn.logisticRegressionLayer.biasTerms.numCols(), dbn_merge_load.logisticRegressionLayer.biasTerms.numCols() );
   

   
    // check hidden layers
    for ( int i = 0; i < dbn.hiddenLayers.length; i++ ) {
     
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.hiddenLayers[i].biasTerms, dbn_deserialize.hiddenLayers[i].biasTerms ) );
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.hiddenLayers[i].connectionWeights, dbn_deserialize.hiddenLayers[i].connectionWeights ) );

      assertEquals( dbn.hiddenLayers[i].biasTerms.numCols(), dbn_merge_load.hiddenLayers[i].biasTerms.numCols() );
      assertEquals( dbn.hiddenLayers[i].connectionWeights.numCols(), dbn_merge_load.hiddenLayers[i].connectionWeights.numCols() );
     
     
    }
   
    // check RBMs
    for ( int i = 0; i < dbn.preTrainingLayers.length; i++ ) {
     
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.preTrainingLayers[i].getConnectionWeights(), dbn_deserialize.preTrainingLayers[i].getConnectionWeights() ) );
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.preTrainingLayers[i].getHiddenBias(), dbn_deserialize.preTrainingLayers[i].getHiddenBias() ) );
      assertEquals( true, MatrixUtils.elementwiseSame(dbn.preTrainingLayers[i].getVisibleBias(), dbn_deserialize.preTrainingLayers[i].getVisibleBias() ) );

      assertEquals( dbn.preTrainingLayers[i].getConnectionWeights().numCols(), dbn_merge_load.preTrainingLayers[i].getConnectionWeights().numCols() );
     
      for ( int row = 0; row < dbn.preTrainingLayers[i].getConnectionWeights().numRows(); row++ ) {

        for ( int col = 0; col < dbn.preTrainingLayers[i].getConnectionWeights().numCols(); col++ ) {
       
          assertEquals( 0.0, dbn_merge_load.preTrainingLayers[i].getConnectionWeights().get( row, col), 0.0 );
         
        }
       
      }
     
      assertEquals( dbn.preTrainingLayers[i].getHiddenBias().numCols(), dbn_merge_load.preTrainingLayers[i].getHiddenBias().numCols() );

      for ( int row = 0; row < dbn.preTrainingLayers[i].getHiddenBias().numRows(); row++ ) {

        for ( int col = 0; col < dbn.preTrainingLayers[i].getHiddenBias().numCols(); col++ ) {

          assertEquals( 0.0, dbn_merge_load.preTrainingLayers[i].getHiddenBias().get( row, col), 0.0 );
         
        }
       
      }
     
      assertEquals( dbn.preTrainingLayers[i].getVisibleBias().numCols(), dbn_merge_load.preTrainingLayers[i].getVisibleBias().numCols() );
     
      for ( int row = 0; row < dbn.preTrainingLayers[i].getVisibleBias().numRows(); row++ ) {

        for ( int col = 0; col < dbn.preTrainingLayers[i].getVisibleBias().numCols(); col++ ) {

          assertEquals( 0.0, dbn_merge_load.preTrainingLayers[i].getVisibleBias().get( row, col), 0.0 );
         
        }
       
      }
     
    }
   
  } 
 
 
  @Test
  public void testParameterAveragingCode() {
   

    RandomGenerator rng = new MersenneTwister(123);

    double preTrainLr = 0.001;
    int preTrainEpochs = 100;
    int k = 1;
    int nIns = 2,nOuts = 2;
    int[] hiddenLayerSizes = new int[] {2,2};
    double fineTuneLr = 0.001;
    int fineTuneEpochs = 100;
   
    DeepBeliefNetwork dbn_A = new DeepBeliefNetwork(nIns, hiddenLayerSizes, nOuts, hiddenLayerSizes.length, rng, x_xor_Matrix, y_xor_Matrix );
   
    // setup layer 0
    dbn_A.preTrainingLayers[ 0 ].getConnectionWeights().set(0, 0, 1.0);
    dbn_A.preTrainingLayers[ 0 ].getConnectionWeights().set(0, 1, 1.0);
    dbn_A.preTrainingLayers[ 0 ].getConnectionWeights().set(1, 0, 1.0);
    dbn_A.preTrainingLayers[ 0 ].getConnectionWeights().set(1, 1, 1.0);
   
    // Layer 0: hidden bias
   
    dbn_A.preTrainingLayers[ 0 ].getHiddenBias().set( 0, 0, 15.0 );
    dbn_A.preTrainingLayers[ 0 ].getHiddenBias().set( 0, 1, 16.0 );
       
    // Layer 0: visible bias

    dbn_A.preTrainingLayers[ 0 ].getVisibleBias().set( 0, 0, 25.0 );
    dbn_A.preTrainingLayers[ 0 ].getVisibleBias().set( 0, 1, 26.0 );
   
   
    // setup layer 1
    dbn_A.preTrainingLayers[ 1 ].getConnectionWeights().set(0, 0, 1.0);
    dbn_A.preTrainingLayers[ 1 ].getConnectionWeights().set(0, 1, 1.0);
    dbn_A.preTrainingLayers[ 1 ].getConnectionWeights().set(1, 0, 1.0);
    dbn_A.preTrainingLayers[ 1 ].getConnectionWeights().set(1, 1, 1.0);

   
   
    // setup log layer
    dbn_A.logisticRegressionLayer.connectionWeights.set(0, 0, 1.0);
    dbn_A.logisticRegressionLayer.connectionWeights.set(0, 1, 1.0);
    dbn_A.logisticRegressionLayer.connectionWeights.set(1, 0, 1.0);
    dbn_A.logisticRegressionLayer.connectionWeights.set(1, 1, 1.0);
   
    MatrixUtils.debug_print( dbn_A.logisticRegressionLayer.biasTerms );

    // set up bias terms
   
    dbn_A.logisticRegressionLayer.biasTerms.set( 0, 0, 35.0 );
    dbn_A.logisticRegressionLayer.biasTerms.set( 0, 1, 36.0 );
   
   
   
   
    DeepBeliefNetwork dbn_B = new DeepBeliefNetwork(nIns, hiddenLayerSizes, nOuts, hiddenLayerSizes.length, rng, x_xor_Matrix, y_xor_Matrix );

    dbn_B.preTrainingLayers[ 0 ].getConnectionWeights().set(0, 0, 2.0);
    dbn_B.preTrainingLayers[ 0 ].getConnectionWeights().set(0, 1, 1.0);
    dbn_B.preTrainingLayers[ 0 ].getConnectionWeights().set(1, 0, 3.0);
    dbn_B.preTrainingLayers[ 0 ].getConnectionWeights().set(1, 1, 4.0);

    // Layer 0: hidden bias
   
    dbn_B.preTrainingLayers[ 0 ].getHiddenBias().set( 0, 0, 5.0 );
    dbn_B.preTrainingLayers[ 0 ].getHiddenBias().set( 0, 1, 6.0 );
       
    // Layer 0: visible bias

    dbn_B.preTrainingLayers[ 0 ].getVisibleBias().set( 0, 0, 5.0 );
    dbn_B.preTrainingLayers[ 0 ].getVisibleBias().set( 0, 1, 6.0 );

    // layer 2
   
    dbn_B.preTrainingLayers[ 1 ].getConnectionWeights().set(0, 0, 2.0);
    dbn_B.preTrainingLayers[ 1 ].getConnectionWeights().set(0, 1, 1.0);
    dbn_B.preTrainingLayers[ 1 ].getConnectionWeights().set(1, 0, 3.0);
    dbn_B.preTrainingLayers[ 1 ].getConnectionWeights().set(1, 1, 4.0);
   
    dbn_B.logisticRegressionLayer.connectionWeights.set(0, 0, 1.0);
    dbn_B.logisticRegressionLayer.connectionWeights.set(0, 1, 2.0);
    dbn_B.logisticRegressionLayer.connectionWeights.set(1, 0, 3.0);
    dbn_B.logisticRegressionLayer.connectionWeights.set(1, 1, 4.0);
   
    dbn_B.logisticRegressionLayer.biasTerms.set( 0, 0, 5.0 );
    dbn_B.logisticRegressionLayer.biasTerms.set( 0, 1, 6.0 );
   
   
    // MatrixUtils.debug_print( dbn_B.preTrainingLayers[ 0 ].getConnectionWeights() );
   
    int[] hiddenLayerSizesTmp = new int[] {1};
   
    // now setup a DBN based on a clone operation via initBasedOn()
//    DeepBeliefNetwork dbn_merge_load = new DeepBeliefNetwork(1, hiddenLayerSizesTmp, 1, hiddenLayerSizesTmp.length, null);
   
    // this only works if the workers have been initialized on data
    DeepBeliefNetwork dbn_master = new DeepBeliefNetwork(1, hiddenLayerSizesTmp, 1, hiddenLayerSizesTmp.length, null);
    dbn_master.initBasedOn(dbn_A);
   
   
    ArrayList<DeepBeliefNetwork> workers = new ArrayList<DeepBeliefNetwork>();
    workers.add(dbn_A);
    workers.add(dbn_B);
   
    dbn_master.computeAverageDBNParameterVector(workers);
   
    assertEquals( 1.5, dbn_master.preTrainingLayers[ 0 ].getConnectionWeights().get(0, 0), 0.0 );
    assertEquals( 1.0, dbn_master.preTrainingLayers[ 0 ].getConnectionWeights().get(0, 1), 0.0 );
    assertEquals( 2.0, dbn_master.preTrainingLayers[ 0 ].getConnectionWeights().get(1, 0), 0.0 );
    assertEquals( 2.5, dbn_master.preTrainingLayers[ 0 ].getConnectionWeights().get(1, 1), 0.0 );

    // check hidden bias averaged weights
   
    assertEquals( 10.0, dbn_master.preTrainingLayers[ 0 ].getHiddenBias().get(0, 0), 0.0 );
    assertEquals( 11.0, dbn_master.preTrainingLayers[ 0 ].getHiddenBias().get(0, 1), 0.0 );

    // check visible bias averaged weights
   
    assertEquals( 15.0, dbn_master.preTrainingLayers[ 0 ].getVisibleBias().get(0, 0), 0.0 );
    assertEquals( 16.0, dbn_master.preTrainingLayers[ 0 ].getVisibleBias().get(0, 1), 0.0 );
   
    // now make sure the stock hidden weights are the same as the pre-train layer
   
    assertEquals( true, MatrixUtils.elementwiseSame( dbn_master.preTrainingLayers[ 0 ].getConnectionWeights(), dbn_master.hiddenLayers[ 0 ].connectionWeights ) );
    // hidden bias == bias ?
    assertEquals( true, MatrixUtils.elementwiseSame( dbn_master.preTrainingLayers[ 0 ].getHiddenBias(), dbn_master.hiddenLayers[ 0 ].biasTerms ) );

   
    assertEquals( 1.5, dbn_master.preTrainingLayers[ 1 ].getConnectionWeights().get(0, 0), 0.0 );
    assertEquals( 1.0, dbn_master.preTrainingLayers[ 1 ].getConnectionWeights().get(0, 1), 0.0 );
    assertEquals( 2.0, dbn_master.preTrainingLayers[ 1 ].getConnectionWeights().get(1, 0), 0.0 );
    assertEquals( 2.5, dbn_master.preTrainingLayers[ 1 ].getConnectionWeights().get(1, 1), 0.0 );
   
   
    assertEquals( 1.0, dbn_master.logisticRegressionLayer.connectionWeights.get(0, 0), 0.0 );   
    assertEquals( 1.5, dbn_master.logisticRegressionLayer.connectionWeights.get(0, 1), 0.0 );
    assertEquals( 2.0, dbn_master.logisticRegressionLayer.connectionWeights.get(1, 0), 0.0 );
    assertEquals( 2.5, dbn_master.logisticRegressionLayer.connectionWeights.get(1, 1), 0.0 );

    assertEquals( 20.0, dbn_master.logisticRegressionLayer.biasTerms.get(0, 0), 0.0 );
    assertEquals( 21.0, dbn_master.logisticRegressionLayer.biasTerms.get(0, 1), 0.0 );

    // TODO: check the hidden layers ----
   
   
   
  }
 
 
}
TOP

Related Classes of tv.floe.metronome.deeplearning.neuralnetwork.dbn.TestDeepBeliefNetwork

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.