Package org.encog.workbench.tabs.visualize.structure

Source Code of org.encog.workbench.tabs.visualize.structure.StructureTab

/*
* Encog(tm) Workbench v3.0
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2011 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.workbench.tabs.visualize.structure;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Paint;
import java.awt.Point;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import javax.swing.BorderFactory;
import javax.swing.JButton;
import javax.swing.JPanel;
import javax.swing.border.Border;

import org.apache.commons.collections15.Transformer;
import org.encog.ml.MLMethod;
import org.encog.neural.flat.FlatNetwork;
import org.encog.neural.neat.NEATLink;
import org.encog.neural.neat.NEATNetwork;
import org.encog.neural.neat.NEATNeuron;
import org.encog.neural.networks.BasicNetwork;
import org.encog.workbench.WorkBenchError;
import org.encog.workbench.tabs.EncogCommonTab;

import edu.uci.ics.jung.algorithms.layout.StaticLayout;
import edu.uci.ics.jung.graph.Graph;
import edu.uci.ics.jung.graph.SparseMultigraph;
import edu.uci.ics.jung.graph.util.EdgeType;
import edu.uci.ics.jung.visualization.GraphZoomScrollPane;
import edu.uci.ics.jung.visualization.Layer;
import edu.uci.ics.jung.visualization.VisualizationViewer;
import edu.uci.ics.jung.visualization.control.AbstractModalGraphMouse;
import edu.uci.ics.jung.visualization.control.CrossoverScalingControl;
import edu.uci.ics.jung.visualization.control.DefaultModalGraphMouse;
import edu.uci.ics.jung.visualization.control.ScalingControl;
import edu.uci.ics.jung.visualization.decorators.ToStringLabeller;
import edu.uci.ics.jung.visualization.renderers.Renderer;

public class StructureTab extends EncogCommonTab {

  private VisualizationViewer<DrawnNeuron, DrawnConnection> vv;
 
  public StructureTab(MLMethod method) {
    super(null);
   
    // Graph<V, E> where V is the type of the vertices
    // and E is the type of the edges
    Graph<DrawnNeuron, DrawnConnection> g = null;
   
    if( method instanceof BasicNetwork ) {
      BasicNetwork network = (BasicNetwork)method;
      g = buildGraph(network.getStructure().getFlat());
    } else if( method instanceof NEATNetwork ) {
      NEATNetwork neat = (NEATNetwork)method;
      g = buildGraph(neat);
    }
   
    if( g==null ) {
      throw new WorkBenchError("Can't visualize network: " + method.getClass().getSimpleName());
    }

    Transformer<DrawnNeuron, Point2D> staticTranformer = new Transformer<DrawnNeuron, Point2D>() {

      public Point2D transform(DrawnNeuron n) {
        int x = (int) (n.getX() * 600);
        int y = (int) (n.getY() * 300);

        Point2D result = new Point(x + 32, y);
        return result;
      }
    };

    Transformer<DrawnNeuron, Paint> vertexPaint = new Transformer<DrawnNeuron, Paint>() {
      public Paint transform(DrawnNeuron neuron) {
        switch (neuron.getType()) {
        case Bias:
          return Color.yellow;
        case Input:
          return Color.white;
        case Output:
          return Color.green;
        case Context:
          return Color.cyan;
        default:
          return Color.red;
        }
      }

    };
   
    Transformer<DrawnConnection, Paint> edgePaint = new Transformer<DrawnConnection, Paint>() {
      public Paint transform(DrawnConnection connection) {
        if( connection.isContext() ) {
          return Color.lightGray;
        } else {
          return Color.black;
        }
      }
    };

    // The Layout<V, E> is parameterized by the vertex and edge types
    StaticLayout<DrawnNeuron, DrawnConnection> layout = new StaticLayout<DrawnNeuron, DrawnConnection>(
        g, staticTranformer);
 

    layout.setSize(new Dimension(5000,5000)); // sets the initial size of the space
    // The BasicVisualizationServer<V,E> is parameterized by the edge types
    //BasicVisualizationServer<DrawnNeuron, DrawnConnection> vv = new BasicVisualizationServer<DrawnNeuron, DrawnConnection>(
    //    layout);
   
    //Dimension d = new Dimension(600,600);
   
    vv =  new VisualizationViewer<DrawnNeuron, DrawnConnection>(layout);
   
    //vv.setPreferredSize(d); //Sets the viewing area size

    vv.getRenderer().getVertexLabelRenderer()
        .setPosition(Renderer.VertexLabel.Position.CNTR);
    vv.getRenderContext().setVertexLabelTransformer(new ToStringLabeller());
    vv.getRenderContext().setVertexFillPaintTransformer(vertexPaint);
    vv.getRenderContext().setEdgeDrawPaintTransformer(edgePaint);
    vv.getRenderContext().setArrowDrawPaintTransformer(edgePaint);
    vv.getRenderContext().setArrowFillPaintTransformer(edgePaint);
   
    vv.setVertexToolTipTransformer(new ToStringLabeller());
   
    vv.setVertexToolTipTransformer(new Transformer<DrawnNeuron,String>() {
      public String transform(DrawnNeuron edge) {
        return edge.getToolTip();
      }});
   
    vv.setEdgeToolTipTransformer(new Transformer<DrawnConnection,String>() {
      public String transform(DrawnConnection edge) {
        return edge.getToolTip();
      }});
   
    final GraphZoomScrollPane panel = new GraphZoomScrollPane(vv);
    this.setLayout(new BorderLayout());
        add(panel, BorderLayout.CENTER);
        final AbstractModalGraphMouse graphMouse = new DefaultModalGraphMouse();
        vv.setGraphMouse(graphMouse);
       
        vv.addKeyListener(graphMouse.getModeKeyListener());
       
        final ScalingControl scaler = new CrossoverScalingControl();       

        JButton plus = new JButton("+");
        plus.addActionListener(new ActionListener() {
            public void actionPerformed(ActionEvent e) {
                scaler.scale(vv, 1.1f, vv.getCenter());
            }
        });
        JButton minus = new JButton("-");
        minus.addActionListener(new ActionListener() {
            public void actionPerformed(ActionEvent e) {
                scaler.scale(vv, 1/1.1f, vv.getCenter());
            }
        });

        JButton reset = new JButton("reset");
        reset.addActionListener(new ActionListener() {

      public void actionPerformed(ActionEvent e) {
        vv.getRenderContext().getMultiLayerTransformer().getTransformer(Layer.LAYOUT).setToIdentity();
        vv.getRenderContext().getMultiLayerTransformer().getTransformer(Layer.VIEW).setToIdentity();
      }});

        JPanel controls = new JPanel();
        controls.setLayout(new FlowLayout(FlowLayout.LEFT));
        controls.add(plus);
        controls.add(minus);
        controls.add(reset);
        Border border = BorderFactory.createEtchedBorder();
        controls.setBorder(border);
        add(controls, BorderLayout.NORTH);
       
       
  }

  private Graph<DrawnNeuron, DrawnConnection> buildGraph(NEATNetwork neat) {
   
    int inputCount = 1;
    int outputCount = 1;
    int hiddenCount = 1;
    int biasCount = 1;


    List<DrawnNeuron> neurons = new ArrayList<DrawnNeuron>();
    Graph<DrawnNeuron, DrawnConnection> result = new SparseMultigraph<DrawnNeuron, DrawnConnection>();
    List<DrawnNeuron> connections = new ArrayList<DrawnNeuron>();
    Map<NEATNeuron,DrawnNeuron> neuronMap = new HashMap<NEATNeuron,DrawnNeuron>();
   
    // place all the neurons
    for(NEATNeuron neatNeuron : neat.getNeurons() ) {
      String name="";
      DrawnNeuronType t = DrawnNeuronType.Hidden;
     
      switch(neatNeuron.getNeuronType()) {
        case Bias:
          t = DrawnNeuronType.Bias;
          name="B"+(biasCount++);
          break;
        case Input:
          t = DrawnNeuronType.Input;
          name="I"+(inputCount++);
          break;
        case Output:
          t = DrawnNeuronType.Output;
          name="O"+(outputCount++);
          break;
        case Hidden:
          t = DrawnNeuronType.Hidden;
          name="H"+(hiddenCount++);
          break;
      }
     
     
      DrawnNeuron neuron = new DrawnNeuron(t, name, neatNeuron.getSplitX(), neatNeuron.getSplitY());
      neurons.add(neuron);
      neuronMap.put(neatNeuron, neuron);
    }
   
    // place all the connections
    for(NEATNeuron neatNeuron : neat.getNeurons() ) {
      for(NEATLink neatLink: neatNeuron.getOutputboundLinks() ) {
        DrawnNeuron fromNeuron = neuronMap.get(neatLink.getFromNeuron());
        DrawnNeuron toNeuron = neuronMap.get(neatLink.getToNeuron());
        DrawnConnection connection = new DrawnConnection(fromNeuron,toNeuron,neatLink.getWeight());
        fromNeuron.getOutbound().add(connection);
        toNeuron.getInbound().add(connection);
      }
    }

   
   
    for (DrawnNeuron neuron : neurons) {
      result.addVertex(neuron);
      for (DrawnConnection connection : neuron.getOutbound()) {
        result.addEdge(connection, connection.getFrom(),
            connection.getTo(), EdgeType.DIRECTED);
      }
    }

    return result;
  }

  public Graph<DrawnNeuron, DrawnConnection> buildGraph(FlatNetwork flat) {
    int inputCount = 1;
    int outputCount = 1;
    int hiddenCount = 1;
    int biasCount = 1;
    int contextCount = 1;

    int layerCount = flat.getLayerCounts().length;
    List<DrawnNeuron> neurons = new ArrayList<DrawnNeuron>();
    Graph<DrawnNeuron, DrawnConnection> result = new SparseMultigraph<DrawnNeuron, DrawnConnection>();
    List<DrawnNeuron> lastFedNeurons;
    List<DrawnNeuron> connections = new ArrayList<DrawnNeuron>();
    double layerSize = 1.0/layerCount;

   
    int neuronNumber = 1;

    for (int currentLayer = 0; currentLayer < layerCount; currentLayer++) {
      lastFedNeurons = new ArrayList<DrawnNeuron>();

      double x = (double) (layerCount - currentLayer - 1)
          / (double) layerCount;
      int neuronCount = flat.getLayerCounts()[currentLayer];
      int feedCount = flat.getLayerFeedCounts()[currentLayer];
      for (int currentNeuron = 0; currentNeuron < neuronCount; currentNeuron++) {
        DrawnNeuronType type;
        double xOffset = 0;

        String name = "?";
        // not a bias or context
        if (currentNeuron < feedCount) {
          if (currentLayer == 0) {
            type = DrawnNeuronType.Output;
            name = "O" + (outputCount++);
          } else if (currentLayer == (layerCount - 1)) {
            type = DrawnNeuronType.Input;
            name = "I" + (inputCount++);
          } else {
            type = DrawnNeuronType.Hidden;
            name = "H" + (hiddenCount++);
          }
        }
        // is a bias
        else if (currentNeuron == feedCount) {
          type = DrawnNeuronType.Bias;
          name = "B" + (biasCount++);
        }
        // is a context
        else {
          type = DrawnNeuronType.Context;
          name = "C" + (contextCount++);
          xOffset=layerSize/4;
        }

        double y = (double) currentNeuron / (double) neuronCount;

        double margin = ((double) (neuronCount - 1) / (double) neuronCount);
        margin = 1.0 - margin;
        margin /= 2.0;

        DrawnNeuron neuron = new DrawnNeuron(type, name, x+xOffset, y + margin);
        neurons.add(neuron);

        if (neuron.getType() == DrawnNeuronType.Hidden
            || neuron.getType() == DrawnNeuronType.Output) {
          lastFedNeurons.add(neuron);
        }

        int toNeuron = 0;
        int count = connections.size();       
        for (DrawnNeuron connectTo : connections) {         
          int weightIndex = flat.getLayerIndex()[currentLayer]+(toNeuron*count)+currentNeuron;
          double w = 0;// this.flat.getWeights()[weightIndex];
          DrawnConnection connection = new DrawnConnection(neuron, connectTo, w);
          neuron.getOutbound().add(connection);
          neuron.getInbound().add(connection);
          toNeuron++;
        }
      }

      connections = lastFedNeurons;
    }

    for (DrawnNeuron neuron : neurons) {
      result.addVertex(neuron);
      for (DrawnConnection connection : neuron.getOutbound()) {
        result.addEdge(connection, connection.getFrom(),
            connection.getTo(), EdgeType.DIRECTED);
      }     
    }
   

    // draw context links
    for (int currentLayer = 0; currentLayer < layerCount; currentLayer++) {
      if( flat.getContextTargetSize()[currentLayer]>0 ) {
        int count = flat.getContextTargetSize()[currentLayer];
        int offset = flat.getContextTargetOffset()[currentLayer];
        int source = flat.getLayerIndex()[currentLayer];
        for(int i=0;i<count;i++) {
          DrawnNeuron n1 = neurons.get(source+i);
          DrawnNeuron n2 = neurons.get(offset+i);
          DrawnConnection connection = new DrawnConnection(n1, n2, 0);
          result.addEdge(connection, connection.getFrom(),
              connection.getTo(), EdgeType.DIRECTED);
          connection.setContext(true);
        }
      }
    }

    return result;

  }

  @Override
  public String getName() {
    return "Structure: " + this.getEncogObject().getName();
  }
}
TOP

Related Classes of org.encog.workbench.tabs.visualize.structure.StructureTab

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.