Package de.jungblut.classification.tree

Source Code of de.jungblut.classification.tree.DecisionTree

package de.jungblut.classification.tree;

import gnu.trove.iterator.TDoubleIterator;
import gnu.trove.iterator.TIntObjectIterator;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.set.hash.TDoubleHashSet;
import gnu.trove.set.hash.TIntHashSet;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

import org.apache.commons.math3.util.FastMath;
import org.apache.hadoop.io.WritableUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;

import de.jungblut.classification.AbstractClassifier;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleVector;
import de.jungblut.math.tuple.Tuple;

/**
* A decision tree that can be used for classification with numerical or
* categorical features. The tree is built by maximizing information gain using
* the ID3 algorithm. If no featureTypes were supplied, the default is assumed
* to be nominal features at all feature dimensions. <br/>
* Instances can be created by the static factory methods #create().
*
* @author thomasjungblut
*
*/
public final class DecisionTree extends AbstractClassifier {

  private static final double LOG2 = FastMath.log(2);

  private AbstractTreeNode rootNode;
  private FeatureType[] featureTypes;
  private int numRandomFeaturesToChoose;
  private int maxHeight = 25;
  private long seed = System.currentTimeMillis();

  // default is binary classification 0 or 1.
  private boolean binaryClassification = true;
  private boolean compile = false;
  private String compiledName = null;
  private byte[] compiledClass = null;
  private int outcomeDimension;
  private int numFeatures;

  // use the static factory methods!
  private DecisionTree() {
  }

  // serialization constructor
  private DecisionTree(AbstractTreeNode rootNode, FeatureType[] featureTypes,
      boolean binaryClassification, int numFeatures, int outcomeDimension) {
    this.binaryClassification = binaryClassification;
    this.rootNode = rootNode;
    this.featureTypes = featureTypes;
    this.numFeatures = numFeatures;
    this.outcomeDimension = outcomeDimension;
    this.compile = true;
  }

  @Override
  public void train(DoubleVector[] features, DoubleVector[] outcome) {
    Preconditions.checkArgument(features.length == outcome.length,
        "Number of examples and outcomes must match!");
    // assume all nominal if nothing was set
    if (featureTypes == null) {
      featureTypes = new FeatureType[features[0].getDimension()];
      Arrays.fill(featureTypes, FeatureType.NOMINAL);
    }
    Preconditions.checkArgument(
        featureTypes.length == features[0].getDimension(),
        "FeatureType length must match the dimension of the features!");
    binaryClassification = outcome[0].getDimension() == 1;
    if (binaryClassification) {
      outcomeDimension = 2;
    } else {
      outcomeDimension = outcome[0].getDimension();
    }
    numFeatures = features[0].getDimension();
    TIntHashSet possibleFeatureIndices = getPossibleFeatures();
    // recursively build the tree...
    // note that we use linked lists to remove examples we don't need, linked
    // structures do not require costly copy operations after removal
    rootNode = build(Lists.newLinkedList(Arrays.asList(features)),
        Lists.newLinkedList(Arrays.asList(outcome)), possibleFeatureIndices, 0);
    if (compile) {
      try {
        compileTree();
      } catch (Exception e) {
        throw new RuntimeException(e);
      }
    }
  }

  @Override
  public DoubleVector predict(DoubleVector features) {
    int clz = rootNode.predict(features);
    if (clz < 0) {
      // let's assume the default case ("negative") here, instead of making NPEs
      // in other areas, as the callers aren't nullsafe
      clz = 0;
    }
    if (binaryClassification) {
      return new DenseDoubleVector(new double[] { clz });
    } else {
      DoubleVector vec = outcomeDimension > 10 ? new SparseDoubleVector(
          outcomeDimension) : new DenseDoubleVector(outcomeDimension);
      vec.set(clz, 1);
      return vec;
    }
  }

