Package org.data2semantics.proppred.predictors

Source Code of org.data2semantics.proppred.predictors.SVMPropertyPredictor

package org.data2semantics.proppred.predictors;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import org.data2semantics.proppred.kernels.graphkernels.GraphKernel;
import org.data2semantics.proppred.kernels.graphkernels.WLSubTreeKernel;
import org.data2semantics.proppred.learners.Prediction;
import org.data2semantics.proppred.learners.libsvm.LibSVM;
import org.data2semantics.proppred.learners.libsvm.LibSVMModel;
import org.data2semantics.proppred.learners.libsvm.LibSVMParameters;
import org.data2semantics.tools.graphs.DirectedMultigraphWithRoot;
import org.data2semantics.tools.graphs.Edge;
import org.data2semantics.tools.graphs.GraphFactory;
import org.data2semantics.tools.graphs.Vertex;
import org.data2semantics.tools.rdf.RDFDataSet;
import org.openrdf.model.Resource;
import org.openrdf.model.Statement;
import org.openrdf.model.Value;


/**
* This class is a Support Vector Machine (using {@link libsvm.LibSVM}) and {@link org.data2semantics.proppred.kernels.graphkernels.GraphKernel} based implementation of the {@link PropertyPredictor} Interface.
* Classification, Regression and Outlier Detection (via One-Class SVM) are all supported.
*
* @author Gerben
*
*/
public class SVMPropertyPredictor implements PropertyPredictor {
  private GraphKernel kernel;
  private List<DirectedMultigraphWithRoot<Vertex<String>, Edge<String>>> trainGraphs;
  private List<String> trainLabels;
  private Map<String, Value> valueMap;
  private Map<String, Integer> labelMap;
  private LibSVMModel trainedModel;

  private LibSVMParameters params;
  private int extractionDepth;

 
 
  /**
   * Construct the default SVMPropertyPredictor. A C-SVC support vector machine is used, with a WLSubtreeKernel and extraction depth 2. This setting is good to start with for a new classification task.
   *
   */
  public SVMPropertyPredictor() {
    this(new WLSubTreeKernel(2,true), 2);
    this.setDefaultLibSVMParams();
  }
 
  /**
   * Default C-SVM settings, but one can specify the graph kernel used.
   *
   * @param kernel an instance of a GraphKernel
   */
  public SVMPropertyPredictor(GraphKernel kernel) {
    this(kernel, 2);
  }
 
  /**
   * Default C-SVM settings.
   *
   * @param kernel an instance of GraphKernel
   * @param extractionDepth the depth used in subgraph extraction
   */
  public SVMPropertyPredictor(GraphKernel kernel, int extractionDepth) {
    this.kernel = kernel;
    this.extractionDepth = extractionDepth; 
    this.setDefaultLibSVMParams();
   
    trainGraphs = new ArrayList<DirectedMultigraphWithRoot<Vertex<String>, Edge<String>>>();
    trainLabels = new ArrayList<String>();
    labelMap = new TreeMap<String, Integer>();
    valueMap = new TreeMap<String, Value>();
  }
 

  /**
   * Using the params object different algorithms from LibSVM can be chosen (C-SVC, nu-SVC for classification, nu-SVR, epsilon-SVR for regression and one-class for outlier detection).
   *
   *
   * @param kernel an instance of GraphKernel
   * @param extractionDepth depth used in subgraph extraction
   * @param params parameters for the LibSVM library. When using algorithms with the nu parameter (nu-SVC,nu-SVR and one-class) make sure the iteration parameters are set between 0 and 1.
   */
  public SVMPropertyPredictor(GraphKernel kernel, int extractionDepth, LibSVMParameters params) {
    this(kernel, extractionDepth);
    this.params = params;
  }
 
  private void setDefaultLibSVMParams() {
    this.params = new LibSVMParameters(LibSVMParameters.C_SVC);
    double[] cs = {0.001, 0.01, 0.1, 1.0, 10, 100, 1000};
    this.params.setItParams(cs);
  }
 
 

