Package weka.classifiers.trees

Source Code of weka.classifiers.trees.SimpleCart

/*
*    This program is free software; you can redistribute it and/or modify
*    it under the terms of the GNU General Public License as published by
*    the Free Software Foundation; either version 2 of the License, or
*    (at your option) any later version.
*
*    This program is distributed in the hope that it will be useful,
*    but WITHOUT ANY WARRANTY; without even the implied warranty of
*    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*    GNU General Public License for more details.
*
*    You should have received a copy of the GNU General Public License
*    along with this program; if not, write to the Free Software
*    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/

/*
* SimpleCart.java
* Copyright (C) 2007 Haijian Shi
*
*/

package weka.classifiers.trees;

import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableClassifier;
import weka.core.AdditionalMeasureProducer;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.Capabilities.Capability;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.matrix.Matrix;

import java.util.Arrays;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

/**
<!-- globalinfo-start -->
* Class implementing minimal cost-complexity pruning.<br/>
* Note when dealing with missing values, use "fractional instances" method instead of surrogate split method.<br/>
* <br/>
* For more information, see:<br/>
* <br/>
* Leo Breiman, Jerome H. Friedman, Richard A. Olshen, Charles J. Stone (1984). Classification and Regression Trees. Wadsworth International Group, Belmont, California.
* <p/>
<!-- globalinfo-end --> 
*
<!-- technical-bibtex-start -->
* BibTeX:
* <pre>
* &#64;book{Breiman1984,
*    address = {Belmont, California},
*    author = {Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone},
*    publisher = {Wadsworth International Group},
*    title = {Classification and Regression Trees},
*    year = {1984}
* }
* </pre>
* <p/>
<!-- technical-bibtex-end -->
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -S &lt;num&gt;
*  Random number seed.
*  (default 1)</pre>
*
* <pre> -D
*  If set, classifier is run in debug mode and
*  may output additional info to the console</pre>
*
* <pre> -M &lt;min no&gt;
*  The minimal number of instances at the terminal nodes.
*  (default 2)</pre>
*
* <pre> -N &lt;num folds&gt;
*  The number of folds used in the minimal cost-complexity pruning.
*  (default 5)</pre>
*
* <pre> -U
*  Don't use the minimal cost-complexity pruning.
*  (default yes).</pre>
*
* <pre> -H
*  Don't use the heuristic method for binary split.
*  (default true).</pre>
*
* <pre> -A
*  Use 1 SE rule to make pruning decision.
*  (default no).</pre>
*
* <pre> -C
*  Percentage of training data size (0-1].
*  (default 1).</pre>
*
<!-- options-end -->
*
* @author Haijian Shi (hs69@cs.waikato.ac.nz)
* @version $Revision: 1.4 $
*/
public class SimpleCart
  extends RandomizableClassifier
  implements AdditionalMeasureProducer, TechnicalInformationHandler {

  /** For serialization.   */
  private static final long serialVersionUID = 4154189200352566053L;

  /** Training data.  */
  protected Instances m_train;

  /** Successor nodes. */
  protected SimpleCart[] m_Successors;

  /** Attribute used to split data. */
  protected Attribute m_Attribute;

  /** Split point for a numeric attribute. */
  protected double m_SplitValue;

  /** Split subset used to split data for nominal attributes. */
  protected String m_SplitString;

  /** Class value if the node is leaf. */
  protected double m_ClassValue;

  /** Class attriubte of data. */
  protected Attribute m_ClassAttribute;

  /** Minimum number of instances in at the terminal nodes. */
  protected double m_minNumObj = 2;

  /** Number of folds for minimal cost-complexity pruning. */
  protected int m_numFoldsPruning = 5;

  /** Alpha-value (for pruning) at the node. */
  protected double m_Alpha;

  /** Number of training examples misclassified by the model (subtree rooted). */
  protected double m_numIncorrectModel;

  /** Number of training examples misclassified by the model (subtree not rooted). */
  protected double m_numIncorrectTree;

  /** Indicate if the node is a leaf node. */
  protected boolean m_isLeaf;

  /** If use minimal cost-compexity pruning. */
  protected boolean m_Prune = true;

  /** Total number of instances used to build the classifier. */
  protected int m_totalTrainInstances;

  /** Proportion for each branch. */
  protected double[] m_Props;

  /** Class probabilities. */
  protected double[] m_ClassProbs = null;

  /** Distributions of leaf node (or temporary leaf node in minimal cost-complexity pruning) */
  protected double[] m_Distribution;

  /** If use huristic search for nominal attributes in multi-class problems (default true). */
  protected boolean m_Heuristic = true;

  /** If use the 1SE rule to make final decision tree. */
  protected boolean m_UseOneSE = false;

  /** Training data size. */
  protected double m_SizePer = 1;

  /**
   * Return a description suitable for displaying in the explorer/experimenter.
   *
   * @return     a description suitable for displaying in the
   *       explorer/experimenter
   */
  public String globalInfo() {
    return 
        "Class implementing minimal cost-complexity pruning.\n"
      + "Note when dealing with missing values, use \"fractional "
      + "instances\" method instead of surrogate split method.\n\n"
      + "For more information, see:\n\n"
      + getTechnicalInformation().toString();
  }

  /**
   * Returns an instance of a TechnicalInformation object, containing
   * detailed information about the technical background of this class,
   * e.g., paper reference or book this class is based on.
   *
   * @return     the technical information about this class
   */
  public TechnicalInformation getTechnicalInformation() {
    TechnicalInformation   result;
   
    result = new TechnicalInformation(Type.BOOK);
    result.setValue(Field.AUTHOR, "Leo Breiman and Jerome H. Friedman and Richard A. Olshen and Charles J. Stone");
    result.setValue(Field.YEAR, "1984");
    result.setValue(Field.TITLE, "Classification and Regression Trees");
    result.setValue(Field.PUBLISHER, "Wadsworth International Group");
    result.setValue(Field.ADDRESS, "Belmont, California");
   
    return result;
  }

  /**
   * Returns default capabilities of the classifier.
   *
   * @return     the capabilities of this classifier
   */
  public Capabilities getCapabilities() {
    Capabilities result = super.getCapabilities();

    // attributes
    result.enable(Capability.NOMINAL_ATTRIBUTES);
    result.enable(Capability.NUMERIC_ATTRIBUTES);
    result.enable(Capability.MISSING_VALUES);

    // class
    result.enable(Capability.NOMINAL_CLASS);

    return result;
  }

  /**
   * Build the classifier.
   *
   * @param data   the training instances
   * @throws Exception   if something goes wrong
   */
  public void buildClassifier(Instances data) throws Exception {

    getCapabilities().testWithFail(data);
    data = new Instances(data);       
    data.deleteWithMissingClass();

    // unpruned CART decision tree
    if (!m_Prune) {

      // calculate sorted indices and weights, and compute initial class counts.
      int[][] sortedIndices = new int[data.numAttributes()][0];
      double[][] weights = new double[data.numAttributes()][0];
      double[] classProbs = new double[data.numClasses()];
      double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);

      makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
    totalWeight,m_minNumObj, m_Heuristic);
      return;
    }

    Random random = new Random(m_Seed);
    Instances cvData = new Instances(data);
    cvData.randomize(random);
    cvData = new Instances(cvData,0,(int)(cvData.numInstances()*m_SizePer)-1);
    cvData.stratify(m_numFoldsPruning);

    double[][] alphas = new double[m_numFoldsPruning][];
    double[][] errors = new double[m_numFoldsPruning][];

    // calculate errors and alphas for each fold
    for (int i = 0; i < m_numFoldsPruning; i++) {

      //for every fold, grow tree on training set and fix error on test set.
      Instances train = cvData.trainCV(m_numFoldsPruning, i);
      Instances test = cvData.testCV(m_numFoldsPruning, i);

      // calculate sorted indices and weights, and compute initial class counts for each fold
      int[][] sortedIndices = new int[train.numAttributes()][0];
      double[][] weights = new double[train.numAttributes()][0];
      double[] classProbs = new double[train.numClasses()];
      double totalWeight = computeSortedInfo(train,sortedIndices, weights,classProbs);

      makeTree(train, train.numInstances(),sortedIndices,weights,classProbs,
    totalWeight,m_minNumObj, m_Heuristic);

      int numNodes = numInnerNodes();
      alphas[i] = new double[numNodes + 2];
      errors[i] = new double[numNodes + 2];

      // prune back and log alpha-values and errors on test set
      prune(alphas[i], errors[i], test);
    }

    // calculate sorted indices and weights, and compute initial class counts on all training instances
    int[][] sortedIndices = new int[data.numAttributes()][0];
    double[][] weights = new double[data.numAttributes()][0];
    double[] classProbs = new double[data.numClasses()];
    double totalWeight = computeSortedInfo(data,sortedIndices, weights,classProbs);

    //build tree using all the data
    makeTree(data, data.numInstances(),sortedIndices,weights,classProbs,
  totalWeight,m_minNumObj, m_Heuristic);

    int numNodes = numInnerNodes();

    double[] treeAlphas = new double[numNodes + 2];

    // prune back and log alpha-values
    int iterations = prune(treeAlphas, null, null);

    double[] treeErrors = new double[numNodes + 2];

    // for each pruned subtree, find the cross-validated error
    for (int i = 0; i <= iterations; i++){
      //compute midpoint alphas
      double alpha = Math.sqrt(treeAlphas[i] * treeAlphas[i+1]);
      double error = 0;
      for (int k = 0; k < m_numFoldsPruning; k++) {
  int l = 0;
  while (alphas[k][l] <= alpha) l++;
  error += errors[k][l - 1];
      }
      treeErrors[i] = error/m_numFoldsPruning;
    }

    // find best alpha
    int best = -1;
    double bestError = Double.MAX_VALUE;
    for (int i = iterations; i >= 0; i--) {
      if (treeErrors[i] < bestError) {
  bestError = treeErrors[i];
  best = i;
      }
    }

    // 1 SE rule to choose expansion
    if (m_UseOneSE) {
      double oneSE = Math.sqrt(bestError*(1-bestError)/(data.numInstances()));
      for (int i = iterations; i >= 0; i--) {
  if (treeErrors[i] <= bestError+oneSE) {
    best = i;
    break;
  }
      }
    }

    double bestAlpha = Math.sqrt(treeAlphas[best] * treeAlphas[best + 1]);

    //"unprune" final tree (faster than regrowing it)
    unprune();
    prune(bestAlpha);       
  }

  /**
   * Make binary decision tree recursively.
   *
   * @param data     the training instances
   * @param totalInstances   total number of instances
   * @param sortedIndices   sorted indices of the instances
   * @param weights     weights of the instances
   * @param classProbs     class probabilities
   * @param totalWeight   total weight of instances
   * @param minNumObj     minimal number of instances at leaf nodes
   * @param useHeuristic   if use heuristic search for nominal attributes in multi-class problem
   * @throws Exception     if something goes wrong
   */
  protected void makeTree(Instances data, int totalInstances, int[][] sortedIndices,
      double[][] weights, double[] classProbs, double totalWeight, double minNumObj,
      boolean useHeuristic) throws Exception{

    // if no instances have reached this node (normally won't happen)
    if (totalWeight == 0){
      m_Attribute = null;
      m_ClassValue = Instance.missingValue();
      m_Distribution = new double[data.numClasses()];
      return;
    }

    m_totalTrainInstances = totalInstances;
    m_isLeaf = true;

    m_ClassProbs = new double[classProbs.length];
    m_Distribution = new double[classProbs.length];
    System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
    System.arraycopy(classProbs, 0, m_Distribution, 0, classProbs.length);
    if (Utils.sum(m_ClassProbs)!=0) Utils.normalize(m_ClassProbs);

    // Compute class distributions and value of splitting
    // criterion for each attribute
    double[][][] dists = new double[data.numAttributes()][0][0];
    double[][] props = new double[data.numAttributes()][0];
    double[][] totalSubsetWeights = new double[data.numAttributes()][2];
    double[] splits = new double[data.numAttributes()];
    String[] splitString = new String[data.numAttributes()];
    double[] giniGains = new double[data.numAttributes()];

    // for each attribute find split information
    for (int i = 0; i < data.numAttributes(); i++) {
      Attribute att = data.attribute(i);
      if (i==data.classIndex()) continue;
      if (att.isNumeric()) {
  // numeric attribute
  splits[i] = numericDistribution(props, dists, att, sortedIndices[i],
      weights[i], totalSubsetWeights, giniGains, data);
      } else {
  // nominal attribute
  splitString[i] = nominalDistribution(props, dists, att, sortedIndices[i],
      weights[i], totalSubsetWeights, giniGains, data, useHeuristic);
      }
    }

    // Find best attribute (split with maximum Gini gain)
    int attIndex = Utils.maxIndex(giniGains);
    m_Attribute = data.attribute(attIndex);

    m_train = new Instances(data, sortedIndices[attIndex].length);
    for (int i=0; i<sortedIndices[attIndex].length; i++) {
      Instance inst = data.instance(sortedIndices[attIndex][i]);
      Instance instCopy = (Instance)inst.copy();
      instCopy.setWeight(weights[attIndex][i]);
      m_train.add(instCopy);
    }

    // Check if node does not contain enough instances, or if it can not be split,
    // or if it is pure. If does, make leaf.
    if (totalWeight < 2 * minNumObj || giniGains[attIndex]==0 ||
  props[attIndex][0]==0 || props[attIndex][1]==0) {
      makeLeaf(data);
    }

    else {           
      m_Props = props[attIndex];
      int[][][] subsetIndices = new int[2][data.numAttributes()][0];
      double[][][] subsetWeights = new double[2][data.numAttributes()][0];

      // numeric split
      if (m_Attribute.isNumeric()) m_SplitValue = splits[attIndex];

      // nominal split
      else m_SplitString = splitString[attIndex];

      splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitValue,
    m_SplitString, sortedIndices, weights, data);

      // If split of the node results in a node with less than minimal number of isntances,
      // make the node leaf node.
      if (subsetIndices[0][attIndex].length<minNumObj ||
    subsetIndices[1][attIndex].length<minNumObj) {
  makeLeaf(data);
  return;
      }

      // Otherwise, split the node.
      m_isLeaf = false;
      m_Successors = new SimpleCart[2];
      for (int i = 0; i < 2; i++) {
  m_Successors[i] = new SimpleCart();
  m_Successors[i].makeTree(data, m_totalTrainInstances, subsetIndices[i],
      subsetWeights[i],dists[attIndex][i], totalSubsetWeights[attIndex][i],
      minNumObj, useHeuristic);
      }
    }
  }

  /**
   * Prunes the original tree using the CART pruning scheme, given a
   * cost-complexity parameter alpha.
   *
   * @param alpha   the cost-complexity parameter
   * @throws Exception   if something goes wrong
   */
  public void prune(double alpha) throws Exception {

    Vector nodeList;

    // determine training error of pruned subtrees (both with and without replacing a subtree),
    // and calculate alpha-values from them
    modelErrors();
    treeErrors();
    calculateAlphas();

    // get list of all inner nodes in the tree
    nodeList = getInnerNodes();

    boolean prune = (nodeList.size() > 0);
    double preAlpha = Double.MAX_VALUE;
    while (prune) {

      // select node with minimum alpha
      SimpleCart nodeToPrune = nodeToPrune(nodeList);

      // want to prune if its alpha is smaller than alpha
      if (nodeToPrune.m_Alpha > alpha) {
  break;
      }

      nodeToPrune.makeLeaf(nodeToPrune.m_train);

      // normally would not happen
      if (nodeToPrune.m_Alpha==preAlpha) {
  nodeToPrune.makeLeaf(nodeToPrune.m_train);
  treeErrors();
  calculateAlphas();
  nodeList = getInnerNodes();
  prune = (nodeList.size() > 0);
  continue;
      }
      preAlpha = nodeToPrune.m_Alpha;

      //update tree errors and alphas
      treeErrors();
      calculateAlphas();

      nodeList = getInnerNodes();
      prune = (nodeList.size() > 0);
    }
  }

  /**
   * Method for performing one fold in the cross-validation of minimal
   * cost-complexity pruning. Generates a sequence of alpha-values with error
   * estimates for the corresponding (partially pruned) trees, given the test
   * set of that fold.
   *
   * @param alphas   array to hold the generated alpha-values
   * @param errors   array to hold the corresponding error estimates
   * @param test   test set of that fold (to obtain error estimates)
   * @return     the iteration of the pruning
   * @throws Exception   if something goes wrong
   */
  public int prune(double[] alphas, double[] errors, Instances test)
    throws Exception {

    Vector nodeList;

    // determine training error of subtrees (both with and without replacing a subtree),
    // and calculate alpha-values from them
    modelErrors();
    treeErrors();
    calculateAlphas();

    // get list of all inner nodes in the tree
    nodeList = getInnerNodes();

    boolean prune = (nodeList.size() > 0);

    //alpha_0 is always zero (unpruned tree)
    alphas[0] = 0;

    Evaluation eval;

    // error of unpruned tree
    if (errors != null) {
      eval = new Evaluation(test);
      eval.evaluateModel(this, test);
      errors[0] = eval.errorRate();
    }

    int iteration = 0;
    double preAlpha = Double.MAX_VALUE;
    while (prune) {

      iteration++;

      // get node with minimum alpha
      SimpleCart nodeToPrune = nodeToPrune(nodeList);

      // do not set m_sons null, want to unprune
      nodeToPrune.m_isLeaf = true;

      // normally would not happen
      if (nodeToPrune.m_Alpha==preAlpha) {
  iteration--;
  treeErrors();
  calculateAlphas();
  nodeList = getInnerNodes();
  prune = (nodeList.size() > 0);
  continue;
      }

      // get alpha-value of node
      alphas[iteration] = nodeToPrune.m_Alpha;

      // log error
      if (errors != null) {
  eval = new Evaluation(test);
  eval.evaluateModel(this, test);
  errors[iteration] = eval.errorRate();
      }
      preAlpha = nodeToPrune.m_Alpha;

      //update errors/alphas
      treeErrors();
      calculateAlphas();

      nodeList = getInnerNodes();
      prune = (nodeList.size() > 0);
    }

    //set last alpha 1 to indicate end
    alphas[iteration + 1] = 1.0;
    return iteration;
  }

  /**
   * Method to "unprune" the CART tree. Sets all leaf-fields to false.
   * Faster than re-growing the tree because CART do not have to be fit again.
   */
  protected void unprune() {
    if (m_Successors != null) {
      m_isLeaf = false;
      for (int i = 0; i < m_Successors.length; i++) m_Successors[i].unprune();
    }
  }

  /**
   * Compute distributions, proportions and total weights of two successor
   * nodes for a given numeric attribute.
   *
   * @param props     proportions of each two branches for each attribute
   * @param dists     class distributions of two branches for each attribute
   * @param att     numeric att split on
   * @param sortedIndices   sorted indices of instances for the attirubte
   * @param weights     weights of instances for the attirbute
   * @param subsetWeights   total weight of two branches split based on the attribute
   * @param giniGains     Gini gains for each attribute
   * @param data     training instances
   * @return       Gini gain the given numeric attribute
   * @throws Exception     if something goes wrong
   */
  protected double numericDistribution(double[][] props, double[][][] dists,
      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
      double[] giniGains, Instances data)
    throws Exception {

    double splitPoint = Double.NaN;
    double[][] dist = null;
    int numClasses = data.numClasses();
    int i; // differ instances with or without missing values

    double[][] currDist = new double[2][numClasses];
    dist = new double[2][numClasses];

    // Move all instances without missing values into second subset
    double[] parentDist = new double[numClasses];
    int missingStart = 0;
    for (int j = 0; j < sortedIndices.length; j++) {
      Instance inst = data.instance(sortedIndices[j]);
      if (!inst.isMissing(att)) {
  missingStart ++;
  currDist[1][(int)inst.classValue()] += weights[j];
      }
      parentDist[(int)inst.classValue()] += weights[j];
    }
    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);

    // Try all possible split points
    double currSplit = data.instance(sortedIndices[0]).value(att);
    double currGiniGain;
    double bestGiniGain = -Double.MAX_VALUE;

    for (i = 0; i < sortedIndices.length; i++) {
      Instance inst = data.instance(sortedIndices[i]);
      if (inst.isMissing(att)) {
  break;
      }
      if (inst.value(att) > currSplit) {

  double[][] tempDist = new double[2][numClasses];
  for (int k=0; k<2; k++) {
    //tempDist[k] = currDist[k];
    System.arraycopy(currDist[k], 0, tempDist[k], 0, tempDist[k].length);
  }

  double[] tempProps = new double[2];
  for (int k=0; k<2; k++) {
    tempProps[k] = Utils.sum(tempDist[k]);
  }

  if (Utils.sum(tempProps) !=0) Utils.normalize(tempProps);

  // split missing values
  int index = missingStart;
  while (index < sortedIndices.length) {
    Instance insta = data.instance(sortedIndices[index]);
    for (int j = 0; j < 2; j++) {
      tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
    }
    index++;
  }

  currGiniGain = computeGiniGain(parentDist,tempDist);

  if (currGiniGain > bestGiniGain) {
    bestGiniGain = currGiniGain;

    // clean split point
    splitPoint = Math.rint((inst.value(att) + currSplit)/2.0*100000)/100000.0;

    for (int j = 0; j < currDist.length; j++) {
      System.arraycopy(tempDist[j], 0, dist[j], 0,
    dist[j].length);
    }
  }
      }
      currSplit = inst.value(att);
      currDist[0][(int)inst.classValue()] += weights[i];
      currDist[1][(int)inst.classValue()] -= weights[i];
    }

    // Compute weights
    int attIndex = att.index();
    props[attIndex] = new double[2];
    for (int k = 0; k < 2; k++) {
      props[attIndex][k] = Utils.sum(dist[k]);
    }
    if (Utils.sum(props[attIndex]) != 0) Utils.normalize(props[attIndex]);

    // Compute subset weights
    subsetWeights[attIndex] = new double[2];
    for (int j = 0; j < 2; j++) {
      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
    }

    // clean Gini gain
    giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;
    dists[attIndex] = dist;

    return splitPoint;
  }

  /**
   * Compute distributions, proportions and total weights of two successor
   * nodes for a given nominal attribute.
   *
   * @param props     proportions of each two branches for each attribute
   * @param dists     class distributions of two branches for each attribute
   * @param att     numeric att split on
   * @param sortedIndices   sorted indices of instances for the attirubte
   * @param weights     weights of instances for the attirbute
   * @param subsetWeights   total weight of two branches split based on the attribute
   * @param giniGains     Gini gains for each attribute
   * @param data     training instances
   * @param useHeuristic   if use heuristic search
   * @return       Gini gain for the given nominal attribute
   * @throws Exception     if something goes wrong
   */
  protected String nominalDistribution(double[][] props, double[][][] dists,
      Attribute att, int[] sortedIndices, double[] weights, double[][] subsetWeights,
      double[] giniGains, Instances data, boolean useHeuristic)
    throws Exception {

    String[] values = new String[att.numValues()];
    int numCat = values.length; // number of values of the attribute
    int numClasses = data.numClasses();

    String bestSplitString = "";
    double bestGiniGain = -Double.MAX_VALUE;

    // class frequency for each value
    int[] classFreq = new int[numCat];
    for (int j=0; j<numCat; j++) classFreq[j] = 0;

    double[] parentDist = new double[numClasses];
    double[][] currDist = new double[2][numClasses];
    double[][] dist = new double[2][numClasses];
    int missingStart = 0;

    for (int i = 0; i < sortedIndices.length; i++) {
      Instance inst = data.instance(sortedIndices[i]);
      if (!inst.isMissing(att)) {
  missingStart++;
  classFreq[(int)inst.value(att)] ++;
      }
      parentDist[(int)inst.classValue()] += weights[i];
    }

    // count the number of values that class frequency is not 0
    int nonEmpty = 0;
    for (int j=0; j<numCat; j++) {
      if (classFreq[j]!=0) nonEmpty ++;
    }

    // attribute values that class frequency is not 0
    String[] nonEmptyValues = new String[nonEmpty];
    int nonEmptyIndex = 0;
    for (int j=0; j<numCat; j++) {
      if (classFreq[j]!=0) {
  nonEmptyValues[nonEmptyIndex] = att.value(j);
  nonEmptyIndex ++;
      }
    }

    // attribute values that class frequency is 0
    int empty = numCat - nonEmpty;
    String[] emptyValues = new String[empty];
    int emptyIndex = 0;
    for (int j=0; j<numCat; j++) {
      if (classFreq[j]==0) {
  emptyValues[emptyIndex] = att.value(j);
  emptyIndex ++;
      }
    }

    if (nonEmpty<=1) {
      giniGains[att.index()] = 0;
      return "";
    }

    // for tow-class probloms
    if (data.numClasses()==2) {

      //// Firstly, for attribute values which class frequency is not zero

      // probability of class 0 for each attribute value
      double[] pClass0 = new double[nonEmpty];
      // class distribution for each attribute value
      double[][] valDist = new double[nonEmpty][2];

      for (int j=0; j<nonEmpty; j++) {
  for (int k=0; k<2; k++) {
    valDist[j][k] = 0;
  }
      }

      for (int i = 0; i < sortedIndices.length; i++) {
  Instance inst = data.instance(sortedIndices[i]);
  if (inst.isMissing(att)) {
    break;
  }

  for (int j=0; j<nonEmpty; j++) {
    if (att.value((int)inst.value(att)).compareTo(nonEmptyValues[j])==0) {
      valDist[j][(int)inst.classValue()] += inst.weight();
      break;
    }
  }
      }

      for (int j=0; j<nonEmpty; j++) {
  double distSum = Utils.sum(valDist[j]);
  if (distSum==0) pClass0[j]=0;
  else pClass0[j] = valDist[j][0]/distSum;
      }

      // sort category according to the probability of the first class
      String[] sortedValues = new String[nonEmpty];
      for (int j=0; j<nonEmpty; j++) {
  sortedValues[j] = nonEmptyValues[Utils.minIndex(pClass0)];
  pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
      }

      // Find a subset of attribute values that maximize Gini decrease

      // for the attribute values that class frequency is not 0
      String tempStr = "";

      for (int j=0; j<nonEmpty-1; j++) {
  currDist = new double[2][numClasses];
  if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
  else tempStr += "|"+ "(" + sortedValues[j] + ")";
  for (int i=0; i<sortedIndices.length;i++) {
    Instance inst = data.instance(sortedIndices[i]);
    if (inst.isMissing(att)) {
      break;
    }

    if (tempStr.indexOf
        ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
      currDist[0][(int)inst.classValue()] += weights[i];
    } else currDist[1][(int)inst.classValue()] += weights[i];
  }

  double[][] tempDist = new double[2][numClasses];
  for (int kk=0; kk<2; kk++) {
    tempDist[kk] = currDist[kk];
  }

  double[] tempProps = new double[2];
  for (int kk=0; kk<2; kk++) {
    tempProps[kk] = Utils.sum(tempDist[kk]);
  }

  if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);

  // split missing values
  int mstart = missingStart;
  while (mstart < sortedIndices.length) {
    Instance insta = data.instance(sortedIndices[mstart]);
    for (int jj = 0; jj < 2; jj++) {
      tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
    }
    mstart++;
  }

  double currGiniGain = computeGiniGain(parentDist,tempDist);

  if (currGiniGain>bestGiniGain) {
    bestGiniGain = currGiniGain;
    bestSplitString = tempStr;
    for (int jj = 0; jj < 2; jj++) {
      //dist[jj] = new double[currDist[jj].length];
      System.arraycopy(tempDist[jj], 0, dist[jj], 0,
    dist[jj].length);
    }
  }
      }
    }

    // multi-class problems - exhaustive search
    else if (!useHeuristic || nonEmpty<=4) {

      // Firstly, for attribute values which class frequency is not zero
      for (int i=0; i<(int)Math.pow(2,nonEmpty-1); i++) {
  String tempStr="";
  currDist = new double[2][numClasses];
  int mod;
  int bit10 = i;
  for (int j=nonEmpty-1; j>=0; j--) {
    mod = bit10%2; // convert from 10bit to 2bit
    if (mod==1) {
      if (tempStr=="") tempStr = "("+nonEmptyValues[j]+")";
      else tempStr += "|" + "("+nonEmptyValues[j]+")";
    }
    bit10 = bit10/2;
  }
  for (int j=0; j<sortedIndices.length;j++) {
    Instance inst = data.instance(sortedIndices[j]);
    if (inst.isMissing(att)) {
      break;
    }

    if (tempStr.indexOf("("+att.value((int)inst.value(att))+")")!=-1) {
      currDist[0][(int)inst.classValue()] += weights[j];
    } else currDist[1][(int)inst.classValue()] += weights[j];
  }

  double[][] tempDist = new double[2][numClasses];
  for (int k=0; k<2; k++) {
    tempDist[k] = currDist[k];
  }

  double[] tempProps = new double[2];
  for (int k=0; k<2; k++) {
    tempProps[k] = Utils.sum(tempDist[k]);
  }

  if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);

  // split missing values
  int index = missingStart;
  while (index < sortedIndices.length) {
    Instance insta = data.instance(sortedIndices[index]);
    for (int j = 0; j < 2; j++) {
      tempDist[j][(int)insta.classValue()] += tempProps[j] * weights[index];
    }
    index++;
  }

  double currGiniGain = computeGiniGain(parentDist,tempDist);

  if (currGiniGain>bestGiniGain) {
    bestGiniGain = currGiniGain;
    bestSplitString = tempStr;
    for (int j = 0; j < 2; j++) {
      //dist[jj] = new double[currDist[jj].length];
      System.arraycopy(tempDist[j], 0, dist[j], 0,
    dist[j].length);
    }
  }
      }
    }

    // huristic search to solve multi-classes problems
    else {
      // Firstly, for attribute values which class frequency is not zero
      int n = nonEmpty;
      int k = data.numClasses()// number of classes of the data
      double[][] P = new double[n][k];      // class probability matrix
      int[] numInstancesValue = new int[n]; // number of instances for an attribute value
      double[] meanClass = new double[k];   // vector of mean class probability
      int numInstances = data.numInstances(); // total number of instances

      // initialize the vector of mean class probability
      for (int j=0; j<meanClass.length; j++) meanClass[j]=0;

      for (int j=0; j<numInstances; j++) {
  Instance inst = (Instance)data.instance(j);
  int valueIndex = 0; // attribute value index in nonEmptyValues
  for (int i=0; i<nonEmpty; i++) {
    if (att.value((int)inst.value(att)).compareToIgnoreCase(nonEmptyValues[i])==0){
      valueIndex = i;
      break;
    }
  }
  P[valueIndex][(int)inst.classValue()]++;
  numInstancesValue[valueIndex]++;
  meanClass[(int)inst.classValue()]++;
      }

      // calculate the class probability matrix
      for (int i=0; i<P.length; i++) {
  for (int j=0; j<P[0].length; j++) {
    if (numInstancesValue[i]==0) P[i][j]=0;
    else P[i][j]/=numInstancesValue[i];
  }
      }

      //calculate the vector of mean class probability
      for (int i=0; i<meanClass.length; i++) {
  meanClass[i]/=numInstances;
      }

      // calculate the covariance matrix
      double[][] covariance = new double[k][k];
      for (int i1=0; i1<k; i1++) {
  for (int i2=0; i2<k; i2++) {
    double element = 0;
    for (int j=0; j<n; j++) {
      element += (P[j][i2]-meanClass[i2])*(P[j][i1]-meanClass[i1])
      *numInstancesValue[j];
    }
    covariance[i1][i2] = element;
  }
      }

      Matrix matrix = new Matrix(covariance);
      weka.core.matrix.EigenvalueDecomposition eigen =
  new weka.core.matrix.EigenvalueDecomposition(matrix);
      double[] eigenValues = eigen.getRealEigenvalues();

      // find index of the largest eigenvalue
      int index=0;
      double largest = eigenValues[0];
      for (int i=1; i<eigenValues.length; i++) {
  if (eigenValues[i]>largest) {
    index=i;
    largest = eigenValues[i];
  }
      }

      // calculate the first principle component
      double[] FPC = new double[k];
      Matrix eigenVector = eigen.getV();
      double[][] vectorArray = eigenVector.getArray();
      for (int i=0; i<FPC.length; i++) {
  FPC[i] = vectorArray[i][index];
      }

      // calculate the first principle component scores
      //System.out.println("the first principle component scores: ");
      double[] Sa = new double[n];
      for (int i=0; i<Sa.length; i++) {
  Sa[i]=0;
  for (int j=0; j<k; j++) {
    Sa[i] += FPC[j]*P[i][j];
  }
      }

      // sort category according to Sa(s)
      double[] pCopy = new double[n];
      System.arraycopy(Sa,0,pCopy,0,n);
      String[] sortedValues = new String[n];
      Arrays.sort(Sa);

      for (int j=0; j<n; j++) {
  sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
  pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
      }

      // for the attribute values that class frequency is not 0
      String tempStr = "";

      for (int j=0; j<nonEmpty-1; j++) {
  currDist = new double[2][numClasses];
  if (tempStr=="") tempStr="(" + sortedValues[j] + ")";
  else tempStr += "|"+ "(" + sortedValues[j] + ")";
  for (int i=0; i<sortedIndices.length;i++) {
    Instance inst = data.instance(sortedIndices[i]);
    if (inst.isMissing(att)) {
      break;
    }

    if (tempStr.indexOf
        ("(" + att.value((int)inst.value(att)) + ")")!=-1) {
      currDist[0][(int)inst.classValue()] += weights[i];
    } else currDist[1][(int)inst.classValue()] += weights[i];
  }

  double[][] tempDist = new double[2][numClasses];
  for (int kk=0; kk<2; kk++) {
    tempDist[kk] = currDist[kk];
  }

  double[] tempProps = new double[2];
  for (int kk=0; kk<2; kk++) {
    tempProps[kk] = Utils.sum(tempDist[kk]);
  }

  if (Utils.sum(tempProps)!=0) Utils.normalize(tempProps);

  // split missing values
  int mstart = missingStart;
  while (mstart < sortedIndices.length) {
    Instance insta = data.instance(sortedIndices[mstart]);
    for (int jj = 0; jj < 2; jj++) {
      tempDist[jj][(int)insta.classValue()] += tempProps[jj] * weights[mstart];
    }
    mstart++;
  }

  double currGiniGain = computeGiniGain(parentDist,tempDist);

  if (currGiniGain>bestGiniGain) {
    bestGiniGain = currGiniGain;
    bestSplitString = tempStr;
    for (int jj = 0; jj < 2; jj++) {
      //dist[jj] = new double[currDist[jj].length];
      System.arraycopy(tempDist[jj], 0, dist[jj], 0,
    dist[jj].length);
    }
  }
      }
    }

    // Compute weights
    int attIndex = att.index();       
    props[attIndex] = new double[2];
    for (int k = 0; k < 2; k++) {
      props[attIndex][k] = Utils.sum(dist[k]);
    }

    if (!(Utils.sum(props[attIndex]) > 0)) {
      for (int k = 0; k < props[attIndex].length; k++) {
  props[attIndex][k] = 1.0 / (double)props[attIndex].length;
      }
    } else {
      Utils.normalize(props[attIndex]);
    }


    // Compute subset weights
    subsetWeights[attIndex] = new double[2];
    for (int j = 0; j < 2; j++) {
      subsetWeights[attIndex][j] += Utils.sum(dist[j]);
    }

    // Then, for the attribute values that class frequency is 0, split it into the
    // most frequent branch
    for (int j=0; j<empty; j++) {
      if (props[attIndex][0]>=props[attIndex][1]) {
  if (bestSplitString=="") bestSplitString = "(" + emptyValues[j] + ")";
  else bestSplitString += "|" + "(" + emptyValues[j] + ")";
      }
    }

    // clean Gini gain for the attribute
    giniGains[attIndex] = Math.rint(bestGiniGain*10000000)/10000000.0;

    dists[attIndex] = dist;
    return bestSplitString;
  }


  /**
   * Split data into two subsets and store sorted indices and weights for two
   * successor nodes.
   *
   * @param subsetIndices   sorted indecis of instances for each attribute
   *         for two successor node
   * @param subsetWeights   weights of instances for each attribute for
   *         two successor node
   * @param att     attribute the split based on
   * @param splitPoint     split point the split based on if att is numeric
   * @param splitStr     split subset the split based on if att is nominal
   * @param sortedIndices   sorted indices of the instances to be split
   * @param weights     weights of the instances to bes split
   * @param data     training data
   * @throws Exception     if something goes wrong 
   */
  protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
      Attribute att, double splitPoint, String splitStr, int[][] sortedIndices,
      double[][] weights, Instances data) throws Exception {

    int j;
    // For each attribute
    for (int i = 0; i < data.numAttributes(); i++) {
      if (i==data.classIndex()) continue;
      int[] num = new int[2];
      for (int k = 0; k < 2; k++) {
  subsetIndices[k][i] = new int[sortedIndices[i].length];
  subsetWeights[k][i] = new double[weights[i].length];
      }

      for (j = 0; j < sortedIndices[i].length; j++) {
  Instance inst = data.instance(sortedIndices[i][j]);
  if (inst.isMissing(att)) {
    // Split instance up
    for (int k = 0; k < 2; k++) {
      if (m_Props[k] > 0) {
        subsetIndices[k][i][num[k]] = sortedIndices[i][j];
        subsetWeights[k][i][num[k]] = m_Props[k] * weights[i][j];
        num[k]++;
      }
    }
  } else {
    int subset;
    if (att.isNumeric())  {
      subset = (inst.value(att) < splitPoint) ? 0 : 1;
    } else { // nominal attribute
      if (splitStr.indexOf
    ("(" + att.value((int)inst.value(att.index()))+")")!=-1) {
        subset = 0;
      } else subset = 1;
    }
    subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
    subsetWeights[subset][i][num[subset]] = weights[i][j];
    num[subset]++;
  }
      }

      // Trim arrays
      for (int k = 0; k < 2; k++) {
  int[] copy = new int[num[k]];
  System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
  subsetIndices[k][i] = copy;
  double[] copyWeights = new double[num[k]];
  System.arraycopy(subsetWeights[k][i], 0 ,copyWeights, 0, num[k]);
  subsetWeights[k][i] = copyWeights;
      }
    }
  }

  /**
   * Updates the numIncorrectModel field for all nodes when subtree (to be
   * pruned) is rooted. This is needed for calculating the alpha-values.
   *
   * @throws Exception   if something goes wrong
   */
  public void modelErrors() throws Exception{
    Evaluation eval = new Evaluation(m_train);

    if (!m_isLeaf) {
      m_isLeaf = true; //temporarily make leaf

      // calculate distribution for evaluation
      eval.evaluateModel(this, m_train);
      m_numIncorrectModel = eval.incorrect();

      m_isLeaf = false;

      for (int i = 0; i < m_Successors.length; i++)
  m_Successors[i].modelErrors();

    } else {
      eval.evaluateModel(this, m_train);
      m_numIncorrectModel = eval.incorrect();
    }      
  }

  /**
   * Updates the numIncorrectTree field for all nodes. This is needed for
   * calculating the alpha-values.
   *
   * @throws Exception   if something goes wrong
   */
  public void treeErrors() throws Exception {
    if (m_isLeaf) {
      m_numIncorrectTree = m_numIncorrectModel;
    } else {
      m_numIncorrectTree = 0;
      for (int i = 0; i < m_Successors.length; i++) {
  m_Successors[i].treeErrors();
  m_numIncorrectTree += m_Successors[i].m_numIncorrectTree;
      }
    }
  }

  /**
   * Updates the alpha field for all nodes.
   *
   * @throws Exception   if something goes wrong
   */
  public void calculateAlphas() throws Exception {

    if (!m_isLeaf) {
      double errorDiff = m_numIncorrectModel - m_numIncorrectTree;
      if (errorDiff <=0) {
  //split increases training error (should not normally happen).
  //prune it instantly.
  makeLeaf(m_train);
  m_Alpha = Double.MAX_VALUE;
      } else {
  //compute alpha
  errorDiff /= m_totalTrainInstances;
  m_Alpha = errorDiff / (double)(numLeaves() - 1);
  long alphaLong = Math.round(m_Alpha*Math.pow(10,10));
  m_Alpha = (double)alphaLong/Math.pow(10,10);
  for (int i = 0; i < m_Successors.length; i++) {
    m_Successors[i].calculateAlphas();
  }
      }
    } else {
      //alpha = infinite for leaves (do not want to prune)
      m_Alpha = Double.MAX_VALUE;
    }
  }

  /**
   * Find the node with minimal alpha value. If two nodes have the same alpha,
   * choose the one with more leave nodes.
   *
   * @param nodeList   list of inner nodes
   * @return     the node to be pruned
   */
  protected SimpleCart nodeToPrune(Vector nodeList) {
    if (nodeList.size()==0) return null;
    if (nodeList.size()==1) return (SimpleCart)nodeList.elementAt(0);
    SimpleCart returnNode = (SimpleCart)nodeList.elementAt(0);
    double baseAlpha = returnNode.m_Alpha;
    for (int i=1; i<nodeList.size(); i++) {
      SimpleCart node = (SimpleCart)nodeList.elementAt(i);
      if (node.m_Alpha < baseAlpha) {
  baseAlpha = node.m_Alpha;
  returnNode = node;
      } else if (node.m_Alpha == baseAlpha) { // break tie
  if (node.numLeaves()>returnNode.numLeaves()) {
    returnNode = node;
  }
      }
    }
    return returnNode;
  }

  /**
   * Compute sorted indices, weights and class probabilities for a given
   * dataset. Return total weights of the data at the node.
   *
   * @param data     training data
   * @param sortedIndices   sorted indices of instances at the node
   * @param weights     weights of instances at the node
   * @param classProbs     class probabilities at the node
   * @return total     weights of instances at the node
   * @throws Exception     if something goes wrong
   */
  protected double computeSortedInfo(Instances data, int[][] sortedIndices, double[][] weights,
      double[] classProbs) throws Exception {

    // Create array of sorted indices and weights
    double[] vals = new double[data.numInstances()];
    for (int j = 0; j < data.numAttributes(); j++) {
      if (j==data.classIndex()) continue;
      weights[j] = new double[data.numInstances()];

      if (data.attribute(j).isNominal()) {

  // Handling nominal attributes. Putting indices of
  // instances with missing values at the end.
  sortedIndices[j] = new int[data.numInstances()];
  int count = 0;
  for (int i = 0; i < data.numInstances(); i++) {
    Instance inst = data.instance(i);
    if (!inst.isMissing(j)) {
      sortedIndices[j][count] = i;
      weights[j][count] = inst.weight();
      count++;
    }
  }
  for (int i = 0; i < data.numInstances(); i++) {
    Instance inst = data.instance(i);
    if (inst.isMissing(j)) {
      sortedIndices[j][count] = i;
      weights[j][count] = inst.weight();
      count++;
    }
  }
      } else {

  // Sorted indices are computed for numeric attributes
  // missing values instances are put to end
  for (int i = 0; i < data.numInstances(); i++) {
    Instance inst = data.instance(i);
    vals[i] = inst.value(j);
  }
  sortedIndices[j] = Utils.sort(vals);
  for (int i = 0; i < data.numInstances(); i++) {
    weights[j][i] = data.instance(sortedIndices[j][i]).weight();
  }
      }
    }

    // Compute initial class counts
    double totalWeight = 0;
    for (int i = 0; i < data.numInstances(); i++) {
      Instance inst = data.instance(i);
      classProbs[(int)inst.classValue()] += inst.weight();
      totalWeight += inst.weight();
    }

    return totalWeight;
  }

  /**
   * Compute and return gini gain for given distributions of a node and its
   * successor nodes.
   *
   * @param parentDist   class distributions of parent node
   * @param childDist   class distributions of successor nodes
   * @return     Gini gain computed
   */
  protected double computeGiniGain(double[] parentDist, double[][] childDist) {
    double totalWeight = Utils.sum(parentDist);
    if (totalWeight==0) return 0;

    double leftWeight = Utils.sum(childDist[0]);
    double rightWeight = Utils.sum(childDist[1]);

    double parentGini = computeGini(parentDist, totalWeight);
    double leftGini = computeGini(childDist[0],leftWeight);
    double rightGini = computeGini(childDist[1], rightWeight);

    return parentGini - leftWeight/totalWeight*leftGini -
    rightWeight/totalWeight*rightGini;
  }

  /**
   * Compute and return gini index for a given distribution of a node.
   *
   * @param dist   class distributions
   * @param total   class distributions
   * @return     Gini index of the class distributions
   */
  protected double computeGini(double[] dist, double total) {
    if (total==0) return 0;
    double val = 0;
    for (int i=0; i<dist.length; i++) {
      val += (dist[i]/total)*(dist[i]/total);
    }
    return 1- val;
  }

  /**
   * Computes class probabilities for instance using the decision tree.
   *
   * @param instance   the instance for which class probabilities is to be computed
   * @return     the class probabilities for the given instance
   * @throws Exception   if something goes wrong
   */
  public double[] distributionForInstance(Instance instance)
  throws Exception {
    if (!m_isLeaf) {
      // value of split attribute is missing
      if (instance.isMissing(m_Attribute)) {
  double[] returnedDist = new double[m_ClassProbs.length];

  for (int i = 0; i < m_Successors.length; i++) {
    double[] help =
      m_Successors[i].distributionForInstance(instance);
    if (help != null) {
      for (int j = 0; j < help.length; j++) {
        returnedDist[j] += m_Props[i] * help[j];
      }
    }
  }
  return returnedDist;
      }

      // split attribute is nonimal
      else if (m_Attribute.isNominal()) {
  if (m_SplitString.indexOf("(" +
      m_Attribute.value((int)instance.value(m_Attribute)) + ")")!=-1)
    return  m_Successors[0].distributionForInstance(instance);
  else return  m_Successors[1].distributionForInstance(instance);
      }

      // split attribute is numeric
      else {
  if (instance.value(m_Attribute) < m_SplitValue)
    return m_Successors[0].distributionForInstance(instance);
  else
    return m_Successors[1].distributionForInstance(instance);
      }
    }

    // leaf node
    else return m_ClassProbs;
  }

  /**
   * Make the node leaf node.
   *
   * @param data   trainging data
   */
  protected void makeLeaf(Instances data) {
    m_Attribute = null;
    m_isLeaf = true;
    m_ClassValue=Utils.maxIndex(m_ClassProbs);
    m_ClassAttribute = data.classAttribute();
  }

  /**
   * Prints the decision tree using the protected toString method from below.
   *
   * @return     a textual description of the classifier
   */
  public String toString() {
    if ((m_ClassProbs == null) && (m_Successors == null)) {
      return "CART Tree: No model built yet.";
    }

    return "CART Decision Tree\n" + toString(0)+"\n\n"
    +"Number of Leaf Nodes: "+numLeaves()+"\n\n" +
    "Size of the Tree: "+numNodes();
  }

  /**
   * Outputs a tree at a certain level.
   *
   * @param level   the level at which the tree is to be printed
   * @return     a tree at a certain level
   */
  protected String toString(int level) {

    StringBuffer text = new StringBuffer();
    // if leaf nodes
    if (m_Attribute == null) {
      if (Instance.isMissingValue(m_ClassValue)) {
  text.append(": null");
      } else {
  double correctNum = (int)(m_Distribution[Utils.maxIndex(m_Distribution)]*100)/
  100.0;
  double wrongNum = (int)((Utils.sum(m_Distribution) -
      m_Distribution[Utils.maxIndex(m_Distribution)])*100)/100.0;
  String str = "("  + correctNum + "/" + wrongNum + ")";
  text.append(": " + m_ClassAttribute.value((int) m_ClassValue)+ str);
      }
    } else {
      for (int j = 0; j < 2; j++) {
  text.append("\n");
  for (int i = 0; i < level; i++) {
    text.append("|  ");
  }
  if (j==0) {
    if (m_Attribute.isNumeric())
      text.append(m_Attribute.name() + " < " + m_SplitValue);
    else
      text.append(m_Attribute.name() + "=" + m_SplitString);
  } else {
    if (m_Attribute.isNumeric())
      text.append(m_Attribute.name() + " >= " + m_SplitValue);
    else
      text.append(m_Attribute.name() + "!=" + m_SplitString);
  }
  text.append(m_Successors[j].toString(level + 1));
      }
    }
    return text.toString();
  }

  /**
   * Compute size of the tree.
   *
   * @return     size of the tree
   */
  public int numNodes() {
    if (m_isLeaf) {
      return 1;
    } else {
      int size =1;
      for (int i=0;i<m_Successors.length;i++) {
  size+=m_Successors[i].numNodes();
      }
      return size;
    }
  }

  /**
   * Method to count the number of inner nodes in the tree.
   *
   * @return     the number of inner nodes
   */
  public int numInnerNodes(){
    if (m_Attribute==null) return 0;
    int numNodes = 1;
    for (int i = 0; i < m_Successors.length; i++)
      numNodes += m_Successors[i].numInnerNodes();
    return numNodes;
  }

  /**
   * Return a list of all inner nodes in the tree.
   *
   * @return     the list of all inner nodes
   */
  protected Vector getInnerNodes(){
    Vector nodeList = new Vector();
    fillInnerNodes(nodeList);
    return nodeList;
  }

  /**
   * Fills a list with all inner nodes in the tree.
   *
   * @param nodeList   the list to be filled
   */
  protected void fillInnerNodes(Vector nodeList) {
    if (!m_isLeaf) {
      nodeList.add(this);
      for (int i = 0; i < m_Successors.length; i++)
  m_Successors[i].fillInnerNodes(nodeList);
    }
  }

  /**
   * Compute number of leaf nodes.
   *
   * @return     number of leaf nodes
   */
  public int numLeaves() {
    if (m_isLeaf) return 1;
    else {
      int size=0;
      for (int i=0;i<m_Successors.length;i++) {
  size+=m_Successors[i].numLeaves();
      }
      return size;
    }
  }

  /**
   * Returns an enumeration describing the available options.
   *
   * @return     an enumeration of all the available options.
   */
  public Enumeration listOptions() {
    Vector   result;
    Enumeration  en;
   
    result = new Vector();
   
    en = super.listOptions();
    while (en.hasMoreElements())
      result.addElement(en.nextElement());

    result.addElement(new Option(
  "\tThe minimal number of instances at the terminal nodes.\n"
  + "\t(default 2)",
  "M", 1, "-M <min no>"));
   
    result.addElement(new Option(
  "\tThe number of folds used in the minimal cost-complexity pruning.\n"
  + "\t(default 5)",
  "N", 1, "-N <num folds>"));
   
    result.addElement(new Option(
  "\tDon't use the minimal cost-complexity pruning.\n"
  + "\t(default yes).",
  "U", 0, "-U"));
   
    result.addElement(new Option(
  "\tDon't use the heuristic method for binary split.\n"
  + "\t(default true).",
  "H", 0, "-H"));
   
    result.addElement(new Option(
  "\tUse 1 SE rule to make pruning decision.\n"
  + "\t(default no).",
  "A", 0, "-A"));
   
    result.addElement(new Option(
  "\tPercentage of training data size (0-1].\n"
  + "\t(default 1).",
  "C", 1, "-C"));

    return result.elements();
  }

  /**
   * Parses a given list of options. <p/>
   *
   <!-- options-start -->
   * Valid options are: <p/>
   *
   * <pre> -S &lt;num&gt;
   *  Random number seed.
   *  (default 1)</pre>
   *
   * <pre> -D
   *  If set, classifier is run in debug mode and
   *  may output additional info to the console</pre>
   *
   * <pre> -M &lt;min no&gt;
   *  The minimal number of instances at the terminal nodes.
   *  (default 2)</pre>
   *
   * <pre> -N &lt;num folds&gt;
   *  The number of folds used in the minimal cost-complexity pruning.
   *  (default 5)</pre>
   *
   * <pre> -U
   *  Don't use the minimal cost-complexity pruning.
   *  (default yes).</pre>
   *
   * <pre> -H
   *  Don't use the heuristic method for binary split.
   *  (default true).</pre>
   *
   * <pre> -A
   *  Use 1 SE rule to make pruning decision.
   *  (default no).</pre>
   *
   * <pre> -C
   *  Percentage of training data size (0-1].
   *  (default 1).</pre>
   *
   <!-- options-end -->
   *
   * @param options the list of options as an array of strings
   * @throws Exception if an options is not supported
   */
  public void setOptions(String[] options) throws Exception {
    String  tmpStr;
   
    super.setOptions(options);
   
    tmpStr = Utils.getOption('M', options);
    if (tmpStr.length() != 0)
      setMinNumObj(Double.parseDouble(tmpStr));
    else
      setMinNumObj(2);

    tmpStr = Utils.getOption('N', options);
    if (tmpStr.length()!=0)
      setNumFoldsPruning(Integer.parseInt(tmpStr));
    else
      setNumFoldsPruning(5);

    setUsePrune(!Utils.getFlag('U',options));
    setHeuristic(!Utils.getFlag('H',options));
    setUseOneSE(Utils.getFlag('A',options));

    tmpStr = Utils.getOption('C', options);
    if (tmpStr.length()!=0)
      setSizePer(Double.parseDouble(tmpStr));
    else
      setSizePer(1);

    Utils.checkForRemainingOptions(options);
  }

  /**
   * Gets the current settings of the classifier.
   *
   * @return     the current setting of the classifier
   */
  public String[] getOptions() {
    int         i;
    Vector      result;
    String[]    options;

    result = new Vector();

    options = super.getOptions();
    for (i = 0; i < options.length; i++)
      result.add(options[i]);

    result.add("-M");
    result.add("" + getMinNumObj());
   
    result.add("-N");
    result.add("" + getNumFoldsPruning());
   
    if (!getUsePrune())
      result.add("-U");
   
    if (!getHeuristic())
      result.add("-H");
   
    if (getUseOneSE())
      result.add("-A");
   
    result.add("-C");
    result.add("" + getSizePer());

    return (String[]) result.toArray(new String[result.size()]);   
  }

  /**
   * Return an enumeration of the measure names.
   *
   * @return     an enumeration of the measure names
   */
  public Enumeration enumerateMeasures() {
    Vector result = new Vector();
   
    result.addElement("measureTreeSize");
   
    return result.elements();
  }

  /**
   * Return number of tree size.
   *
   * @return     number of tree size
   */
  public double measureTreeSize() {
    return numNodes();
  }

  /**
   * Returns the value of the named measure.
   *
   * @param additionalMeasureName   the name of the measure to query for its value
   * @return         the value of the named measure
   * @throws IllegalArgumentException   if the named measure is not supported
   */
  public double getMeasure(String additionalMeasureName) {
    if (additionalMeasureName.compareToIgnoreCase("measureTreeSize") == 0) {
      return measureTreeSize();
    } else {
      throw new IllegalArgumentException(additionalMeasureName
    + " not supported (Cart pruning)");
    }
  }

  /**
   * Returns the tip text for this property
   *
   * @return     tip text for this property suitable for
   *       displaying in the explorer/experimenter gui
   */
  public String minNumObjTipText() {
    return "The minimal number of observations at the terminal nodes (default 2).";
  }

  /**
   * Set minimal number of instances at the terminal nodes.
   *
   * @param value   minimal number of instances at the terminal nodes
   */
  public void setMinNumObj(double value) {
    m_minNumObj = value;
  }

  /**
   * Get minimal number of instances at the terminal nodes.
   *
   * @return     minimal number of instances at the terminal nodes
   */
  public double getMinNumObj() {
    return m_minNumObj;
  }

  /**
   * Returns the tip text for this property
   *
   * @return     tip text for this property suitable for
   *       displaying in the explorer/experimenter gui
   */
  public String numFoldsPruningTipText() {
    return "The number of folds in the internal cross-validation (default 5).";
  }

  /**
   * Set number of folds in internal cross-validation.
   *
   * @param value   number of folds in internal cross-validation.
   */
  public void setNumFoldsPruning(int value) {
    m_numFoldsPruning = value;
  }

  /**
   * Set number of folds in internal cross-validation.
   *
   * @return     number of folds in internal cross-validation.
   */
  public int getNumFoldsPruning() {
    return m_numFoldsPruning;
  }

  /**
   * Return the tip text for this property
   *
   * @return     tip text for this property suitable for displaying in
   *       the explorer/experimenter gui.
   */
  public String usePruneTipText() {
    return "Use minimal cost-complexity pruning (default yes).";
  }

  /**
   * Set if use minimal cost-complexity pruning.
   *
   * @param value   if use minimal cost-complexity pruning
   */
  public void setUsePrune(boolean value) {
    m_Prune = value;
  }

  /**
   * Get if use minimal cost-complexity pruning.
   *
   * @return     if use minimal cost-complexity pruning
   */
  public boolean getUsePrune() {
    return m_Prune;
  }

  /**
   * Returns the tip text for this property
   *
   * @return     tip text for this property suitable for
   *       displaying in the explorer/experimenter gui.
   */
  public String heuristicTipText() {
    return
        "If heuristic search is used for binary split for nominal attributes "
      + "in multi-class problems (default yes).";
  }

  /**
   * Set if use heuristic search for nominal attributes in multi-class problems.
   *
   * @param value   if use heuristic search for nominal attributes in
   *       multi-class problems
   */
  public void setHeuristic(boolean value) {
    m_Heuristic = value;
  }

  /**
   * Get if use heuristic search for nominal attributes in multi-class problems.
   *
   * @return     if use heuristic search for nominal attributes in
   *       multi-class problems
   */
  public boolean getHeuristic() {return m_Heuristic;}

  /**
   * Returns the tip text for this property
   *
   * @return     tip text for this property suitable for
   *       displaying in the explorer/experimenter gui.
   */
  public String useOneSETipText() {
    return "Use the 1SE rule to make pruning decisoin.";
  }

  /**
   * Set if use the 1SE rule to choose final model.
   *
   * @param value   if use the 1SE rule to choose final model
   */
  public void setUseOneSE(boolean value) {
    m_UseOneSE = value;
  }

  /**
   * Get if use the 1SE rule to choose final model.
   *
   * @return     if use the 1SE rule to choose final model
   */
  public boolean getUseOneSE() {
    return m_UseOneSE;
  }

  /**
   * Returns the tip text for this property
   *
   * @return     tip text for this property suitable for
   *       displaying in the explorer/experimenter gui.
   */
  public String sizePerTipText() {
    return "The percentage of the training set size (0-1, 0 not included).";
  }

  /**
   * Set training set size.
   *
   * @param value   training set size
   */ 
  public void setSizePer(double value) {
    if ((value <= 0) || (value > 1))
      System.err.println(
    "The percentage of the training set size must be in range 0 to 1 "
    + "(0 not included) - ignored!");
    else
      m_SizePer = value;
  }

  /**
   * Get training set size.
   *
   * @return     training set size
   */
  public double getSizePer() {
    return m_SizePer;
  }
 
  /**
   * Returns the revision string.
   *
   * @return    the revision
   */
  public String getRevision() {
    return RevisionUtils.extract("$Revision: 1.4 $");
  }

  /**
   * Main method.
   * @param args the options for the classifier
   */
  public static void main(String[] args) {
    runClassifier(new SimpleCart(), args);
  }
}
TOP

Related Classes of weka.classifiers.trees.SimpleCart

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.