  /**
   * Compiles this current tree representation into byte code and loads it into
   * this class. This is considered faster, as the interpreted code can be
   * optimized by the hotspot JVM.
   *
   * @throws Exception some error might happen during compilation or loading.
   */
  public void compileTree() throws Exception {
    if (compiledClass == null) {
      compiledName = TreeCompiler.generateClassName();
      compiledClass = TreeCompiler.compileNode(compiledName, rootNode);
      rootNode = TreeCompiler.load(compiledName, compiledClass);
    }
  }

  /**
   * @return a random subset of the features.
   */
  TIntHashSet chooseRandomFeatures(TIntHashSet possibleFeatureIndices) {
    if (numRandomFeaturesToChoose > 0
        && numRandomFeaturesToChoose < numFeatures
        && possibleFeatureIndices.size() > numRandomFeaturesToChoose) {
      // make a copy
      TIntHashSet set = new TIntHashSet();
      int[] arr = possibleFeatureIndices.toArray();
      Random rnd = new Random(seed);
      while (set.size() < numRandomFeaturesToChoose) {
        set.add(arr[rnd.nextInt(arr.length)]);
      }
      return set;
    }
    return possibleFeatureIndices;
  }

  /**
   * Recursively build the decision tree in a top down fashion.
   */
  private AbstractTreeNode build(List<DoubleVector> features,
      List<DoubleVector> outcome, TIntHashSet possibleFeatureIndices, int level) {

    // we select a subset of features at every tree level
    possibleFeatureIndices = chooseRandomFeatures(possibleFeatureIndices);

    int[] countOutcomeClasses = getPossibleClasses(outcome);
    TIntHashSet notZeroClasses = new TIntHashSet();
    for (int i = 0; i < countOutcomeClasses.length; i++) {
      if (countOutcomeClasses[i] != 0) {
        notZeroClasses.add(i);
      }
    }

    // if we only have a single class to predict, we will create a leaf to
    // predict it
    if (notZeroClasses.size() == 1) {
      return new LeafNode(notZeroClasses.iterator().next());
    }

    // if we don't have anymore features to split on or reached the max height,
    // we will use the majority class and predict it
    if (possibleFeatureIndices.isEmpty() || level >= maxHeight) {
      return new LeafNode(ArrayUtils.maxIndex(countOutcomeClasses));
    }

    // now we can evaluate the infogain for every possible feature and choose
    // that one that maximizes it and split on it
    double targetEntropy = getEntropy(countOutcomeClasses);
    Split[] infoGain = new Split[numFeatures];
    for (int featureIndex : possibleFeatureIndices.toArray()) {
      infoGain[featureIndex] = computeSplit(targetEntropy, featureIndex,
          countOutcomeClasses, features, outcome);
    }

    // pick the split with highest info gain
    int maxIndex = 0;
    double maxGain = infoGain[maxIndex] != null ? infoGain[maxIndex]
        .getInformationGain() : Integer.MIN_VALUE;
    for (int i = 1; i < infoGain.length; i++) {
      if (infoGain[i] != null && infoGain[i].getInformationGain() > maxGain) {
        maxGain = infoGain[i].getInformationGain();
        maxIndex = i;
      }
    }
    Split bestSplit = infoGain[maxIndex];
    int bestSplitIndex = bestSplit.getSplitAttributeIndex();
    if (featureTypes[bestSplitIndex].isNominal()) {
      TIntHashSet uniqueFeatures = getNominalValues(bestSplitIndex, features);
      NominalNode node = new NominalNode(bestSplitIndex, uniqueFeatures.size());
      int cIndex = 0;
      for (int nominalValue : uniqueFeatures.toArray()) {
        node.nominalSplitValues[cIndex] = nominalValue;
        Tuple<List<DoubleVector>, List<DoubleVector>> filtered = filterNominal(
            features, outcome, bestSplitIndex, nominalValue);
        TIntHashSet newPossibleFeatures = new TIntHashSet(
            possibleFeatureIndices);
        // remove that feature
        newPossibleFeatures.remove(bestSplitIndex);
        node.children[cIndex] = build(filtered.getFirst(),
            filtered.getSecond(), newPossibleFeatures, level + 1);
        cIndex++;
      }
      // make a faster lookup by sorting
      node.sortInternal();
      return node;
    } else {
      // numerical split
      TIntHashSet newPossibleFeatures = new TIntHashSet(possibleFeatureIndices);
      Tuple<List<DoubleVector>, List<DoubleVector>> filterNumeric = filterNumeric(
          features, outcome, bestSplitIndex,
          bestSplit.getNumericalSplitValue(), true);
      Tuple<List<DoubleVector>, List<DoubleVector>> filterNumericHigher = filterNumeric(
          features, outcome, bestSplitIndex,
          bestSplit.getNumericalSplitValue(), false);

      if (filterNumeric.getFirst().isEmpty()
          || filterNumericHigher.getFirst().isEmpty()) {
        newPossibleFeatures.remove(bestSplitIndex);
      } else {
        // we changed something, thus we can unselect all numerical features
        for (int i = 0; i < featureTypes.length; i++) {
          if (featureTypes[i].isNumerical()) {
            newPossibleFeatures.add(i);
          }
        }
      }

      // build subtrees
      AbstractTreeNode lower = build(filterNumeric.getFirst(),
          filterNumeric.getSecond(), new TIntHashSet(newPossibleFeatures),
          level + 1);
      AbstractTreeNode higher = build(filterNumericHigher.getFirst(),
          filterNumericHigher.getSecond(),
          new TIntHashSet(newPossibleFeatures), level + 1);
      // now we can return this completed node
      return new NumericalNode(bestSplitIndex,
          bestSplit.getNumericalSplitValue(), lower, higher);
    }

  }

