Package tv.floe.metronome.deeplearning.datasets

Examples of tv.floe.metronome.deeplearning.datasets.DataSet


    double learningRate = 0.001;
   
    int[] batchSteps = { 250, 200, 150, 100, 50, 25, 5 };
   
    DataSet first = fetcher.next();
/*
    RestrictedBoltzmannMachine da = new RBM.Builder().numberOfVisible(784).numHidden(400).withRandom(rand).renderWeights(1000)
        .useRegularization(false)
        .withMomentum(0).build();
*/
    RestrictedBoltzmannMachine rbm = new RestrictedBoltzmannMachine( 784, 400, null );
    rbm.useRegularization = false;
    //rbm.scaleWeights( 1000 );
    rbm.momentum = 0 ;
    rbm.sparsity = 0.01;
    // TODO: investigate "render weights"



    rbm.trainingDataset = first.getFirst();

    //MatrixUtils.debug_print( rbm.trainingDataset );

    // render base activations pre train
   
    this.renderActivationsToDisk(rbm, "init");
    this.renderWeightValuesToDisk(rbm, "init");
    this.renderFiltersToDisk(rbm, "init");
   
   
    System.out.println(" ----- Training ------");
   
    //for(int i = 0; i < 2; i++) {
    int epoch = 0;
   
    System.out.println("Epoch " + epoch + " Negative Log Likelhood: " + rbm.getReConstructionCrossEntropy() );
   
    for (int stepIndex = 0; stepIndex < batchSteps.length; stepIndex++ ) {
   
      int minCrossEntropy = batchSteps[ stepIndex ];
     
      while ( rbm.getReConstructionCrossEntropy() > minCrossEntropy) {
       
        System.out.println("Epoch " + epoch + " Negative Log Likelhood: " + rbm.getReConstructionCrossEntropy() );
       
        //rbm.trainTillConvergence( first.getFirst(), learningRate, new Object[]{ 1 } );
        //rbm.trainTillConvergence(learningRate, 1, first.getFirst());
        // new Object[]{1,0.01,1000}
        rbm.trainTillConvergence(first.getFirst(), learningRate, new Object[]{ 1, learningRate, 10 } );
       
        epoch++;
       
      }

View Full Code Here


    int totalNumExamples = 100 * datasetSize;
   
   
   
    MnistDataSetIterator fetcher = new MnistDataSetIterator( batchSize, totalNumExamples );
    DataSet recordBatch = fetcher.next();
   
   
    Map<Integer, Integer> filter = new HashMap<Integer, Integer>();
    for (int x = 0; x < classIndexes.length; x++ ) {
     
      filter.put(classIndexes[x], 1);
     
    }
   
   
   
    Matrix input = recordBatch.getFirst();
    Matrix labels = recordBatch.getSecond();
   
    Matrix inputFiltered = new DenseMatrix( datasetSize, input.numCols() );
    Matrix labelsFiltered = new DenseMatrix( datasetSize, labels.numCols() );
   
    int recFound = 0;
   
    for ( int row = 0; row < input.numRows(); row++ ) {
     
      int rowLabel = labels.viewRow( row ).maxValueIndex();
     
      if ( filter.containsKey(rowLabel)) {
       
        inputFiltered.viewRow(recFound).assign( input.viewRow(row) );
        labelsFiltered.viewRow(recFound).assign( labels.viewRow(row) );
        recFound++;
       
        if ( recFound >= inputFiltered.numRows() ) {
          break;
        }
       
      }
     
     
    }

    if ( recFound < inputFiltered.numRows() ) {

      System.out.println("We did not fill the filtered input count fully.");
     
    }
   
   
    //DataSet ret = new DataSet();
    return new DataSet( inputFiltered, labelsFiltered );
  }
View Full Code Here

    int totalNumExamples = 20;
       
    int batchSize = 1;
   
    int[] filter = { 0, 1 };
    DataSet recordBatch = this.filterDataset( filter, 20 );
   
    Matrix input = recordBatch.getFirst();
    Matrix labels = recordBatch.getSecond();
   
    assertEquals(20, input.numRows() );
    assertEquals(20, labels.numRows() );
   
    //MatrixUtils.debug_print( input );
View Full Code Here

    // mini-batches through dataset
