Package eu.stratosphere.test.recordJobs.kmeans

Source Code of eu.stratosphere.test.recordJobs.kmeans.KMeansCross

/***********************************************************************************************************************
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
**********************************************************************************************************************/

package eu.stratosphere.test.recordJobs.kmeans;

import java.util.ArrayList;
import java.util.List;

import eu.stratosphere.api.common.Plan;
import eu.stratosphere.api.common.Program;
import eu.stratosphere.api.common.ProgramDescription;
import eu.stratosphere.api.java.record.operators.BulkIteration;
import eu.stratosphere.api.java.record.operators.FileDataSink;
import eu.stratosphere.api.java.record.operators.FileDataSource;
import eu.stratosphere.api.java.record.operators.CrossOperator;
import eu.stratosphere.api.java.record.operators.ReduceOperator;
import eu.stratosphere.client.LocalExecutor;
import eu.stratosphere.test.recordJobs.kmeans.udfs.ComputeDistance;
import eu.stratosphere.test.recordJobs.kmeans.udfs.FindNearestCenter;
import eu.stratosphere.test.recordJobs.kmeans.udfs.PointInFormat;
import eu.stratosphere.test.recordJobs.kmeans.udfs.PointOutFormat;
import eu.stratosphere.test.recordJobs.kmeans.udfs.RecomputeClusterCenter;
import eu.stratosphere.types.IntValue;


public class KMeansCross implements Program, ProgramDescription {

  private static final long serialVersionUID = 1L;

  @Override
  public Plan getPlan(String... args) {
    // parse job parameters
    final int numSubTasks = (args.length > 0 ? Integer.parseInt(args[0]) : 1);
    final String dataPointInput = (args.length > 1 ? args[1] : "");
    final String clusterInput = (args.length > 2 ? args[2] : "");
    final String output = (args.length > 3 ? args[3] : "");
    final int numIterations = (args.length > 4 ? Integer.parseInt(args[4]) : 1);

    // create DataSourceContract for cluster center input
    FileDataSource initialClusterPoints = new FileDataSource(new PointInFormat(), clusterInput, "Centers");
    initialClusterPoints.setDegreeOfParallelism(1);
   
    BulkIteration iteration = new BulkIteration("K-Means Loop");
    iteration.setInput(initialClusterPoints);
    iteration.setMaximumNumberOfIterations(numIterations);
   
    // create DataSourceContract for data point input
    FileDataSource dataPoints = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points");

    // create CrossOperator for distance computation
    CrossOperator computeDistance = CrossOperator.builder(new ComputeDistance())
        .input1(dataPoints)
        .input2(iteration.getPartialSolution())
        .name("Compute Distances")
        .build();

    // create ReduceOperator for finding the nearest cluster centers
    ReduceOperator findNearestClusterCenters = ReduceOperator.builder(new FindNearestCenter(), IntValue.class, 0)
        .input(computeDistance)
        .name("Find Nearest Centers")
        .build();

    // create ReduceOperator for computing new cluster positions
    ReduceOperator recomputeClusterCenter = ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0)
        .input(findNearestClusterCenters)
        .name("Recompute Center Positions")
        .build();
    iteration.setNextPartialSolution(recomputeClusterCenter);
   
    // create DataSourceContract for data point input
    FileDataSource dataPoints2 = new FileDataSource(new PointInFormat(), dataPointInput, "Data Points 2");
   
    // compute distance of points to final clusters
    CrossOperator computeFinalDistance = CrossOperator.builder(new ComputeDistance())
        .input1(dataPoints2)
        .input2(iteration)
        .name("Compute Final Distances")
        .build();

    // find nearest final cluster for point
    ReduceOperator findNearestFinalCluster = ReduceOperator.builder(new FindNearestCenter(), IntValue.class, 0)
        .input(computeFinalDistance)
        .name("Find Nearest Final Centers")
        .build();

    // create DataSinkContract for writing the new cluster positions
    FileDataSink finalClusters = new FileDataSink(new PointOutFormat(), output+"/centers", iteration, "Cluster Positions");

    // write assigned clusters
    FileDataSink clusterAssignments = new FileDataSink(new PointOutFormat(), output+"/points", findNearestFinalCluster, "Cluster Assignments");
   
    List<FileDataSink> sinks = new ArrayList<FileDataSink>();
    sinks.add(finalClusters);
    sinks.add(clusterAssignments);
   
    // return the PACT plan
    Plan plan = new Plan(sinks, "Iterative KMeans");
    plan.setDefaultParallelism(numSubTasks);
    return plan;
  }

  @Override
  public String getDescription() {
    return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output> <numIterations>";
  }
 
  public static void main(String[] args) throws Exception {
    KMeansCross kmi = new KMeansCross();
   
    if (args.length < 5) {
      System.err.println(kmi.getDescription());
      System.exit(1);
    }
   
    Plan plan = kmi.getPlan(args);
   
    // This will execute the kMeans clustering job embedded in a local context.
    LocalExecutor.execute(plan);

  }
}
TOP

Related Classes of eu.stratosphere.test.recordJobs.kmeans.KMeansCross

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.