Package org.springframework.integration.x.ip.websocket

Source Code of org.springframework.integration.x.ip.websocket.WebSocketSerializer

/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.x.ip.websocket;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.commons.codec.binary.Base64;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.core.serializer.Serializer;
import org.springframework.integration.MessagingException;
import org.springframework.integration.ip.tcp.serializer.SoftEndOfStreamException;
import org.springframework.integration.x.ip.serializer.AbstractHttpSwitchingDeserializer;
import org.springframework.integration.x.ip.serializer.DataFrame;
import org.springframework.util.Assert;

/**
* @author Gary Russell
* @since 3.0
*
*/
public class WebSocketSerializer extends AbstractHttpSwitchingDeserializer implements Serializer<Object> {

  private static final String HTTP_1_1_101_WEB_SOCKET_PROTOCOL_HANDSHAKE_SPRING_INTEGRATION =
      "HTTP/1.1 101 Web Socket Protocol Handshake - Spring Integration\r\n";

  private static final Set<Short> INVALID_STATUS = new HashSet<Short>(
    Arrays.asList((short) 1004, (short) 1005, (short) 1006, (short) 1012, (short) 1013, (short) 1014, (short) 1015));

  private volatile boolean server;

  private boolean validateUtf8;

  private volatile Boolean streamChecked;

  private volatile boolean nio;

  private volatile DirectFieldAccessor streamAccessor;

  public void setServer(boolean server) {
    this.server = server;
  }

  /**
   * Validate UTF-8 (required for Autobahn tests).
   * @param validateUtf8
   */
  public void setValidateUtf8(boolean validateUtf8) {
    this.validateUtf8 = validateUtf8;
  }

  @Override
  protected DataFrame createDataFrame(int type, String frameData) {
    return new WebSocketFrame(type, frameData);
  }

  @Override
  protected BasicState createState() {
    return new WebSocketState();
  }

  @Override
  public void serialize(final Object frame, OutputStream outputStream)
      throws IOException {
    String data = "";
    WebSocketFrame theFrame = null;
    if (frame instanceof String) {
      data = (String) frame;
      theFrame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
    }
    else if (frame instanceof WebSocketFrame) {
      theFrame = (WebSocketFrame) frame;
      data = theFrame.getPayload();
    }
    if (data != null && data.startsWith("HTTP/1.1")) {
      outputStream.write(data.getBytes());
      return;
    }
    int lenBytes;
    int payloadLen = this.server ? 0 : 0x80; //masked
    boolean close = theFrame.getType() == WebSocketFrame.TYPE_CLOSE;
    boolean ping = theFrame.getType() == WebSocketFrame.TYPE_PING;
    boolean pong = theFrame.getType() == WebSocketFrame.TYPE_PONG;
    byte[] bytes = theFrame.getBinary() != null ? theFrame.getBinary() : data.getBytes("UTF-8");

    int length = bytes.length;
    if (close) {
      length += 2;
    }
    if (length >= Math.pow(2, 16)) {
      lenBytes = 8;
      payloadLen |= 127;
    }
    else if (length > 125) {
      lenBytes = 2;
      payloadLen |= 126;
    }
    else {
      lenBytes = 0;
      payloadLen |= length;
    }
    int mask = (int) System.currentTimeMillis();
    ByteBuffer buffer = ByteBuffer.allocate(length + 6 + lenBytes);
    if (ping) {
      buffer.put((byte) 0x89);
    }
    else if (pong) {
      buffer.put((byte) 0x8a);
    }
    else if (close) {
      buffer.put((byte) 0x88);
    }
    else if (theFrame.getType() == WebSocketFrame.TYPE_DATA_BINARY) {
      buffer.put((byte) 0x82);
    }
    else {
      // Final fragment; text
      buffer.put((byte) 0x81);
    }
    buffer.put((byte) payloadLen);
    if (lenBytes == 2) {
      buffer.putShort((short) length);
    }
    else if (lenBytes == 8) {
      buffer.putLong(length);
    }

    byte[] maskBytes = new byte[4];
    if (!server) {
      buffer.putInt(mask);
      buffer.position(buffer.position() - 4);
      buffer.get(maskBytes);
    }
    if (close) {
      buffer.putShort(theFrame.getStatus());
      // TODO: mask status when client
    }
    for (int i = 0; i < bytes.length; i++) {
      if (server) {
        buffer.put(bytes[i]);
      }
      else {
        buffer.put((byte) (bytes[i] ^ maskBytes[i % 4]));
      }
    }
    outputStream.write(buffer.array(), 0, buffer.position());
  }

