/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.integration.x.ip.websocket;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.codec.binary.Base64;
import org.springframework.beans.DirectFieldAccessor;
import org.springframework.core.serializer.Serializer;
import org.springframework.integration.MessagingException;
import org.springframework.integration.ip.tcp.serializer.SoftEndOfStreamException;
import org.springframework.integration.x.ip.serializer.AbstractHttpSwitchingDeserializer;
import org.springframework.integration.x.ip.serializer.DataFrame;
import org.springframework.util.Assert;
/**
* @author Gary Russell
* @since 3.0
*
*/
public class WebSocketSerializer extends AbstractHttpSwitchingDeserializer implements Serializer<Object> {
private static final String HTTP_1_1_101_WEB_SOCKET_PROTOCOL_HANDSHAKE_SPRING_INTEGRATION =
"HTTP/1.1 101 Web Socket Protocol Handshake - Spring Integration\r\n";
private static final Set<Short> INVALID_STATUS = new HashSet<Short>(
Arrays.asList((short) 1004, (short) 1005, (short) 1006, (short) 1012, (short) 1013, (short) 1014, (short) 1015));
private volatile boolean server;
private boolean validateUtf8;
private volatile Boolean streamChecked;
private volatile boolean nio;
private volatile DirectFieldAccessor streamAccessor;
public void setServer(boolean server) {
this.server = server;
}
/**
* Validate UTF-8 (required for Autobahn tests).
* @param validateUtf8
*/
public void setValidateUtf8(boolean validateUtf8) {
this.validateUtf8 = validateUtf8;
}
@Override
protected DataFrame createDataFrame(int type, String frameData) {
return new WebSocketFrame(type, frameData);
}
@Override
protected BasicState createState() {
return new WebSocketState();
}
@Override
public void serialize(final Object frame, OutputStream outputStream)
throws IOException {
String data = "";
WebSocketFrame theFrame = null;
if (frame instanceof String) {
data = (String) frame;
theFrame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
}
else if (frame instanceof WebSocketFrame) {
theFrame = (WebSocketFrame) frame;
data = theFrame.getPayload();
}
if (data != null && data.startsWith("HTTP/1.1")) {
outputStream.write(data.getBytes());
return;
}
int lenBytes;
int payloadLen = this.server ? 0 : 0x80; //masked
boolean close = theFrame.getType() == WebSocketFrame.TYPE_CLOSE;
boolean ping = theFrame.getType() == WebSocketFrame.TYPE_PING;
boolean pong = theFrame.getType() == WebSocketFrame.TYPE_PONG;
byte[] bytes = theFrame.getBinary() != null ? theFrame.getBinary() : data.getBytes("UTF-8");
int length = bytes.length;
if (close) {
length += 2;
}
if (length >= Math.pow(2, 16)) {
lenBytes = 8;
payloadLen |= 127;
}
else if (length > 125) {
lenBytes = 2;
payloadLen |= 126;
}
else {
lenBytes = 0;
payloadLen |= length;
}
int mask = (int) System.currentTimeMillis();
ByteBuffer buffer = ByteBuffer.allocate(length + 6 + lenBytes);
if (ping) {
buffer.put((byte) 0x89);
}
else if (pong) {
buffer.put((byte) 0x8a);
}
else if (close) {
buffer.put((byte) 0x88);
}
else if (theFrame.getType() == WebSocketFrame.TYPE_DATA_BINARY) {
buffer.put((byte) 0x82);
}
else {
// Final fragment; text
buffer.put((byte) 0x81);
}
buffer.put((byte) payloadLen);
if (lenBytes == 2) {
buffer.putShort((short) length);
}
else if (lenBytes == 8) {
buffer.putLong(length);
}
byte[] maskBytes = new byte[4];
if (!server) {
buffer.putInt(mask);
buffer.position(buffer.position() - 4);
buffer.get(maskBytes);
}
if (close) {
buffer.putShort(theFrame.getStatus());
// TODO: mask status when client
}
for (int i = 0; i < bytes.length; i++) {
if (server) {
buffer.put(bytes[i]);
}
else {
buffer.put((byte) (bytes[i] ^ maskBytes[i % 4]));
}
}
outputStream.write(buffer.array(), 0, buffer.position());
}
@Override
public DataFrame deserialize(InputStream inputStream) throws IOException {
if (this.streamChecked == null) {
this.nio = inputStream.getClass().getName().endsWith("TcpNioConnection$ChannelInputStream");
this.streamAccessor = new DirectFieldAccessor(inputStream);
this.streamChecked = Boolean.TRUE;
}
DataFrame frame = null;
BasicState state = this.getState(inputStream);
if (state != null) {
frame = state.getPendingFrame();
}
while (frame == null || (frame.getPayload() == null && frame.getBinary() == null)) {
frame = doDeserialize(inputStream, frame);
if (frame.getPayload() == null && frame.getBinary() == null) {
state.setPendingFrame(frame);
}
}
return frame;
}
private DataFrame doDeserialize(InputStream inputStream, DataFrame protoFrame) throws IOException {
List<DataFrame> headers = checkStreaming(inputStream);
if (headers != null) {
return headers.get(0);
}
int bite;
if (logger.isDebugEnabled()) {
logger.debug("Available to read:" + inputStream.available());
}
boolean done = false;
int len = 0;
int n = 0;
int dataInx = 0;
byte[] buffer = null;
boolean fin = false;
boolean ping = false;
boolean pong = false;
boolean close = false;
boolean binary = false;
boolean invalid = false;
String invalidText = null;
boolean fragmentedControl = false;
int lenBytes = 0;
byte[] mask = new byte[4];
int maskInx = 0;
int rsv = 0;
while (!done ) {
bite = inputStream.read();
// logger.debug("Read:" + Integer.toHexString(bite));
if (this.nio) {
bite = checkclosed(bite, inputStream);
}
if (bite < 0 && n == 0) {
throw new SoftEndOfStreamException("Stream closed between payloads");
}
checkClosure(bite);
switch (n++) {
case 0:
fin = (bite & 0x80) > 0;
rsv = (bite & 0x70) >> 4;
bite &= 0x0f;
switch (bite) {
case 0x00:
logger.debug("Continuation, fin=" + fin);
if (protoFrame == null) {
invalid = true;
invalidText = "Unexpected continuation frame";
}
else {
binary = protoFrame.getType() == WebSocketFrame.TYPE_DATA_BINARY;
}
this.getState(inputStream).setPendingFrame(null);
break;
case 0x01:
logger.debug("Text, fin=" + fin);
if (protoFrame != null) {
invalid = true;
invalidText = "Expected continuation frame";
}
break;
case 0x02:
logger.debug("Binary, fin=" + fin);
if (protoFrame != null) {
invalid = true;
invalidText = "Expected continuation frame";
}
binary = true;
break;
case 0x08:
logger.debug("Close, fin=" + fin);
fragmentedControl = !fin;
close = true;
break;
case 0x09:
ping = true;
binary = true;
fragmentedControl = !fin;
logger.debug("Ping, fin=" + fin);
break;
case 0x0a:
pong = true;
fragmentedControl = !fin;
logger.debug("Pong, fin=" + fin);
break;
case 0x03:
case 0x04:
case 0x05:
case 0x06:
case 0x07:
case 0x0b:
case 0x0c:
case 0x0d:
case 0x0e:
case 0x0f:
invalid = true;
invalidText = "Reserved opcode " + Integer.toHexString(bite);
break;
default:
throw new IOException("Unexpected opcode " + Integer.toHexString(bite));
}
break;
case 1:
if (this.server) {
if ((bite & 0x80) == 0) {
throw new IOException("Illegal: Expected masked data from client");
}
bite &= 0x7f;
}
if ((bite & 0x80) > 0) {
throw new IOException("Illegal: Received masked data from server");
}
if (bite < 126) {
len = bite;
buffer = new byte[len];
}
else if (bite == 126) {
lenBytes = 2;
}
else {
lenBytes = 8;
}
break;
case 2:
case 3:
case 4:
case 5:
if (lenBytes > 4 && bite != 0) {
throw new IOException("Max supported length exceeded");
}
case 6:
if (lenBytes > 3 && (bite & 0x80) > 0) {
throw new IOException("Max supported length exceeded");
}
case 7:
case 8:
case 9:
if (lenBytes-- > 0) {
len = len << 8 | (bite & 0xff);
if (lenBytes == 0) {
buffer = new byte[len];
}
break;
}
default:
if (this.server && maskInx < 4) {
mask[maskInx++] = (byte) bite;
}
else {
if (this.server) {
bite ^= mask[dataInx % 4];
}
buffer[dataInx++] = (byte) bite;
}
done = (server ? maskInx == 4 : true) && dataInx >= len;
}
};
WebSocketFrame frame;
if (fragmentedControl) {
frame = new WebSocketFrame(WebSocketFrame.TYPE_FRAGMENTED_CONTROL, "Fragmented control frame", buffer);
}
else if (invalid) {
frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID, invalidText, buffer);
}
else if (!fin) {
List<byte[]> fragments = this.getState(inputStream).getFragments();
fragments.add(buffer);
logger.debug("Fragment");
return new WebSocketFrame(binary ? WebSocketFrame.TYPE_DATA_BINARY : WebSocketFrame.TYPE_DATA, (String) null);
}
else if (ping) {
frame = new WebSocketFrame(WebSocketFrame.TYPE_PING, buffer);
}
else if (pong) {
String data = new String(buffer, "UTF-8");
frame = new WebSocketFrame(WebSocketFrame.TYPE_PONG, data);
}
else if (close) {
String data = new String(buffer, "UTF-8");
if (data.length() >= 2) {
data = data.substring(2);
}
WebSocketFrame closeFrame = new WebSocketFrame(WebSocketFrame.TYPE_CLOSE, data);
short status = 1000;
if (buffer.length >= 2) {
status = (short) ((buffer[0] << 8) | (buffer[1] & 0xff));
closeFrame.setStatus(status);
}
if (buffer.length == 1 || buffer.length > 125 ||
(buffer.length > 2 && !validateUtf8IfNecessary(buffer, 2, data)) ||
status < 1000 || INVALID_STATUS.contains(status) || (status >= 1016 && status < 3000) || status >= 5000) {
// Simply close in this case; no close reply
((WebSocketState) this.getState(inputStream)).setCloseInitiated(true);
}
frame = closeFrame;
}
else {
List<byte[]> fragments = this.getState(inputStream).getFragments();
if (fragments.size() == 0) {
if (binary) {
frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA_BINARY, buffer);
}
else {
String data = new String(buffer, "UTF-8");
if (!validateUtf8IfNecessary(buffer, 0, data)) {
frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID_UTF8, "Invalid UTF-8", buffer);
}
else {
frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
}
}
}
else {
fragments.add(buffer);
int utf8Len = 0;
for (byte[] fragment : fragments) {
utf8Len += fragment.length;
}
byte[] reconstructed = new byte[utf8Len];
int utf8Pos = 0;
for (byte[] fragment : fragments) {
System.arraycopy(fragment, 0, reconstructed, utf8Pos, fragment.length);
utf8Pos += fragment.length;
}
fragments.clear();
if (binary) {
frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA_BINARY, reconstructed);
}
else {
String data = new String(reconstructed, "UTF-8");
if (!validateUtf8IfNecessary(reconstructed, 0, data)) {
frame = new WebSocketFrame(WebSocketFrame.TYPE_INVALID_UTF8, "Invalid UTF-8", reconstructed);
}
else {
frame = new WebSocketFrame(WebSocketFrame.TYPE_DATA, data);
}
}
}
}
if (rsv > 0) {
frame.setRsv(rsv);
}
return frame;
}
/**
* TODO: workaround for INT-2936
*/
private int checkclosed(int bite, InputStream inputStream) {
if (bite < 0) { // possibly a closed stream
try {
if ((Boolean) streamAccessor.getPropertyValue("isClosed") &&
inputStream.available() == 0) {
return -1;
}
else {
return bite & 0xff;
}
}
catch (Exception e) {
if (logger.isDebugEnabled()) {
logger.debug("Failed to check closed", e);
}
return bite;
}
}
else {
return bite;
}
}
private boolean validateUtf8IfNecessary(byte[] buffer, int offset, String data) {
if (this.validateUtf8) {
try {
byte[] bytes = data.getBytes("UTF-8");
if (bytes.length != buffer.length - offset) {
return false;
}
for (int i = 0; i < bytes.length; i++) {
if (buffer[i + offset] != bytes[i]) {
return false;
}
}
}
catch (UnsupportedEncodingException e) {
throw new MessagingException("UTF-8 Conversion error");
}
}
return true;
}
@Override
protected void checkClosure(int bite) throws IOException {
if (bite < 0) {
logger.debug("Socket closed during message assembly");
throw new IOException("Socket closed during message assembly");
}
}
@Override
public void removeState(Object inputStream) {
super.removeState(inputStream);
}
public WebSocketFrame generateHandshake(WebSocketFrame frame) throws Exception {
Assert.isTrue(frame.getType() == WebSocketFrame.TYPE_HEADERS, "Expected headers:" + frame);
String[] headers = frame.getPayload().split("\\r\\n");
String key = null;
String version = null;
for (String header : headers) {
if (header.toLowerCase().startsWith("sec-websocket-key")) {
key = header.split(":")[1].trim();
}
else if (header.toLowerCase().startsWith("sec-websocket-version")) {
version = header.split(":")[1].trim();
}
}
if (key == null) {
throw new WebSocketUpgradeException("400 Bad Request: No sec-websocket-key header detected");
}
else if (!"13".equals(version)) {
throw new WebSocketUpgradeException("426 Upgrade Required", "sec-websocket-version: 13\r\n");
}
String handshake = HTTP_1_1_101_WEB_SOCKET_PROTOCOL_HANDSHAKE_SPRING_INTEGRATION +
"Upgrade: WebSocket\r\n" +
"Connection: Upgrade\r\n" +
"Sec-WebSocket-Accept: " + this.generateWebSocketAccept(key) + "\r\n\r\n";
return new WebSocketFrame(WebSocketFrame.TYPE_DATA, handshake);
}
private String generateWebSocketAccept(String key) throws NoSuchAlgorithmException {
MessageDigest md = MessageDigest.getInstance("SHA-1");
String toDigest = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
byte[] acceptStringBytes = md.digest(toDigest.getBytes());
acceptStringBytes = Base64.encodeBase64(acceptStringBytes);
String acceptString = new String(acceptStringBytes);
return acceptString;
}
public static class WebSocketState extends BasicState {
private volatile boolean closeInitiated;
private volatile boolean expectingPong;
public boolean isCloseInitiated() {
return this.closeInitiated;
}
public void setCloseInitiated(boolean closeInitiated) {
this.closeInitiated = closeInitiated;
}
public boolean isExpectingPong() {
return this.expectingPong;
}
public void setExpectingPong(boolean expectingPong) {
this.expectingPong = expectingPong;
}
}
}