Package com.tamingtext.classifier.mlt

Source Code of com.tamingtext.classifier.mlt.MoreLikeThisCategorizer

/*
* Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
*
*    Licensed under the Apache License, Version 2.0 (the "License");
*    you may not use this file except in compliance with the License.
*    You may obtain a copy of the License at
*
*        http://www.apache.org/licenses/LICENSE-2.0
*
*    Unless required by applicable law or agreed to in writing, software
*    distributed under the License is distributed on an "AS IS" BASIS,
*    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*    See the License for the specific language governing permissions and
*    limitations under the License.
* -------------------
* To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
* http://www.manning.com/ingersoll
*/

package com.tamingtext.classifier.mlt;

import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.shingle.ShingleAnalyzerWrapper;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Fieldable;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.similar.MoreLikeThis;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.tamingtext.classifier.mlt.TrainMoreLikeThis.MatchMode;

public class MoreLikeThisCategorizer {
 
  private static final Logger log = LoggerFactory.getLogger(MoreLikeThisCategorizer.class);

  MatchMode matchMode = MatchMode.TFIDF;
  IndexReader indexReader;
  IndexSearcher indexSearcher;
  MoreLikeThis moreLikeThis;
  String categoryFieldName;
  final Set<String> categories = new HashSet<String>();
  boolean captureCategories = false;
  int maxResults = 10;
 
  public MoreLikeThisCategorizer(IndexReader indexReader, String categoryFieldName) throws IOException {
    this.indexReader   = indexReader;
    this.indexSearcher = new IndexSearcher(indexReader);
    this.moreLikeThis  = new MoreLikeThis(indexReader);
    this.categoryFieldName = categoryFieldName;
    loadCategoriesFromIndex();
  }
 
  /** populate the list of categories by reading the values embedded in the index userData, falls back
   *  to scanCategories if the data is not present
   * @throws IOException
   */
  protected void loadCategoriesFromIndex() throws IOException {
    Map<String, String> userData = indexReader.getCommitUserData();
    String categoryString = userData.get(TrainMoreLikeThis.CATEGORY_KEY);
    if (categoryString == null) {
      scanCategories();
      return;
     
    }
   
    String[] parts = categoryString.split("\\|");
   
    if (parts.length < 1) {
      scanCategories();
      return;
    }
   
    categories.addAll(Arrays.asList(parts));
    log.info("Loaded " + categories.size() + " categories from index");
  }
 
  /** populate the list of categories by reading the values from the categoryField in the index */
  protected void scanCategories() throws IOException {
    TermEnum te = indexReader.terms(new Term(categoryFieldName));
    final Set<String> c = categories;
   
    do {
      if (!te.term().field().equals(categoryFieldName)) break;
      c.add(te.term().text());
    } while (te.next());
   
    log.info("Scanned " + c.size() + " categories from index");
  }
 
  public void setMaxResults(int maxResults) {
    this.maxResults = maxResults;
  }
 
  public Collection<String> getCategories() {
    return Collections.unmodifiableSet(categories);
  }
 
  public MatchMode getMatchMode() {
    return matchMode;
  }

  public void setMatchMode(MatchMode matchMode) {
    this.matchMode = matchMode;
  }

  public void setFieldNames(String[] fieldNames) {
    moreLikeThis.setFieldNames(fieldNames);
  }

  public void setAnalyzer(Analyzer analyzer) {
    moreLikeThis.setAnalyzer(analyzer);
  }
 
  public void setNgramSize(int size) {
    if (size <= 1) return;
   
    Analyzer a = moreLikeThis.getAnalyzer();
    ShingleAnalyzerWrapper sw;
    if (a instanceof ShingleAnalyzerWrapper) {
      sw = (ShingleAnalyzerWrapper) a;
    }
    else {
      sw = new ShingleAnalyzerWrapper(a);
      moreLikeThis.setAnalyzer(sw);
    }
   
    sw.setMaxShingleSize(size);
    sw.setMinShingleSize(size);
  }
 
  public CategoryHits[] categorize(Reader reader) throws IOException {
    Query query = moreLikeThis.like(reader);

    HashMap<String, CategoryHits> categoryHash = new HashMap<String, CategoryHits>(25);
   
    for (ScoreDoc sd: indexSearcher.search(query, maxResults).scoreDocs) {
      String cat = getDocClass(sd.doc);
      if (cat == null) continue;
      CategoryHits ch = categoryHash.get(cat);
      if (ch == null) {
        ch = new CategoryHits();
        ch.setLabel(cat);
        categoryHash.put(cat, ch);
      }

      ch.incrementScore(sd.score);
    }

    SortedSet<CategoryHits> sortedCats = new TreeSet<CategoryHits>(CategoryHits.byScoreComparator());
    sortedCats.addAll(categoryHash.values());
    return sortedCats.toArray(new CategoryHits[0]);
  }
  protected String getDocClass(int doc) throws IOException {
    Document d = indexReader.document(doc);
    Fieldable f = d.getFieldable(categoryFieldName);
    if (f == null) return null;
    if (!f.isStored()) throw new IllegalArgumentException("Field " + f.name() + " is not stored.");
    return f.stringValue();
  }
 