  /**
   * Filters all examples where the feature at the given index has NOT the
   * specific value. So the returned lists contain only vectors where the
   * feature has the specific value.
   *
   * @return a new tuple of two new lists (features and their outcome).
   */
  private Tuple<List<DoubleVector>, List<DoubleVector>> filterNominal(
      List<DoubleVector> features, List<DoubleVector> outcome,
      int bestSplitIndex, int nominalValue) {

    List<DoubleVector> newFeatures = Lists.newLinkedList();
    List<DoubleVector> newOutcomes = Lists.newLinkedList();

    Iterator<DoubleVector> featureIterator = features.iterator();
    Iterator<DoubleVector> outcomeIterator = outcome.iterator();
    while (featureIterator.hasNext()) {
      DoubleVector feature = featureIterator.next();
      DoubleVector out = outcomeIterator.next();
      if (((int) feature.get(bestSplitIndex)) == nominalValue) {
        newFeatures.add(feature);
        newOutcomes.add(out);
      }
    }

    return new Tuple<>(newFeatures, newOutcomes);
  }

  /**
   * Filters the lists by numerical decision.
   *
   * @param features the features to filter.
   * @param outcome the outcome to filter.
   * @param bestSplitIndex the feature split index with highest information
   *          gain.
   * @param splitValue the value of the split point.
   * @param lower true if the returned list should contain lower items, else it
   *          contains strictly higher items.
   * @return two filtered parallel lists.
   */
  private Tuple<List<DoubleVector>, List<DoubleVector>> filterNumeric(
      List<DoubleVector> features, List<DoubleVector> outcome,
      int bestSplitIndex, double splitValue, boolean lower) {

    List<DoubleVector> newFeatures = Lists.newLinkedList();
    List<DoubleVector> newOutcomes = Lists.newLinkedList();

    Iterator<DoubleVector> featureIterator = features.iterator();
    Iterator<DoubleVector> outcomeIterator = outcome.iterator();
    while (featureIterator.hasNext()) {
      DoubleVector feature = featureIterator.next();
      DoubleVector out = outcomeIterator.next();
      if (lower) {
        if (feature.get(bestSplitIndex) <= splitValue) {
          newFeatures.add(feature);
          newOutcomes.add(out);
        }
      } else {
        if (feature.get(bestSplitIndex) > splitValue) {
          newFeatures.add(feature);
          newOutcomes.add(out);
        }
      }

    }

    return new Tuple<>(newFeatures, newOutcomes);
  }

