package dkpro.similarity.algorithms.ml;
import java.io.File;
import java.util.List;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.functions.Logistic;
import weka.classifiers.functions.SMO;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import de.tudarmstadt.ukp.dkpro.core.api.metadata.type.DocumentMetaData;
import dkpro.similarity.algorithms.api.JCasTextSimilarityMeasureBase;
import dkpro.similarity.algorithms.api.SimilarityException;
import dkpro.similarity.ml.filters.LogFilter;
/**
* Runs a machine learning classifier on the provided test data on a model
* that is trained on the given training data. The available classifiers
* are Naive Bayes, J48, SMO, and Logistic. Mind that the
* {@link #getSimilarity(JCas,JCas) getSimilarity} method
* classifies the input texts by their ID, not their textual contents. The
* <pre>DocumentID</pre> of the <pre>DocumentMetaData</pre> is expected to denote
* the corresponding input line in the test data.
*/
public class ClassifierSimilarityMeasure
extends JCasTextSimilarityMeasureBase
{
public static Classifier CLASSIFIER;
public enum WekaClassifier
{
NAIVE_BAYES,
J48,
SMO,
LOGISTIC
}
Classifier filteredClassifier;
List<String> features;
Instances test;
public ClassifierSimilarityMeasure(WekaClassifier classifier, File trainArff, File testArff)
throws Exception
{
CLASSIFIER = getClassifier(classifier);
// Get all instances
Instances train = getTrainInstances(trainArff);
test = getTestInstances(testArff);
// Apply log filter
Filter logFilter = new LogFilter();
logFilter.setInputFormat(train);
train = Filter.useFilter(train, logFilter);
logFilter.setInputFormat(test);
test = Filter.useFilter(test, logFilter);
Classifier clsCopy;
try {
// Copy the classifier
clsCopy = AbstractClassifier.makeCopy(CLASSIFIER);
// Build the classifier
filteredClassifier = clsCopy;
filteredClassifier.buildClassifier(train);
Evaluation eval = new Evaluation(train);
eval.evaluateModel(filteredClassifier, test);
System.out.println(eval.toSummaryString());
System.out.println(eval.toMatrixString());
}
catch (Exception e) {
throw new SimilarityException(e);
}
}
public static Classifier getClassifier(WekaClassifier classifier)
throws IllegalArgumentException
{
try {
switch (classifier)
{
case NAIVE_BAYES:
return new NaiveBayes();
case J48:
J48 j48 = new J48();
j48.setOptions(new String[] { "-C", "0.25", "-M", "2" });
return j48;
case SMO:
SMO smo = new SMO();
smo.setOptions(Utils.splitOptions("-C 1.0 -L 0.001 -P 1.0E-12 -N 0 -V -1 -W 1 -K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0\""));
return smo;
case LOGISTIC:
Logistic logistic = new Logistic();
logistic.setOptions(Utils.splitOptions("-R 1.0E-8 -M -1"));
return logistic;
default:
throw new IllegalArgumentException("Classifier " + classifier + " not found!");
}
}
catch (Exception e) {
throw new IllegalArgumentException(e);
}
}
private Instances getTrainInstances(File trainArff)
throws SimilarityException
{
// Read with Weka
Instances data;
try {
data = DataSource.read(trainArff.getAbsolutePath());
}
catch (Exception e) {
throw new SimilarityException(e);
}
// Set the index of the class attribute
data.setClassIndex(data.numAttributes() - 1);
return data;
}
private Instances getTestInstances(File testArff)
throws SimilarityException
{
// Read with Weka
Instances data;
try {
data = DataSource.read(testArff.getAbsolutePath());
}
catch (Exception e) {
throw new SimilarityException(e);
}
// Set the index of the class attribute
data.setClassIndex(data.numAttributes() - 1);
return data;
}
@Override
public double getSimilarity(JCas jcas1, JCas jcas2, Annotation coveringAnnotation1,
Annotation coveringAnnotation2)
throws SimilarityException
{
// The feature generation needs to have happened before!
DocumentMetaData md = DocumentMetaData.get(jcas1);
int id = Integer.parseInt(md.getDocumentId().substring(md.getDocumentId().indexOf("-") + 1));
System.out.println(id);
Instance testInst = test.get(id - 1);
try {
return filteredClassifier.classifyInstance(testInst);
}
catch (Exception e) {
throw new SimilarityException(e);
}
}
}