//    MnistDataSetIterator fetcher = new MnistDataSetIterator( batchSize, totalNumExamples );
//    DataSet first = fetcher.next();
   
    int[] filter = { 0, 1 };
    DataSet recordBatch = this.filterDataset( filter, 20 );
   
    MatrixUtils.debug_print(recordBatch.getSecond());
   
    //int numIns = first.getFirst().numCols();
    //int numLabels = first.getSecond().numCols();
   
    int numIns = recordBatch.getFirst().numCols();
    int numLabels = recordBatch.getSecond().numCols();

    int n_layers = hiddenLayerSizes.length;
    RandomGenerator rng = new MersenneTwister(123);
   
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork( numIns, hiddenLayerSizes, numLabels, n_layers, rng ); //, Matrix input, Matrix labels);
       
    dbn.useRegularization = false;
    dbn.setSparsity(0.01);
    dbn.setMomentum(0);
   
    int recordsProcessed = 0;
   
   
    StopWatch watch = new StopWatch();
    watch.start();
   
    StopWatch batchWatch = new StopWatch();
   
   
//    do  {
     
      recordsProcessed += batchSize;
     
      System.out.println( "PreTrain: Batch Mode, Processed Total " + recordsProcessed + ", Elapsed Time " + watch.toString() );
     
      batchWatch.reset();
      batchWatch.start();
      dbn.preTrain( recordBatch.getFirst(), 1, learningRate, preTrainEpochs);
      batchWatch.stop();
     
      System.out.println( "Batch Training Elapsed Time " + batchWatch.toString() );

      System.out.println( "DBN Network Stats:\n" + dbn.generateNetworkSizeReport() );

/*     
      if (fetcher.hasNext()) {
        first = fetcher.next();
      }
     
    } while (fetcher.hasNext());

    fetcher.reset();
    first = fetcher.next();
   
    recordsProcessed = 0;
   
    do {
     
      recordsProcessed += batchSize;
*/     
      System.out.println( "FineTune: Batch Mode, Processed Total " + recordsProcessed + ", Elapsed Time " + watch.toString() );
     
     
      dbn.finetune( recordBatch.getSecond(), learningRate, fineTuneEpochs );

      dbn.preTrain( recordBatch.getFirst(), 1, learningRate, preTrainEpochs);
      dbn.finetune( recordBatch.getSecond(), learningRate, fineTuneEpochs );
     
      dbn.preTrain( recordBatch.getFirst(), 1, learningRate, preTrainEpochs);
      dbn.finetune( recordBatch.getSecond(), learningRate, fineTuneEpochs );
     
      /*     
      if (fetcher.hasNext()) {
        first = fetcher.next();
      }
     
    } while (fetcher.hasNext());
*/   
    watch.stop();
   
    System.out.println("----------- Training Complete! -----------");
    System.out.println( "Processed Total " + recordsProcessed + ", Elapsed Time " + watch.toString() );
   
    // save model
   
  //  dbn.write( "/tmp/metronome/dbn/TEST_DBN_MNIST/models/mnist.model" );
   
    FileOutputStream oFileOutStream = new FileOutputStream( "/tmp/Metronome_DBN_Mnist.model", false);
    dbn.write( oFileOutStream );
   
   
    // now do evaluation of results ....
//    fetcher.reset();
 
   
   
    ModelTester.evaluateModel( recordBatch.getFirst(), recordBatch.getSecond(), dbn);
   
  }
