Package com.google.appengine.tools.mapreduce.impl.shardedjob

Source Code of com.google.appengine.tools.mapreduce.impl.shardedjob.ShardedJobRunner

// Copyright 2012 Google Inc. All Rights Reserved.

package com.google.appengine.tools.mapreduce.impl.shardedjob;

import com.google.appengine.api.backends.BackendServiceFactory;
import com.google.appengine.api.datastore.DatastoreService;
import com.google.appengine.api.datastore.DatastoreServiceFactory;
import com.google.appengine.api.datastore.Entity;
import com.google.appengine.api.datastore.EntityNotFoundException;
import com.google.appengine.api.datastore.Key;
import com.google.appengine.api.datastore.Transaction;
import com.google.appengine.api.taskqueue.QueueFactory;
import com.google.appengine.api.taskqueue.TaskOptions;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;

/**
* Contains all logic to manage and run sharded jobs.
*
* This is a helper class for {@link ShardedJobServiceImpl} that implements
* all the functionality but assumes fixed types for {@code <T>} and
* {@code <R>}.
*
* @author ohler@google.com (Christian Ohler)
*
* @param <T> type of tasks that the job being processed consists of
* @param <R> type of intermediate and final results of the job being processed
*/
class ShardedJobRunner<T extends IncrementalTask<T, R>, R extends Serializable> {

  // High-level overview:
  //
  // A sharded job is started with a given number of tasks, and every task
  // has zero or one follow-up task; so the total number of tasks never
  // increases.  We assign each follow-up task the same taskId as its
  // predecessor and reuse the same datastore entity to store it.  So, each
  // taskId really represents an entire chain of tasks, and the set of such task
  // chains is known at startup.
  //
  // Each task chain is its own entity group to avoid contention.
  //
  // (We could extend the API later to allow more than one follow-up task, but
  // that either leads to datastore contention, or makes finding the set of all
  // task entities harder; so, since we don't need more than one follow-up task
  // for now, we use a single entity for each chain of follow-up tasks.)
  //
  // There is also a single entity (in its own entity group) that holds the
  // overall job state.  It is updated only during initialization and from the
  // controller.
  //
  // Partial results of each task and its chain of follow-up tasks are combined
  // incrementally as the tasks complete.  Partial results across chains are
  // combined only when the job completes, or when getJobState() is called.  (We
  // could have the controller store the overall combined result, but there's a
  // risk that it doesn't fit in a single entity, so we don't.
  // ShardedJobStateImpl has a field for it and the Serializer supports it
  // for completeness, but the value in the datastore is always null.)
  //
  // Worker and controller tasks and entities carry a strictly monotonic
  // "sequence number" that allows each task to detect if its work has already
  // been done (useful in case the task queue runs it twice).  We schedule each
  // task in the same datastore transaction that updates the sequence number in
  // the entity.
  //
  // Each task also checks the job state entity to detect if the job has been
  // aborted or deleted, and terminates if so.
  //
  // We make job startup idempotent by letting the caller specify the job id
  // (rather than generating one randomly), and deriving task ids from it in a
  // deterministic fashion.  This makes it possible to schedule sharded jobs
  // from Pipeline jobs with no danger of scheduling a duplicate sharded job if
  // Pipeline or the task queue runs a job twice.  (For example, a caller could
  // derive the job id for the sharded job from the Pipeline job id.)

  @SuppressWarnings("unused")
  private static final Logger log = Logger.getLogger(ShardedJobRunner.class.getName());

  static final String JOB_ID_PARAM = "job";
  static final String TASK_ID_PARAM = "task";
  static final String SEQUENCE_NUMBER_PARAM = "seq";

  private static final DatastoreService DATASTORE = DatastoreServiceFactory.getDatastoreService();

  ShardedJobRunner() {
  }

  private ShardedJobStateImpl<T, R> lookupJobState(Transaction tx, String jobId) {
    Entity entity;
    try {
      entity = DATASTORE.get(tx, ShardedJobStateImpl.ShardedJobSerializer.makeKey(jobId));
    } catch (EntityNotFoundException e) {
      return null;
    }
    return ShardedJobStateImpl.ShardedJobSerializer.<T, R>fromEntity(entity);
  }

  private IncrementalTaskState<T, R> lookupTaskState(Transaction tx, String taskId) {
    Entity entity;
    try {
      entity = DATASTORE.get(tx, IncrementalTaskState.Serializer.makeKey(taskId));
    } catch (EntityNotFoundException e) {
      return null;
    }
    return IncrementalTaskState.Serializer.<T, R>fromEntity(entity);
  }