  /**
   * Computes the split of nominal and numerical values.
   *
   * @param overallEntropy the overall entropy at the given time.
   * @param featureIndex the feature index to evaluate on.
   * @param countOutcomeClasses the histogram over all possible outcome classes.
   * @param features the features.
   * @param outcome the outcome.
   * @return a {@link Split} that contains a possible split (either numerical or
   *         categorical) along with the information gain.
   */
  private Split computeSplit(double overallEntropy, int featureIndex,
      int[] countOutcomeClasses, List<DoubleVector> features,
      List<DoubleVector> outcome) {

    if (featureTypes[featureIndex].isNominal()) {
      TIntObjectHashMap<int[]> featureValueOutcomeCount = new TIntObjectHashMap<>();
      TIntIntHashMap rowSums = new TIntIntHashMap();
      int numFeatures = 0;
      Iterator<DoubleVector> featureIterator = features.iterator();
      Iterator<DoubleVector> outcomeIterator = outcome.iterator();
      while (featureIterator.hasNext()) {
        DoubleVector feature = featureIterator.next();
        DoubleVector out = outcomeIterator.next();
        int classIndex = getOutcomeClassIndex(out);
        int nominalFeatureValue = (int) feature.get(featureIndex);
        int[] is = featureValueOutcomeCount.get(nominalFeatureValue);
        if (is == null) {
          is = new int[outcomeDimension];
          featureValueOutcomeCount.put(nominalFeatureValue, is);
        }
        is[classIndex]++;
        rowSums.put(nominalFeatureValue, rowSums.get(nominalFeatureValue) + 1);
        numFeatures++;
      }
      double entropySum = 0d;
      // now we can calculate the entropy
      TIntObjectIterator<int[]> iterator = featureValueOutcomeCount.iterator();
      while (iterator.hasNext()) {
        iterator.advance();
        int[] outcomeCounts = iterator.value();
        double condEntropy = rowSums.get(iterator.key()) / (double) numFeatures
            * getEntropy(outcomeCounts);
        entropySum += condEntropy;
      }
      return new Split(featureIndex, overallEntropy - entropySum);
    } else {
      // numerical case
      Iterator<DoubleVector> featureIterator = features.iterator();
      TDoubleHashSet possibleFeatureValues = new TDoubleHashSet();
      while (featureIterator.hasNext()) {
        DoubleVector feature = featureIterator.next();
        possibleFeatureValues.add(feature.get(featureIndex));
      }
      double bestInfogain = -1;
      double bestSplit = 0.0;
      TDoubleIterator iterator = possibleFeatureValues.iterator();
      while (iterator.hasNext()) {
        double value = iterator.next();
        double ig = computeNumericalInfogain(features, outcome, overallEntropy,
            featureIndex, value);
        if (ig > bestInfogain) {
          bestInfogain = ig;
          bestSplit = value;
        }
      }
      return new Split(featureIndex, bestInfogain, bestSplit);
    }

  }

