Package tv.floe.metronome.deeplearning.dbn.model.evaluation

Source Code of tv.floe.metronome.deeplearning.dbn.model.evaluation.ModelTester

package tv.floe.metronome.deeplearning.dbn.model.evaluation;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Properties;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.TextInputFormat;
import org.apache.mahout.math.Matrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.cloudera.iterativereduce.io.TextRecordParser;

import tv.floe.metronome.berkley.Pair;
import tv.floe.metronome.deeplearning.datasets.DataSet;
import tv.floe.metronome.deeplearning.datasets.iterator.BaseDatasetIterator;
import tv.floe.metronome.deeplearning.datasets.iterator.impl.MnistDataSetIterator;
import tv.floe.metronome.deeplearning.datasets.iterator.impl.MnistHDFSDataSetIterator;
import tv.floe.metronome.deeplearning.dbn.DeepBeliefNetwork;
import tv.floe.metronome.deeplearning.neuralnetwork.core.BaseMultiLayerNeuralNetworkVectorized;
import tv.floe.metronome.eval.Evaluation;

/**
* Model tester build specifically for Deep Belief Networks
*
* @author josh
*
*/
public class ModelTester {
 

  private static Logger log = LoggerFactory.getLogger(ModelTester.class);
 
  public static String model_path = "";
  public static String test_input_data_path = "";
 
 