  private List<IncrementalTaskState<T, R>> lookupTasks(ShardedJobState jobState) {
    ImmutableList.Builder<Key> b = ImmutableList.builder();
    for (int i = 0; i < jobState.getTotalTaskCount(); i++) {
     b.add(IncrementalTaskState.Serializer.makeKey(getTaskId(jobState.getJobId(), i)));
    }
    List<Key> keys = b.build();
    Map<Key, Entity> entities = DATASTORE.get(keys);
    ImmutableList.Builder<IncrementalTaskState<T, R>> out = ImmutableList.builder();
    // Iterate over keys rather than entities.entrySet() to ensure ordering.
    for (Key key : keys) {
      Entity entity = entities.get(key);
      Preconditions.checkState(entity != null, "%s: Missing task: %s", this, key);
      out.add(IncrementalTaskState.Serializer.<T, R>fromEntity(entity));
    }
    return out.build();
  }

  private int countActiveTasks(List<IncrementalTaskState<T, R>> taskStates) {
    int count = 0;
    for (IncrementalTaskState taskState : taskStates) {
      if (taskState.getNextTask() != null) {
        count++;
      }
    }
    return count;
  }

  private R aggregateState(ShardedJobController<T, R> controller,
      List<IncrementalTaskState<T, R>> taskStates) {
    ImmutableList.Builder<R> results = ImmutableList.builder();
    for (IncrementalTaskState<T, R> taskState : taskStates) {
      results.add(taskState.getPartialResult());
    }
    return controller.combineResults(results.build());
  }

  private void scheduleControllerTask(Transaction tx, ShardedJobStateImpl<T, R> state) {
    ShardedJobSettings settings = state.getSettings();
    TaskOptions taskOptions = TaskOptions.Builder.withMethod(TaskOptions.Method.POST)
        .url(settings.getControllerPath())
        .param(JOB_ID_PARAM, state.getJobId())
        .param(SEQUENCE_NUMBER_PARAM, "" + state.getNextSequenceNumber())
        .countdownMillis(settings.getMillisBetweenPolls());
    if (settings.getControllerBackend() != null) {
      taskOptions.header("Host",
          BackendServiceFactory.getBackendService().getBackendAddress(
              settings.getControllerBackend()));
    }
    QueueFactory.getQueue(settings.getControllerQueueName()).add(tx, taskOptions);
  }

  private void scheduleWorkerTask(Transaction tx,
      ShardedJobSettings settings, IncrementalTaskState<T, R> state) {
    TaskOptions taskOptions = TaskOptions.Builder.withMethod(TaskOptions.Method.POST)
        .url(settings.getWorkerPath())
        .param(TASK_ID_PARAM, state.getTaskId())
        .param(JOB_ID_PARAM, state.getJobId())
        .param(SEQUENCE_NUMBER_PARAM, "" + state.getNextSequenceNumber());
    if (settings.getWorkerBackend() != null) {
      taskOptions.header("Host",
          BackendServiceFactory.getBackendService().getBackendAddress(settings.getWorkerBackend()));
    }
    QueueFactory.getQueue(settings.getWorkerQueueName()).add(tx, taskOptions);
  }

  void pollTaskStates(String jobId, int sequenceNumber) {
    ShardedJobStateImpl<T, R> jobState = lookupJobState(null, jobId);
    if (jobState == null) {
      log.info(jobId + ": Job gone");
      return;
    }
    log.info("Polling task states for job " + jobId + ", sequence number " + sequenceNumber);
    Preconditions.checkState(jobState.getStatus() != Status.INITIALIZING,
        "Should be done initializing: %s", jobState);
    if (!jobState.getStatus().isActive()) {
      log.info(jobId + ": Job no longer active: " + jobState);
      return;
    }
    if (jobState.getNextSequenceNumber() != sequenceNumber) {
      Preconditions.checkState(jobState.getNextSequenceNumber() > sequenceNumber,
          "%s: Job state is from the past: %s", jobId, jobState);
      log.info(jobId + ": Poll sequence number " + sequenceNumber
          + " already completed: " + jobState);
      return;
    }

    long currentPollTimeMillis = System.currentTimeMillis();

    List<IncrementalTaskState<T, R>> taskStates = lookupTasks(jobState);
    int activeTasks = countActiveTasks(taskStates);

    jobState.setMostRecentUpdateTimeMillis(currentPollTimeMillis);
    jobState.setActiveTaskCount(activeTasks);
    if (activeTasks == 0) {
      jobState.setStatus(Status.DONE);
      R aggregateResult = aggregateState(jobState.getController(), taskStates);
      jobState.getController().completed(aggregateResult);
    } else {
      jobState.setNextSequenceNumber(sequenceNumber + 1);
    }

    log.fine(jobId + ": Writing " + jobState);
    Transaction tx = DATASTORE.beginTransaction();
    try {
      ShardedJobStateImpl existing = lookupJobState(tx, jobId);
      if (existing == null) {
        log.info(jobId + ": Job gone after poll");
        return;
      }
      if (existing.getNextSequenceNumber() != sequenceNumber) {
        log.info(jobId + ": Job processed concurrently; was sequence number " + sequenceNumber
            + ", now " + existing.getNextSequenceNumber());
        return;
      }
      DATASTORE.put(tx, ShardedJobStateImpl.ShardedJobSerializer.toEntity(jobState));
      if (jobState.getStatus().isActive()) {
        scheduleControllerTask(tx, jobState);
      }
      tx.commit();
    } finally {
      if (tx.isActive()) {
        tx.rollback();
      }
    }
  }

