Package org.chromium.sdk.internal.websocket

Source Code of org.chromium.sdk.internal.websocket.Hybi17WsConnection

// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

package org.chromium.sdk.internal.websocket;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;

import org.chromium.sdk.ConnectionLogger;
import org.chromium.sdk.internal.websocket.ManualLoggingSocketWrapper.LoggableInput;
import org.chromium.sdk.internal.websocket.ManualLoggingSocketWrapper.LoggableOutput;
import org.chromium.sdk.util.BasicUtil;

/**
* WebSocket connection. Sends and receives messages. Implements HyBi-17 protocol specification.
* @see http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
*/
public class Hybi17WsConnection extends AbstractWsConnection<LoggableInput, LoggableOutput> {
  private static final Logger LOGGER = Logger.getLogger(Hybi17WsConnection.class.getName());
  private static final Random RANDOM = new Random();

  /**
   * Specifies how outgoing frames get masked. While protocol specification requires that every
   * outgoing frame must be masked (to disable provocative content that socket client may send),
   * this doesn't really make sense when the client is trusted. On the other hand, transparent mask
   * makes debug sniffering easier.
   */
  public enum MaskStrategy {
    /**
     * Directs to use no mask at all. This is explicitly against protocol specification, peer
     * is expected to terminate connection in response.
     */
    NO_MASK() {
      @Override public byte[] generate() {
        return null;
      }

      @Override
      ManualLoggingSocketWrapper.FactoryBase getLogWrapperFactory() {
        return ManualLoggingSocketWrapper.PLAIN_ASCII;
      }
    },
    /**
     * Directs to always use transparent mask (i.e. all zeroes). This makes all frames clear-text.
     * Not suitable when untrusted client uses the WebSocket.
     */
    TRANSPARENT_MASK() {
      private final byte[] bytes = new byte[4];

      @Override public byte[] generate() {
        return bytes;
      }

      @Override
      ManualLoggingSocketWrapper.FactoryBase getLogWrapperFactory() {
        return ManualLoggingSocketWrapper.PLAIN_ASCII;
      }
    },
    /**
     * Directs to use randomly generated masks as specified by specification. As a by-product makes
     * traffic hard to sniff.
     */
    NORMAL_MASK() {
      @Override
      byte[] generate() {
        byte[] result = new byte[4];
        RANDOM.nextBytes(result);
        return result;
      }

      @Override
      ManualLoggingSocketWrapper.FactoryBase getLogWrapperFactory() {
        return ManualLoggingSocketWrapper.ANNOTATED;
      }
    };

    /** @return 4-byte array or null */
    abstract byte[] generate();

    abstract ManualLoggingSocketWrapper.FactoryBase getLogWrapperFactory();
  }

  public static Hybi17WsConnection connect(InetSocketAddress endpoint, int timeout,
      String resourceId, MaskStrategy maskStrategy, ConnectionLogger connectionLogger)
      throws IOException {
    ManualLoggingSocketWrapper socketWrapper = new ManualLoggingSocketWrapper(endpoint, timeout,
        connectionLogger, maskStrategy.getLogWrapperFactory());

    boolean handshakeDone = false;
    Exception handshakeException = null;
    try {
      performHandshakeOrFail(socketWrapper, endpoint, resourceId);
      handshakeDone = true;
    } catch (RuntimeException e) {
      handshakeException = e;
      throw e;
    } catch (IOException e) {
      handshakeException = e;
      throw e;
    } finally {
      if (!handshakeDone) {
        socketWrapper.getShutdownRelay().sendSignal(null, handshakeException);
      }
    }

    return new Hybi17WsConnection(socketWrapper, maskStrategy, connectionLogger);
  }

  private final MaskStrategy maskStrategy;

  private Hybi17WsConnection(ManualLoggingSocketWrapper socketWrapper, MaskStrategy maskStrategy,
      ConnectionLogger connectionLogger) {
    super(socketWrapper, connectionLogger);
    this.maskStrategy = maskStrategy;
  }

  @Override
  public void sendTextualMessage(final String message) throws IOException {
    final byte[] bytes = message.getBytes(UTF_8_CHARSET);

    LoggablePayload payload = new LoggablePayload() {
      @Override void send(LoggableOutput output, byte[] maskBytes) throws IOException {
        output.writeToLog(message, "utf-8 demasked");
        if (maskBytes != null) {
          for (int i = 0; i < bytes.length; i++) {
            bytes[i] = (byte) (bytes[i] ^ maskBytes[i % 4]);
          }
        }
        output.writeBytesNoLogging(bytes);
      }
      @Override int getLength() {
        return bytes.length;
      }
    };

    sendMessage(OpCode.TEXT, payload, false);
  }

