Package org.apache.mahout.df.builder

Source Code of org.apache.mahout.df.builder.DefaultTreeBuilder

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.df.builder;

import java.util.Random;

import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.data.conditions.Condition;
import org.apache.mahout.df.node.CategoricalNode;
import org.apache.mahout.df.node.Leaf;
import org.apache.mahout.df.node.Node;
import org.apache.mahout.df.node.NumericalNode;
import org.apache.mahout.df.split.IgSplit;
import org.apache.mahout.df.split.OptIgSplit;
import org.apache.mahout.df.split.Split;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Builds a Decision Tree <br>
* Based on the algorithm described in the "Decision Trees" tutorials by Andrew W. Moore, available at:<br>
* <br>
* http://www.cs.cmu.edu/~awm/tutorials
*/
public class DefaultTreeBuilder implements TreeBuilder {
 
  private static final Logger log = LoggerFactory.getLogger(DefaultTreeBuilder.class);

  private static final int[] NO_ATTRIBUTES = new int[0];

  /** indicates which CATEGORICAL attributes have already been selected in the parent nodes */
  private boolean[] selected;
  /** number of attributes to select randomly at each node */
  private int m = 1;
  /** IgSplit implementation */
  private IgSplit igSplit;
 
  public DefaultTreeBuilder() {
    igSplit = new OptIgSplit();
  }
 
  public void setM(int m) {
    this.m = m;
  }
 
  public void setIgSplit(IgSplit igSplit) {
    this.igSplit = igSplit;
  }
 
  @Override
  public Node build(Random rng, Data data) {
   
    if (selected == null) {
      selected = new boolean[data.getDataset().nbAttributes()];
    }
   
    if (data.isEmpty()) {
      return new Leaf(-1);
    }
    if (isIdentical(data)) {
      return new Leaf(data.majorityLabel(rng));
    }
    if (data.identicalLabel()) {
      return new Leaf(data.get(0).getLabel());
    }
   
    int[] attributes = randomAttributes(rng, selected, m);
    if (attributes == null || attributes.length == 0) {
      // we tried all the attributes and could not split the data anymore
      return new Leaf(data.majorityLabel(rng));
    }

    // find the best split
    Split best = null;
    for (int attr : attributes) {
      Split split = igSplit.computeSplit(data, attr);
      if (best == null || best.getIg() < split.getIg()) {
        best = split;
      }
    }
   
    boolean alreadySelected = selected[best.getAttr()];
    if (alreadySelected) {
      // attribute already selected
      log.warn("attribute {} already selected in a parent node", best.getAttr());
    }
   
    Node childNode;
    if (data.getDataset().isNumerical(best.getAttr())) {
      boolean[] temp = null;

      Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
      Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));

      if (loSubset.isEmpty() || hiSubset.isEmpty()) {
        // the selected attribute did not change the data, avoid using it in the child notes
        selected[best.getAttr()] = true;
      } else {
        // the data changed, so we can unselect all previousely selected NUMERICAL attributes
        temp = selected;
        selected = cloneCategoricalAttributes(data.getDataset(), selected);
      }

      Node loChild = build(rng, loSubset);
      Node hiChild = build(rng, hiSubset);

      // restore the selection state of the attributes
      if (temp != null) {
        selected = temp;
      } else {
        selected[best.getAttr()] = alreadySelected;
      }

      childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
    } else { // CATEGORICAL attribute
      selected[best.getAttr()] = true;
     
      double[] values = data.values(best.getAttr());
      Node[] children = new Node[values.length];
     
      for (int index = 0; index < values.length; index++) {
        Data subset = data.subset(Condition.equals(best.getAttr(), values[index]));
        children[index] = build(rng, subset);
      }

      selected[best.getAttr()] = alreadySelected;
     
      childNode = new CategoricalNode(best.getAttr(), values, children);
    }
   
    return childNode;
  }
 
  /**
   * checks if all the vectors have identical attribute values. Ignore selected attributes.
   *
   * @return true is all the vectors are identical or the data is empty<br>
   *         false otherwise
   */
  private boolean isIdentical(Data data) {
    if (data.isEmpty()) {
      return true;
    }
   
    Instance instance = data.get(0);
    for (int attr = 0; attr < selected.length; attr++) {
      if (selected[attr]) {
        continue;
      }
     
      for (int index = 1; index < data.size(); index++) {
        if (data.get(index).get(attr) != instance.get(attr)) {
          return false;
        }
      }
    }
   
    return true;
  }


  /**
   * Make a copy of the selection state of the attributes, unselect all numerical attributes
   * @param dataset
   * @param selected selection state to clone
   * @return cloned selection state
   */
  protected static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) {
    boolean[] cloned = new boolean[selected.length];

    for (int i = 0; i < selected.length; i++) {
      cloned[i] = !dataset.isNumerical(i) && selected[i];
    }

    return cloned;
  }

  /**
   * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
   *
   * @param rng
   *          random-numbers generator
   * @param selected
   *          attributes' state (selected or not)
   * @param m
   *          number of attributes to choose
   * @return list of selected attributes' indices, or null if all attributes have already been selected
   */
  protected static int[] randomAttributes(Random rng, boolean[] selected, int m) {
    int nbNonSelected = 0; // number of non selected attributes
    for (boolean sel : selected) {
      if (!sel) {
        nbNonSelected++;
      }
    }
   
    if (nbNonSelected == 0) {
      log.warn("All attributes are selected !");
      return NO_ATTRIBUTES;
    }
   
    int[] result;
    if (nbNonSelected <= m) {
      // return all non selected attributes
      result = new int[nbNonSelected];
      int index = 0;
      for (int attr = 0; attr < selected.length; attr++) {
        if (!selected[attr]) {
          result[index++] = attr;
        }
      }
    } else {
      result = new int[m];
      for (int index = 0; index < m; index++) {
        // randomly choose a "non selected" attribute
        int rind;
        do {
          rind = rng.nextInt(selected.length);
        } while (selected[rind]);
       
        result[index] = rind;
        selected[rind] = true; // temporarily set the chosen attribute to be selected
      }
     
      // the chosen attributes are not yet selected
      for (int attr : result) {
        selected[attr] = false;
      }
    }
   
    return result;
  }
}
TOP

Related Classes of org.apache.mahout.df.builder.DefaultTreeBuilder

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.