Package eu.stratosphere.pact.runtime.iterative.task

Source Code of eu.stratosphere.pact.runtime.iterative.task.IterationSynchronizationSinkTask

/***********************************************************************************************************************
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
**********************************************************************************************************************/

package eu.stratosphere.pact.runtime.iterative.task;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import com.google.common.base.Preconditions;

import eu.stratosphere.api.common.aggregators.Aggregator;
import eu.stratosphere.api.common.aggregators.AggregatorWithName;
import eu.stratosphere.api.common.aggregators.ConvergenceCriterion;
import eu.stratosphere.nephele.event.task.AbstractTaskEvent;
import eu.stratosphere.nephele.execution.librarycache.LibraryCacheManager;
import eu.stratosphere.runtime.io.api.MutableRecordReader;
import eu.stratosphere.nephele.template.AbstractOutputTask;
import eu.stratosphere.nephele.types.IntegerRecord;
import eu.stratosphere.pact.runtime.iterative.event.AllWorkersDoneEvent;
import eu.stratosphere.pact.runtime.iterative.event.TerminationEvent;
import eu.stratosphere.pact.runtime.iterative.event.WorkerDoneEvent;
import eu.stratosphere.pact.runtime.task.RegularPactTask;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
import eu.stratosphere.types.Value;
import eu.stratosphere.util.InstantiationUtil;

/**
* The task responsible for synchronizing all iteration heads, implemented as an {@link AbstractOutputTask}. This task
* will never see any data.
* In each superstep, it simply waits until it has receiced a {@link WorkerDoneEvent} from each head and will send back
* an {@link AllWorkersDoneEvent} to signal that the next superstep can begin.
*/
public class IterationSynchronizationSinkTask extends AbstractOutputTask implements Terminable {

  private static final Log log = LogFactory.getLog(IterationSynchronizationSinkTask.class);

  private MutableRecordReader<IntegerRecord> headEventReader;
 
  private ClassLoader userCodeClassLoader;
 
  private SyncEventHandler eventHandler;

  private ConvergenceCriterion<Value> convergenceCriterion;
 
  private Map<String, Aggregator<?>> aggregators;

  private String convergenceAggregatorName;

  private int currentIteration = 1;
 
  private int maxNumberOfIterations;

  private final AtomicBoolean terminated = new AtomicBoolean(false);


  // --------------------------------------------------------------------------------------------
 
  @Override
  public void registerInputOutput() {
    this.headEventReader = new MutableRecordReader<IntegerRecord>(this);
  }

  @Override
  public void invoke() throws Exception {
    userCodeClassLoader = LibraryCacheManager.getClassLoader(getEnvironment().getJobID());
    TaskConfig taskConfig = new TaskConfig(getTaskConfiguration());
   
    // instantiate all aggregators
    this.aggregators = new HashMap<String, Aggregator<?>>();
    for (AggregatorWithName<?> aggWithName : taskConfig.getIterationAggregators()) {
      Aggregator<?> agg = InstantiationUtil.instantiate(aggWithName.getAggregator(), Aggregator.class);
      aggregators.put(aggWithName.getName(), agg);
    }
   
    // instantiate the aggregator convergence criterion
    if (taskConfig.usesConvergenceCriterion()) {
      Class<? extends ConvergenceCriterion<Value>> convClass = taskConfig.getConvergenceCriterion();
      convergenceCriterion = InstantiationUtil.instantiate(convClass, ConvergenceCriterion.class);
      convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName();
      Preconditions.checkNotNull(convergenceAggregatorName);
    }
   
    maxNumberOfIterations = taskConfig.getNumberOfIterations();
   
    // set up the event handler
    int numEventsTillEndOfSuperstep = taskConfig.getNumberOfEventsUntilInterruptInIterativeGate(0);
    eventHandler = new SyncEventHandler(numEventsTillEndOfSuperstep, aggregators, userCodeClassLoader);
    headEventReader.subscribeToEvent(eventHandler, WorkerDoneEvent.class);

    IntegerRecord dummy = new IntegerRecord();
   
    while (!terminationRequested()) {

//      notifyMonitor(IterationMonitoring.Event.SYNC_STARTING, currentIteration);
      if (log.isInfoEnabled()) {
        log.info(formatLogString("starting iteration [" + currentIteration + "]"));
      }

      // this call listens for events until the end-of-superstep is reached
      readHeadEventChannel(dummy);

      if (log.isInfoEnabled()) {
        log.info(formatLogString("finishing iteration [" + currentIteration + "]"));
      }

      if (checkForConvergence()) {
        if (log.isInfoEnabled()) {
          log.info(formatLogString("signaling that all workers are to terminate in iteration ["
            + currentIteration + "]"));
        }

        requestTermination();
        sendToAllWorkers(new TerminationEvent());
//        notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration);
      } else {
        if (log.isInfoEnabled()) {
          log.info(formatLogString("signaling that all workers are done in iteration [" + currentIteration
            + "]"));
        }

        AllWorkersDoneEvent allWorkersDoneEvent = new AllWorkersDoneEvent(aggregators);
        sendToAllWorkers(allWorkersDoneEvent);
       
        // reset all aggregators
        for (Aggregator<?> agg : aggregators.values()) {
          agg.reset();
        }
       
//        notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration);
        currentIteration++;
      }
    }
  }

//  protected void notifyMonitor(IterationMonitoring.Event event, int currentIteration) {
//    if (log.isInfoEnabled()) {
//      log.info(IterationMonitoring.logLine(getEnvironment().getJobID(), event, currentIteration, 1));
//    }
//  }

  private boolean checkForConvergence() {
    if (maxNumberOfIterations == currentIteration) {
      if (log.isInfoEnabled()) {
        log.info(formatLogString("maximum number of iterations [" + currentIteration
          + "] reached, terminating..."));
      }
      return true;
    }

    if (convergenceAggregatorName != null) {
      @SuppressWarnings("unchecked")
      Aggregator<Value> aggregator = (Aggregator<Value>) aggregators.get(convergenceAggregatorName);
      if (aggregator == null) {
        throw new RuntimeException("Error: Aggregator for convergence criterion was null.");
      }
     
      Value aggregate = aggregator.getAggregate();

      if (convergenceCriterion.isConverged(currentIteration, aggregate)) {
        if (log.isInfoEnabled()) {
          log.info(formatLogString("convergence reached after [" + currentIteration
            + "] iterations, terminating..."));
        }
        return true;
      }
    }
   
    return false;
  }

  private void readHeadEventChannel(IntegerRecord rec) throws IOException {
    // reset the handler
    eventHandler.resetEndOfSuperstep();
   
    // read (and thereby process all events in the handler's event handling functions)
    try {
      while (this.headEventReader.next(rec)) {
        throw new RuntimeException("Synchronization task must not see any records!");
      }
    } catch (InterruptedException iex) {
      // sanity check
      if (!(eventHandler.isEndOfSuperstep())) {
        throw new RuntimeException("Event handler interrupted without reaching end-of-superstep.");
      }
    }
  }

  private void sendToAllWorkers(AbstractTaskEvent event) throws IOException, InterruptedException {
    headEventReader.publishEvent(event);
  }

  private String formatLogString(String message) {
    return RegularPactTask.constructLogString(message, getEnvironment().getTaskName(), this);
  }
 
  // --------------------------------------------------------------------------------------------
 
  @Override
  public boolean terminationRequested() {
    return terminated.get();
  }

  @Override
  public void requestTermination() {
    terminated.set(true);
  }
}
TOP

Related Classes of eu.stratosphere.pact.runtime.iterative.task.IterationSynchronizationSinkTask

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.