  @Override
  protected CloseReason runListenLoop(LoggableInput loggableReader)
      throws IOException, InterruptedException {
    try {
      return runListenLoopImpl(loggableReader);
    } catch (IOException e) {
      String stackTrace = BasicUtil.getStacktraceString(e);
      sendClosingMessage(StatusCode.PROTOCOL_ERROR, stackTrace);
      throw new IOException(e);
    } catch (IncomingProtocolException e) {
      String stackTrace = BasicUtil.getStacktraceString(e);
      sendClosingMessage(e.getStatusCode(), stackTrace);
      throw new IOException(e);
    }
  }

  private CloseReason runListenLoopImpl(LoggableInput loggableReader)
      throws IOException, InterruptedException, IncomingProtocolException {
    while (true) {
      loggableReader.markSeparatorForLog();
      int firstByte;
      try {
        firstByte = loggableReader.readByteOrEos();
      } catch (IOException e) {
        if (isClosingGracefully()) {
          return CloseReason.USER_REQUEST;
        } else {
          throw e;
        }
      }
      if (firstByte == -1) {
        if (isClosingGracefully()) {
          return CloseReason.USER_REQUEST;
        } else {
          throw new IOException("Unexpected end of stream");
        }
      }

      if ((firstByte & FrameBits.FIN_BIT) == 0) {
        throw new IncomingProtocolException("Fragments unsupported",
            StatusCode.CANNOT_ACCEPT, null);
      }
      if ((firstByte & FrameBits.RESERVED_MASK) != 0) {
        throw new IncomingProtocolException("Unexpected reserved bits",
            StatusCode.PROTOCOL_ERROR, null);
      }

      int opcode = firstByte & FrameBits.OPCODE_MASK;

      IncomingFrameHandler frameHandler;

      switch (opcode) {
      case OpCode.CONTINUATION:
        throw new IncomingProtocolException("Continuation is not supported",
            StatusCode.CANNOT_ACCEPT, null);
      case OpCode.TEXT:
        frameHandler = IncomingFrameHandler.TEXT_MESSAGE;
        break;
      case OpCode.BINARY:
        throw new IncomingProtocolException("Binary is not supported",
            StatusCode.CANNOT_ACCEPT, null);
      case OpCode.CLOSE:
        sendClosingMessage(StatusCode.NORMAL, null);
        return CloseReason.REMOTE_CLOSE_REQUEST;
      case OpCode.PING:
        frameHandler = IncomingFrameHandler.PING;
        break;
      case OpCode.PONG:
        frameHandler = IncomingFrameHandler.PONG;
        break;
      default:
        throw new IncomingProtocolException("Unsupported opcode " + opcode,
            StatusCode.CANNOT_ACCEPT, null);
      }

      int secondByte = readByteOfFail(loggableReader);

      boolean hasMask = (secondByte & FrameBits.MASK_BIT) != 0;

      if (hasMask) {
        throw new IncomingProtocolException("Masked server-to-client message is not supported",
            StatusCode.PROTOCOL_ERROR, null);
      }

      int payloadLenByte = secondByte & FrameBits.LENGTH_MASK;
      int payloadLen;
      if (payloadLenByte == FrameBits.LENGTH_2_BYTE_CODE) {
        int lengthTemp = readByteOfFail(loggableReader);
        lengthTemp <<= 8;
        lengthTemp += readByteOfFail(loggableReader);
        payloadLen = lengthTemp;
      } else if (payloadLenByte == FrameBits.LENGTH_8_BYTE_CODE) {
        for (int i = 0; i < 4; i++) {
          int b = readByteOfFail(loggableReader);
          if (b != 0) {
            throw new IncomingProtocolException("Payload length is too large",
                StatusCode.CANNOT_ACCEPT, null);
          }
        }
        int lengthTemp = readByteOfFail(loggableReader);
        if ((lengthTemp & FrameBits.HIGH_BIT) != 0) {
          throw new IncomingProtocolException("Payload length is too large",
              StatusCode.CANNOT_ACCEPT, null);
        }
        for (int i = 0; i < 3; i++) {
          lengthTemp <<= 8;
          lengthTemp += readByteOfFail(loggableReader);
        }
        payloadLen = lengthTemp;
      } else {
        payloadLen = payloadLenByte;
      }

      byte [] bytes = loggableReader.readBytes(payloadLen);
      frameHandler.process(bytes, this);
    }
  }

