Package tv.floe.metronome.classification.logisticregression.iterativereduce

Source Code of tv.floe.metronome.classification.logisticregression.iterativereduce.POLRWorkerNode

/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements.  See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package tv.floe.metronome.classification.logisticregression.iterativereduce;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.UniformPrior;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

import tv.floe.metronome.classification.logisticregression.POLRModelParameters;
import tv.floe.metronome.classification.logisticregression.ParallelOnlineLogisticRegression;
import tv.floe.metronome.classification.logisticregression.metrics.POLRMetrics;
import tv.floe.metronome.io.records.RCV1RecordFactory;
import tv.floe.metronome.io.records.RecordFactory;


import com.cloudera.iterativereduce.ComputableWorker;
import com.cloudera.iterativereduce.yarn.appworker.ApplicationWorker;

import com.cloudera.iterativereduce.io.RecordParser;
import com.cloudera.iterativereduce.io.TextRecordParser;
import com.google.common.collect.Lists;

/**
* The Worker node for IterativeReduce - performs work on the shard of input
* data for the parallel iterative algorithm - runs the SGD algorithm locally on
* its shard of data
*
* @author jpatterson
*
*/
public class POLRWorkerNode extends POLRNodeBase implements
    ComputableWorker<ParameterVectorUpdatable> {
 
  private static final Log LOG = LogFactory.getLog(POLRWorkerNode.class);
 
  int masterTotal = 0;
 
  public ParallelOnlineLogisticRegression polr = null; // lmp.createRegression();
  public POLRModelParameters polr_modelparams;
 
  public String internalID = "0";
  private RecordFactory VectorFactory = null;
 
  private TextRecordParser lineParser = null;
 
  private boolean IterationComplete = false;
  private int CurrentIteration = 0;
 
  // basic stats tracking
  POLRMetrics metrics = new POLRMetrics();
 
  double averageLineCount = 0.0;
  int k = 0;
  double step = 0.0;
  int[] bumps = new int[] {1, 2, 5};
  double lineCount = 0;
 
  /**
   * Sends a full copy of the multinomial logistic regression array of parameter
   * vectors to the master - this method plugs the local parameter vector into
   * the message
   */
  public ParameterVector GenerateUpdate() {
   
    ParameterVector gradient = new ParameterVector();
    gradient.parameter_vector = this.polr.getBeta().clone(); // this.polr.getGamma().getMatrix().clone();
    gradient.SrcWorkerPassCount = this.LocalBatchCountForIteration;
   
    if (this.lineParser.hasMoreRecords()) {
      gradient.IterationComplete = 0;
    } else {
      gradient.IterationComplete = 1;
    }
   
    gradient.CurrentIteration = this.CurrentIteration;
   
    gradient.AvgLogLikelihood = (new Double(metrics.AvgLogLikelihood))
        .floatValue();
    gradient.PercentCorrect = (new Double(metrics.AvgCorrect * 100))
        .floatValue();
    gradient.TrainedRecords = (new Long(metrics.TotalRecordsProcessed))
        .intValue();
   
    return gradient;
   
  }
 
  /**
   * The IR::Compute method - this is where we do the next batch of records for
   * SGD
   */
  @Override
  public ParameterVectorUpdatable compute() {
   
    Text value = new Text();
    long batch_vec_factory_time = 0;
   
    boolean result = true;
    //boolean processBatch = false;

    while (this.lineParser.hasMoreRecords()) {
       
        try {
          result = this.lineParser.next(value);
        } catch (IOException e1) {
          // TODO Auto-generated catch block
          e1.printStackTrace();
        }
       
        if (result) {
         
          long startTime = System.currentTimeMillis();
         
          Vector v = new RandomAccessSparseVector(this.FeatureVectorSize);
          int actual = -1;
          try {
           
            actual = this.VectorFactory.processLine(value.toString(), v);
          } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
          }
         
          long endTime = System.currentTimeMillis();
         
          batch_vec_factory_time += (endTime - startTime);
         
          // calc stats ---------
         
          double mu = Math.min(k + 1, 200);
          double ll = this.polr.logLikelihood(actual, v);
         
          metrics.AvgLogLikelihood = metrics.AvgLogLikelihood
              + (ll - metrics.AvgLogLikelihood) / mu;
         
          if (Double.isNaN(metrics.AvgLogLikelihood)) {
            metrics.AvgLogLikelihood = 0;
          }
         
          Vector p = new DenseVector(this.num_categories);
          this.polr.classifyFull(p, v);
          int estimated = p.maxValueIndex();
          int correct = (estimated == actual ? 1 : 0);
          metrics.AvgCorrect = metrics.AvgCorrect
              + (correct - metrics.AvgCorrect) / mu;
          this.polr.train(actual, v);
         
          k++;
          metrics.TotalRecordsProcessed = k;
//          if (x == this.BatchSize - 1) {
           
/*            System.err
                .printf(
                    "Worker %s:\t Iteration: %s, Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n",
                    this.internalID, this.CurrentIteration, k, metrics.AvgLogLikelihood,
                    metrics.AvgCorrect * 100, batch_vec_factory_time);
  */         
//          }
         
          this.polr.close();
         
        } else {
         
//          this.LocalBatchCountForIteration++;
          // this.input_split.ResetToStartOfSplit();
          // nothing else to process in split!
//          break;
         
        } // if
       
      } // for the batch size
   
    System.err
    .printf(
        "Worker %s:\t Iteration: %s, Trained Recs: %10d, AvgLL: %10.3f, Percent Correct: %10.2f, VF: %d\n",
        this.internalID, this.CurrentIteration, k, metrics.AvgLogLikelihood,
        metrics.AvgCorrect * 100, batch_vec_factory_time);
   
   
   
/*    } else {
      System.err
      .printf(
          "Worker %s:\t Trained Recs: %10d,  AvgLL: %10.3f, Percent Correct: %10.2f, [Done With Iteration]\n",
          this.internalID, k, metrics.AvgLogLikelihood,
          metrics.AvgCorrect * 100);
     
    } // if
  */ 
    return new ParameterVectorUpdatable(this.GenerateUpdate());
  }
 
  public ParameterVectorUpdatable getResults() {
    return new ParameterVectorUpdatable(GenerateUpdate());
  }
 
  /**
   * This is called when we recieve an update from the master
   *
   * here we - replace the gradient vector with the new global gradient vector
   *
   */
  @Override
  public void update(ParameterVectorUpdatable t) {
    // masterTotal = t.get();
    ParameterVector global_update = t.get();
   
    // set the local parameter vector to the global aggregate ("beta")
    this.polr.SetBeta(global_update.parameter_vector);
   
    // update global count
    this.GlobalBatchCountForIteration = global_update.GlobalPassCount;
   
    // flush the local gradient delta buffer ("gamma")
//    this.polr.FlushGamma();
   
/*    if (global_update.IterationComplete == 0) {
      this.IterationComplete = false;
    } else {
      this.IterationComplete = true;

     
      // when this happens, it will trip the ApplicationWorkerService loop and iteration will increment
     
    }
  */ 
  }
 
  @Override
  public void setup(Configuration c) {
   
    this.conf = c;
   
    try {
     
      this.num_categories = this.conf.getInt(
          "com.cloudera.knittingboar.setup.numCategories", 2);
     
      // feature vector size
     
      this.FeatureVectorSize = LoadIntConfVarOrException(
          "com.cloudera.knittingboar.setup.FeatureVectorSize",
          "Error loading config: could not load feature vector size");
     
      // feature vector size
//      this.BatchSize = this.conf.getInt(
//          "com.cloudera.knittingboar.setup.BatchSize", 200);
     
//      this.NumberPasses = this.conf.getInt(
//          "com.cloudera.knittingboar.setup.NumberPasses", 1);
      // app.iteration.count
    this.NumberIterations = this.conf.getInt("app.iteration.count", 1);
     
      // protected double Lambda = 1.0e-4;
      this.Lambda = Double.parseDouble(this.conf.get(
          "com.cloudera.knittingboar.setup.Lambda", "1.0e-4"));
     
      // protected double LearningRate = 50;
      this.LearningRate = Double.parseDouble(this.conf.get(
          "com.cloudera.knittingboar.setup.LearningRate", "10"));
     
      // maps to either CSV, 20newsgroups, or RCV1
      this.RecordFactoryClassname = LoadStringConfVarOrException(
          "com.cloudera.knittingboar.setup.RecordFactoryClassname",
          "Error loading config: could not load RecordFactory classname");
     
      if (this.RecordFactoryClassname.equals(RecordFactory.CSV_RECORDFACTORY)) {
       
        // so load the CSV specific stuff ----------
       
        // predictor label names
        this.PredictorLabelNames = LoadStringConfVarOrException(
            "com.cloudera.knittingboar.setup.PredictorLabelNames",
            "Error loading config: could not load predictor label names");
       
        // predictor var types
        this.PredictorVariableTypes = LoadStringConfVarOrException(
            "com.cloudera.knittingboar.setup.PredictorVariableTypes",
            "Error loading config: could not load predictor variable types");
       
        // target variables
        this.TargetVariableName = LoadStringConfVarOrException(
            "com.cloudera.knittingboar.setup.TargetVariableName",
            "Error loading config: Target Variable Name");
       
        // column header names
        this.ColumnHeaderNames = LoadStringConfVarOrException(
            "com.cloudera.knittingboar.setup.ColumnHeaderNames",
            "Error loading config: Column Header Names");
       
        // System.out.println("LoadConfig(): " + this.ColumnHeaderNames);
       
      }
     
    } catch (Exception e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
   
    this.SetupPOLR();
  }
 
  /**
   * TODO:
   * - throw a mis-configuration exception of some sort
   *
   */
  private void SetupPOLR() {
   
    // do splitting strings into arrays here...
    String[] predictor_label_names = this.PredictorLabelNames.split(",");
    String[] variable_types = this.PredictorVariableTypes.split(",");
   
    polr_modelparams = new POLRModelParameters();
    polr_modelparams.setTargetVariable(this.TargetVariableName);
    polr_modelparams.setNumFeatures(this.FeatureVectorSize);
    polr_modelparams.setUseBias(true);
   
    List<String> typeList = Lists.newArrayList();
    for (int x = 0; x < variable_types.length; x++) {
      typeList.add(variable_types[x]);
    }
   
    List<String> predictorList = Lists.newArrayList();
    for (int x = 0; x < predictor_label_names.length; x++) {
      predictorList.add(predictor_label_names[x]);
    }
   
    // where do these come from?
    polr_modelparams.setTypeMap(predictorList, typeList);
    polr_modelparams.setLambda(this.Lambda); // based on defaults - match
                                             // command line
    polr_modelparams.setLearningRate(this.LearningRate); // based on defaults -
                                                         // match command line
   
    // setup record factory stuff here ---------

    // ####### disabled this input format, was not a long term solution ##########
    /*   
    if (RecordFactory.TWENTYNEWSGROUPS_RECORDFACTORY
        .equals(this.RecordFactoryClassname)) {
     
      this.VectorFactory = new TwentyNewsgroupsRecordFactory("\t");
     
    } else */
   
   
    if (RecordFactory.RCV1_RECORDFACTORY
        .equals(this.RecordFactoryClassname)) {
     
      this.VectorFactory = new RCV1RecordFactory();
     
    } else {
     
      // it defaults to the CSV record factor, but a custom one
/*     
      this.VectorFactory = new CSVBasedDatasetRecordFactory(
          this.TargetVariableName, polr_modelparams.getTypeMap());
     
      ((CSVBasedDatasetRecordFactory) this.VectorFactory)
          .firstLine(this.ColumnHeaderNames);
  */
     
    //  throw new Exception("Invalid Record Factory Class");
     
    }
   
    polr_modelparams.setTargetCategories(this.VectorFactory
        .getTargetCategories());
   
    // ----- this normally is generated from the POLRModelParams ------
   
    this.polr = new ParallelOnlineLogisticRegression(this.num_categories,
        this.FeatureVectorSize, new UniformPrior()).alpha(1).stepOffset(1000)
        .decayExponent(0.9).lambda(this.Lambda).learningRate(this.LearningRate);
   
    polr_modelparams.setPOLR(polr);
   
    // this.bSetup = true;
  }
 
  @Override
  public void setRecordParser(RecordParser r) {
    this.lineParser = (TextRecordParser) r;
  }
 
  /**
   * only implemented for completeness with the interface, we argued over how to
   * implement this. - this is currently a legacy artifact
   */
  @Override
  public ParameterVectorUpdatable compute(
      List<ParameterVectorUpdatable> records) {
    // TODO Auto-generated method stub
    return compute();
  }
 
  public static void main(String[] args) throws Exception {
    TextRecordParser parser = new TextRecordParser();
    POLRWorkerNode pwn = new POLRWorkerNode();
    ApplicationWorker<ParameterVectorUpdatable> aw = new ApplicationWorker<ParameterVectorUpdatable>(
        parser, pwn, ParameterVectorUpdatable.class);
   
    ToolRunner.run(aw, args);
  }

/*  @Override
  public int getCurrentGlobalIteration() {
    // TODO Auto-generated method stub
    return 0;
  }
*/
 
  /**
   * returns false if we're done with iterating over the data 
   *
   * @return
   */
  @Override
  public boolean IncrementIteration() {
   
    this.CurrentIteration++;
    this.IterationComplete = false;
    this.lineParser.reset();
   
    System.out.println( "IncIteration > " + this.CurrentIteration + ", " + this.NumberIterations );
   
    if (this.CurrentIteration >= this.NumberIterations) {
      System.out.println("POLRWorkerNode: [ done with all iterations ]");
      return false;
    }
   
    return true;
   
  }

/*  @Override
  public boolean isStillWorkingOnCurrentIteration() {

   
    //return this.lineParser.hasMoreRecords();
   
    //return this.
    return !this.IterationComplete;
  }
  */
TOP

Related Classes of tv.floe.metronome.classification.logisticregression.iterativereduce.POLRWorkerNode

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.