Package com.enigmastation.classifier.impl

Source Code of com.enigmastation.classifier.impl.ClassifierImpl

package com.enigmastation.classifier.impl;

import java.util.Collections;
import java.util.Map;
import java.util.Set;

import com.enigmastation.classifier.CategoryIncrement;
import com.enigmastation.classifier.Classifier;
import com.enigmastation.classifier.ClassifierDataModelFactory;
import com.enigmastation.classifier.ClassifierListener;
import com.enigmastation.classifier.FeatureIncrement;
import com.enigmastation.extractors.WordLister;
import com.enigmastation.extractors.impl.StemmingWordLister;
import com.google.common.collect.Sets;

/**
* This is a simple Bayesian calculation class. It was ported from Python
* contained in the book
* "<a href="http://www.oreilly.com/catalog/9780596529321/index
* .html">Programming Collective Intelligence</a>," by Toby Segaran.
*
* @author <a href="mailto:joeo@enigmastation.com">Joseph B. Ottinger</a>
* @version $Revision: 36 $
*/
public class ClassifierImpl implements Classifier {

  /**
   *
   */
  private static final long serialVersionUID = -3349599753555172369L;
  protected WordLister wordLister = null;
  private Set<ClassifierListener> trainingListeners;
  private ClassifierDataModelFactory classifierDataModelFactory;
  private boolean initialized = false;

  public ClassifierImpl() {
    init();
  }

  public synchronized void init() {
    if (!initialized) {
      if (getClassifierDataModelFactory() == null) {
        setClassifierDataModelFactory(new BasicClassifierDataModelFactory());
      }

      if (wordLister == null) {
        wordLister = new StemmingWordLister();
      }

      initialized = true;
    }
  }

  public ClassifierDataModelFactory getClassifierDataModelFactory() {
    return classifierDataModelFactory;
  }

  public void setClassifierDataModelFactory(
      ClassifierDataModelFactory classifierDataModelFactory) {
    if (getClassifierDataModelFactory() != null) {
      throw new IllegalStateException(
          "Cannot set ModelFactory twice; old type is "
              + getClassifierDataModelFactory().getClass()
                  .getName());
    }
    this.classifierDataModelFactory = classifierDataModelFactory;
  }

  public WordLister getWordLister() {
    return wordLister;
  }

  public void setWordLister(WordLister wordLister) {
    this.wordLister = wordLister;
  }

  public void addListener(ClassifierListener listener) {
    if (trainingListeners == null)
      trainingListeners = Sets.newHashSet();
    trainingListeners.add(listener);
  }

  /**
   * Increase the count of a feature/category pair.
   * <p/>
   * Direct port from Segaran's book, including method name
   *
   * @param feature
   *            the feature (the 'word')
   * @param category
   *            the category
   */

  void incf(String feature, String category) {

    // Map<String,Map<String,Integer>> featureMap = this.categoryFeatureMap;
    // Map<String,Integer> fm = featureMap.get(feature);
    Map<String, Integer> fm = this.classifierDataModelFactory
        .getFeatureMap(feature);
    if (fm == null) {
      // throw new
      // IllegalStateException("You must either supply all classifiers in a map or supply a factory");
      throw new IllegalStateException(
          "You must be able to create a feature map");
    }

    incrementCategory(fm, category);

    FeatureIncrement fi = null;
    if (trainingListeners != null) {
      for (ClassifierListener l : trainingListeners) {
        if (fi == null) {
          fi = new FeatureIncrement(feature, category,
              fm.get(category));
        }
        l.handleFeatureUpdate(fi);
      }
    }

  }

  /**
   * Increase the count of a category. Direct port from Segaran's book,
   * including method name
   *
   * @param category
   *            the category to increment
   */
  void incc(String category) {

    incrementCategory(getCategoryDocCount(), category);

    CategoryIncrement ci = null;
    if (trainingListeners != null) {
      for (ClassifierListener l : trainingListeners) {
        if (ci == null) {
          ci = new CategoryIncrement(category, getCategoryDocCount()
              .get(category));
          ci.setCountDelta(1);
        }
        l.handleCategoryUpdate(ci);
      }
    }

  }