  private static class IncomingProtocolException extends Exception {
    private final int statusCode;

    private IncomingProtocolException(String message, int statusCode, Throwable cause) {
      super(message, cause);
      this.statusCode = statusCode;
    }

    int getStatusCode() {
      return statusCode;
    }
  }

  private static abstract class IncomingFrameHandler {
    abstract void process(byte[] bytes, Hybi17WsConnection hybiWsConnection);

    static final IncomingFrameHandler TEXT_MESSAGE = new IncomingFrameHandler() {
      @Override
      void process(byte[] bytes, Hybi17WsConnection hybiWsConnection) {
        final String text = new String(bytes, UTF_8_CHARSET);
        hybiWsConnection.getDispatchQueue().add(new MessageDispatcher() {
          @Override
          boolean dispatch(Listener userListener) {
            userListener.textMessageRecieved(text);
            return false;
          }
        });
      }
    };

    static final IncomingFrameHandler PING = new IncomingFrameHandler() {
      @Override
      void process(final byte[] bytes, Hybi17WsConnection hybiWsConnection) {
        LoggablePayload payload = new LoggablePayload() {
          @Override
          void send(LoggableOutput output, byte[] maskBytes) throws IOException {
            output.writeBytesToLog(bytes);
            if (maskBytes != null) {
              for (int i = 0; i < bytes.length; i++) {
                bytes[i] = (byte) (bytes[i] ^ maskBytes[i % 4]);
              }
            }
            output.writeBytes(bytes);
            output.markSeparatorForLog();
          }
          @Override int getLength() {
            return bytes.length;
          }
        };
        try {
          // Should we do in this thread or relay it to Dispatch thread?
          hybiWsConnection.sendMessage(OpCode.PONG, payload, false);
        } catch (IOException e) {
          LOGGER.log(Level.WARNING, "Failed to send pong", e);
        }
      }
    };

    static final IncomingFrameHandler PONG = new IncomingFrameHandler() {
      @Override
      void process(byte[] bytes, Hybi17WsConnection hybiWsConnection) {
        // Ignore
      }
    };
  }

  /**
   * Payload that can send and properly log itself. Good logging requires that the body
   * is not masked.
   */
  private static abstract class LoggablePayload {
    abstract void send(LoggableOutput output, byte[] maskBytes) throws IOException;
    abstract int getLength();
  }

  private void sendClosingMessage(final int statusCode, final String message) throws IOException {
    final byte[] bytes;
    if (message == null) {
      bytes = new byte[0];
    } else {
      bytes = message.getBytes(UTF_8_CHARSET);
    }

    LoggablePayload payload = new LoggablePayload() {
      @Override
      void send(LoggableOutput output, byte[] maskBytes) throws IOException {
        byte codeByte1 = (byte) ((statusCode >> 8) & 0xFF);
        byte codeByte2 = (byte) (statusCode & 0xFF);

        byte codeByteMasked1 = codeByte1;
        byte codeByteMasked2 = codeByte2;
        if (maskBytes != null) {
          codeByteMasked1 ^= maskBytes[0];
          codeByteMasked2 ^= maskBytes[1];

          for (int i = 0; i < bytes.length; i++) {
            bytes[i] = (byte) (bytes[i] ^ maskBytes[(i + STATUS_CODE_LENTGH) % 4]);
          }
        }
        output.writeByteNoLogging(codeByteMasked1);
        output.writeByteNoLogging(codeByteMasked2);

        output.writeByteToLog(codeByte1);
        output.writeByteToLog(codeByte2);

        output.writeBytesNoLogging(bytes);
        output.writeToLog(message, "utf-8 demasked");
      }

      @Override int getLength() {
        return STATUS_CODE_LENTGH + bytes.length;
      }
    };

    sendMessage(OpCode.CLOSE, payload, true);
  }