  void runTask(String taskId, String jobId, int sequenceNumber) {
    ShardedJobState<T, R> jobState = lookupJobState(null, jobId);
    if (jobState == null) {
      log.info(taskId + ": Job gone");
      return;
    }
    if (!jobState.getStatus().isActive()) {
      log.info(taskId + ": Job no longer active: " + jobState);
      return;
    }
    IncrementalTaskState<T, R> taskState = lookupTaskState(null, taskId);
    if (taskState == null) {
      log.info(taskId + ": Task gone");
      return;
    }
    log.info("Running task " + taskId + " (job " + jobId + "), sequence number " + sequenceNumber);
    if (taskState.getNextSequenceNumber() != sequenceNumber) {
      Preconditions.checkState(taskState.getNextSequenceNumber() > sequenceNumber,
          "%s: Task state is from the past: %s", taskId, taskState);
      log.info(taskId + ": Task sequence number " + sequenceNumber
          + " already completed: " + taskState);
      return;
    }
    Preconditions.checkState(taskState.getNextTask() != null, "%s: Next task is null",
        taskState);

    log.fine("About to run task: " + taskState);

    IncrementalTask.RunResult<T, R> result = taskState.getNextTask().run();
    taskState.setPartialResult(
        jobState.getController().combineResults(
            ImmutableList.of(taskState.getPartialResult(), result.getPartialResult())));
    taskState.setNextTask(result.getFollowupTask());
    taskState.setMostRecentUpdateMillis(System.currentTimeMillis());
    taskState.setNextSequenceNumber(sequenceNumber + 1);
    // TOOD: retries.  we should only have one writer, so should have no
    // concurrency exceptions, but we should guard against other RPC failures.
    Transaction tx = DATASTORE.beginTransaction();
    try {
      IncrementalTaskState existing = lookupTaskState(tx, taskId);
      if (existing == null) {
        log.info(taskId + ": Task disappeared while processing");
        return;
      }
      if (existing.getNextSequenceNumber() != sequenceNumber) {
        log.info(taskId + ": Task processed concurrently; was sequence number " + sequenceNumber
            + ", now " + existing.getNextSequenceNumber());
        return;
      }
      DATASTORE.put(tx, IncrementalTaskState.Serializer.toEntity(taskState));
      if (result.getFollowupTask() != null) {
        scheduleWorkerTask(tx, jobState.getSettings(), taskState);
      }
      tx.commit();
    } finally {
      if (tx.isActive()) {
        tx.rollback();
      }
    }
  }

  private String getTaskId(String jobId, int taskNumber) {
    return jobId + "-task" + taskNumber;
  }

  private void createTasks(ShardedJobController<T, R> controller,
      ShardedJobSettings settings, String jobId,
      List<? extends T> initialTasks,
      long startTimeMillis) {
    log.info(jobId + ": Creating " + initialTasks.size() + " tasks");
    // It's tempting to try to do a large batch put and a large batch enqueue,
    // but I don't see a way to make the batch put idempotent; we don't want to
    // clobber entities that record progress that has already been made.
    for (int i = 0; i < initialTasks.size(); i++) {
      String taskId = getTaskId(jobId, i);
      Transaction tx = DATASTORE.beginTransaction();
      try {
        IncrementalTaskState<T, R> existing = lookupTaskState(tx, taskId);
        if (existing != null) {
          log.info(jobId + ": Task already exists: " + existing);
          continue;
        }
        IncrementalTaskState<T, R> taskState =
            new IncrementalTaskState<T, R>(taskId, jobId, startTimeMillis,
                initialTasks.get(i), controller.combineResults(ImmutableList.<R>of()));
        DATASTORE.put(tx, IncrementalTaskState.Serializer.toEntity(taskState));
        scheduleWorkerTask(tx, settings, taskState);
        tx.commit();
      } finally {
        if (tx.isActive()) {
          tx.rollback();
        }
      }
    }
  }