  /**
   * Direct port from Segaran's book, including method name
   *
   * @param feature
   *            the feature
   * @param category
   *            the category to query
   * @return the number of times a feature has appeared in a category
   */
  double fcount(String feature, String category) {
    Map<String, Integer> fm = this.classifierDataModelFactory
        .getFeatureMap(feature);
    if (fm != null) {
      Integer count = fm.get(category);
      if (count != null) {
        return count;
      }
    }

    return 0.0;
  }

  /**
   * Direct port from Segaran's book, including method name
   *
   * @param category
   *            the category to count items for
   * @return the number of items in a category
   */
  double catcount(String category) {
    Integer count = getCategoryDocCount().get(category);
    if (count == null) {
      return 0.0;
    }
    return count.doubleValue();
  }

  /**
   * Direct port from Segaran's book, including method name
   *
   * @return the total number of items
   */
  double totalcount() {
    Map<String, Integer> map = getCategoryDocCount();
    return getTotalCount(map);
  }

  double totalcount(String feature) {
    Map<String, Integer> classifierMap = this.classifierDataModelFactory
        .getFeatureMap(feature);
    if (classifierMap != null) {
      return getTotalCount(classifierMap);
    }
    return 0.0;
  }

  double getTotalFeatureCount(String feature) {
    Map<String, Integer> classifierMap = this.classifierDataModelFactory
        .getFeatureMap(feature);
    if (classifierMap != null) {
      return getTotalCount(classifierMap);
    }
    return 0.0;
  }

  /**
   * Direct port from Segaran's book, including method name.
   *
   * @return the list of all getCategories
   */
  public final Set<String> getCategories() {
    return Collections.unmodifiableSet(getCategoryDocCount().keySet());
  }

  public void train(Object item, String category) {
    Set<String> features = wordLister.getUniqueWords(item);

    for (String f : features) {
      incf(f, category);
    }
    incc(category);
  }

  /**
   * Convenience method for descendant classes - aids in porting from
   * Segaran's book.
   * <p/>
   * I want to change this method to use the arithmetic exception *only* if
   * it's rare. It's possible that determining rarity might be even more
   * expensive, though.
   *
   * @param feature
   *            the feature to consider
   * @param category
   *            the category
   * @return the feature probability for the class
   */
  protected double fprob(String feature, String category) {
    try {
      return fcount(feature, category) / catcount(category);
    } catch (ArithmeticException ae) {
      return 0;
    }
  }

  /**
   * @param feature
   *            the feature to consider
   * @param category
   *            the category
   * @return the feature probability for the class
   */
  public final double getFeatureProbability(String feature, String category) {
    return fprob(feature, category);
  }

  private double WEIGHT = 1.0;
  private double ASSUMED_PROBABILITY = 0.5;

  protected double weightedprob(String feature, String category) {
    return getWeightedProbability(feature, category);
  }

  /**
   * @param feature
   *            The feature to consider
   * @param category
   *            the category to consider weight for
   * @return the weighted probability
   */
  public double getWeightedProbability(String feature, String category) {
    double basicprob = getFeatureProbability(feature, category);
    double totals = getTotalFeatureCount(feature);
    return ((WEIGHT * ASSUMED_PROBABILITY) + (totals * basicprob))
        / (WEIGHT + totals);
  }

  public final Map<String, Integer> getCategoryDocCount() {
    return classifierDataModelFactory.getCategoryCountMap();
  }

  private double getTotalCount(Map<String, Integer> map) {
    int totalCount = 0;
    for (Integer i : map.values()) {
      totalCount += i;
    }
    return totalCount;
  }

  private void incrementCategory(Map<String, Integer> map, String category) {
    Integer val = map.get(category);
    if (val != null) {
      map.put(category, val + 1);
    } else {
      map.put(category, 1);
    }
  }
}
TOP

Related Classes of com.enigmastation.classifier.impl.ClassifierImpl

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.