Package org.encog.ml.svm.training.search

Source Code of org.encog.ml.svm.training.search.SVMSearchJob

package org.encog.ml.svm.training.search;

import java.util.ArrayList;
import java.util.List;

import org.encog.EncogError;
import org.encog.StatusReportable;
import org.encog.mathutil.libsvm.svm;
import org.encog.mathutil.libsvm.svm_problem;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.svm.KernelType;
import org.encog.ml.svm.SVM;
import org.encog.ml.svm.training.EncodeSVMProblem;
import org.encog.ml.svm.training.SVMTrain;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.util.Format;
import org.encog.util.concurrency.job.ConcurrentJob;
import org.encog.util.concurrency.job.JobUnitContext;

public class SVMSearchJob extends ConcurrentJob implements MLTrain {

  /**
   * The default starting number for C.
   */
  public static final double DEFAULT_CONST_BEGIN = -5;

  /**
   * The default ending number for C.
   */
  public static final double DEFAULT_CONST_END = 15;

  /**
   * The default step for C.
   */
  public static final double DEFAULT_CONST_STEP = 2;

  /**
   * The default gamma begin.
   */
  public static final double DEFAULT_GAMMA_BEGIN = -10;

  /**
   * The default gamma end.
   */
  public static final double DEFAULT_GAMMA_END = 10;

  /**
   * The default gamma step.
   */
  public static final double DEFAULT_GAMMA_STEP = 1;

  /**
   * The best values found for C.
   */
  private double bestConst;

  /**
   * The best values found for gamma.
   */
  private double bestGamma;

  /**
   * The best error.
   */
  private double bestError;

  /**
   * The current C.
   */
  private double currentConst;

  /**
   * The current gamma.
   */
  private double currentGamma;

  /**
   * The beginning value for C.
   */
  private double constBegin = SVMSearchJob.DEFAULT_CONST_BEGIN;

  /**
   * The step value for C.
   */
  private double constStep = SVMSearchJob.DEFAULT_CONST_STEP;

  /**
   * The ending value for C.
   */
  private double constEnd = SVMSearchJob.DEFAULT_CONST_END;

  /**
   * The beginning value for gamma.
   */
  private double gammaBegin = SVMSearchJob.DEFAULT_GAMMA_BEGIN;

  /**
   * The ending value for gamma.
   */
  private double gammaEnd = SVMSearchJob.DEFAULT_GAMMA_END;

  /**
   * The step value for gamma.
   */
  private double gammaStep = SVMSearchJob.DEFAULT_GAMMA_STEP;

  private final SVM modelSVM;

  private boolean done;

  private int iterationCount;

  private boolean started;

  /**
   * The problem to train for.
   */
  private final svm_problem problem;

  /**
   * The number of folds.
   */
  private int fold = 0;

  /**
   * Is the network setup.
   */
  private boolean isSetup;

  private double svmTrain;

  private final MLDataSet training;

  public SVMSearchJob(final SVM svm, final MLDataSet dataSet,
      final StatusReportable report) {
    super(report);
    if (svm.getKernelType() != KernelType.RadialBasisFunction) {
      throw new EncogError(
          "To use SVM search train, the SVM kernel must be RBF.");
    }

    this.problem = EncodeSVMProblem.encode(dataSet, 0);
    this.modelSVM = svm;
    this.training = dataSet;
    this.bestError = Double.MAX_VALUE;
  }

  @Override
  public void addStrategy(final Strategy strategy) {
    throw new EncogError("Not supported.");

  }

  @Override
  public boolean canContinue() {
    return false;
  }

  @Override
  public void finishTraining() {
    stop();
  }

  private SVM generateSVM() {
    final SVM svm = new SVM(this.modelSVM.getInputCount(),
        this.modelSVM.getSVMType(), this.modelSVM.getKernelType());
    return svm;
  }

  /**
   * @return the bestConst
   */
  public final double getBestConst() {
    return this.bestConst;
  }

  /**
   * @return the bestError
   */
  public final double getBestError() {
    return this.bestError;
  }

  /**
   * @return the bestGamma
   */
  public final double getBestGamma() {
    return this.bestGamma;
  }

  /**
   * @return the constBegin
   */
  public final double getConstBegin() {
    return this.constBegin;
  }

  /**
   * @return the constEnd
   */
  public final double getConstEnd() {
    return this.constEnd;
  }

  /**
   * @return the constStep
   */
  public final double getConstStep() {
    return this.constStep;
  }

  /**
   * @return the currentConst
   */
  public final double getCurrentConst() {
    return this.currentConst;
  }

  /**
   * @return the currentGamma
   */
  public final double getCurrentGamma() {
    return this.currentGamma;
  }

  @Override
  public double getError() {
    return this.bestError;
  }

  /**
   * @return the fold
   */
  public final int getFold() {
    return this.fold;
  }

  /**
   * @return the gammaBegin
   */
  public final double getGammaBegin() {
    return this.gammaBegin;
  }

  /**
   * @return the gammaEnd
   */
  public final double getGammaEnd() {
    return this.gammaEnd;
  }

  /**
   * @return the gammaStep
   */
  public final double getGammaStep() {
    return this.gammaStep;
  }