  /**
   * This method computes the numerical information gain for the given features
   * and outcomes and a featureIndex and its value. This is done by iterating
   * once over all features and outcomes and calculating a table of outcome
   * counts given a higher/lower relationship to the given feature value.
   *
   * @param features the features.
   * @param outcome the outcomes.
   * @param overallEntropy the overall entropy of the selectable features.
   * @param featureIndex the feature index to check.
   * @param value the value that acts as a possible split point between a lower
   *          and higher partition for a given feature.
   * @return the information gain under the feature given a value as a split
   *         point.
   */
  private double computeNumericalInfogain(List<DoubleVector> features,
      List<DoubleVector> outcome, double overallEntropy, int featureIndex,
      double value) {
    double invDatasize = 1d / features.size();
    // 0 denotes lower than or equal, 1 denotes higher
    int[][] counts = new int[2][outcomeDimension];
    int lowCount = 0;
    int highCount = 0;
    Arrays.fill(counts, new int[outcomeDimension]);
    Iterator<DoubleVector> featureIterator = features.iterator();
    Iterator<DoubleVector> outcomeIterator = outcome.iterator();
    while (featureIterator.hasNext()) {
      DoubleVector feature = featureIterator.next();
      DoubleVector out = outcomeIterator.next();
      int idx = getOutcomeClassIndex(out);
      if (feature.get(featureIndex) > value) {
        counts[1][idx]++;
        highCount++;
      } else {
        counts[0][idx]++;
        lowCount++;
      }
    }

    // discount the lower set
    overallEntropy -= (lowCount * invDatasize * getEntropy(counts[0]));
    // and the higher one
    overallEntropy -= (highCount * invDatasize * getEntropy(counts[1]));
    return overallEntropy;
  }

  /**
   * @return the class index, this takes binary classification into account as
   *         well as multi class classification.
   */
  private int getOutcomeClassIndex(DoubleVector out) {
    int classIndex = 0;
    if (binaryClassification) {
      classIndex = (int) out.get(0);
    } else {
      classIndex = out.maxIndex();
    }
    return classIndex;
  }

  /**
   * @return a set of nominal values of that feature index given the examples
   *         that contain this feature.
   */
  private TIntHashSet getNominalValues(int featureIndex,
      List<DoubleVector> features) {
    TIntHashSet uniqueFeatures = new TIntHashSet();
    for (DoubleVector vec : features) {
      int featureValue = (int) vec.get(featureIndex);
      uniqueFeatures.add(featureValue);
    }
    return uniqueFeatures;
  }

  /**
   * @return an array from 0-outcome dimension that has a count on every feature
   *         index representing how often it occurred.
   */
  private int[] getPossibleClasses(List<DoubleVector> outcome) {
    int[] clzs = new int[outcomeDimension];
    for (DoubleVector out : outcome) {
      if (binaryClassification) {
        clzs[(int) out.get(0)]++;
      } else {
        clzs[out.maxIndex()]++;
      }
    }

    return clzs;
  }

  /**
   * Sets the type of feature per index. This should match the inputted number
   * of features in the training method. If this isn't set at all, all
   * attributes are assumed to be nominal.
   *
   * @return this decision tree instance.
   */
  public DecisionTree setFeatureTypes(FeatureType[] featureTypes) {
    this.featureTypes = featureTypes;
    return this;
  }

  /**
   * Sets the number of random features to choose from all features.Zero,
   * negative numbers or numbers greater than the really available features
   * indicate all features to be used.
   *
   * @return this decision tree instance.
   */
  public DecisionTree setNumRandomFeaturesToChoose(int numRandomFeaturesToChoose) {
    this.numRandomFeaturesToChoose = numRandomFeaturesToChoose;
    return this;
  }

  /**
   * If set to true, this tree will be compiled after training time
   * automatically.
   *
   * @return this decision tree instance.
   */
  public DecisionTree setCompiled(boolean compiled) {
    this.compile = compiled;
    return this;
  }

  /**
   * Sets the maximum height of this tree.
   *
   * @return this instance.
   */
  public DecisionTree setMaxHeight(int max) {
    this.maxHeight = max;
    return this;
  }

  /**
   * Sets the seed for a random number generator if used.
   */
  public DecisionTree setSeed(long seed) {
    this.seed = seed;
    return this;
  }

  /*
   * for testing
   */
  void setNumFeatures(int numFeatures) {
    this.numFeatures = numFeatures;
  }

  /**
   * @return the set of possible features
   */
  TIntHashSet getPossibleFeatures() {
    // all features are possible here
    TIntHashSet possibleFeatureIndices = new TIntHashSet();
    for (int i = 0; i < numFeatures; i++) {
      possibleFeatureIndices.add(i);
    }
    return possibleFeatureIndices;
  }