View Full Code Here

    int batchSize = 50;
    boolean showNetworkStats = true;
   
    // mini-batches through dataset
    MnistDataSetIterator fetcher = new MnistDataSetIterator( batchSize, totalNumExamples );
    DataSet first = fetcher.next();
    int numIns = first.getFirst().numCols();
    int numLabels = first.getSecond().numCols();

    int n_layers = hiddenLayerSizes.length;
    RandomGenerator rng = new MersenneTwister(123);
   
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork( numIns, hiddenLayerSizes, numLabels, n_layers, rng ); //, Matrix input, Matrix labels);
       
    dbn.useRegularization = false;
    dbn.setSparsity(0.01);
    dbn.setMomentum(0);
   
   
    int recordsProcessed = 0;
    int batchesProcessed = 0;
    long totalBatchProcessingTime = 0;
   
    StopWatch watch = new StopWatch();
    watch.start();
   
    StopWatch batchWatch = new StopWatch();
   
   
    do  {
     
      recordsProcessed += batchSize;
      batchesProcessed++;
     
      System.out.println( "PreTrain: Batch Mode, Processed Total " + recordsProcessed + ", Elapsed Time " + watch.toString() );
     
      batchWatch.reset();
      batchWatch.start();
      dbn.preTrain( first.getFirst(), 1, learningRate, preTrainEpochs);
      batchWatch.stop();
     
      totalBatchProcessingTime += batchWatch.getTime();
     
      System.out.println( "Batch Training Elapsed Time " + batchWatch.toString() );
View Full Code Here

    //int rowLimit = 100;
       
    int batchSize = 10;
    // mini-batches through dataset
    MnistDataSetIterator fetcher = new MnistDataSetIterator( batchSize, totalNumExamples );
    DataSet first = fetcher.next();
    int numIns = first.getFirst().numCols();
    int numLabels = first.getSecond().numCols();

    int n_layers = hiddenLayerSizes.length;
    RandomGenerator rng = new MersenneTwister(123);
   
   
    DeepBeliefNetwork dbn = new DeepBeliefNetwork( numIns, hiddenLayerSizes, numLabels, n_layers, rng ); //, Matrix input, Matrix labels);
       
    int recordsProcessed = 0;
   
    do  {
     
      recordsProcessed += batchSize;
     
      System.out.println( "PreTrain: Batch Mode, Processed Total " + recordsProcessed );
      dbn.preTrain( first.getFirst(), 1, learningRate, preTrainEpochs);

      if (fetcher.hasNext()) {
        first = fetcher.next();
      }
     
    } while (fetcher.hasNext());

    fetcher.reset();
    first = fetcher.next();
   
    recordsProcessed = 0;
   
    do {
     
      recordsProcessed += batchSize;
     
      System.out.println( "FineTune: Batch Mode, Processed Total " + recordsProcessed );
     
     
      dbn.finetune( first.getSecond(), learningRate, fineTuneEpochs );
     
      if (fetcher.hasNext()) {
        first = fetcher.next();
      }
     
View Full Code Here

    double learningRate = 0.001;
   
    int[] batchSteps = { 250, 200, 150, 100, 50, 25, 5 };
   
    DataSet first = fetcher.next();
/*
    RestrictedBoltzmannMachine da = new RBM.Builder().numberOfVisible(784).numHidden(400).withRandom(rand).renderWeights(1000)
        .useRegularization(false)
        .withMomentum(0).build();
*/
    RestrictedBoltzmannMachine rbm = new RestrictedBoltzmannMachine( 784, 400, null );
    rbm.useRegularization = false;
    //rbm.scaleWeights( 1000 );
    rbm.momentum = 0 ;
    rbm.sparsity = 0.01;
    // TODO: investigate "render weights"



    rbm.trainingDataset = first.getFirst();

    //MatrixUtils.debug_print_row( rbm.trainingDataset, 1 );

    // render base activations pre train
   
View Full Code Here

   
       

     
    MnistHDFSDataSetIterator hdfs_fetcher = new MnistHDFSDataSetIterator( batchSize, totalNumExamples, txt_reader );
    DataSet hdfs_recordBatch = hdfs_fetcher.next();
   
    return hdfs_recordBatch;
  }
View Full Code Here

      //labels.putRow(i,examples.get(i).getSecond());
      labels.assignRow( i, examples.get(i).getSecond().viewRow(0) );
   
    }
   
    curr = new DataSet(inputs,labels);

  }
View Full Code Here

   
    int recordsProcessed = 0;
   
    StopWatch batchWatch = new StopWatch();
   
    DataSet hdfs_recordBatch = null; //this.hdfs_fetcher.next();
   
    System.out.println("Iteration: " + this.currentIteration );
   
