Package org.apache.mahout.classifier.df.split

Source Code of org.apache.mahout.classifier.df.split.OptIgSplit

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.df.split;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataUtils;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;

import java.util.Arrays;

/**
* Optimized implementation of IgSplit<br>
* This class can be used when the criterion variable is the categorical attribute.
*/
public class OptIgSplit extends IgSplit {

  private int[][] counts;

  private int[] countAll;

  private int[] countLess;

  @Override
  public Split computeSplit(Data data, int attr) {
    if (data.getDataset().isNumerical(attr)) {
      return numericalSplit(data, attr);
    } else {
      return categoricalSplit(data, attr);
    }
  }

  /**
   * Computes the split for a CATEGORICAL attribute
   */
  private static Split categoricalSplit(Data data, int attr) {
    double[] values = data.values(attr);
    int[][] counts = new int[values.length][data.getDataset().nblabels()];
    int[] countAll = new int[data.getDataset().nblabels()];

    Dataset dataset = data.getDataset();

    // compute frequencies
    for (int index = 0; index < data.size(); index++) {
      Instance instance = data.get(index);
      counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
      countAll[(int) dataset.getLabel(instance)]++;
    }

    int size = data.size();
    double hy = entropy(countAll, size); // H(Y)
    double hyx = 0.0; // H(Y|X)
    double invDataSize = 1.0 / size;

    for (int index = 0; index < values.length; index++) {
      size = DataUtils.sum(counts[index]);
      hyx += size * invDataSize * entropy(counts[index], size);
    }

    double ig = hy - hyx;
    return new Split(attr, ig);
  }

  /**
   * Return the sorted list of distinct values for the given attribute
   */
  private static double[] sortedValues(Data data, int attr) {
    double[] values = data.values(attr);
    Arrays.sort(values);

    return values;
  }

  /**
   * Instantiates the counting arrays
   */
  void initCounts(Data data, double[] values) {
    counts = new int[values.length][data.getDataset().nblabels()];
    countAll = new int[data.getDataset().nblabels()];
    countLess = new int[data.getDataset().nblabels()];
  }

  void computeFrequencies(Data data, int attr, double[] values) {
    Dataset dataset = data.getDataset();

    for (int index = 0; index < data.size(); index++) {
      Instance instance = data.get(index);
      counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
      countAll[(int) dataset.getLabel(instance)]++;
    }
  }

  /**
   * Computes the best split for a NUMERICAL attribute
   */
  Split numericalSplit(Data data, int attr) {
    double[] values = sortedValues(data, attr);

    initCounts(data, values);

    computeFrequencies(data, attr, values);

    int size = data.size();
    double hy = entropy(countAll, size);
    double invDataSize = 1.0 / size;

    int best = -1;
    double bestIg = -1.0;

    // try each possible split value
    for (int index = 0; index < values.length; index++) {
      double ig = hy;

      // instance with attribute value < values[index]
      size = DataUtils.sum(countLess);
      ig -= size * invDataSize * entropy(countLess, size);

      // instance with attribute value >= values[index]
      size = DataUtils.sum(countAll);
      ig -= size * invDataSize * entropy(countAll, size);

      if (ig > bestIg) {
        bestIg = ig;
        best = index;
      }

      DataUtils.add(countLess, counts[index]);
      DataUtils.dec(countAll, counts[index]);
    }

    if (best == -1) {
      throw new IllegalStateException("no best split found !");
    }
    return new Split(attr, bestIg, values[best]);
  }

  /**
   * Computes the Entropy
   *
   * @param counts   counts[i] = numInstances with label i
   * @param dataSize numInstances
   */
  private static double entropy(int[] counts, int dataSize) {
    if (dataSize == 0) {
      return 0.0;
    }

    double entropy = 0.0;
    double invDataSize = 1.0 / dataSize;

    for (int count : counts) {
      if (count == 0) {
        continue; // otherwise we get a NaN
      }
      double p = count * invDataSize;
      entropy += -p * Math.log(p) / LOG2;
    }

    return entropy;
  }

}
TOP

Related Classes of org.apache.mahout.classifier.df.split.OptIgSplit

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.