Package org.jboss.aerogear.io.netty.handler.codec.sockjs.transport

Source Code of org.jboss.aerogear.io.netty.handler.codec.sockjs.transport.WebSocketTransport

/*
* Copyright 2013 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License, version
* 2.0 (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package org.jboss.aerogear.io.netty.handler.codec.sockjs.transport;

import static io.netty.handler.codec.http.HttpMethod.GET;
import static org.jboss.aerogear.io.netty.handler.codec.sockjs.transport.Transports.badRequestResponse;
import static org.jboss.aerogear.io.netty.handler.codec.sockjs.transport.Transports.internalServerErrorResponse;
import static org.jboss.aerogear.io.netty.handler.codec.sockjs.transport.Transports.methodNotAllowedResponse;

import com.fasterxml.jackson.core.JsonParseException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker;
import io.netty.handler.codec.http.websocketx.WebSocketServerHandshakerFactory;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.SockJsConfig;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.CorsInboundHandler;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.CorsOutboundHandler;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.SessionHandler.Event;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.SockJsHandler;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.util.JsonUtil;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

/**
* WebSocketTransport is responsible for the WebSocket handshake and
* also for receiving WebSocket frames.
*/
public class WebSocketTransport extends SimpleChannelInboundHandler<Object> {

    private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketTransport.class);
    private static final AttributeKey<HttpRequest> REQUEST_KEY = AttributeKey.valueOf("request.key");
    private final SockJsConfig config;
    private WebSocketServerHandshaker handshaker;

    public WebSocketTransport(final SockJsConfig config) {
        this.config = config;
    }

    @Override
    protected void messageReceived(final ChannelHandlerContext ctx, final Object msg) throws Exception {
        if (msg instanceof FullHttpRequest) {
            handleHttpRequest(ctx, (FullHttpRequest) msg);
        } else if (msg instanceof WebSocketFrame) {
            handleWebSocketFrame(ctx, (WebSocketFrame) msg);
        }
    }

    private static boolean checkRequestHeaders(final ChannelHandlerContext ctx, final HttpRequest req) {
        if (req.getMethod() != GET) {
            logger.debug("Request was not of type GET, was {}", req.getMethod());
            ctx.writeAndFlush(methodNotAllowedResponse(req.getProtocolVersion()))
            .addListener(ChannelFutureListener.CLOSE);
            return false;
        }

        final String upgradeHeader = req.headers().get(HttpHeaders.Names.UPGRADE);
        if (upgradeHeader == null || !"websocket".equals(upgradeHeader.toLowerCase())) {
            logger.debug("Upgrade header was not 'websocket' was: {}", upgradeHeader);
            ctx.writeAndFlush(badRequestResponse(req.getProtocolVersion(), "Can \"Upgrade\" only to \"WebSocket\"."))
            .addListener(ChannelFutureListener.CLOSE);
            return false;
        }

        String connectHeader = req.headers().get(HttpHeaders.Names.CONNECTION);
        if (connectHeader != null && "keep-alive, upgrade".equals(connectHeader.toLowerCase())) {
            logger.debug("Connection header was not 'keep-alive, upgrade' was: {}", connectHeader);
            req.headers().set(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.UPGRADE);
            connectHeader = HttpHeaders.Values.UPGRADE.toString();
        }
        if (connectHeader == null || !"upgrade".equals(connectHeader.toLowerCase())) {
            logger.debug("Connection header was not 'upgrade' was: {}", connectHeader);
            ctx.writeAndFlush(badRequestResponse(req.getProtocolVersion(), "\"Connection\" must be \"Upgrade\"."))
            .addListener(ChannelFutureListener.CLOSE);
            return false;
        }
        return true;
    }

    private void handleHttpRequest(final ChannelHandlerContext ctx, FullHttpRequest req) {
        if (!checkRequestHeaders(ctx, req)) {
            return;
        }
        ctx.attr(REQUEST_KEY).set(req);

        if (WebSocketHAProxyHandshaker.isHAProxyReqeust(req)) {
            final String wsUrl = getWebSocketLocation(config.isTls(), req);
            final WebSocketHAProxyHandshaker haHandshaker = new WebSocketHAProxyHandshaker(wsUrl, null, 65365);
            final ChannelFuture handshakeFuture = haHandshaker.handshake(ctx.channel(), req);
            handshakeFuture.addListener(new ChannelFutureListener() {
                @Override
                public void operationComplete(final ChannelFuture future) throws Exception {
                    if (future.isSuccess()) {
                        final ChannelPipeline pipeline = future.channel().pipeline();
                        pipeline.remove(SockJsHandler.class);
                        pipeline.remove(CorsInboundHandler.class);
                        pipeline.remove(CorsOutboundHandler.class);
                        pipeline.replace(WebSocketTransport.class, "websocket-ha-proxy",
                                new WebSocketHAProxyTransport(haHandshaker));
                        pipeline.addLast(new WebSocketSendHandler());
                    }
                }
            });
            return;
        }
        final String wsUrl = getWebSocketLocation(config.isTls(), req, Transports.Type.WEBSOCKET.path());
        final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(wsUrl, null, false);
        handshaker = wsFactory.newHandshaker(req);

        if (handshaker == null) {
            WebSocketServerHandshakerFactory.sendUnsupportedWebSocketVersionResponse(ctx.channel());
        } else {
            final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req);
            handshakeFuture.addListener(new ChannelFutureListener() {
                @Override
                public void operationComplete(final ChannelFuture future) throws Exception {
                    if (future.isSuccess()) {
                        ctx.pipeline().remove(SockJsHandler.class);
                        ctx.pipeline().remove(CorsInboundHandler.class);
                        ctx.pipeline().remove(CorsOutboundHandler.class);
                        ctx.pipeline().addLast(new WebSocketSendHandler());
                    } else {
                        logger.error("Handshake error", future.cause());
                    }
                }
            });
            ctx.fireChannelRead(ReferenceCountUtil.retain(req));
        }
    }

    private static String getWebSocketLocation(final boolean tls, final FullHttpRequest req) {
        final String protocol = tls ? "wss://" : "ws://";
        return protocol + req.headers().get(HttpHeaders.Names.HOST) + req.getUri();
    }

    private static String getWebSocketLocation(final boolean tls, final FullHttpRequest req, final String path) {
        final String protocol = tls ? "wss://" : "ws://";
        return protocol + req.headers().get(HttpHeaders.Names.HOST) + path;
    }

    private void handleWebSocketFrame(final ChannelHandlerContext ctx, final WebSocketFrame wsFrame) throws Exception {
        if (wsFrame instanceof CloseWebSocketFrame) {
            wsFrame.retain();
            logger.debug("CloseWebSocketFrame received");
            handshaker.close(ctx.channel(), (CloseWebSocketFrame) wsFrame);
            ctx.fireUserEventTriggered(Event.CLOSE_SESSION);
            return;
        }
        if (wsFrame instanceof PingWebSocketFrame) {
            logger.debug("PingWebSocketFrame received");
            wsFrame.content().retain();
            ctx.channel().writeAndFlush(new PongWebSocketFrame(wsFrame.content()));
            return;
        }
        if (!(wsFrame instanceof TextWebSocketFrame)) {
            throw new UnsupportedOperationException(String.format("%s frame types not supported",
                    wsFrame.getClass().getName()));
        }
        final String[] messages = JsonUtil.decode((TextWebSocketFrame) wsFrame);
        for (String message : messages) {
            ctx.fireChannelRead(message);
        }
    }

    @Override
    public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception {
        if (cause instanceof JsonParseException) {
            logger.trace("Failed to part JSON", cause);
            ctx.close();
        } else if (cause instanceof WebSocketHandshakeException) {
            final HttpRequest request = ctx.attr(REQUEST_KEY).get();
            logger.error("Failed with ws handshake for request: " + request, cause);
            ctx.writeAndFlush(internalServerErrorResponse(request.getProtocolVersion(), cause.getMessage()))
            .addListener(ChannelFutureListener.CLOSE);
        } else {
            ctx.fireExceptionCaught(cause);
        }
    }

}
TOP

Related Classes of org.jboss.aerogear.io.netty.handler.codec.sockjs.transport.WebSocketTransport

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.