Package com.github.neuralnetworks.input

Source Code of com.github.neuralnetworks.input.MultipleNeuronsOutputError$OutputTargetTuple

package com.github.neuralnetworks.input;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.github.neuralnetworks.calculation.OutputError;
import com.github.neuralnetworks.tensor.Matrix;
import com.github.neuralnetworks.tensor.Tensor;

public class MultipleNeuronsOutputError implements OutputError {

    private static final long serialVersionUID = 1L;

    private List<OutputTargetTuple> tuples;
    private Map<Integer, Integer> outputToTarget;
    private int nullCount;
    private int dim;

    public MultipleNeuronsOutputError() {
  super();
  reset();
    }

    @Override
    public void addItem(Tensor newtorkOutput, Tensor targetOutput) {
  Matrix target = (Matrix) targetOutput;
  Matrix actual = (Matrix) newtorkOutput;
 
  if (dim == -1) {
      dim = target.getRows();
  }

  if (!Arrays.equals(actual.getDimensions(), target.getDimensions())) {
      throw new IllegalArgumentException("Dimensions don't match");
  }

  for (int i = 0; i < target.getColumns(); i++) {
      boolean hasDifferentValues = false;
      for (int j = 0; j < actual.getRows(); j++) {
    if (actual.get(j, i) != actual.get(0, i)) {
        hasDifferentValues = true;
        break;
    }
      }

      if (hasDifferentValues) {
    int targetPos = 0;
    for (int j = 0; j < target.getRows(); j++) {
        if (target.get(j, i) == 1) {
      targetPos = j;
      break;
        }
    }

    int outputPos = 0;
    float max = actual.get(0, i);
    for (int j = 0; j < actual.getRows(); j++) {
        if (actual.get(j, i) > max) {
      max = actual.get(j, i);
      outputPos = j;
        }
    }

    tuples.add(new OutputTargetTuple(outputPos, targetPos));
      } else {
    nullCount++;
      }
  }
    }

    @Override
    public float getTotalNetworkError() {
  return getTotalInputSize() > 0 ? ((float) getTotalErrorSamples()) / getTotalInputSize() : 0;
    }

    @Override
    public int getTotalErrorSamples() {
  if (outputToTarget == null) {
      outputToTarget = outputToTarget();
  }

  int errorSamples = 0;
  for (OutputTargetTuple t : tuples) {
      if (!outputToTarget.get(t.outputPos).equals(t.targetPos)) {
    errorSamples++;
      }
  }

  return nullCount + errorSamples;
    }

    @Override
    public int getTotalInputSize() {
  return tuples.size() + nullCount;
    }

    private Map<Integer, Integer> outputToTarget() {
  Map<Integer, Integer> result = new HashMap<>();
  Map<Integer, int[]> targetToOutput = new HashMap<>();
  for (OutputTargetTuple t : tuples) {
      if (!targetToOutput.containsKey(t.targetPos)) {
    targetToOutput.put(t.targetPos, new int[dim]);
      }

      targetToOutput.get(t.targetPos)[t.outputPos]++;
  }

  for (int i = 0; i < dim; i++) {
      int[] d = targetToOutput.get(i);
      if (d != null) {
    int max = 0;
    for (int j = 0; j < dim; j++) {
        if (d[j] > d[max] && !result.values().contains(j)) {
      max = j;
        }
    }

    result.put(i, max);
      }
  }

  return result;
    }

    @Override
    public void reset() {
  this.tuples = new ArrayList<>();
  this.dim = -1;
  this.nullCount = 0;
  this.outputToTarget = null;
    }

    private static class OutputTargetTuple {

  public OutputTargetTuple(Integer outputPos, Integer targetPos) {
      this.outputPos = outputPos;
      this.targetPos = targetPos;
  }

  public Integer outputPos;
  public Integer targetPos;
    }
}
TOP

Related Classes of com.github.neuralnetworks.input.MultipleNeuronsOutputError$OutputTargetTuple

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.