Package com.google.gerrit.sshd

Source Code of com.google.gerrit.sshd.BaseCommand$TaskThunk

// Copyright (C) 2009 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.gerrit.sshd;

import com.google.gerrit.server.CurrentUser;
import com.google.gerrit.server.IdentifiedUser;
import com.google.gerrit.server.RequestCleanup;
import com.google.gerrit.server.git.ProjectRunnable;
import com.google.gerrit.server.git.WorkQueue;
import com.google.gerrit.server.git.WorkQueue.CancelableRunnable;
import com.google.gerrit.server.project.NoSuchProjectException;
import com.google.gerrit.sshd.SshScope.Context;
import com.google.gerrit.util.cli.CmdLineParser;
import com.google.gerrit.util.cli.EndOfOptionsHandler;
import com.google.inject.Inject;
import com.google.inject.Provider;
import org.apache.sshd.common.SshException;
import org.apache.sshd.server.Command;
import org.apache.sshd.server.Environment;
import org.apache.sshd.server.ExitCallback;
import org.kohsuke.args4j.Argument;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.util.concurrent.Future;

public abstract class BaseCommand implements Command {
  private static final Logger log = LoggerFactory.getLogger(BaseCommand.class);
  public static final String ENC = "UTF-8";

  private static final int PRIVATE_STATUS = 1 << 30;
  static final int STATUS_CANCEL = PRIVATE_STATUS | 1;
  static final int STATUS_NOT_FOUND = PRIVATE_STATUS | 2;
  static final int STATUS_NOT_ADMIN = PRIVATE_STATUS | 3;

  @Option(name = "--help", usage = "display this help text", aliases = {"-h"})
  private boolean help;

  @SuppressWarnings("unused")
  @Option(name = "--", usage = "end of options", handler = EndOfOptionsHandler.class)
  private boolean endOfOptions;

  protected InputStream in;
  protected OutputStream out;
  protected OutputStream err;

  private ExitCallback exit;

  @Inject
  private CmdLineParser.Factory cmdLineParserFactory;

  @Inject
  private RequestCleanup cleanup;

  @Inject
  @CommandExecutor
  private WorkQueue.Executor executor;

  @Inject
  private Provider<CurrentUser> userProvider;

  @Inject
  private Provider<SshScope.Context> contextProvider;

  /** The task, as scheduled on a worker thread. */
  private Future<?> task;

  /** Text of the command line which lead up to invoking this instance. */
  private String commandName = "";

  /** Unparsed command line options. */
  private String[] argv;

  public void setInputStream(final InputStream in) {
    this.in = in;
  }

  public void setOutputStream(final OutputStream out) {
    this.out = out;
  }

  public void setErrorStream(final OutputStream err) {
    this.err = err;
  }

  public void setExitCallback(final ExitCallback callback) {
    this.exit = callback;
  }

  void setName(final String prefix) {
    this.commandName = prefix;
  }

  public void setArguments(final String[] argv) {
    this.argv = argv;
  }

  @Override
  public void destroy() {
    if (task != null && !task.isDone()) {
      task.cancel(true);
    }
  }

  /**
   * Pass all state into the command, then run its start method.
   * <p>
   * This method copies all critical state, like the input and output streams,
   * into the supplied command. The caller must still invoke {@code cmd.start()}
   * if wants to pass control to the command.
   *
   * @param cmd the command that will receive the current state.
   */
  protected void provideStateTo(final Command cmd) {
    cmd.setInputStream(in);
    cmd.setOutputStream(out);
    cmd.setErrorStream(err);
    cmd.setExitCallback(exit);
  }

  /**
   * Parses the command line argument, injecting parsed values into fields.
   * <p>
   * This method must be explicitly invoked to cause a parse.
   *
   * @throws UnloggedFailure if the command line arguments were invalid.
   * @see Option
   * @see Argument
   */
  protected void parseCommandLine() throws UnloggedFailure {
    final CmdLineParser clp = newCmdLineParser();
    try {
      clp.parseArgument(argv);
    } catch (IllegalArgumentException err) {
      if (!help) {
        throw new UnloggedFailure(1, "fatal: " + err.getMessage());
      }
    } catch (CmdLineException err) {
      if (!help) {
        throw new UnloggedFailure(1, "fatal: " + err.getMessage());
      }
    }

    if (help) {
      final StringWriter msg = new StringWriter();
      msg.write(commandName);
      clp.printSingleLineUsage(msg, null);
      msg.write('\n');

      msg.write('\n');
      clp.printUsage(msg, null);
      msg.write('\n');
      msg.write(usage());
      throw new UnloggedFailure(1, msg.toString());
    }
  }