  /**
   * Writes the given tree to the output stream. Note that the stream isn't
   * closed here.
   */
  public static void serialize(DecisionTree tree, DataOutput out)
      throws IOException {
    try {
      out.writeBoolean(tree.binaryClassification);
      WritableUtils.writeVInt(out, tree.outcomeDimension);
      WritableUtils.writeVInt(out, tree.numFeatures);
      for (int i = 0; i < tree.featureTypes.length; i++) {
        WritableUtils.writeVInt(out, tree.featureTypes[i].ordinal());
      }

      if (tree.compiledClass == null) {
        out.writeBoolean(false);
        tree.rootNode.write(out);
      } else {
        out.writeBoolean(true);
        out.writeUTF(tree.compiledName);
        WritableUtils.writeCompressedByteArray(out, tree.compiledClass);
      }
    } catch (Exception e) {
      throw new IOException(e);
    }
  }

  /**
   * Reads a new tree from the given stream. Note that the stream isn't closed
   * here.
   */
  public static DecisionTree deserialize(DataInput in) throws IOException {
    boolean binary = in.readBoolean();
    int outcomeDimension = WritableUtils.readVInt(in);
    int numFeatures = WritableUtils.readVInt(in);
    FeatureType[] arr = new FeatureType[numFeatures];
    for (int i = 0; i < numFeatures; i++) {
      arr[i] = FeatureType.values()[WritableUtils.readVInt(in)];
    }
    if (in.readBoolean()) {
      String name = in.readUTF();
      byte[] compiled = WritableUtils.readCompressedByteArray(in);
      try {
        AbstractTreeNode loadedRoot = TreeCompiler.load(name, compiled);
        return new DecisionTree(loadedRoot, arr, binary, numFeatures,
            outcomeDimension);
      } catch (Exception e) {
        throw new IOException(e);
      }
    } else {
      AbstractTreeNode root = AbstractTreeNode.read(in);
      return new DecisionTree(root, arr, binary, numFeatures, outcomeDimension);
    }
  }

  /**
   * @return a default decision tree with all features beeing nominal.
   */
  public static DecisionTree create() {
    return new DecisionTree();
  }

  /**
   * Creates a new decision tree with the given feature types.
   *
   * @param featureTypes the types of the feature that must match the number of
   *          features in length.
   * @return a default decision tree with all features beeing set to what has
   *         been configured in the given array.
   */
  public static DecisionTree create(FeatureType[] featureTypes) {
    return new DecisionTree().setFeatureTypes(featureTypes);
  }

  /**
   * @return a default compiled decision tree with all features beeing nominal.
   */
  public static DecisionTree createCompiledTree() {
    return new DecisionTree().setCompiled(true);
  }

  /**
   * Creates a new compiled decision tree with the given feature types.
   *
   * @param featureTypes the types of the feature that must match the number of
   *          features in length.
   * @return a default decision tree with all features beeing set to what has
   *         been configured in the given array.
   */
  public static DecisionTree createCompiledTree(FeatureType[] featureTypes) {
    return new DecisionTree().setFeatureTypes(featureTypes).setCompiled(true);
  }

  /**
   * @return the entropy of the given prediction class counts.
   */
  static double getEntropy(int[] outcomeCounter) {
    double entropySum = 0d;
    double sum = 0d;
    for (int x : outcomeCounter) {
      sum += x;
    }
    for (int x : outcomeCounter) {
      if (x == 0) {
        return 0d;
      }
      double conditionalProbability = x / sum;
      entropySum -= (conditionalProbability * log2(conditionalProbability));
    }

    return entropySum;
  }

  /**
   * @return the log2 of the given input.
   */
  private static double log2(double num) {
    return FastMath.log(num) / LOG2;
  }

}
TOP

Related Classes of de.jungblut.classification.tree.DecisionTree

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.