Package distributedRedditAnalyser

Source Code of distributedRedditAnalyser.OzaBoost

/*
*    OzaBoost.java
*    @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
*    @author Luke Barnett (luke@barnett.net.nz)
*
*    This program is free software; you can redistribute it and/or modify
*    it under the terms of the GNU General Public License as published by
*    the Free Software Foundation; either version 3 of the License, or
*    (at your option) any later version.
*
*    This program is distributed in the hope that it will be useful,
*    but WITHOUT ANY WARRANTY; without even the implied warranty of
*    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*    GNU General Public License for more details.
*
*    You should have received a copy of the GNU General Public License
*    along with this program. If not, see <http://www.gnu.org/licenses/>.
*   
*/
package distributedRedditAnalyser;

import java.util.ArrayDeque;
import java.util.Random;
import java.util.concurrent.Semaphore;

import weka.core.Instance;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Classifier;
import moa.core.DoubleVector;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.options.ClassOption;
import moa.options.FlagOption;
import moa.options.IntOption;

/**
* Rewrite of the moa implementation of OzaBoost to accommodate a distributed setting and the sharing of classifiers
*
* Keeps the latest K classifiers
*
* Largely derrived from the moa implementation:
* http://code.google.com/p/moa/source/browse/moa/src/main/java/moa/classifiers/meta/OzaBoost.java
*
* @author Luke Barnett 1109967
* @author Tony Chen 1111377
*
*/
public class OzaBoost extends AbstractClassifier {
  private static final long serialVersionUID = -4456874021287021340L;
 
  private Semaphore lock = new Semaphore(1);
 
  @Override
    public String getPurposeString() {
        return "Incremental on-line boosting of Oza and Russell.";
    }

    public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
            "Classifier to train.", Classifier.class, "trees.HoeffdingTree");

    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
            "The max number of models to boost.", 10, 1, Integer.MAX_VALUE);

    public FlagOption pureBoostOption = new FlagOption("pureBoost", 'p',
            "Boost with weights only; no poisson.");

    protected ArrayDeque<ClassifierInstance> ensemble;

    @Override
    public void resetLearningImpl() {
      try {
      lock.acquire();   
          this.ensemble = new ArrayDeque<ClassifierInstance>(ensembleSizeOption.getValue());
          Classifier baseLearner = (Classifier) getPreparedClassOption(this.baseLearnerOption);
          baseLearner.resetLearning();
      } catch (InterruptedException e) {
      e.printStackTrace();
    } finally {
      lock.release();
    }
    }

    @Override
    public void trainOnInstanceImpl(Instance inst) {
      try {
      lock.acquire();
      //Get a new classifier
        Classifier newClassifier = ((Classifier) getPreparedClassOption(this.baseLearnerOption)).copy();
        ensemble.add(new ClassifierInstance(newClassifier));
       
        //If we have too many classifiers
        while(ensemble.size() > ensembleSizeOption.getValue())
          ensemble.pollFirst();
       
          double lambda_d = 1.0;
          for(ClassifierInstance c : ensemble){
            double k = this.pureBoostOption.isSet() ? lambda_d : MiscUtils.poisson(lambda_d, this.classifierRandom);
            if (k > 0.0) {
                  Instance weightedInst = (Instance) inst.copy();
                  weightedInst.setWeight(inst.weight() * k);
                  c.getClassifier().trainOnInstance(weightedInst);
              }
              if (c.getClassifier().correctlyClassifies(inst)) {
                  c.setScms(c.getScms() + lambda_d);
                  lambda_d *= this.trainingWeightSeenByModel / (2 * c.getScms());
              } else {
                c.setSwms(c.getSwms() + lambda_d);
                  lambda_d *= this.trainingWeightSeenByModel / (2 * c.getSwms());
              }
          }
      } catch (InterruptedException e) {
      e.printStackTrace();
    }finally{
      lock.release();
    }
    }

    protected double getEnsembleMemberWeight(ClassifierInstance i) {
        double em = i.getSwms() / (i.getScms() + i.getSwms());
        if ((em == 0.0) || (em > 0.5)) {
            return 0.0;
        }
        double Bm = em / (1.0 - em);
        return Math.log(1.0 / Bm);
    }

    public double[] getVotesForInstance(Instance inst) {
      DoubleVector combinedVote = new DoubleVector();
      try {
      lock.acquire();
          for(ClassifierInstance c : ensemble){
            double memberWeight = getEnsembleMemberWeight(c);
            if (memberWeight > 0.0) {
                  DoubleVector vote = new DoubleVector(c.getClassifier().getVotesForInstance(inst));
                  if (vote.sumOfValues() > 0.0) {
                      vote.normalize();
                      vote.scaleValues(memberWeight);
                      combinedVote.addValues(vote);
                  }
              } else {
                  break;
              }
          }
      } catch (InterruptedException e) {
      e.printStackTrace();
    } finally {
      lock.release();
    }
        return combinedVote.getArrayRef();
    }

    public boolean isRandomizable() {
        return true;
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {}

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return new Measurement[]{new Measurement("ensemble size",
                    this.ensemble != null ? this.ensemble.size() : 0)};
    }

    @Override
    public Classifier[] getSubClassifiers() {
      Classifier[] classifiers = new Classifier[ensemble.size()];
      try {
      lock.acquire();
        int i = 0;
        for(ClassifierInstance c : ensemble){
          if(i < classifiers.length){
            classifiers[i] = c.getClassifier().copy();
          }else{
            break;
          }
          i++;
        }
      } catch (InterruptedException e) {
      e.printStackTrace();
    } finally {
      lock.release();
    }
        return classifiers;
    }
   
    public ClassifierInstance getLatestClassifier(){
      return ensemble.peekLast();
    }

    public void addClassifier(ClassifierInstance c){
      try {
      lock.acquire();
        ensemble.add(c.clone());
       
        //If we have too many classifiers
        while(ensemble.size() > ensembleSizeOption.getValue())
          ensemble.pollFirst();
     
      } catch (InterruptedException e) {
      e.printStackTrace();
    } finally {
      lock.release();
    }
    }
}
TOP

Related Classes of distributedRedditAnalyser.OzaBoost

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.