Package cc.mallet.classify

Source Code of cc.mallet.classify.C45$Node

/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org.  For further
information, see the file `LICENSE' included with this distribution. */





package cc.mallet.classify;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.logging.Logger;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.GainRatio;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;


/**
* A C4.5 Decision Tree classifier.
*
* @see C45Trainer
* @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a>
*/
public class C45 extends Classifier implements Boostable, Serializable
{
  private static Logger logger = MalletLogger.getLogger(C45.class.getName());
  Node m_root;
 
  public C45 (Pipe instancePipe, C45.Node root)
  {
    super (instancePipe);
    m_root = root;
  }
 
  public Node getRoot ()
  {
    return m_root;
  }
 
  private Node getLeaf (Node node, FeatureVector fv)
  {
    if (node.getLeftChild() == null && node.getRightChild() == null)
      return node;
    else if (fv.value(node.getGainRatio().getMaxValuedIndex()) <= node.getGainRatio().getMaxValuedThreshold())
      return getLeaf(node.getLeftChild(), fv);
    else
      return getLeaf(node.getRightChild(), fv);
  }
 
  public Classification classify (Instance instance)
  {
    FeatureVector fv = (FeatureVector) instance.getData ();
    assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());
   
    Node leaf = getLeaf(m_root, fv);
    return new Classification (instance, this, leaf.getGainRatio().getBaseLabelDistribution());
  }
 
  /**
   * Prune the tree using minimum description length
   */
  public void prune()
  {
    getRoot().computeCostAndPrune();
  }
 
  /**
   * @return the total number of nodes in this tree
   */
  public int getSize()
  {
    Node root = getRoot();       
    if (root == null)
      return 0;
    return 1+root.getNumDescendants();
  }
 
  /**
   * Prints the tree
   */
  public void print()
  {
    if (getRoot() != null)
      getRoot().print();
  }
 
  public static class Node implements Serializable
  {
    private static final long serialVersionUID = 1L;
   
    GainRatio m_gainRatio;
    // the entire set of instances given to the root node
    InstanceList m_ilist;
    // indices of instances at this node
    int[] m_instIndices;
    // data vocabulary
    Alphabet m_dataDict;
    // mininum number of instances allowed in this node
    int m_minNumInsts;
    Node m_parent, m_leftChild, m_rightChild;
   
    public Node(InstanceList ilist, Node parent, int minNumInsts)
    {
      this(ilist, parent, minNumInsts, null);
    }
   
    public Node(InstanceList ilist, Node parent, int minNumInsts, int[] instIndices)
    {
      if (instIndices == null) {
        instIndices = new int[ilist.size()];
        for (int ii = 0; ii < instIndices.length; ii++)
          instIndices[ii] = ii;
      }
      m_gainRatio = GainRatio.createGainRatio(ilist, instIndices, minNumInsts);
      m_ilist = ilist;
      m_instIndices = instIndices;
      m_dataDict = m_ilist.getDataAlphabet();
      m_minNumInsts = minNumInsts;
      m_parent = parent;
      m_leftChild = m_rightChild = null;
    }
   
    /** The root has depth zero. */
    public int depth ()
    {
      int depth = 0;
      Node p = m_parent;
      while (p != null) {
        p = p.m_parent;
        depth++;
      }
      return depth;
    }
   
    public int getSize() { return m_instIndices.length; }
    public boolean isLeaf() { return (m_leftChild == null && m_rightChild == null); }
    public boolean isRoot() { return m_parent == null; }
    public Node getParent() { return m_parent; }
    public Node getLeftChild() { return m_leftChild; }
    public Node getRightChild() { return m_rightChild; }
    public GainRatio getGainRatio() { return m_gainRatio; }
    public Object getSplitFeature() { return m_dataDict.lookupObject(m_gainRatio.getMaxValuedIndex()); }
   
    public InstanceList getInstances()
    {
      InstanceList ret = new InstanceList(m_ilist.getPipe());
      for (int ii = 0; ii < m_instIndices.length; ii++)
        ret.add(m_ilist.get(m_instIndices[ii]));
      return ret;
    }
   
    /**
     * Count the number of non-leaf descendant nodes
     */
    public int getNumDescendants()
    {
      if (isLeaf())
        return 0;
      int count = 0;
      if (! getLeftChild().isLeaf())
        count += 1 + getLeftChild().getNumDescendants();
      if (! getRightChild().isLeaf())
        count += 1 + getRightChild().getNumDescendants();
      return count;
    }
   
    public void split()
    {
      if (m_ilist == null)
        throw new IllegalStateException ("Frozen.  Cannot split.");
      int numLeftChildren = 0;
      boolean[] toLeftChild = new boolean[m_instIndices.length];
      for (int i = 0; i < m_instIndices.length; i++) {
        Instance instance = m_ilist.get(m_instIndices[i]);
        FeatureVector fv = (FeatureVector) instance.getData();
        if (fv.value (m_gainRatio.getMaxValuedIndex()) <= m_gainRatio.getMaxValuedThreshold()) {
          toLeftChild[i] = true;
          numLeftChildren++;
        }
        else
          toLeftChild[i] = false;
      }
      logger.info("leftChild.size=" + numLeftChildren
          + " rightChild.size=" + (m_instIndices.length-numLeftChildren));
      int[] leftIndices = new int[numLeftChildren];
      int[] rightIndices = new int[m_instIndices.length - numLeftChildren];
      int li = 0, ri = 0;
      for (int i = 0; i < m_instIndices.length; i++) {
        if (toLeftChild[i])
          leftIndices[li++] = m_instIndices[i];
        else
          rightIndices[ri++] = m_instIndices[i];
      }
      m_leftChild = new Node(m_ilist, this, m_minNumInsts, leftIndices);
      m_rightChild = new Node(m_ilist, this, m_minNumInsts, rightIndices);
    }
   
    public double computeCostAndPrune()
    {
      double costS = getMDL();

      if (isLeaf())
        return costS + 1;

      double minCost1 = getLeftChild().computeCostAndPrune();
      double minCost2 = getRightChild().computeCostAndPrune();
      double costSplit = Math.log(m_gainRatio.getNumSplitPointsForBestFeature()) / GainRatio.log2;
      double minCostN = Math.min(costS+1, costSplit+1+minCost1+minCost2);

      if (Maths.almostEquals(minCostN, costS+1))
        m_leftChild = m_rightChild = null;

      return minCostN;
    }
   
    /**
     * Calculates the minimum description length of this node, i.e.,
     * the length of the binary encoding that describes the feature
     * and the split value used at this node
     */
    public double getMDL()
    {
      int numClasses = m_ilist.getTargetAlphabet().size();
      double mdl = getSize() * getGainRatio().getBaseEntropy();
      mdl += ((numClasses-1) * Math.log(getSize() / 2.0)) / (2 * GainRatio.log2);
      double piPow = Math.pow(Math.PI, numClasses/2.0);
      double gammaVal = Maths.gamma(numClasses/2.0);
      mdl += Math.log(piPow/gammaVal) / GainRatio.log2;
      return mdl;
    }
   
    /**
     * Saves memory by allowing ilist to be garbage collected
     * (deletes this node's associated instance list)
     */
    public void stopGrowth ()
    {
      if (m_leftChild != null)
        m_leftChild.stopGrowth();
      if (m_rightChild != null)
        m_rightChild.stopGrowth();   
      m_ilist = null;
    }
   
    public String getName()
    {
      return getStringBufferName().toString();
    }
   
    public StringBuffer getStringBufferName()
    {
      StringBuffer sb = new StringBuffer();
      if (m_parent == null)
        return sb.append("root");
      else if (m_parent.getParent() == null) {
        sb.append("(\"");
        sb.append(m_dataDict.lookupObject(m_parent.getGainRatio().getMaxValuedIndex()).toString());
        sb.append("\"");
        if (m_parent.getLeftChild() == this)
          sb.append(" <= ");
        else
          sb.append(" > ");
        sb.append(m_parent.getGainRatio().getMaxValuedThreshold());
        return sb.append(")");
      }
      else {
        sb.append(m_parent.getStringBufferName());
        sb.append(" && (\"");
        sb.append(m_dataDict.lookupObject(m_parent.getGainRatio().getMaxValuedIndex()).toString());
        sb.append("\"");
        if (m_parent.getLeftChild() == this)
          sb.append(" <= ");
        else
          sb.append(" > ");
        sb.append(m_parent.getGainRatio().getMaxValuedThreshold());
        return sb.append(")");
      }
    }
   
    /**
     * Prints the tree rooted at this node
     */
    public void print()
    {
      print("");
    }
   
    public void print(String prefix)
    {   
      if (isLeaf()) {
        int bestLabelIndex = getGainRatio().getBaseLabelDistribution().getBestIndex();
        int numMajorityLabel = (int) (getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * getSize());
        System.out.println("root:" + getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + getSize());
      }
      else {
        String featName = m_dataDict.lookupObject(getGainRatio().getMaxValuedIndex()).toString();
        double threshold = getGainRatio().getMaxValuedThreshold();
        System.out.print(prefix + "\"" + featName + "\" <= " + threshold + ":");
        if (m_leftChild.isLeaf()) {
          int bestLabelIndex = m_leftChild.getGainRatio().getBaseLabelDistribution().getBestIndex();
          int numMajorityLabel = (int) (m_leftChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * m_leftChild.getSize());
          System.out.println(m_leftChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + m_leftChild.getSize());
        }
        else {
          System.out.println();
          m_leftChild.print(prefix + "|    ");
        }       
        System.out.print(prefix + "\"" + featName + "\" > " + threshold + ":");
        if (m_rightChild.isLeaf()) {
          int bestLabelIndex = m_rightChild.getGainRatio().getBaseLabelDistribution().getBestIndex();
          int numMajorityLabel = (int) (m_rightChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * m_rightChild.getSize());
          System.out.println(m_rightChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + m_rightChild.getSize());
        }
        else {
          System.out.println();
          m_rightChild.print(prefix + "|    ");
        }
      }
    }
   
  }
 
  // Serialization
  // serialVersionUID is overriden to prevent innocuous changes in this
  // class from making the serialization mechanism think the external
  // format has changed.
 
  private static final long serialVersionUID = 1;
  private static final int CURRENT_SERIAL_VERSION = 1;
 
  private void writeObject(ObjectOutputStream out) throws IOException
  {
    out.writeInt(CURRENT_SERIAL_VERSION);
    out.writeObject(getInstancePipe());
    out.writeObject(m_root);
  }
 
  private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
    int version = in.readInt();
    if (version != CURRENT_SERIAL_VERSION)
      throw new ClassNotFoundException("Mismatched C45 versions: wanted " +
          CURRENT_SERIAL_VERSION + ", got " +
          version);
    instancePipe = (Pipe) in.readObject();
    m_root = (Node) in.readObject();
   
  }
 
}
TOP

Related Classes of cc.mallet.classify.C45$Node

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.