    private static JobConf defaultConf = new JobConf();
    private static FileSystem localFs = null;
    static {
      try {
        defaultConf.set("fs.defaultFS", "file:///");
        localFs = FileSystem.getLocal(defaultConf);
      } catch (IOException e) {
        throw new RuntimeException("init failure", e);
      }
    }
 
 
  private static InputSplit[] generateDebugSplits(Path input_path, JobConf job) {

    long block_size = localFs.getDefaultBlockSize();

    System.out.println("default block size: " + (block_size / 1024 / 1024)
        + "MB");

    // ---- set where we'll read the input files from -------------
    FileInputFormat.setInputPaths(job, input_path);

    // try splitting the file in a variety of sizes
    TextInputFormat format = new TextInputFormat();
    format.configure(job);

    int numSplits = 1;

    InputSplit[] splits = null;

    try {
      splits = format.getSplits(job, numSplits);
    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }

    return splits;

 
 
 
 
  public static void evaluateModel( BaseDatasetIterator iterator, BaseMultiLayerNeuralNetworkVectorized model ) throws IOException {
   

    Evaluation eval = new Evaluation();
    //BaseMultiLayerNeuralNetworkVectorized load = BaseMultiLayerNeuralNetworkVectorized.loadFromFile(new FileInputStream(new File(modelLocation)));
   
    while (iterator.hasNext()) {
     
      DataSet inputs = iterator.next();

      Matrix in = inputs.getFirst();
      Matrix outcomes = inputs.getSecond();
      Matrix predicted = model.predict(in);
      eval.eval( outcomes, predicted );
     
    }
   
   
   
    log.warn( "evaluateModel" );
    log.info( eval.stats() );   
   
    //writeReportToDisk( eval, pathForReport );
   
 
 
 
 
  public static void evaluateModel( Matrix inputs, Matrix labels, BaseMultiLayerNeuralNetworkVectorized model ) throws IOException {
   

    Evaluation eval = new Evaluation();
    //BaseMultiLayerNeuralNetworkVectorized load = BaseMultiLayerNeuralNetworkVectorized.loadFromFile(new FileInputStream(new File(modelLocation)));
   
//    while (iterator.hasNext()) {
     
//      DataSet inputs = iterator.next();

//      Matrix in = inputs.getFirst();
//      Matrix outcomes = inputs.getSecond();

    Matrix predicted = model.predict(inputs);
   
    eval.eval( labels, predicted );
     
//    }
   
   
   
    log.warn( "evaluateModel" );
    log.info( eval.stats() );   
   
    //writeReportToDisk( eval, pathForReport );
   
  }   
 
  public static void evaluateSavedModel( BaseDatasetIterator iterator, String modelLocation, String pathForReport ) throws IOException {
   

    Evaluation eval = new Evaluation();
    BaseMultiLayerNeuralNetworkVectorized load = BaseMultiLayerNeuralNetworkVectorized.loadFromFile(new FileInputStream(new File(modelLocation)));
   
    while (iterator.hasNext()) {
     
      DataSet inputs = iterator.next();

      Matrix in = inputs.getFirst();
      Matrix outcomes = inputs.getSecond();
      Matrix predicted = load.predict(in);
      eval.eval( outcomes, predicted );
     
    }
   
   
   
   
    log.info( eval.stats() );   
   
    writeReportToDisk( eval, pathForReport );
   
  }
 
  public static void evaluateModel(String propsFilepath, Configuration hadoopConf, int batchSize ) throws Exception {

   
    parsePropertiesFile(propsFilepath);
   
   
 
 
 
    // setup splits ala HDFS style -------------------
   
      JobConf job = new JobConf(defaultConf);
      if (null != hadoopConf ) {
        job = new JobConf( hadoopConf );
      } else {
        hadoopConf = defaultConf;
      }
     
      Path workDir = new Path( test_input_data_path );
   
   
      InputSplit[] splits = generateDebugSplits(workDir, job);
     
      System.out.println( "> splits: " + splits[0].toString() );

     
      TextRecordParser txt_reader = new TextRecordParser();

      // drop the fully qualified host
      //String minus_host = splits[0].toString().split( "" )
     
      long len = 0;
      String path = "";
     
      // YES, more clever regex here would be clever.
      // I'm on a plane right now and lazy. sue me.

      if ( splits[0].toString().startsWith( "file:///" ) ) {
       
        path = splits[0].toString().replaceFirst("file:///", "/").split(":")[0];
       
        len = splits[0].getLength();
       
      } else if ( splits[0].toString().startsWith( "file://" ) ) {
         
          path = splits[0].toString().replaceFirst("file://", "/").split(":")[0];
         
          len = splits[0].getLength();
     
      } else if ( splits[0].toString().startsWith( "file:/" ) ) {
       
        path = splits[0].toString().replaceFirst("file:/", "/").split(":")[0];
       
        len = splits[0].getLength();
     
      } else if ( splits[0].toString().substring(1).equals( "/" ) ) {
     
        // we have something like:
        // /user/cloudera/metronome/dbn/mnist/mnist_filtered_conversion_test.metronome:0+20908
       
        len = Integer.parseInt(splits[0].toString().split(":")[2]
            .split("\\+")[1]);
       
        //System.out.println("Split Length: " + len + ", versus method: " + splits[0].getLength() );
       
        path = splits[0].toString().split(":")[0];
       
      } else if (splits[0].toString().substring(8).equals( "hdfs:///" ) ) {
       
        // we have something like:
        // hdfs:///user/cloudera/metronome/dbn/mnist/mnist_filtered_conversion_test.metronome:0+20908
       
       
        len = Integer.parseInt(splits[0].toString().split(":")[2]
            .split("\\+")[1]);

        //System.out.println( "number of slices from ':' " + splits[0].toString().split(":").length );
       
        //System.out.println("Split Length: " + len + ", versus method: " + splits[0].getLength() );
       
        path = splits[0].toString().split(":")[1];
       
      } else {
       
        // we have something like:
        // hdfs://localhost.localdomain:8020/user/cloudera/metronome/dbn/mnist/mnist_filtered_conversion_test.metronome:0+20908
       
        len = Integer.parseInt(splits[0].toString().split(":")[3]
            .split("\\+")[1]);

        //System.out.println( "number of slices from ':' " + splits[0].toString().split(":").length );
       
        //System.out.println("Split Length: " + len + ", versus method: " + splits[0].getLength() );
       
        String[] parts = splits[0].toString().split(":");
       
        path = parts[0] + ":" + parts[1] + ":" + parts[2];
       
       
      }
     
      //System.out.println("raw path: " + path);
     

    //  txt_reader.setFile(splits[0].toString().split(":")[1], 0, len);   
      txt_reader.setFile( path, 0, len);
   
         

     
    MnistHDFSDataSetIterator hdfs_fetcher = new MnistHDFSDataSetIterator( batchSize, 1, txt_reader );
    DataSet hdfs_recordBatch = hdfs_fetcher.next();
   
    Matrix hdfs_input = hdfs_recordBatch.getFirst();
    Matrix hdfs_labels = hdfs_recordBatch.getSecond();     
   
    //ModelTester.evaluateModel( recordBatch.getFirst(), recordBatch.getSecond(), dbn);
     
   
    int[] hiddenLayerSizes = new int[] {2,2,2};
   
    if (model_path.startsWith("file://")) {
      model_path = model_path.replaceFirst("file://", "");
    }
   
    System.out.println("Evaluating DBN Model Saved at: " + model_path );
   
    Path modelInputPath = new Path( model_path );
   
  //  System.out.println("Evaluating DBN Model Saved at: " + modelInputPath.getName() );
  //  System.out.println("Evaluating DBN Model Saved at: " + modelInputPath.toString() );
  //  System.out.println("Evaluating DBN Model Saved at: " + modelInputPath.toUri().toString() );
   
//    FileInputStream oFileInputStream = new FileInputStream( modelInputPath.toUri().toString().replaceAll( "hdfs:///" , "hdfs://localhost.localdomain:8020/" ) );
   
    Path model_new_path = new Path( model_path );
      FileSystem fs = model_new_path.getFileSystem( hadoopConf );
      FSDataInputStream inputStream = fs.open( model_new_path);   
   
   
    DeepBeliefNetwork dbn_deserialize = new DeepBeliefNetwork(1, hiddenLayerSizes, 1, hiddenLayerSizes.length, null ); //, Matrix input, Matrix labels);
    //dbn_deserialize.load(oFileInputStream);
    dbn_deserialize.load( inputStream );
   
    evaluateModel( hdfs_input, hdfs_labels, dbn_deserialize );
     
     
   
  }
 

  public static void parsePropertiesFile(String app_properties_file) throws Exception {
   
    Properties props = new Properties();
    // Configuration conf = getConf();

    try {
      FileInputStream fis = new FileInputStream(app_properties_file);
      props.load(fis);
      fis.close();
    } catch (FileNotFoundException ex) {
      // throw ex; // TODO: be nice
      System.out.println(ex);
    } catch (IOException ex) {
      // throw ex; // TODO: be nice
      System.out.println(ex);
   
   
    model_path = props.getProperty("app.output.path");
    if (null == model_path) {
      throw new Exception("Can't find the model output path in the properites file!");
    }

/*    schema = props.getProperty("tv.floe.metronome.neuralnetwork.conf.InputRecordSchema");
    if (null == schema) {
      throw new Exception("Can't find the input record schema in the properites file!");
    }
*/
    test_input_data_path = props.getProperty("tv.floe.metronome.evaluate.dataset.path");
    if (null == test_input_data_path) {
      throw new Exception("Can't find the eval/test recordset in the properites file!");
    }

   
  } 
   
 
  public static void writeReportToDisk( Evaluation eval, String fileLocation ) throws IOException {
   
    // open files somewhere
   
    File yourFile = new File(fileLocation);
    if(!yourFile.exists()) {
        yourFile.createNewFile();
    }
    FileOutputStream oFile = new FileOutputStream(fileLocation, false);
   
    oFile.write(eval.stats().getBytes() );
   
    oFile.close();
   
   
  }
 
 
  /**
   * @param args
   * @throws IOException
   */
/*  public static void main(String[] args) throws IOException {
   
    MnistDataSetIterator iter = new MnistDataSetIterator(10, 60000);
   
    Evaluation eval = new Evaluation();
    BaseMultiLayerNeuralNetworkVectorized load = BaseMultiLayerNeuralNetworkVectorized.loadFromFile(new FileInputStream(new File(args[0])));
   
    while (iter.hasNext()) {
     
      DataSet inputs = iter.next();

      Matrix in = inputs.getFirst();
      Matrix outcomes = inputs.getSecond();
      Matrix predicted = load.predict(in);
      eval.eval( outcomes, predicted );
     
    }
   
   
   
   
    log.info( eval.stats() );
  } 
*/ 

TOP

Related Classes of tv.floe.metronome.deeplearning.dbn.model.evaluation.ModelTester

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.