  public void train(RDFDataSet dataset, List<Resource> instances,
      List<Value> labels) {
    Map<Resource, List<Statement>> dummyMap = new HashMap<Resource, List<Statement>>();
    for (Resource instance : instances) {
      dummyMap.put(instance, null);
    }
    train(dataset, instances, labels, dummyMap);
  }

 
  /**
   * Used to train an SVM model. For regression SVM models (nu-SVR,epsilon-SVR) the StringValue of the Value labels should be parseable to a double.
   * For one-class SVM the labels do not matter, all instances are considered to be part of the class.   *
   *
   */
  public void train(RDFDataSet dataset, List<Resource> instances,
      List<Value> labels, Map<Resource, List<Statement>> blackLists) {
    DirectedMultigraphWithRoot<Vertex<String>, Edge<String>> subGraph;

    for (Resource instance : instances) {
      subGraph = GraphFactory.copyDirectedGraph2GraphWithRoot(
          GraphFactory.createDirectedGraph(dataset.getSubGraph(
              instance, extractionDepth, false, true,
              blackLists.get(instance))), instance.toString());
      trainGraphs.add(subGraph);
    }
    System.out.println("Constructed dataset.");

    double[][] kernelMatrix = kernel.compute(trainGraphs);
    System.out.println("Computed kernel.");

    double[] target = new double[labels.size()];

    if (params.getAlgorithm() == LibSVMParameters.EPSILON_SVR || params.getAlgorithm() == LibSVMParameters.NU_SVR) {
      for (int i = 0; i < labels.size(); i++) {
        target[i] = Double.parseDouble(labels.get(i).stringValue());
      }
     
    } else {
      for (Value label : labels) {
        trainLabels.add(label.toString());
        valueMap.put(label.toString(), label);
      }
     
     
      target = LibSVM.createTargets(trainLabels, labelMap);
    }
   
    // Just to indicate the performance of the predictor, we run cross-validation first
    Prediction[] pred = LibSVM.crossValidate(kernelMatrix, target, params, 10);
   
    if (params.getAlgorithm() == LibSVMParameters.EPSILON_SVR || params.getAlgorithm() == LibSVMParameters.NU_SVR) {
      System.out.println("10-fold CV MSE: "
          + LibSVM.computeMeanSquaredError(target, LibSVM.extractLabels(pred)));
    } else {
      System.out.println("10-fold CV accuracy: "
          + LibSVM.computeAccuracy(target, LibSVM.extractLabels(pred)));
    }   
   
    trainedModel = LibSVM.trainSVMModel(kernelMatrix, target, params);
    System.out.println("Trained model.");
  }

  public List<Value> predict(RDFDataSet dataset, List<Resource> instances) {
    Map<Resource, List<Statement>> dummyMap = new HashMap<Resource, List<Statement>>();
    for (Resource instance : instances) {
      dummyMap.put(instance, null);
    }
    return predict(dataset, instances, dummyMap);
  }

 
  /**
   * Predict for new instances using a trained SVM model. For regression SVMs the Value's contain a double (as String).
   * For one-class the Values contain a String: "normal" for instances falling within the model and "outlier" for instances outside the model.
   *
   */
  public List<Value> predict(RDFDataSet dataset, List<Resource> instances,
      Map<Resource, List<Statement>> blackLists) {
    List<Value> predictions = new ArrayList<Value>();

    if (trainedModel == null) {
      System.out.println("Please train first.");
    } else {

      DirectedMultigraphWithRoot<Vertex<String>, Edge<String>> subGraph;
      List<DirectedMultigraphWithRoot<Vertex<String>, Edge<String>>> testGraphs = new ArrayList<DirectedMultigraphWithRoot<Vertex<String>, Edge<String>>>();

      for (Resource instance : instances) {
        subGraph = GraphFactory
            .copyDirectedGraph2GraphWithRoot(GraphFactory
                .createDirectedGraph(dataset.getSubGraph(
                    instance, extractionDepth, false,
                    true, blackLists.get(instance))),
                    instance.toString());
        testGraphs.add(subGraph);
      }
      System.out.println("Constructed prediction set.");

      double[][] kernelMatrix = kernel.compute(trainGraphs, testGraphs);
      System.out.println("Computed kernel.");

      double[] pred = LibSVM.extractLabels(LibSVM.testSVMModel(
          trainedModel, kernelMatrix));
      Map<Integer, String> revMap = LibSVM.reverseLabelMap(labelMap);

      for (double p : pred) {
        if (params.getAlgorithm() == LibSVMParameters.EPSILON_SVR || params.getAlgorithm() == LibSVMParameters.NU_SVR) {
          predictions.add(dataset.createLiteral(Double.toString(p)));
        } else if (params.getAlgorithm() == LibSVMParameters.ONE_CLASS) {
          if (p == 1) {
            predictions.add(dataset.createLiteral("normal"));
          } else {
            predictions.add(dataset.createLiteral("outlier"));
          }
        } else {
          predictions.add(valueMap.get(revMap.get((int) p)));
        }
      }
    }

    return predictions;
  }

}
TOP

Related Classes of org.data2semantics.proppred.predictors.SVMPropertyPredictor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.