Package io.prediction.examples.java.parallel

Source Code of io.prediction.examples.java.parallel.Algorithm$ReadingAndCount

package io.prediction.examples.java.parallel;

import io.prediction.controller.java.EmptyParams;
import io.prediction.controller.java.PJavaAlgorithm;

import java.io.Serializable;
import java.util.List;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import scala.Tuple2;

public class Algorithm extends PJavaAlgorithm<
  EmptyParams, JavaPairRDD<String, Float>, Model, Query, Float> {

  final static Logger logger = LoggerFactory.getLogger(Algorithm.class);

  public static class ReadingAndCount implements Serializable {
    public float reading;
    public int count;

    public ReadingAndCount(float reading, int count) {
      this.reading = reading;
      this.count = count;
    }

    public ReadingAndCount(float reading) {
      this(reading, 1);
    }

    @Override
    public String toString() {
      return "(reading = " + reading + ", count = " + count + ")";
    }
  }

  @Override
  public Model train(JavaPairRDD<String, Float> data) {
    // take averages just like the local helloworld program
    JavaPairRDD<String, Float> averages = data.mapValues(
      new Function<Float, ReadingAndCount>() {
        @Override
        public ReadingAndCount call(Float reading) {
          return new ReadingAndCount(reading);
        }
      }).reduceByKey(
      new Function2<ReadingAndCount, ReadingAndCount, ReadingAndCount>() {
        @Override
        public ReadingAndCount call(ReadingAndCount rac1, ReadingAndCount rac2) {
          return new ReadingAndCount(rac1.reading + rac2.reading, rac1.count + rac2.count);
        }
      }).mapValues(
      new Function<ReadingAndCount, Float>() {
        @Override
        public Float call(ReadingAndCount rac) {
          return rac.reading / rac.count;
        }
      });
    return new Model(averages);
  }

  @Override
  public JavaPairRDD<Object, Float> batchPredict(Model model,
      JavaPairRDD<Object, Query> indexedQueries) {
    return model.temperatures.join(indexedQueries.mapToPair(
        new PairFunction<Tuple2<Object, Query>, String, Object>() {
          @Override   // reverse the query tuples, then join
          public Tuple2 call(Tuple2<Object, Query> tuple) {
            return new Tuple2(tuple._2.day, tuple._1);
          }
        })).mapToPair(
        new PairFunction<Tuple2<String, Tuple2<Float, Object>>, Object, Float>() {
          @Override   // map result back to predictions, dropping the day
          public Tuple2 call(Tuple2<String, Tuple2<Float, Object>> tuple) {
            return new Tuple2(tuple._2._2, tuple._2._1);
          }
        });
  }

  @Override
  public Float predict(Model model, Query query) {
    final String day = query.day;
    List<Float> reading = model.temperatures.lookup(day);
    if (reading.size() == 0) {
      return -10000f; // JSON does not support NaN
    }
    return reading.get(0);
  }
}
TOP

Related Classes of io.prediction.examples.java.parallel.Algorithm$ReadingAndCount

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.