Package ca.uwo.csd.ai.nlp.libsvm.ex

Source Code of ca.uwo.csd.ai.nlp.libsvm.ex.SVMPredictor

package ca.uwo.csd.ai.nlp.libsvm.ex;

import ca.uwo.csd.ai.nlp.libsvm.svm;
import ca.uwo.csd.ai.nlp.libsvm.svm_model;
import ca.uwo.csd.ai.nlp.libsvm.svm_node;
import java.io.IOException;
import java.util.List;

/**
*
* @author Syeed Ibn Faiz
*/
public class SVMPredictor {

    public static double[] predict(List<Instance> instances, svm_model model) {
        return predict(instances, model, true);
    }

    public static double[] predict(List<Instance> instances, svm_model model, boolean displayResult) {
        Instance[] array = new Instance[instances.size()];
        array = instances.toArray(array);
        return predict(array, model, displayResult);
    }

    public static double predict(Instance instance, svm_model model, boolean displayResult) {
        return svm.svm_predict(model, new svm_node(instance.getData()));
    }
   
    public static double predictProbability(Instance instance, svm_model model, double[] probabilities) {
        return svm.svm_predict_probability(model, new svm_node(instance.getData()), probabilities);
    }   
    public static double[] predict(Instance[] instances, svm_model model, boolean displayResult) {
        int total = 0;
        int correct = 0;

        int tp = 0;
        int fp = 0;
        int fn = 0;

        boolean binary = model.nr_class == 2;
        double[] predictions = new double[instances.length];
        int count = 0;

        for (Instance instance : instances) {
            double target = instance.getLabel();
            double p = svm.svm_predict(model, new svm_node(instance.getData()));
            predictions[count++] = p;

            ++total;
            if (p == target) {
                correct++;
                if (target > 0) {
                    tp++;
                }
            } else if (target > 0) {
                fn++;
            } else {
                fp++;
            }
        }
        if (displayResult) {
            System.out.print("Accuracy = " + (double) correct / total * 100
                    + "% (" + correct + "/" + total + ") (classification)\n");

            if (binary) {
                double precision = (double) tp / (tp + fp);
                double recall = (double) tp / (tp + fn);
                System.out.println("Precision: " + precision);
                System.out.println("Recall: " + recall);
                System.out.println("Fscore: " + 2 * precision * recall / (precision + recall));
            }
        }
        return predictions;
    }

    public static svm_model loadModel(String filePath) throws IOException, ClassNotFoundException {
        return svm.svm_load_model(filePath);
    }
}
TOP

Related Classes of ca.uwo.csd.ai.nlp.libsvm.ex.SVMPredictor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.