  @Override
  public TrainingImplementationType getImplementationType() {
    // TODO Auto-generated method stub
    return TrainingImplementationType.Background;
  }

  @Override
  public int getIteration() {
    return this.iterationCount;
  }

  /**
   * This method creates, and trains, a SVM with the best const and gamma.
   * @return The best SVM.
   */
  @Override
  public MLMethod getMethod() {
    final SVM result = generateSVM();
    result.getParams().C = this.bestConst;
    result.getParams().gamma = this.bestGamma;
    result.setModel(svm.svm_train(this.problem, result.getParams()));
    return result;
  }

  @Override
  public List<Strategy> getStrategies() {
    return new ArrayList<Strategy>();
  }

  @Override
  public MLDataSet getTraining() {
    // TODO Auto-generated method stub
    return this.training;
  }

  @Override
  public boolean isTrainingDone() {
    return this.done && !this.isRunning();
  }

  @Override
  public void iteration() {
    if (!this.started) {
      processBackground();
      this.started = true;
      this.iterationCount++;
    } else {
      try {
        Thread.sleep(10000);
      } catch (final InterruptedException e) {
      }
    }

    this.iterationCount++;

  }

  @Override
  public void iteration(final int count) {
    iteration();
  }

  @Override
  public int loadWorkload() {
    double d = (this.gammaEnd - this.gammaBegin) / this.gammaStep;
    d += (this.constEnd - this.constBegin) / this.constStep;
    return (int) d;
  }

  @Override
  public TrainingContinuation pause() {
    return null;
  }

  @Override
  public void performJobUnit(final JobUnitContext context) {
    final SVMJobPackage pack = (SVMJobPackage) context.getJobUnit();
    final double[] target = new double[this.problem.l];

    // set params
    pack.getSvm().getParams().gamma = pack.getGamma();
    pack.getSvm().getParams().C = pack.getC();

    double error;
   
    pack.getSvm().getParams().C = this.currentConst;
    pack.getSvm().getParams().gamma = this.currentGamma;
   
    if( fold==0 ) {
      // train it     
      pack.getSvm().setModel(svm.svm_train(this.problem, pack.getSvm().getParams()));
      error = pack.getSvm().calculateError(getTraining());
    } else {
      // cross validate it
      svm.svm_cross_validation(this.problem, pack.getSvm().getParams(),
          this.fold, target);
     
      error = SVMTrain.evaluate(pack.getSvm().getParams(),
          this.problem, target);
    }
   
    // new best error?
    if (!Double.isNaN(error)) {
      if (error < this.bestError) {
        this.bestConst = pack.getC();
        this.bestGamma = pack.getGamma();
        this.bestError = error;
      }
    }

    // report progress
    final StringBuilder message = new StringBuilder();

    message.append("Current: gamma= ");
    message.append(Format.formatDouble(this.currentGamma, 2));
    message.append("; Const: ");
    message.append(Format.formatDouble(this.currentConst, 2));
    message.append("; Best Error: " + Format.formatPercent(this.bestError));

    reportStatus(context, message.toString());

  }

  @Override
  public Object requestNextTask() {
    if (this.done || getShouldStop()) {
      return null;
    }

    final SVM svm = generateSVM();

    // advance
    this.currentConst += this.constStep;
    if (this.currentConst > this.constEnd) {
      this.currentConst = this.constBegin;
      this.currentGamma += this.gammaStep;
      if (this.currentGamma > this.gammaEnd) {
        this.done = true;
      }
    }

    return new SVMJobPackage(svm, this.problem, this.currentConst,
        this.currentGamma, this.fold);
  }

  @Override
  public void resume(final TrainingContinuation state) {
  }

  /**
   * @param constBegin the constBegin to set
   */
  public final void setConstBegin(final double constBegin) {
    this.constBegin = constBegin;
  }

  /**
   * @param constEnd the constEnd to set
   */
  public final void setConstEnd(final double constEnd) {
    this.constEnd = constEnd;
  }

  /**
   * @param constStep the constStep to set
   */
  public final void setConstStep(final double constStep) {
    this.constStep = constStep;
  }

  /**
   * @param currentConst the currentConst to set
   */
  public final void setCurrentConst(final double currentConst) {
    this.currentConst = currentConst;
  }

  /**
   * @param currentGamma the currentGamma to set
   */
  public final void setCurrentGamma(final double currentGamma) {
    this.currentGamma = currentGamma;
  }

  @Override
  public void setError(final double error) {
    // TODO Auto-generated method stub

  }

  /**
   * @param fold the fold to set
   */
  public final void setFold(final int fold) {
    this.fold = fold;
  }

  /**
   * @param gammaBegin the gammaBegin to set
   */
  public final void setGammaBegin(final double gammaBegin) {
    this.gammaBegin = gammaBegin;
  }

  /**
   * @param gammaEnd the gammaEnd to set
   */
  public final void setGammaEnd(final double gammaEnd) {
    this.gammaEnd = gammaEnd;
  }

  /**
   * @param gammaStep the gammaStep to set
   */
  public final void setGammaStep(final double gammaStep) {
    this.gammaStep = gammaStep;
  }

  @Override
  public void setIteration(final int iteration) {
    this.iterationCount = iteration;
  }
}
TOP

Related Classes of org.encog.ml.svm.training.search.SVMSearchJob

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.