  protected String usage() {
    return "";
  }

  /** Construct a new parser for this command's received command line. */
  protected CmdLineParser newCmdLineParser() {
    return cmdLineParserFactory.create(this);
  }

  /**
   * Spawn a function into its own thread.
   * <p>
   * Typically this should be invoked within {@link Command#start(Environment)},
   * such as:
   *
   * <pre>
   * startThread(new Runnable() {
   *   public void run() {
   *     runImp();
   *   }
   * });
   * </pre>
   *
   * @param thunk the runnable to execute on the thread, performing the
   *        command's logic.
   */
  protected void startThread(final Runnable thunk) {
    startThread(new CommandRunnable() {
      @Override
      public void run() throws Exception {
        thunk.run();
      }
    });
  }

  /**
   * Spawn a function into its own thread.
   * <p>
   * Typically this should be invoked within {@link Command#start(Environment)},
   * such as:
   *
   * <pre>
   * startThread(new CommandRunnable() {
   *   public void run() throws Exception {
   *     runImp();
   *   }
   * });
   * </pre>
   * <p>
   * If the function throws an exception, it is translated to a simple message
   * for the client, a non-zero exit code, and the stack trace is logged.
   *
   * @param thunk the runnable to execute on the thread, performing the
   *        command's logic.
   */
  protected synchronized void startThread(final CommandRunnable thunk) {
    final TaskThunk tt = new TaskThunk(thunk);

    if (isAdminCommand()||(isAdminHighPriorityCommand() && userProvider.get().isAdministrator())) {
      // Admin commands should not block the main work threads (there
      // might be an interactive shell there), nor should they wait
      // for the main work threads.
      //
      new Thread(tt, tt.toString()).start();
    } else {
      task = executor.submit(tt);
    }
  }

  private final boolean isAdminCommand() {
    return getClass().getAnnotation(AdminCommand.class) != null;
  }

  private final boolean isAdminHighPriorityCommand() {
    return getClass().getAnnotation(AdminHighPriorityCommand.class) != null;
  }

  /**
   * Terminate this command and return a result code to the remote client.
   * <p>
   * Commands should invoke this at most once. Once invoked, the command may
   * lose access to request based resources as any callbacks previously
   * registered with {@link RequestCleanup} will fire.
   *
   * @param rc exit code for the remote client.
   */
  protected void onExit(final int rc) {
    exit.onExit(rc);
    cleanup.run();
  }

  /** Wrap the supplied output stream in a UTF-8 encoded PrintWriter. */
  protected static PrintWriter toPrintWriter(final OutputStream o) {
    try {
      return new PrintWriter(new BufferedWriter(new OutputStreamWriter(o, ENC)));
    } catch (UnsupportedEncodingException e) {
      // Our default encoding is required by the specifications for the
      // runtime APIs, this should never, ever happen.
      //
      throw new RuntimeException("JVM lacks " + ENC + " encoding", e);
    }
  }

  private int handleError(final Throwable e) {
    if (e.getClass() == IOException.class
        && "Pipe closed".equals(e.getMessage())) {
      // This is sshd telling us the client just dropped off while
      // we were waiting for a read or a write to complete. Either
      // way its not really a fatal error. Don't log it.
      //
      return 127;
    }

    if (e.getClass() == SshException.class
        && "Already closed".equals(e.getMessage())) {
      // This is sshd telling us the client just dropped off while
      // we were waiting for a read or a write to complete. Either
      // way its not really a fatal error. Don't log it.
      //
      return 127;
    }

    if (e instanceof UnloggedFailure) {
    } else {
      final StringBuilder m = new StringBuilder();
      m.append("Internal server error");
      if (userProvider.get() instanceof IdentifiedUser) {
        final IdentifiedUser u = (IdentifiedUser) userProvider.get();
        m.append(" (user ");
        m.append(u.getUserName());
        m.append(")");
      }
      m.append(" during ");
      m.append(contextProvider.get().getCommandLine());
      log.error(m.toString(), e);
    }

    if (e instanceof Failure) {
      final Failure f = (Failure) e;
      try {
        err.write((f.getMessage() + "\n").getBytes(ENC));
        err.flush();
      } catch (IOException e2) {
      } catch (Throwable e2) {
        log.warn("Cannot send failure message to client", e2);
      }
      return f.exitCode;

    } else {
      try {
        err.write("fatal: internal server error\n".getBytes(ENC));
        err.flush();
      } catch (IOException e2) {
      } catch (Throwable e2) {
        log.warn("Cannot send internal server error message to client", e2);
      }
      return 128;
    }
  }