  private void sendMessage(int opCode, LoggablePayload loggablePayload, boolean isClosingMessage)
      throws IOException {
    int length = loggablePayload.getLength();
    LoggableOutput output = getSocketWrapper().getLoggableOutput();

    byte[] maskBytes = maskStrategy.generate();

    synchronized (this) {
      if (isOutputClosed()) {
        throw new IOException("WebSocket is already closed for output");
      }

      byte firstByte = (byte) (FrameBits.FIN_BIT | OpCode.TEXT);

      output.writeByte(firstByte);

      int maskFlag = maskBytes == null ? 0 : FrameBits.MASK_BIT;

      if (length <= 125) {
        output.writeByte((byte) (length | maskFlag));
      } else if (length <= FrameBits.MAX_TWO_BYTE_INT) {
        output.writeByte((byte) (FrameBits.LENGTH_2_BYTE_CODE | maskFlag));
        output.writeByte((byte) ((length >> 8) & 0xFF));
        output.writeByte((byte) (length & 0xFF));
      } else {
        output.writeByte((byte) (FrameBits.LENGTH_8_BYTE_CODE | maskFlag));
        output.writeByte((byte) 0);
        output.writeByte((byte) 0);
        output.writeByte((byte) 0);
        output.writeByte((byte) 0);
        output.writeByte((byte) (length >>> 24));
        output.writeByte((byte) ((length >> 16) & 0xFF));
        output.writeByte((byte) ((length >> 8) & 0xFF));
        output.writeByte((byte) (length & 0xFF));
      }

      if (maskBytes != null) {
        output.writeBytes(maskBytes);
      }
      loggablePayload.send(output, maskBytes);

      if (isClosingMessage) {
        setOutputClosed(true);
      }
    }

    output.markSeparatorForLog();
  }

  private static void performHandshakeOrFail(ManualLoggingSocketWrapper socket,
      InetSocketAddress endpoint, String resourceId) throws IOException {
    Hybi17Handshake.Result result =
        Hybi17Handshake.performHandshake(socket, endpoint, resourceId, RANDOM);
    result.accept(HANDSHAKE_RESULT_VISITOR).get();
  }

  private static final Hybi17Handshake.Result.Visitor<DataOrException<Void>>
      HANDSHAKE_RESULT_VISITOR =
      new Hybi17Handshake.Result.Visitor<DataOrException<Void>>() {
        @Override
        public DataOrException<Void> visitConnected() {
          return new DataOrException<Void>() {
            @Override Void get() throws IOException {
              return null;
            }
          };
        }

        @Override
        public DataOrException<Void> visitUnknownError(final Exception exception) {
          return new DataOrException<Void>() {
            @Override Void get() throws IOException {
              throw new IOException("Failed to establish WebSocket connection", exception);
            }
          };
        }

        @Override
        public DataOrException<Void> visitErrorMessage(final int code,
            final String errorName, final String text) {
          return new DataOrException<Void>() {
            @Override Void get() throws IOException {
              throw new IOException("Failed to establish WebSocket connection: " + code + " " +
                  errorName + " | " + text);
            }
          };
        }
      };

  /**
   * This class is used solely to put IOException through Visitor.
   */
  private static abstract class DataOrException<T> {
    abstract T get() throws IOException;
  }

  private static int readByteOfFail(LoggableInput loggableReader) throws IOException {
    int b = loggableReader.readByteOrEos();
    if (b == -1) {
      throw new IOException("Unexpected EOS");
    }
    return b;
  }

  private interface FrameBits {
    // First byte bits.
    int FIN_BIT = 1 << 7;
    int MASK_BIT = 1 << 7;

    // Second byte bits.
    int OPCODE_LENGTH = 4;
    int OPCODE_MASK = (1 << OPCODE_LENGTH) - 1;
    int RESERVED_MASK = ((1 << 3) - 1) << OPCODE_LENGTH ;

    int LENGTH_MASK = (1 << 7) - 1;
    int LENGTH_2_BYTE_CODE = 126;
    int LENGTH_8_BYTE_CODE = 127;

    // Length bytes.
    int HIGH_BIT = 1 << 7;
    int MAX_TWO_BYTE_INT = 1 << 16 - 1;
  }

  private interface OpCode {
    int CONTINUATION = 0x0;
    int TEXT = 0x1;
    int BINARY = 0x2;
    int CLOSE = 0x8;
    int PING = 0x9;
    int PONG = 0xA;
  }

  private interface StatusCode {
    int NORMAL = 1000;
    int PROTOCOL_ERROR = 1002;
    int CANNOT_ACCEPT = 1003;
  }

  private static final int STATUS_CODE_LENTGH = 2;
}
TOP

Related Classes of org.chromium.sdk.internal.websocket.Hybi17WsConnection

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.