Package aima.core.learning.learners

Source Code of aima.core.learning.learners.DecisionTreeLearner

package aima.core.learning.learners;

import java.util.Iterator;
import java.util.List;

import aima.core.learning.framework.DataSet;
import aima.core.learning.framework.Example;
import aima.core.learning.framework.Learner;
import aima.core.learning.inductive.ConstantDecisonTree;
import aima.core.learning.inductive.DecisionTree;
import aima.core.util.Util;

/**
* @author Ravi Mohan
* @author Mike Stampone
*/
public class DecisionTreeLearner implements Learner {
  private DecisionTree tree;

  private String defaultValue;

  public DecisionTreeLearner() {
    this.defaultValue = "Unable To Classify";

  }

  // used when you have to test a non induced tree (eg: for testing)
  public DecisionTreeLearner(DecisionTree tree, String defaultValue) {
    this.tree = tree;
    this.defaultValue = defaultValue;
  }

  //
  // START-Learner

  /**
   * Induces the decision tree from the specified set of examples
   *
   * @param ds
   *            a set of examples for constructing the decision tree
   */
  @Override
  public void train(DataSet ds) {
    List<String> attributes = ds.getNonTargetAttributes();
    this.tree = decisionTreeLearning(ds, attributes,
        new ConstantDecisonTree(defaultValue));
  }

  @Override
  public String predict(Example e) {
    return (String) tree.predict(e);
  }

  @Override
  public int[] test(DataSet ds) {
    int[] results = new int[] { 0, 0 };

    for (Example e : ds.examples) {
      if (e.targetValue().equals(tree.predict(e))) {
        results[0] = results[0] + 1;
      } else {
        results[1] = results[1] + 1;
      }
    }
    return results;
  }

  // END-Learner
  //

  /**
   * Returns the decision tree of this decision tree learner
   *
   * @return the decision tree of this decision tree learner
   */
  public DecisionTree getDecisionTree() {
    return tree;
  }

  //
  // PRIVATE METHODS
  //

  private DecisionTree decisionTreeLearning(DataSet ds,
      List<String> attributeNames, ConstantDecisonTree defaultTree) {
    if (ds.size() == 0) {
      return defaultTree;
    }
    if (allExamplesHaveSameClassification(ds)) {
      return new ConstantDecisonTree(ds.getExample(0).targetValue());
    }
    if (attributeNames.size() == 0) {
      return majorityValue(ds);
    }
    String chosenAttribute = chooseAttribute(ds, attributeNames);

    DecisionTree tree = new DecisionTree(chosenAttribute);
    ConstantDecisonTree m = majorityValue(ds);

    List<String> values = ds.getPossibleAttributeValues(chosenAttribute);
    for (String v : values) {
      DataSet filtered = ds.matchingDataSet(chosenAttribute, v);
      List<String> newAttribs = Util.removeFrom(attributeNames,
          chosenAttribute);
      DecisionTree subTree = decisionTreeLearning(filtered, newAttribs, m);
      tree.addNode(v, subTree);

    }

    return tree;
  }

  private ConstantDecisonTree majorityValue(DataSet ds) {
    Learner learner = new MajorityLearner();
    learner.train(ds);
    return new ConstantDecisonTree(learner.predict(ds.getExample(0)));
  }

  private String chooseAttribute(DataSet ds, List<String> attributeNames) {
    double greatestGain = 0.0;
    String attributeWithGreatestGain = attributeNames.get(0);
    for (String attr : attributeNames) {
      double gain = ds.calculateGainFor(attr);
      if (gain > greatestGain) {
        greatestGain = gain;
        attributeWithGreatestGain = attr;
      }
    }

    return attributeWithGreatestGain;
  }

  private boolean allExamplesHaveSameClassification(DataSet ds) {
    String classification = ds.getExample(0).targetValue();
    Iterator<Example> iter = ds.iterator();
    while (iter.hasNext()) {
      Example element = iter.next();
      if (!(element.targetValue().equals(classification))) {
        return false;
      }

    }
    return true;
  }
}
TOP

Related Classes of aima.core.learning.learners.DecisionTreeLearner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.