/*
* Copyright 2008-2011 Grant Ingersoll, Thomas Morton and Drew Farris
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* -------------------
* To purchase or learn more about Taming Text, by Grant Ingersoll, Thomas Morton and Drew Farris, visit
* http://www.manning.com/ingersoll
*/
package com.tamingtext.classifier.mlt;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.shingle.ShingleAnalyzerWrapper;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Fieldable;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermEnum;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.similar.MoreLikeThis;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.tamingtext.classifier.mlt.TrainMoreLikeThis.MatchMode;
public class MoreLikeThisCategorizer {
private static final Logger log = LoggerFactory.getLogger(MoreLikeThisCategorizer.class);
MatchMode matchMode = MatchMode.TFIDF;
IndexReader indexReader;
IndexSearcher indexSearcher;
MoreLikeThis moreLikeThis;
String categoryFieldName;
final Set<String> categories = new HashSet<String>();
boolean captureCategories = false;
int maxResults = 10;
public MoreLikeThisCategorizer(IndexReader indexReader, String categoryFieldName) throws IOException {
this.indexReader = indexReader;
this.indexSearcher = new IndexSearcher(indexReader);
this.moreLikeThis = new MoreLikeThis(indexReader);
this.categoryFieldName = categoryFieldName;
loadCategoriesFromIndex();
}
/** populate the list of categories by reading the values embedded in the index userData, falls back
* to scanCategories if the data is not present
* @throws IOException
*/
protected void loadCategoriesFromIndex() throws IOException {
Map<String, String> userData = indexReader.getCommitUserData();
String categoryString = userData.get(TrainMoreLikeThis.CATEGORY_KEY);
if (categoryString == null) {
scanCategories();
return;
}
String[] parts = categoryString.split("\\|");
if (parts.length < 1) {
scanCategories();
return;
}
categories.addAll(Arrays.asList(parts));
log.info("Loaded " + categories.size() + " categories from index");
}
/** populate the list of categories by reading the values from the categoryField in the index */
protected void scanCategories() throws IOException {
TermEnum te = indexReader.terms(new Term(categoryFieldName));
final Set<String> c = categories;
do {
if (!te.term().field().equals(categoryFieldName)) break;
c.add(te.term().text());
} while (te.next());
log.info("Scanned " + c.size() + " categories from index");
}
public void setMaxResults(int maxResults) {
this.maxResults = maxResults;
}
public Collection<String> getCategories() {
return Collections.unmodifiableSet(categories);
}
public MatchMode getMatchMode() {
return matchMode;
}
public void setMatchMode(MatchMode matchMode) {
this.matchMode = matchMode;
}
public void setFieldNames(String[] fieldNames) {
moreLikeThis.setFieldNames(fieldNames);
}
public void setAnalyzer(Analyzer analyzer) {
moreLikeThis.setAnalyzer(analyzer);
}
public void setNgramSize(int size) {
if (size <= 1) return;
Analyzer a = moreLikeThis.getAnalyzer();
ShingleAnalyzerWrapper sw;
if (a instanceof ShingleAnalyzerWrapper) {
sw = (ShingleAnalyzerWrapper) a;
}
else {
sw = new ShingleAnalyzerWrapper(a);
moreLikeThis.setAnalyzer(sw);
}
sw.setMaxShingleSize(size);
sw.setMinShingleSize(size);
}
public CategoryHits[] categorize(Reader reader) throws IOException {
Query query = moreLikeThis.like(reader);
HashMap<String, CategoryHits> categoryHash = new HashMap<String, CategoryHits>(25);
for (ScoreDoc sd: indexSearcher.search(query, maxResults).scoreDocs) {
String cat = getDocClass(sd.doc);
if (cat == null) continue;
CategoryHits ch = categoryHash.get(cat);
if (ch == null) {
ch = new CategoryHits();
ch.setLabel(cat);
categoryHash.put(cat, ch);
}
ch.incrementScore(sd.score);
}
SortedSet<CategoryHits> sortedCats = new TreeSet<CategoryHits>(CategoryHits.byScoreComparator());
sortedCats.addAll(categoryHash.values());
return sortedCats.toArray(new CategoryHits[0]);
}
protected String getDocClass(int doc) throws IOException {
Document d = indexReader.document(doc);
Fieldable f = d.getFieldable(categoryFieldName);
if (f == null) return null;
if (!f.isStored()) throw new IllegalArgumentException("Field " + f.name() + " is not stored.");
return f.stringValue();
}
public static void main(String[] args) throws Exception {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
Option helpOpt = DefaultOptionCreator.helpOption();
Option inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument(
abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
"The input file to classify")
.withShortName("i").create();
Option modelOpt = obuilder.withLongName("model").withRequired(true).withArgument(
abuilder.withName("index").withMinimum(1).withMaximum(1).create()).withDescription(
"The directory containing the index model").withShortName("m").create();
Option categoryFieldOpt = obuilder.withLongName("categoryField").withRequired(true).withArgument(
abuilder.withName("index").withMinimum(1).withMaximum(1).create()).withDescription(
"Name of the field containing category information").withShortName("catf").create();
Option contentFieldOpt = obuilder.withLongName("contentField").withRequired(true).withArgument(
abuilder.withName("index").withMinimum(1).withMaximum(1).create()).withDescription(
"Name of the field containing content information").withShortName("contf").create();
Option maxResultsOpt = obuilder.withLongName("maxResults").withRequired(false).withArgument(
abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription(
"Number of results to retrive, default: 10 ").withShortName("r").create();
Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(false).withArgument(
abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription(
"Size of the n-gram. Default Value: 1 ").withShortName("ng").create();
Option typeOpt = obuilder.withLongName("classifierType").withRequired(false).withArgument(
abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).withDescription(
"Type of classifier: knn|tfidf. Default: bayes").withShortName("type").create();
Group group = gbuilder.withName("Options").withOption(gramSizeOpt).withOption(helpOpt).withOption(
inputDirOpt).withOption(modelOpt).withOption(typeOpt).withOption(contentFieldOpt)
.withOption(categoryFieldOpt).withOption(maxResultsOpt)
.create();
try {
Parser parser = new Parser();
parser.setGroup(group);
parser.setHelpOption(helpOpt);
CommandLine cmdLine = parser.parse(args);
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
}
String classifierType = (String) cmdLine.getValue(typeOpt);
if (cmdLine.hasOption(gramSizeOpt)) {
}
int gramSize = 1;
if (cmdLine.hasOption(gramSizeOpt)) {
gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
}
int maxResults = 10;
if (cmdLine.hasOption(maxResultsOpt)) {
maxResults = Integer.parseInt((String) cmdLine.getValue(maxResultsOpt));
}
String inputPath = (String) cmdLine.getValue(inputDirOpt);
String modelPath = (String) cmdLine.getValue(modelOpt);
String categoryField = (String) cmdLine.getValue(categoryFieldOpt);
String contentField = (String) cmdLine.getValue(contentFieldOpt);
MatchMode mode;
if ("knn".equalsIgnoreCase(classifierType)) {
mode = MatchMode.KNN;
}
else if ("tfidf".equalsIgnoreCase(classifierType)) {
mode = MatchMode.TFIDF;
}
else {
throw new IllegalArgumentException("Unkown classifierType: " + classifierType);
}
Reader reader = new FileReader(inputPath);
Directory directory = FSDirectory.open(new File(modelPath));
IndexReader indexReader = IndexReader.open(directory);
MoreLikeThisCategorizer categorizer = new MoreLikeThisCategorizer(indexReader, categoryField);
categorizer.setMatchMode(mode);
categorizer.setFieldNames(new String[]{ contentField });
categorizer.setMaxResults(maxResults);
if (gramSize > 1)
categorizer.setNgramSize(gramSize);
CategoryHits[] categories = categorizer.categorize(reader);
for (CategoryHits c: categories) {
System.out.println(c.getLabel()+ "\t" + c.getHits() + "\t" + c.getScore());
}
} catch (OptionException e) {
log.error("Error while parsing options", e);
}
}
}