  @Override
  public DataFrame deserialize(InputStream inputStream) throws IOException {
    if (this.streamChecked == null) {
      this.nio = inputStream.getClass().getName().endsWith("TcpNioConnection$ChannelInputStream");
      this.streamAccessor = new DirectFieldAccessor(inputStream);
      this.streamChecked = Boolean.TRUE;
    }
    DataFrame frame = null;
    BasicState state = this.getState(inputStream);
    if (state != null) {
      frame = state.getPendingFrame();
    }
    while (frame == null || (frame.getPayload() == null && frame.getBinary() == null)) {
      frame = doDeserialize(inputStream, frame);
      if (frame.getPayload() == null && frame.getBinary() == null) {
        state.setPendingFrame(frame);
      }
    }
    return frame;
  }

  private DataFrame doDeserialize(InputStream inputStream, DataFrame protoFrame) throws IOException {
    List<DataFrame> headers = checkStreaming(inputStream);
    if (headers != null) {
      return headers.get(0);
    }
    int bite;
    if (logger.isDebugEnabled()) {
      logger.debug("Available to read:" + inputStream.available());
    }
    boolean done = false;
    int len = 0;
    int n = 0;
    int dataInx = 0;
    byte[] buffer = null;
    boolean fin = false;
    boolean ping = false;
    boolean pong = false;
    boolean close = false;
    boolean binary = false;
    boolean invalid = false;
    String invalidText = null;
    boolean fragmentedControl = false;
    int lenBytes = 0;
    byte[] mask = new byte[4];
    int maskInx = 0;
    int rsv = 0;
    while (!done ) {
      bite = inputStream.read();
//      logger.debug("Read:" + Integer.toHexString(bite));
      if (this.nio) {
        bite = checkclosed(bite, inputStream);
      }
      if (bite < 0 && n == 0) {
        throw new SoftEndOfStreamException("Stream closed between payloads");
      }
      checkClosure(bite);
      switch (n++) {
      case 0:
        fin = (bite & 0x80) > 0;
        rsv = (bite & 0x70) >> 4;
        bite &= 0x0f;
        switch (bite) {
        case 0x00:
          logger.debug("Continuation, fin=" + fin);
          if (protoFrame == null) {
            invalid = true;
            invalidText = "Unexpected continuation frame";
          }
          else {
            binary = protoFrame.getType() == WebSocketFrame.TYPE_DATA_BINARY;
          }
          this.getState(inputStream).setPendingFrame(null);
          break;
        case 0x01:
          logger.debug("Text, fin=" + fin);
          if (protoFrame != null) {
            invalid = true;
            invalidText = "Expected continuation frame";
          }
          break;
        case 0x02:
          logger.debug("Binary, fin=" + fin);
          if (protoFrame != null) {
            invalid = true;
            invalidText = "Expected continuation frame";
          }
          binary = true;
          break;
        case 0x08:
          logger.debug("Close, fin=" + fin);
          fragmentedControl = !fin;
          close = true;
          break;
        case 0x09:
          ping = true;
          binary = true;
          fragmentedControl = !fin;
          logger.debug("Ping, fin=" + fin);
          break;
        case 0x0a:
          pong = true;
          fragmentedControl = !fin;
          logger.debug("Pong, fin=" + fin);
          break;
        case 0x03:
        case 0x04:
        case 0x05:
        case 0x06:
        case 0x07:
        case 0x0b:
        case 0x0c:
        case 0x0d:
        case 0x0e:
        case 0x0f:
          invalid = true;
          invalidText = "Reserved opcode " + Integer.toHexString(bite);
          break;
        default:
          throw new IOException("Unexpected opcode " + Integer.toHexString(bite));
        }
        break;
      case 1:
        if (this.server) {
          if ((bite & 0x80) == 0) {
            throw new IOException("Illegal: Expected masked data from client");
          }
          bite &= 0x7f;
        }
        if ((bite & 0x80) > 0) {
          throw new IOException("Illegal: Received masked data from server");
        }
        if (bite < 126) {
          len = bite;
          buffer = new byte[len];
        }
        else if (bite == 126) {
          lenBytes = 2;
        }
        else {
          lenBytes = 8;
        }
        break;
      case 2:
      case 3:
      case 4:
      case 5:
        if (lenBytes > 4 && bite != 0) {
          throw new IOException("Max supported length exceeded");
        }
      case 6:
        if (lenBytes > 3 && (bite & 0x80) > 0) {
          throw new IOException("Max supported length exceeded");
        }
      case 7:
      case 8:
      case 9:
        if (lenBytes-- > 0) {
          len = len << 8 | (bite & 0xff);
          if (lenBytes == 0) {
            buffer = new byte[len];
          }
          break;
        }
      default:
        if (this.server && maskInx < 4) {
          mask[maskInx++] = (byte) bite;
        }
        else {
          if (this.server) {
            bite ^= mask[dataInx % 4];
          }
          buffer[dataInx++] = (byte) bite;
        }
        done = (server ? maskInx == 4 : true) && dataInx >= len;
      }
    };

    WebSocketFrame frame;

    if (fragmentedControl) {
      frame = new WebSocketFrame(WebSocketFrame.TYPE_FRAGMENTED_CONTROL, "Fragmented control frame", buffer);
    }
    else if (invalid) {
      frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID, invalidText, buffer);
    }
    else if (!fin) {
      List<byte[]> fragments = this.getState(inputStream).getFragments();
      fragments.add(buffer);
      logger.debug("Fragment");
      return new WebSocketFrame(binary ? WebSocketFrame.TYPE_DATA_BINARY : WebSocketFrame.TYPE_DATA, (String) null);
    }
    else if (ping) {
      frame = new WebSocketFrame(WebSocketFrame.TYPE_PING, buffer);
    }
    else if (pong) {
      String data = new String(buffer, "UTF-8");
      frame = new WebSocketFrame(WebSocketFrame.TYPE_PONG, data);
    }
    else if (close) {
      String data = new String(buffer, "UTF-8");
      if (data.length() >= 2) {
        data = data.substring(2);
      }
      WebSocketFrame closeFrame = new WebSocketFrame(WebSocketFrame.TYPE_CLOSE, data);
      short status = 1000;
      if (buffer.length >= 2) {
        status = (short) ((buffer[0] << 8) | (buffer[1] & 0xff));
        closeFrame.setStatus(status);
      }
      if (buffer.length == 1 || buffer.length > 125 ||
          (buffer.length > 2 && !validateUtf8IfNecessary(buffer, 2, data)) ||
          status < 1000 || INVALID_STATUS.contains(status) || (status >= 1016 && status < 3000) || status >= 5000) {
        // Simply close in this case; no close reply
        ((WebSocketState) this.getState(inputStream)).setCloseInitiated(true);
      }
      frame = closeFrame;
    }
    else {
      List<byte[]> fragments = this.getState(inputStream).getFragments();
      if (fragments.size() == 0) {
        if (binary) {
          frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA_BINARY, buffer);
        }
        else {
          String data = new String(buffer, "UTF-8");
          if (!validateUtf8IfNecessary(buffer, 0, data)) {
            frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID_UTF8, "Invalid UTF-8", buffer);
          }
          else {
            frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
          }
        }
      }
      else {
        fragments.add(buffer);
        int utf8Len = 0;
        for (byte[] fragment : fragments) {
          utf8Len += fragment.length;
        }
        byte[] reconstructed = new byte[utf8Len];
        int utf8Pos = 0;
        for (byte[] fragment : fragments) {
          System.arraycopy(fragment, 0, reconstructed, utf8Pos, fragment.length);
          utf8Pos += fragment.length;
        }
        fragments.clear();
        if (binary) {
          frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA_BINARY, reconstructed);
        }
        else {
          String data = new String(reconstructed, "UTF-8");
          if (!validateUtf8IfNecessary(reconstructed, 0, data)) {
            frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID_UTF8, "Invalid UTF-8", reconstructed);
          }
          else {
            frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
          }
        }
      }
    }
    if (rsv > 0) {
      frame.setRsv(rsv);
    }
    return frame;
  }

  /**
   * TODO: workaround for INT-2936
   */
  private int checkclosed(int bite, InputStream inputStream) {
    if (bite < 0) { // possibly a closed stream
      try {
        if ((Boolean) streamAccessor.getPropertyValue("isClosed") &&
            inputStream.available() == 0) {
          return -1;
        }
        else {
          return bite & 0xff;
        }
      }
      catch (Exception e) {
        if (logger.isDebugEnabled()) {
          logger.debug("Failed to check closed", e);
        }
        return bite;
      }
    }
    else {
      return bite;
    }
  }

  private boolean validateUtf8IfNecessary(byte[] buffer, int offset, String data) {
    if (this.validateUtf8) {
      try {
        byte[] bytes = data.getBytes("UTF-8");
        if (bytes.length != buffer.length - offset) {
          return false;
        }
        for (int i = 0; i < bytes.length; i++) {
          if (buffer[i + offset] != bytes[i]) {
            return false;
          }
        }
      }
      catch (UnsupportedEncodingException e) {
        throw new MessagingException("UTF-8 Conversion error");
      }
    }
    return true;
  }

  @Override
  protected void checkClosure(int bite) throws IOException {
    if (bite < 0) {
      logger.debug("Socket closed during message assembly");
      throw new IOException("Socket closed during message assembly");
    }
  }

  @Override
  public void removeState(Object inputStream) {
    super.removeState(inputStream);
  }

  public WebSocketFrame generateHandshake(WebSocketFrame frame) throws Exception {
    Assert.isTrue(frame.getType() == WebSocketFrame.TYPE_HEADERS, "Expected headers:" + frame);
    String[] headers = frame.getPayload().split("\\r\\n");
    String key = null;
    String version = null;
    for (String header : headers) {
      if (header.toLowerCase().startsWith("sec-websocket-key")) {
        key = header.split(":")[1].trim();
      }
      else if (header.toLowerCase().startsWith("sec-websocket-version")) {
        version = header.split(":")[1].trim();
      }
    }
    if (key == null) {
      throw new WebSocketUpgradeException("400 Bad Request: No sec-websocket-key header detected");
    }
    else if (!"13".equals(version)) {
      throw new WebSocketUpgradeException("426 Upgrade Required", "sec-websocket-version: 13\r\n");
    }
    String handshake = HTTP_1_1_101_WEB_SOCKET_PROTOCOL_HANDSHAKE_SPRING_INTEGRATION +
               "Upgrade: WebSocket\r\n" +
               "Connection: Upgrade\r\n" +
               "Sec-WebSocket-Accept: " + this.generateWebSocketAccept(key) + "\r\n\r\n";
    return new WebSocketFrame(WebSocketFrame.TYPE_DATA, handshake);
  }

  private String generateWebSocketAccept(String key) throws NoSuchAlgorithmException  {
    MessageDigest md = MessageDigest.getInstance("SHA-1");
    String toDigest = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    byte[] acceptStringBytes  = md.digest(toDigest.getBytes());
    acceptStringBytes = Base64.encodeBase64(acceptStringBytes);
    String acceptString = new String(acceptStringBytes);
    return acceptString;
  }

  public static class WebSocketState extends BasicState {

    private volatile boolean closeInitiated;

    private volatile boolean expectingPong;

    public boolean isCloseInitiated() {
      return this.closeInitiated;
    }

    public void setCloseInitiated(boolean closeInitiated) {
      this.closeInitiated = closeInitiated;
    }

    public boolean isExpectingPong() {
      return this.expectingPong;
    }

    public void setExpectingPong(boolean expectingPong) {
      this.expectingPong = expectingPong;
    }

  }
}
TOP

Related Classes of org.springframework.integration.x.ip.websocket.WebSocketSerializer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.