//    if (hdfs_recordBatch.getFirst().numRows() > 0) {
//    do  {
   
    if ( TrainingState.PRE_TRAIN == this.currentTrainingState ) {
   
      System.out.println("Worker > PRE TRAIN! " );
     
       if ( this.hdfs_fetcher.hasNext() ) {
       
        
        
        hdfs_recordBatch = this.hdfs_fetcher.next();

        System.out.println("Worker > Has Next! > Recs: " + hdfs_recordBatch.getFirst().numRows() );
       
        // check for the straggler batch condition
        if (0 == this.currentIteration && hdfs_recordBatch.getFirst().numRows() > 0 && hdfs_recordBatch.getFirst().numRows() < this.batchSize) {
         
        //  System.out.println( "Worker > Straggler Batch Condition!" );
         
          // ok, only in this situation do we lower the batch size
          this.batchSize = hdfs_recordBatch.getFirst().numRows();

          // re-setup the dataset iterator
          try {
            this.hdfs_fetcher = new MnistHDFSDataSetIterator( this.batchSize, this.totalTrainingDatasetSize, (TextRecordParser)lineParser );
          } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
          }

        //  System.out.println( "Worker > PreTrain: Setting up for a straggler split... (sub batch size)" );         
        //  System.out.println( "New batch size: " + this.batchSize );
        } else {
         
        //  System.out.println( "Worker > NO Straggler Batch Condition!" );
         
        }
       
        if (hdfs_recordBatch.getFirst().numRows() > 0) {
         
          if (hdfs_recordBatch.getFirst().numRows() < this.batchSize) {
           
          //  System.out.println( "Worker > PreTrain: [Jagged End of Split: Skipped] Processed Total " + recordsProcessed + " Total Time " + watch.toString() );
           
           
          } else {
           
          //  System.out.println( "Worker > Normal Processing!" );
           
            // calc stats on number records processed
            recordsProcessed += hdfs_recordBatch.getFirst().numRows();
           
            //System.out.println( "PreTrain: Batch Size: " + hdfs_recordBatch.getFirst().numRows() );
           
            batchWatch.reset();
           
            batchWatch.start();
       
            this.dbn.preTrain( hdfs_recordBatch.getFirst(), 1, this.learningRate, this.preTrainEpochs);
           
            batchWatch.stop();
   
            System.out.println( "Worker > PreTrain: Batch Mode, Processed Total " + recordsProcessed + ", Batch Time " + batchWatch.toString() + " Total Time " + watch.toString() );
   
          } // if
         
         
        } else {
       
          // in case we get a blank line
          System.out.println( "Worker > PreTrain > Idle pass, no records left to process in phase" );
         
        }
       
      } else {
       
        System.out.println( "Worker > PreTrain > Idle pass, no records left to process in phase" );
       
      }
     
    //  System.out.println( "Worker > Check PreTrain completion > completedEpochs: " + this.completedDatasetEpochs + ", preTrainDatasetPasses: " + this.preTrainDatasetPasses );
     
      // check for completion of split, to signal master on state change
      if (false == this.hdfs_fetcher.hasNext() && this.completedDatasetEpochs + 1 >= this.preTrainDatasetPasses ) {
       
        this.preTrainPhaseComplete = true;
      //  System.out.println( "Worker > Completion of pre-train phase" );
       
      }
     
         
   
    } else if ( TrainingState.FINE_TUNE == this.currentTrainingState) {
     
      //System.out.println( "DBN Network Stats:\n" + dbn.generateNetworkSizeReport() );

      if ( this.hdfs_fetcher.hasNext() ) {
       
        hdfs_recordBatch = this.hdfs_fetcher.next();
       
        if (hdfs_recordBatch.getFirst().numRows() > 0) {
         
          if (hdfs_recordBatch.getFirst().numRows() < this.batchSize) {
           
          //  System.out.println( "Worker > FineTune: [Jagged End of Split: Skipped] Processed Total " + recordsProcessed + " Total Time " + watch.toString() );

          } else {
           
            batchWatch.reset();
           
            batchWatch.start();
           
            this.dbn.finetune( hdfs_recordBatch.getSecond(), learningRate, fineTuneEpochs );
           
            batchWatch.stop();
           
            System.out.println( "Worker > FineTune > Batch Mode, Processed Total " + recordsProcessed + ", Batch Time " + batchWatch.toString() + " Total Time " + watch.toString() );
           
View Full Code Here

TOP

Related Classes of tv.floe.metronome.deeplearning.datasets.DataSet

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.