Package tv.floe.metronome.classification.neuralnetworks.iterativereduce

Source Code of tv.floe.metronome.classification.neuralnetworks.iterativereduce.NeuralNetworkWeightsDelta

package tv.floe.metronome.classification.neuralnetworks.iterativereduce;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

//import org.apache.mahout.math.Matrix;
//import org.apache.mahout.math.MatrixWritable;

import tv.floe.metronome.classification.neuralnetworks.core.NeuralNetwork;

/**
* TODO:
* - move the connection weights into this structure
* - provide a custom serde for the connection weights
*  
*
*
* @author josh
*
*/
public class NeuralNetworkWeightsDelta {
 
  public NeuralNetwork network = null;
 
    //public int SrcWorkerPassCount = 0;
   
//    public Matrix parameter_vector = null;
    public int GlobalPassCount = 0; // what pass should the worker dealing with?
   
    public int IterationComplete = 0; // 0 = no, 1 = yes
    public int CurrentIteration = 0;
   
    public int TrainedRecords = 0;
//    public float AvgLogLikelihood = 0;
    public float PercentCorrect = 0;
    public double RMSE = 0.0;
 
    public byte[] Serialize() throws IOException {
       
        // DataOutput d
       
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        DataOutput d = new DataOutputStream(out);
       
        // d.writeUTF(src_host);
        //d.writeInt(this.SrcWorkerPassCount);
        d.writeInt(this.GlobalPassCount);
       
        d.writeInt(this.IterationComplete);
        d.writeInt(this.CurrentIteration);
       
        d.writeInt(this.TrainedRecords);
        //d.writeFloat(this.AvgLogLikelihood);
        d.writeFloat(this.PercentCorrect);
        d.writeDouble(this.RMSE);
       
        //d.write
       
        // buf.write
        // MatrixWritable.writeMatrix(d, this.worker_gradient.getMatrix());
        //MatrixWritable.writeMatrix(d, this.parameter_vector);
        // MatrixWritable.
        ObjectOutputStream oos = new ObjectOutputStream(out);
       
        //System.out.println("Worker:Serialize() > " + this.network.getClass());
       
        oos.writeObject( this.network );
       
        oos.flush();
        oos.close();

       
        return out.toByteArray();
      }
     
      public void Deserialize(byte[] bytes) throws IOException {
        // DataInput in) throws IOException {
       
        ByteArrayInputStream b = new ByteArrayInputStream(bytes);
        DataInput in = new DataInputStream(b);
        // this.src_host = in.readUTF();
        //this.SrcWorkerPassCount = in.readInt();
        this.GlobalPassCount = in.readInt();
       
        this.IterationComplete = in.readInt();
        this.CurrentIteration = in.readInt();
       
        this.TrainedRecords = in.readInt(); // d.writeInt(this.TrainedRecords);
        //this.AvgLogLikelihood = in.readFloat(); // d.writeFloat(this.AvgLogLikelihood);
        this.PercentCorrect = in.readFloat(); // d.writeFloat(this.PercentCorrect);
        this.RMSE = in.readDouble();

         ObjectInputStream oistream = null;

          try {

              oistream = new ObjectInputStream(b);
              this.network = (NeuralNetwork) oistream.readObject();

          } catch (IOException ioe) {
              ioe.printStackTrace();
          } catch (ClassNotFoundException cnfe) {
              cnfe.printStackTrace();
          } finally {
              if (oistream != null) {
                  try {
                      oistream.close();
                  } catch (IOException ioe) {
                  }
              }
          }       
       
       
      }   

}
TOP

Related Classes of tv.floe.metronome.classification.neuralnetworks.iterativereduce.NeuralNetworkWeightsDelta

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.