  protected UnloggedFailure die(String msg) {
    return new UnloggedFailure(1, "fatal: " + msg);
  }

  private final class TaskThunk implements CancelableRunnable, ProjectRunnable {
    private final CommandRunnable thunk;
    private final Context context;
    private final String taskName;
    private String projectName;

    private TaskThunk(final CommandRunnable thunk) {
      this.thunk = thunk;
      this.context = contextProvider.get();

      StringBuilder m = new StringBuilder();
      m.append(context.getCommandLine());
      if (userProvider.get() instanceof IdentifiedUser) {
        IdentifiedUser u = (IdentifiedUser) userProvider.get();
        m.append(" (" + u.getUserName() + ")");
      }
      this.taskName = m.toString();
    }

    @Override
    public void cancel() {
      final Context old = SshScope.set(context);
      try {
        onExit(STATUS_CANCEL);
      } finally {
        SshScope.set(old);
      }
    }

    @Override
    public void run() {
      final Thread thisThread = Thread.currentThread();
      final String thisName = thisThread.getName();
      int rc = 0;
      final Context old = SshScope.set(context);
      try {
        context.started = System.currentTimeMillis();
        thisThread.setName("SSH " + taskName);

        if (thunk instanceof ProjectCommandRunnable) {
          ((ProjectCommandRunnable) thunk).executeParseCommand();
          projectName = ((ProjectCommandRunnable) thunk).getProjectName();
        }

        try {
          thunk.run();
        } catch (NoSuchProjectException e) {
          throw new UnloggedFailure(1, e.getMessage() + " no such project");
        }

        out.flush();
        err.flush();
      } catch (Throwable e) {
        try {
          out.flush();
        } catch (Throwable e2) {
        }
        try {
          err.flush();
        } catch (Throwable e2) {
        }
        rc = handleError(e);
      } finally {
        try {
          onExit(rc);
        } finally {
          SshScope.set(old);
          thisThread.setName(thisName);
        }
      }
    }

    @Override
    public String toString() {
      return taskName;
    }

    @Override
    public String getProjectNameKey() {
      return projectName;
    }

    @Override
    public String getRemoteName() {
      return null;
    }

    @Override
    public boolean hasCustomizedPrint() {
      return false;
    }
  }

  /** Runnable function which can throw an exception. */
  public static interface CommandRunnable {
    public void run() throws Exception;
  }

  /** Runnable function which can retrieve a project name related to the task */
  public static interface ProjectCommandRunnable extends CommandRunnable {
    // execute parser command before running, in order to be able to retrieve
    // project name
    public void executeParseCommand() throws Exception;

    public String getProjectName();
  }

  /** Thrown from {@link CommandRunnable#run()} with client message and code. */
  public static class Failure extends Exception {
    private static final long serialVersionUID = 1L;

    final int exitCode;

    /**
     * Create a new failure.
     *
     * @param exitCode exit code to return the client, which indicates the
     *        failure status of this command. Should be between 1 and 255,
     *        inclusive.
     * @param msg message to also send to the client's stderr.
     */
    public Failure(final int exitCode, final String msg) {
      this(exitCode, msg, null);
    }

    /**
     * Create a new failure.
     *
     * @param exitCode exit code to return the client, which indicates the
     *        failure status of this command. Should be between 1 and 255,
     *        inclusive.
     * @param msg message to also send to the client's stderr.
     * @param why stack trace to include in the server's log, but is not sent to
     *        the client's stderr.
     */
    public Failure(final int exitCode, final String msg, final Throwable why) {
      super(msg, why);
      this.exitCode = exitCode;
    }
  }

  /** Thrown from {@link CommandRunnable#run()} with client message and code. */
  public static class UnloggedFailure extends Failure {
    private static final long serialVersionUID = 1L;

    /**
     * Create a new failure.
     *
     * @param exitCode exit code to return the client, which indicates the
     *        failure status of this command. Should be between 1 and 255,
     *        inclusive.
     * @param msg message to also send to the client's stderr.
     */
    public UnloggedFailure(final int exitCode, final String msg) {
      this(exitCode, msg, null);
    }

    /**
     * Create a new failure.
     *
     * @param exitCode exit code to return the client, which indicates the
     *        failure status of this command. Should be between 1 and 255,
     *        inclusive.
     * @param msg message to also send to the client's stderr.
     * @param why stack trace to include in the server's log, but is not sent to
     *        the client's stderr.
     */
    public UnloggedFailure(final int exitCode, final String msg,
        final Throwable why) {
      super(exitCode, msg, why);
    }
  }
}
TOP

Related Classes of com.google.gerrit.sshd.BaseCommand$TaskThunk

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.