  // Returns true on success, false if the job is already finished according to
  // the datastore.
  private boolean writeInitialJobState(ShardedJobStateImpl<T, R> jobState) {
    String jobId = jobState.getJobId();
    log.fine(jobId + ": Writing initial job state");
    Transaction tx = DATASTORE.beginTransaction();
    try {
      ShardedJobState<T, R> existing = lookupJobState(tx, jobId);
      if (existing == null) {
        DATASTORE.put(tx, ShardedJobStateImpl.ShardedJobSerializer.toEntity(jobState));
        tx.commit();
      } else {
        if (!existing.getStatus().isActive()) {
          // Maybe a concurrent initialization finished before us, and the
          // job was very quick and has already completed as well.
          log.info(jobId + ": Attempt to reinitialize inactive job: " + existing);
          return false;
        }
        log.info(jobId + ": Reinitializing job: " + existing);
      }
      return true;
    } finally {
      if (tx.isActive()) {
        tx.rollback();
      }
    }
  }

  private void scheduleControllerAndMarkActive(ShardedJobStateImpl<T, R> jobState) {
    String jobId = jobState.getJobId();
    log.fine(jobId + ": Scheduling controller and marking active");
    Transaction tx = DATASTORE.beginTransaction();
    try {
      ShardedJobStateImpl<T, R> existing = lookupJobState(tx, jobId);
      if (existing == null) {
        // Someone cleaned it up while we were still initializing.
        log.warning(jobId + ": Job disappeared while initializing");
        return;
      }
      if (existing.getStatus() != Status.INITIALIZING) {
        // Maybe a concurrent initialization finished first, or someone
        // cancelled it while we were initializing.
        log.info(jobId + ": Job changed status while initializing: " + jobState);
        return;
      }
      DATASTORE.put(tx, ShardedJobStateImpl.ShardedJobSerializer.toEntity(jobState));
      scheduleControllerTask(tx, jobState);
      tx.commit();
    } finally {
      if (tx.isActive()) {
        tx.rollback();
      }
    }
    log.info(jobId + ": Started");
  }

  void startJob(String jobId,
      List<? extends T> rawInitialTasks,
      ShardedJobController<T, R> controller,
      ShardedJobSettings settings) {
    long startTimeMillis = System.currentTimeMillis();
    // ImmutableList.copyOf() checks for null elements.
    List<? extends T> initialTasks = ImmutableList.copyOf(rawInitialTasks);
    ShardedJobStateImpl<T, R> jobState = new ShardedJobStateImpl<T, R>(
        jobId, controller, settings, initialTasks.size(), startTimeMillis, Status.INITIALIZING,
        null);
    if (initialTasks.isEmpty()) {
      log.info(jobId + ": No tasks, immediately complete: " + controller);
      jobState.setStatus(Status.DONE);
      DATASTORE.put(ShardedJobStateImpl.ShardedJobSerializer.toEntity(jobState));
      controller.completed(controller.combineResults(ImmutableList.<R>of()));
      return;
    }
    if (!writeInitialJobState(jobState)) {
      return;
    }
    createTasks(controller, settings, jobId, initialTasks, startTimeMillis);
    jobState.setStatus(Status.RUNNING);
    scheduleControllerAndMarkActive(jobState);
  }

  ShardedJobState<T, R> getJobState(String jobId) {
    ShardedJobStateImpl<T, R> jobState = lookupJobState(null, jobId);
    if (jobState == null) {
      return null;
    }
    // We don't pre-aggregate the result across all tasks since it might not fit
    // in one entity.  The stored value is always null.
    Preconditions.checkState(jobState.getAggregateResult() == null,
        "%s: Non-null aggregate result: %s", jobState, jobState.getAggregateResult());
    List<IncrementalTaskState<T, R>> taskStates = lookupTasks(jobState);
    R aggregateResult = aggregateState(jobState.getController(), taskStates);
    jobState.setAggregateResult(aggregateResult);
    return jobState;
  }

  void abortJob(String jobId) {
    log.info(jobId + ": Aborting");
    Transaction tx = DATASTORE.beginTransaction();
    try {
      ShardedJobStateImpl<T, R> jobState = lookupJobState(tx, jobId);
      if (jobState == null || !jobState.getStatus().isActive()) {
        log.info(jobId + ": Job not active, not aborting: " + jobState);
        return;
      }
      jobState.setStatus(Status.ABORTED);
      DATASTORE.put(tx, ShardedJobStateImpl.ShardedJobSerializer.toEntity(jobState));
      tx.commit();
    } finally {
      if (tx.isActive()) {
        tx.rollback();
      }
    }
  }

}
TOP

Related Classes of com.google.appengine.tools.mapreduce.impl.shardedjob.ShardedJobRunner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.