Package distributedRedditAnalyser.bolt

Source Code of distributedRedditAnalyser.bolt.StringToWordVectorBolt

package distributedRedditAnalyser.bolt;

import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.Semaphore;

import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.StringToWordVector;

import distributedRedditAnalyser.reddit.Post;

import backtype.storm.spout.SpoutOutputCollector;
import backtype.storm.task.OutputCollector;
import backtype.storm.task.TopologyContext;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.topology.base.BaseRichBolt;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Tuple;
import backtype.storm.tuple.Values;

/**
* Takes string instances and turns them into word vectors
*
* @author Luke Barnett 1109967
* @author Tony Chen 1111377
*
*/
public class StringToWordVectorBolt extends BaseRichBolt{

  private static final long serialVersionUID = -7494062164103601417L;
  private final int BATCH_SIZE;
  private final int MAX_NUMBER_OF_WORDS_TO_KEEP;
  private Instances INST_HEADERS;
  private final ArrayBlockingQueue<Instance> BATCH_QUEUE;
  private OutputCollector collector;
  private Semaphore semaphore;
  private StringToWordVector filter;
  private Boolean training = true;
 
  public StringToWordVectorBolt(int batchSize, int maxNumberOfWordsToKeep, Instances instHeaders){
    //We need to have a batch size that is at least 1
    if(batchSize < 1){
      throw new IllegalArgumentException("Batch Size is less than 1");
    }
    if(maxNumberOfWordsToKeep < 1){
      throw new IllegalArgumentException("Max number of words to keep is less than 1");
    }
   
    //Store all the variables we need and set things up
    MAX_NUMBER_OF_WORDS_TO_KEEP = maxNumberOfWordsToKeep;
    BATCH_SIZE = batchSize;
    BATCH_QUEUE = new ArrayBlockingQueue<Instance>(BATCH_SIZE);
    INST_HEADERS = instHeaders;
    semaphore = new Semaphore(1);
    filter = new StringToWordVector(MAX_NUMBER_OF_WORDS_TO_KEEP);
    filter.setOutputWordCounts(true);
  }

  @Override
  public void declareOutputFields(OutputFieldsDeclarer declarer) {
    declarer.declare(new Fields("StringVectors"));
   
  }

  @Override
  public void prepare(Map stormConf, TopologyContext context,  OutputCollector collector) {
    this.collector = collector;
   
  }

  @Override
  public void execute(Tuple input) {
    //Get the instance
    DenseInstance inst = (DenseInstance) input.getValue(0);
    INST_HEADERS = inst.dataset();
   
    //Retrieve the semaphore
    try {
      semaphore.acquire();
    } catch (InterruptedException e2) {
      e2.printStackTrace();
    }
   
    /*
     * If we are training then we add it to the batch until it's full
     * At which point the model is created and we just stream instances through the model
     */
    if(training){
      try{
        BATCH_QUEUE.add(inst);
      }catch(IllegalStateException e){
        //Queue is full so we should train the filter
        try {
         
          //Add all the instances to the batch
          Instances data = new Instances(INST_HEADERS);
         
          for(Instance i : BATCH_QUEUE){
            data.add(i);
          }
         
          //Set up the filter
          filter.setInputFormat(data);
         
          //Run the model creation
          Instances filter_training_set = Filter.useFilter(data, filter);
         
          //emit the instances used to train the filter
          for(int i=0; i<filter_training_set.numInstances(); i++){
            collector.emit(new Values(filter_training_set.get(i)));
          }
         
          training = false;
          //Empty the queue for memories
          BATCH_QUEUE.clear();
        } catch (Exception e1) {
          e1.printStackTrace();
        }
       
      }
    }else{
      //Filter through the model and emit
      try {
        filter.input(inst);
      } catch (Exception e) {
        e.printStackTrace();
      }
     
      Instance filteredValue;
      while((filteredValue = filter.output()) != null){
        collector.emit(new Values(filteredValue));
      }
    }
   
    //Always acknowledge the tuple we have processed so it isn't sent somewhere else
    collector.ack(input);
    semaphore.release();
  }

}
TOP

Related Classes of distributedRedditAnalyser.bolt.StringToWordVectorBolt

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.