Package tv.floe.metronome.classification.neuralnetworks.core.neurons

Source Code of tv.floe.metronome.classification.neuralnetworks.core.neurons.Neuron

package tv.floe.metronome.classification.neuralnetworks.core.neurons;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;

import tv.floe.metronome.classification.neuralnetworks.activation.ActivationFunction;
import tv.floe.metronome.classification.neuralnetworks.conf.Config;
import tv.floe.metronome.classification.neuralnetworks.core.Connection;
import tv.floe.metronome.classification.neuralnetworks.core.Layer;
import tv.floe.metronome.classification.neuralnetworks.core.Weight;
import tv.floe.metronome.classification.neuralnetworks.input.InputFunction;
import tv.floe.metronome.classification.neuralnetworks.input.WeightedSum;
import tv.floe.metronome.classification.neuralnetworks.activation.Sigmoid;
import tv.floe.metronome.classification.neuralnetworks.activation.Step;
import tv.floe.metronome.classification.neuralnetworks.activation.ActivationFunction;

public class Neuron implements Serializable {
 
  public Layer parentLayer = null;
  public ArrayList<Connection> inConnections = null;
  public ArrayList<Connection> outConnections = null;
 
  protected InputFunction inputFunction;
  protected ActivationFunction activationFunction;
    double netInput = 0;
         
 
  protected transient double totalNetInputFunctionInput = 0;

  protected transient double output = 0;

  protected transient double error = 0;
 
  public Neuron() {
   
    this.inputFunction = new WeightedSum();
    this.activationFunction = new Step();
   
   
    this.inConnections = new ArrayList<Connection>();
    this.outConnections = new ArrayList<Connection>();
   
   
  }

 
  public Neuron( InputFunction inputFunc, ActivationFunction actFunc) {
   
    this.inConnections = new ArrayList<Connection>();
    this.outConnections = new ArrayList<Connection>();
   
    this.inputFunction = inputFunc;
    this.activationFunction = actFunc;
   
   
  }
 
 
  public static Neuron createNeuron(Config c, int layerIndex) {
   
    Neuron n = null;
    ActivationFunction tf = null;
   
    if (0 == layerIndex) {
      n = new InputNeuron();
     
    } else {
      n = new Neuron(new WeightedSum(), new Sigmoid());
    }
   
   
   
    return n;
   
   
  }
 
  public Weight[] getWeights() {
   
    Weight[] weights = new Weight[inConnections.size()];
    for(int i = 0; i< inConnections.size(); i++) {
      weights[i] = inConnections.get(i).getWeight();
    }
    return weights;

   
  }
 
  public void randomizeWeights() {
   
    for(Connection connection : this.inConnections) {
     
      connection.getWeight().randomize();
     
    }   
   
   
  }
 
  public void randomizeWeights(double min, double max) {
    for(Connection connection : this.inConnections) {
      connection.getWeight().randomize(min, max);
    }
  } 
 
  public void initWeights(double val) {
   
        for(Connection connection : this.inConnections) {
         
            connection.getWeight().setValue(val);
           
       }
   
  }
 
  public void reset() {
   
   
  }
 
  /**
   * Calculate the output, cache the value to be quickly loaded by backprop
   *
   */
  public void calcOutput() {
   
        if ((this.inConnections.size() > 0)) {
            this.netInput = this.inputFunction.getOutput(this.inConnections);
        }

        this.output = this.activationFunction.getOutput(this.netInput);
   
  }
 
    public void removeAllInputConnections() {
        // run through all input connections
        for(int i = 0; i < inConnections.size(); i++) {
          inConnections.get(i).getFromNeuron().removeOutputConnection(inConnections.get(i));   
          inConnections.set(i,  null );                                  
        }
       
        this.inConnections = new ArrayList<Connection>();
    }
   
    public void removeAllOutputConnections() {
        for(int i = 0; i <  outConnections.size(); i++) {
            outConnections.get(i).getToNeuron().removeInputConnection(outConnections.get(i));
            outConnections.set( i, null );
        }           
        this.outConnections = new ArrayList<Connection>();               
    }
   
    public void removeAllConnections() {
        removeAllInputConnections();
        removeAllOutputConnections();
   
 
 
    public boolean hasOutputConnectionTo(Neuron neuron) {
        for(Connection connection : outConnections) {
            if (connection.getToNeuron() == neuron) {
                return true;
            }
        }           
        return false;           
    }
   
    public boolean hasInputConnectionFrom(Neuron neuron) {
        for(Connection connection : inConnections) {
            if (connection.getFromNeuron() == neuron) {
                return true;
            }
        }           
        return false;           
    }
   
   
 
  public void addInConnection(Connection connection) throws Exception {    
        if (connection == null) {
            throw new Exception("Attempt to add null connection to neuron!");
        }
             
        if (connection.getToNeuron() != this) {
            throw new Exception("Cannot add input connection - bad toNeuron specified!");
        }
       
        if (this.hasInputConnectionFrom(connection.getFromNeuron())) {
          System.out.println( "> already connected!" );
            return;
        }           
       
        this.inConnections.add(connection);
       
        Neuron fromNeuron = connection.getFromNeuron();
        fromNeuron.addOutConnection(connection);                   
  }
 
  public void addInConnection(Neuron fromNeuron) throws Exception {
    Connection connection = new Connection(fromNeuron, this);
    this.addInConnection(connection);
  }
 
  public void addInConnection(Neuron fromNeuron, double weightVal) throws Exception {
    Connection connection = new Connection(fromNeuron, this, weightVal);
    this.addInConnection(connection);
 
 
  private void addOutConnection(Connection c) {

    this.outConnections.add(c);
   
  }
 
 
 
  public ArrayList<Connection> getInConnections() {
    return this.inConnections;
  }

  public ArrayList<Connection> getOutConnections() {
    return this.outConnections;
  }
 
 
  public double getOutput() {
   
    return this.output;
   
  }
 
  public void setInput(double input) {
    this.netInput = input;
  }

  public double getNetInput() {
    return this.netInput;
  }
 
  public double getError() {
    return error;
  }

  public void setError(double error) {
    this.error = error;
  }
 
 
  public void setParentLayer(Layer p) {
    this.parentLayer = p;
  }

  public Layer getParentLayer() {
    return this.parentLayer;
  }
 
  public ActivationFunction getActivationFunction() {
    return this.activationFunction;
  }

 
 
    protected void removeInputConnection(Connection conn) {
     
      this.inConnections.remove(conn);

    }
   
    protected void removeOutputConnection(Connection conn) {
     
      this.outConnections.remove(conn);

    }       

  public void removeInputConnectionFrom(Neuron fromNeuron) {

        for (int x = 0; x < inConnections.size(); x++) {

          if (inConnections.get(x).getFromNeuron() == fromNeuron) {
           
            fromNeuron.removeOutputConnection(inConnections.get(x));
           
                this.removeInputConnection(inConnections.get(x));
               
                return;
               
          } // if
         
        } // for
         
  }
     
  public void removeOutputConnectionTo(Neuron toNeuron) {
   
        for(int i = 0; i < outConnections.size(); i++) {
         
      if (outConnections.get(i).getToNeuron() == toNeuron) {
       
        toNeuron.removeInputConnection( outConnections.get( i ) );   
        this.removeOutputConnection( outConnections.get( i ) );
       
        return;
       
      }
     
        }  
       
  }       
 
 
 

}
TOP

Related Classes of tv.floe.metronome.classification.neuralnetworks.core.neurons.Neuron

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.