  public static void main(String[] args) throws Exception {
    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
    ArgumentBuilder abuilder = new ArgumentBuilder();
    GroupBuilder gbuilder = new GroupBuilder();
   
    Option helpOpt = DefaultOptionCreator.helpOption();
   
    Option inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument(
      abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
      "The input file to classify")
        .withShortName("i").create();
   
    Option modelOpt = obuilder.withLongName("model").withRequired(true).withArgument(
      abuilder.withName("index").withMinimum(1).withMaximum(1).create()).withDescription(
      "The directory containing the index model").withShortName("m").create();
   
    Option categoryFieldOpt = obuilder.withLongName("categoryField").withRequired(true).withArgument(
        abuilder.withName("index").withMinimum(1).withMaximum(1).create()).withDescription(
        "Name of the field containing category information").withShortName("catf").create();

    Option contentFieldOpt = obuilder.withLongName("contentField").withRequired(true).withArgument(
        abuilder.withName("index").withMinimum(1).withMaximum(1).create()).withDescription(
        "Name of the field containing content information").withShortName("contf").create();
   
    Option maxResultsOpt = obuilder.withLongName("maxResults").withRequired(false).withArgument(
        abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription(
        "Number of results to retrive, default: 10 ").withShortName("r").create();
   
    Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(false).withArgument(
      abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription(
      "Size of the n-gram. Default Value: 1 ").withShortName("ng").create();
   
    Option typeOpt = obuilder.withLongName("classifierType").withRequired(false).withArgument(
      abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).withDescription(
      "Type of classifier: knn|tfidf. Default: bayes").withShortName("type").create();
   
    Group group = gbuilder.withName("Options").withOption(gramSizeOpt).withOption(helpOpt).withOption(
        inputDirOpt).withOption(modelOpt).withOption(typeOpt).withOption(contentFieldOpt)
        .withOption(categoryFieldOpt).withOption(maxResultsOpt)
        .create();
   
    try {
      Parser parser = new Parser();
     
      parser.setGroup(group);
      parser.setHelpOption(helpOpt);
      CommandLine cmdLine = parser.parse(args);
      if (cmdLine.hasOption(helpOpt)) {
        CommandLineUtil.printHelp(group);
        return;
      }
     
      String classifierType = (String) cmdLine.getValue(typeOpt);
     
      if (cmdLine.hasOption(gramSizeOpt)) {
       
      }
     
      int gramSize = 1;
      if (cmdLine.hasOption(gramSizeOpt)) {
        gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
      }

      int maxResults = 10;
      if (cmdLine.hasOption(maxResultsOpt)) {
        maxResults = Integer.parseInt((String) cmdLine.getValue(maxResultsOpt));
      }
     
      String inputPath  = (String) cmdLine.getValue(inputDirOpt);
      String modelPath = (String) cmdLine.getValue(modelOpt);
      String categoryField = (String) cmdLine.getValue(categoryFieldOpt);
      String contentField = (String) cmdLine.getValue(contentFieldOpt);
     
      MatchMode mode;
     
      if ("knn".equalsIgnoreCase(classifierType)) {
        mode = MatchMode.KNN;
      }
      else if ("tfidf".equalsIgnoreCase(classifierType)) {
        mode = MatchMode.TFIDF;
      }
      else {
        throw new IllegalArgumentException("Unkown classifierType: " + classifierType);
      }

      Reader reader = new FileReader(inputPath);
      Directory directory = FSDirectory.open(new File(modelPath));
      IndexReader indexReader = IndexReader.open(directory);
      MoreLikeThisCategorizer categorizer = new MoreLikeThisCategorizer(indexReader, categoryField);
      categorizer.setMatchMode(mode);
      categorizer.setFieldNames(new String[]{ contentField });
      categorizer.setMaxResults(maxResults);
     
      if (gramSize > 1)
        categorizer.setNgramSize(gramSize);
     
     
      CategoryHits[] categories = categorizer.categorize(reader);
      for (CategoryHits c: categories) {
        System.out.println(c.getLabel()+ "\t" + c.getHits() + "\t" + c.getScore());
      }
     
    } catch (OptionException e) {
      log.error("Error while parsing options", e);
    }
  } 
}
TOP

Related Classes of com.tamingtext.classifier.mlt.MoreLikeThisCategorizer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.