Package org.apache.mahout.classifier.bayes.algorithm

Source Code of org.apache.mahout.classifier.bayes.algorithm.CBayesAlgorithm

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.bayes.algorithm;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;

import org.apache.commons.lang.mutable.MutableDouble;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.bayes.common.ByScoreLabelResultComparator;
import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException;
import org.apache.mahout.classifier.bayes.interfaces.Algorithm;
import org.apache.mahout.classifier.bayes.interfaces.Datastore;
import org.apache.mahout.math.function.ObjectIntProcedure;
import org.apache.mahout.math.map.OpenObjectIntHashMap;

/**
* Class implementing the Complementary Naive Bayes Classifier Algorithm
*
*/
public class CBayesAlgorithm implements Algorithm {
 
  @Override
  public ClassifierResult classifyDocument(String[] document,
                                           Datastore datastore,
                                           String defaultCategory) throws InvalidDatastoreException {
    ClassifierResult result = new ClassifierResult(defaultCategory);
    double max = Double.MIN_VALUE;
    Collection<String> categories = datastore.getKeys("labelWeight");
   
    for (String category : categories) {
      double prob = documentWeight(datastore, category, document);
      if (max < prob) {
        max = prob;
        result.setLabel(category);
      }
    }
    result.setScore(max);
    return result;
  }
 
  @Override
  public ClassifierResult[] classifyDocument(String[] document,
                                             Datastore datastore,
                                             String defaultCategory,
                                             int numResults) throws InvalidDatastoreException {
    Collection<String> categories = datastore.getKeys("labelWeight");
    PriorityQueue<ClassifierResult> pq = new PriorityQueue<ClassifierResult>(numResults,
        new ByScoreLabelResultComparator());
    for (String category : categories) {
      double prob = documentWeight(datastore, category, document);
      if (prob > 0.0) {
        pq.add(new ClassifierResult(category, prob));
        if (pq.size() > numResults) {
          pq.remove();
        }
      }
    }
   
    if (pq.isEmpty()) {
      return new ClassifierResult[] {new ClassifierResult(defaultCategory, 0.0)};
    } else {
      List<ClassifierResult> result = new ArrayList<ClassifierResult>(pq.size());
      while (!pq.isEmpty()) {
        result.add(pq.remove());
      }
      Collections.reverse(result);
      return result.toArray(new ClassifierResult[pq.size()]);
    }
  }
 
  @Override
  public double featureWeight(Datastore datastore, String label, String feature) throws InvalidDatastoreException {
   
    double result = datastore.getWeight("weight", feature, label);
    double vocabCount = datastore.getWeight("sumWeight", "vocabCount");
   
    double featureSum = datastore.getWeight("weight", feature, "sigma_j");
    double totalSum = datastore.getWeight("sumWeight", "sigma_jSigma_k");
    double labelSum = datastore.getWeight("labelWeight", label);
   
    double thetaNormalizer = datastore.getWeight("thetaNormalizer", label);
   
    double numerator = featureSum - result + datastore.getWeight("params", "alpha_i");
    double denominator = totalSum - labelSum + vocabCount;
   
    double weight = Math.log(numerator / denominator);
   
    return weight / thetaNormalizer;
  }
 
  @Override
  public void initialize(Datastore datastore) throws InvalidDatastoreException {
    datastore.getKeys("labelWeight");
  }
 
  @Override
  public double documentWeight(final Datastore datastore,
                               final String label,
                               String[] document) {
    OpenObjectIntHashMap<String> wordList = new OpenObjectIntHashMap<String>(document.length / 2);
    for (String word : document) {
      if (wordList.containsKey(word)) {
        wordList.put(word, wordList.get(word) + 1);
      } else {
        wordList.put(word, 1);
      }
    }
    final MutableDouble result = new MutableDouble(0.0);
   
    wordList.forEachPair(new ObjectIntProcedure<String>() {
     
      @Override
      public boolean apply(String word, int frequency) {
        try {
          result.add(frequency * featureWeight(datastore, label, word));
        } catch (InvalidDatastoreException e) {
          throw new IllegalStateException(e);
        }
        return true;
      }
    });
    return result.doubleValue();
  }
 
  @Override
  public Collection<String> getLabels(Datastore datastore) throws InvalidDatastoreException {
    return datastore.getKeys("labelWeight");
  }
 
}
TOP

Related Classes of org.apache.mahout.classifier.bayes.algorithm.CBayesAlgorithm

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.