Package de.tudarmstadt.ukp.similarity.experiments.coling2012.util

Source Code of de.tudarmstadt.ukp.similarity.experiments.coling2012.util.Evaluator

/*******************************************************************************
* Copyright 2013
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the GNU Public License v3.0
* which accompanies this distribution, and is available at
* http://www.gnu.org/licenses/gpl-3.0.txt
******************************************************************************/
package de.tudarmstadt.ukp.similarity.experiments.coling2012.util;

import static de.tudarmstadt.ukp.similarity.experiments.coling2012.Pipeline.MODELS_DIR;
import static de.tudarmstadt.ukp.similarity.experiments.coling2012.Pipeline.OUTPUT_DIR;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;

import org.apache.commons.io.FileUtils;

import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.trees.J48;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.supervised.attribute.AddClassification;
import weka.filters.unsupervised.attribute.AddID;
import weka.filters.unsupervised.attribute.Remove;
import de.tudarmstadt.ukp.similarity.experiments.coling2012.Pipeline.Dataset;
import de.tudarmstadt.ukp.similarity.experiments.coling2012.Pipeline.EvaluationMetric;


public class Evaluator
{
  public static final String LF = System.getProperty("line.separator");
 
  public enum WekaClassifier
  {
    NAIVE_BAYES,
    J48
  }
 
 
  public static void runClassifierCV(WekaClassifier wekaClassifier, Dataset dataset)
    throws Exception
  {
    // Set parameters
    int folds = 10;
    Classifier baseClassifier = getClassifier(wekaClassifier);
   
    // Set up the random number generator
      long seed = new Date().getTime();     
    Random random = new Random(seed)
     
    // Add IDs to the instances
    AddID.main(new String[] {"-i", MODELS_DIR + "/" + dataset.toString() + ".arff",
                  "-o", MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff" });
    Instances data = DataSource.read(MODELS_DIR + "/" + dataset.toString() + "-plusIDs.arff");
    data.setClassIndex(data.numAttributes() - 1);       
   
        // Instantiate the Remove filter
        Remove removeIDFilter = new Remove();
      removeIDFilter.setAttributeIndices("first");
   
    // Randomize the data
    data.randomize(random);
 
    // Perform cross-validation
      Instances predictedData = null;
      Evaluation eval = new Evaluation(data);
     
      for (int n = 0; n < folds; n++)
      {
        Instances train = data.trainCV(folds, n, random);
          Instances test = data.testCV(folds, n);
         
          // Apply log filter
//        Filter logFilter = new LogFilter();
//          logFilter.setInputFormat(train);
//          train = Filter.useFilter(train, logFilter);       
//          logFilter.setInputFormat(test);
//          test = Filter.useFilter(test, logFilter);
         
          // Copy the classifier
          Classifier classifier = AbstractClassifier.makeCopy(baseClassifier);
                                
          // Instantiate the FilteredClassifier
          FilteredClassifier filteredClassifier = new FilteredClassifier();
          filteredClassifier.setFilter(removeIDFilter);
          filteredClassifier.setClassifier(classifier);
            
          // Build the classifier
          filteredClassifier.buildClassifier(train);
          
          // Evaluate
          eval.evaluateModel(filteredClassifier, test);
         
          // Add predictions
          AddClassification filter = new AddClassification();
          filter.setClassifier(filteredClassifier);
          filter.setOutputClassification(true);
          filter.setOutputDistribution(false);
          filter.setOutputErrorFlag(true);
          filter.setInputFormat(train);
          Filter.useFilter(train, filter)// trains the classifier
         
          Instances pred = Filter.useFilter(test, filter)// performs predictions on test set
          if (predictedData == null)
            predictedData = new Instances(pred, 0);
          for (int j = 0; j < pred.numInstances(); j++)
            predictedData.add(pred.instance(j));           
      }
     
      // Prepare output classification
      String[] scores = new String[predictedData.numInstances()];
     
      for (Instance predInst : predictedData)
      {
        int id = new Double(predInst.value(predInst.attribute(0))).intValue() - 1;
       
        int valueIdx = predictedData.numAttributes() - 2;
       
        String value = predInst.stringValue(predInst.attribute(valueIdx));
       
        scores[id] = value;
      }
     
      // Output
      StringBuilder sb = new StringBuilder();
      for (String score : scores)
        sb.append(score.toString() + LF);
     
      FileUtils.writeStringToFile(
        new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"),
        sb.toString());
  }
 
  @SuppressWarnings("unchecked")
  public static void runEvaluationMetric(WekaClassifier wekaClassifier, EvaluationMetric metric, Dataset dataset)
    throws IOException
  {
    StringBuilder sb = new StringBuilder();
   
    List<String> gold = ColingUtils.readGoldstandard(dataset);
    List<String> exp = FileUtils.readLines(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/output.csv"));
   
    if (metric.equals(EvaluationMetric.Accuracy))
    {     
      double acc = 0.0;
     
      for (int i = 0; i < gold.size(); i++)
      {
        if (gold.get(i).equals(exp.get(i)))
          acc++;
      }
     
      acc /= gold.size();
     
      sb.append(acc);
    }
    else if (metric.equals(EvaluationMetric.AverageF1))
    {
      // Get all classes
      Set<String> classesSet = new HashSet<String>();
      for (String cl : gold)
        classesSet.add(cl);
     
      // Order the classes
      List<String> classes = new ArrayList<String>(classesSet);
     
      // Initialize confusion matrix
      // exp\class  A  B
      // A      x1  x2
      // B      x3  x4     
      int[][] matrix = new int[classes.size()][classes.size()];
     
      // Initialize matrix
      for (int i = 0; i < classes.size(); i++)
        for (int j = 0; j < classes.size(); j++)
          matrix[i][j] = 0;
     
      // Construct confusion matrix
      for (int i = 0; i < gold.size(); i++)
      {
        int goldIndex = classes.indexOf(gold.get(i));
        int expIndex = classes.indexOf(exp.get(i));
       
        matrix[goldIndex][expIndex] += 1;
      }
     
      // Compute precision and recall per class
      double[] prec = new double[classes.size()];
      double[] rec = new double[classes.size()];
     
      for (int i = 0; i < classes.size(); i++)
      {
        double tp = matrix[i][i];
        double fp = 0.0;
        double fn = 0.0;
       
        // FP
        for (int j = 0; j < classes.size(); j++)
        {
          if (i == j)
            continue;
         
          fp += matrix[j][i];
        }
       
        // FN
        for (int j = 0; j < classes.size(); j++)
        {
          if (i == j)
            continue;
         
          fn += matrix[i][j];
        }
       
        // Save
        prec[i] = tp / (tp + fp);
        rec[i] = tp / (tp + fn);       
      }
     
      // Compute average F1 score across all classes
      double f1 = 0.0;
     
      for (int i = 0; i < classes.size(); i++)
      {
        double f1PerClass = (2 * prec[i] * rec[i]) / (prec[i] + rec[i]);
        f1 += f1PerClass;
      }
     
      f1 = f1 / classes.size();
     
      // Output
      sb.append(f1);
    }
   
    FileUtils.writeStringToFile(new File(OUTPUT_DIR + "/" + dataset.toString() + "/" + wekaClassifier.toString() + "/" + metric.toString() + ".txt"), sb.toString());
  }
 
 
 
  public static Classifier getClassifier(WekaClassifier classifier)
    throws IllegalArgumentException
  {
    try {
      switch (classifier)
      {
        case NAIVE_BAYES:
          return new NaiveBayes();
        case J48:
          J48 j48 = new J48();     
          j48.setOptions(new String[] { "-C", "0.25", "-M", "2" });
          return j48;
//        case SMO:
//          SMO smo = new SMO();
//          smo.setOptions(Utils.splitOptions("-C 1.0 -L 0.001 -P 1.0E-12 -N 0 -V -1 -W 1 -K \"weka.classifiers.functions.supportVector.PolyKernel -C 250007 -E 1.0\""));
//          return smo;
//        case LOGISTIC:
//          Logistic logistic = new Logistic();
//          logistic.setOptions(Utils.splitOptions("-R 1.0E-8 -M -1"));
//          return logistic;
        default:
          throw new IllegalArgumentException("Classifier " + classifier + " not found!");
      }
    }
    catch (Exception e) {
      throw new IllegalArgumentException(e);
    }

  }
}
TOP

Related Classes of de.tudarmstadt.ukp.similarity.experiments.coling2012.util.Evaluator

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.