Package org.jboss.aerogear.simplepush.server.netty

Source Code of org.jboss.aerogear.simplepush.server.netty.SimplePushSockJSServiceTest

/**
* JBoss, Home of Professional Open Source
* Copyright Red Hat, Inc., and individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.aerogear.simplepush.server.netty;

import static io.netty.handler.codec.http.HttpHeaders.Values.WEBSOCKET;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static io.netty.util.CharsetUtil.UTF_8;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerInvoker;
import io.netty.channel.ChannelOutboundBuffer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.HttpHeaders.Names;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.ReferenceCountUtil;

import java.net.SocketAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;

import org.jboss.aerogear.simplepush.protocol.Ack;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.SockJsConfig;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.SockJsService;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.SockJsServiceFactory;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.CorsInboundHandler;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.CorsOutboundHandler;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.handler.SockJsHandler;
import org.jboss.aerogear.io.netty.handler.codec.sockjs.transport.Transports;
import org.jboss.aerogear.simplepush.protocol.HelloResponse;
import org.jboss.aerogear.simplepush.protocol.MessageType;
import org.jboss.aerogear.simplepush.protocol.PingMessage;
import org.jboss.aerogear.simplepush.protocol.RegisterResponse;
import org.jboss.aerogear.simplepush.protocol.UnregisterResponse;
import org.jboss.aerogear.simplepush.protocol.impl.AckMessageImpl;
import org.jboss.aerogear.simplepush.protocol.impl.HelloResponseImpl;
import org.jboss.aerogear.simplepush.protocol.impl.NotificationMessageImpl;
import org.jboss.aerogear.simplepush.protocol.impl.PingMessageImpl;
import org.jboss.aerogear.simplepush.protocol.impl.RegisterResponseImpl;
import org.jboss.aerogear.simplepush.protocol.impl.UnregisterResponseImpl;
import org.jboss.aerogear.simplepush.protocol.impl.AckImpl;
import org.jboss.aerogear.simplepush.protocol.impl.json.JsonUtil;
import org.jboss.aerogear.simplepush.server.DefaultSimplePushConfig;
import org.jboss.aerogear.simplepush.server.DefaultSimplePushServer;
import org.jboss.aerogear.simplepush.server.SimplePushServer;
import org.jboss.aerogear.simplepush.server.SimplePushServerConfig;
import org.jboss.aerogear.simplepush.server.datastore.ChannelNotFoundException;
import org.jboss.aerogear.simplepush.server.datastore.DataStore;
import org.jboss.aerogear.simplepush.server.datastore.InMemoryDataStore;
import org.jboss.aerogear.simplepush.util.CryptoUtil;
import org.jboss.aerogear.simplepush.util.UUIDUtil;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

public class SimplePushSockJSServiceTest {

    private SockJsServiceFactory factory;
    private String sessionUrl;

    @Before
    public void setup() {
        factory = defaultFactory();
        sessionUrl = randomSessionIdUrl(factory);
    }

    @Test
    public void xhrPollingOpenFrame() throws Exception {
        final FullHttpResponse openFrameResponse = sendXhrOpenFrameRequest(factory, sessionUrl);
        assertThat(openFrameResponse.getStatus(), is(HttpResponseStatus.OK));
        assertThat(openFrameResponse.content().toString(UTF_8), equalTo("o\n"));
    }

    @Test
    public void xhrPollingHelloWithChannelId() throws Exception {
        final String uaid = UUIDUtil.newUAID();
        final String channelId = UUID.randomUUID().toString();
        sendXhrOpenFrameRequest(factory, sessionUrl);

        final FullHttpResponse sendResponse = sendXhrHelloMessageRequest(factory, sessionUrl, uaid, channelId);
        assertThat(sendResponse.getStatus(), is(HttpResponseStatus.NO_CONTENT));
        final HelloResponseImpl handshakeResponse = pollXhrHelloMessageResponse(factory, sessionUrl);
        assertThat(handshakeResponse.getUAID(), equalTo(uaid));
    }

    @Test
    public void xhrPollingHelloWithInvalidUaid() throws Exception {
        final String uaid = "non-valie2233??";
        final String channelId = UUID.randomUUID().toString();
        sendXhrOpenFrameRequest(factory, sessionUrl);

        final FullHttpResponse sendResponse = sendXhrHelloMessageRequest(factory, sessionUrl, uaid, channelId);
        assertThat(sendResponse.getStatus(), is(HttpResponseStatus.NO_CONTENT));
        final HelloResponseImpl handshakeResponse = pollXhrHelloMessageResponse(factory, sessionUrl);
        assertThat(handshakeResponse.getMessageType(), is(MessageType.Type.HELLO));
        assertThat(handshakeResponse.getUAID(), not(equalTo(uaid)));
    }

    @Test
    public void xhrPollingRegister() throws Exception {
        final String channelId = UUID.randomUUID().toString();
        sendXhrOpenFrameRequest(factory, sessionUrl);
        sendXhrHelloMessageRequest(factory, sessionUrl, UUIDUtil.newUAID());
        pollXhrHelloMessageResponse(factory, sessionUrl);

        final FullHttpResponse registerChannelIdRequest = sendXhrRegisterChannelIdRequest(factory, sessionUrl, channelId);
        assertThat(registerChannelIdRequest.getStatus(), is(HttpResponseStatus.NO_CONTENT));

        final RegisterResponseImpl registerChannelIdResponse = pollXhrRegisterChannelIdResponse(factory, sessionUrl);
        assertThat(registerChannelIdResponse.getChannelId(), equalTo(channelId));
        assertThat(registerChannelIdResponse.getStatus().getCode(), equalTo(200));
        assertThat(registerChannelIdResponse.getPushEndpoint().startsWith("http://127.0.0.1:7777/update/"), is(true));
    }

    @Test
    public void xhrPollingUnregister() throws Exception {
        final String channelId = UUID.randomUUID().toString();
        sendXhrOpenFrameRequest(factory, sessionUrl);
        sendXhrHelloMessageRequest(factory, sessionUrl, UUIDUtil.newUAID());
        pollXhrHelloMessageResponse(factory, sessionUrl);
        sendXhrRegisterChannelIdRequest(factory, sessionUrl, channelId);
        pollXhrRegisterChannelIdResponse(factory, sessionUrl);

        final FullHttpResponse unregisterChannelIdRequest = unregisterChannelIdRequest(factory, sessionUrl, channelId);
        assertThat(unregisterChannelIdRequest.getStatus(), is(HttpResponseStatus.NO_CONTENT));

        final UnregisterResponseImpl unregisterChannelIdResponse = unregisterChannelIdResponse(factory, sessionUrl);
        assertThat(unregisterChannelIdResponse.getStatus().getCode(), is(200));
        assertThat(unregisterChannelIdResponse.getChannelId(), equalTo(channelId));
    }

    @Test
    public void xhrPollingPing() throws Exception {
        sendXhrOpenFrameRequest(factory, sessionUrl);
        sendXhrHelloMessageRequest(factory, sessionUrl, UUIDUtil.newUAID());
        pollXhrHelloMessageResponse(factory, sessionUrl);

        final FullHttpResponse registerChannelIdRequest = sendXhrPingRequest(factory, sessionUrl);
        assertThat(registerChannelIdRequest.getStatus(), is(HttpResponseStatus.NO_CONTENT));

        final PingMessageImpl pingResponse = pollXhrPingMessageResponse(factory, sessionUrl);
        assertThat(pingResponse.getPingMessage(), equalTo(PingMessage.PING_MESSAGE));
    }

    @Test
    public void websocketUpgradeRequest() throws Exception {
        final EmbeddedChannel channel = createChannel(factory);
        final HttpResponse response = websocketHttpUpgradeRequest(sessionUrl, channel);
        assertThat(response.getStatus(), is(HttpResponseStatus.SWITCHING_PROTOCOLS));
        assertThat(response.headers().get(HttpHeaders.Names.UPGRADE), equalTo("websocket"));
        assertThat(response.headers().get(HttpHeaders.Names.CONNECTION), equalTo("Upgrade"));
        assertThat(response.headers().get(Names.SEC_WEBSOCKET_ACCEPT), equalTo("s3pPLMBiTxaQ9kYGzzhZRbK+xOo="));
        channel.close();
    }

    public static HttpResponse decodeHttpResponse(final EmbeddedChannel channel) {
        final EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
        ch.writeInbound(channel.readOutbound());
        return ch.readInbound();
    }

    public static FullHttpResponse decodeFullHttpResponse(final EmbeddedChannel channel) {
        final EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder());
        ch.writeInbound(channel.outboundMessages().toArray());
        final HttpResponse response = ch.readInbound();
        final HttpContent content = ch.readInbound();
        final DefaultFullHttpResponse fullResponse;
        if (content != null) {
            fullResponse = new DefaultFullHttpResponse(response.getProtocolVersion(), response.getStatus(), content.content());
        } else {
            fullResponse = new DefaultFullHttpResponse(response.getProtocolVersion(), response.getStatus());
        }
        fullResponse.headers().add(response.headers());
        return fullResponse;
    }

    @Test
    public void rawWebSocketUpgradeRequest() throws Exception {
        final SimplePushServerConfig simplePushConfig = DefaultSimplePushConfig.create().password("test").build();
        final SockJsConfig sockjsConf = SockJsConfig.withPrefix("/simplepush").webSocketProtocols("push-notification").build();
        final byte[] privateKey = CryptoUtil.secretKey(simplePushConfig.password(), "someSaltForTesting".getBytes());
        final SimplePushServer pushServer = new DefaultSimplePushServer(new InMemoryDataStore(), simplePushConfig, privateKey);
        final SimplePushServiceFactory factory = new SimplePushServiceFactory(sockjsConf, pushServer);
        final EmbeddedChannel channel = createChannel(factory);
        final FullHttpRequest request = websocketUpgradeRequest(factory.config().prefix() + Transports.Type.WEBSOCKET.path());
        request.headers().set(Names.SEC_WEBSOCKET_PROTOCOL, "push-notification");
        channel.writeInbound(request);
        final FullHttpResponse response = decodeFullHttpResponse(channel);
        assertThat(response.getStatus(), is(HttpResponseStatus.SWITCHING_PROTOCOLS));
        assertThat(response.headers().get(HttpHeaders.Names.UPGRADE), equalTo("websocket"));
        assertThat(response.headers().get(HttpHeaders.Names.CONNECTION), equalTo("Upgrade"));
        assertThat(response.headers().get(Names.SEC_WEBSOCKET_PROTOCOL), equalTo("push-notification"));
        assertThat(response.headers().get(Names.SEC_WEBSOCKET_ACCEPT), equalTo("s3pPLMBiTxaQ9kYGzzhZRbK+xOo="));
        channel.close();
    }

    @Test
    public void websocketHello() {
        final EmbeddedChannel channel = createWebSocketChannel(factory);
        final String uaid = UUIDUtil.newUAID();
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);

        final HelloResponse response = sendWebSocketHelloFrame(uaid, channel);
        assertThat(response.getMessageType(), equalTo(MessageType.Type.HELLO));
        assertThat(response.getUAID(), equalTo(uaid));
        channel.close();
    }

    @Test
    public void websocketHelloWithInvalidUaid() {
        final String uaid = "non-valie2233??";
        final EmbeddedChannel channel = createWebSocketChannel(factory);
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);

        final HelloResponse response = sendWebSocketHelloFrame(uaid, channel);
        assertThat(response.getMessageType(), equalTo(MessageType.Type.HELLO));
        assertThat(response.getUAID(), not(equalTo(uaid)));
        channel.close();
    }

    @Test
    public void websocketRegister() {
        final EmbeddedChannel channel = createWebSocketChannel(factory);
        final String channelId = UUID.randomUUID().toString();
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel);

        final RegisterResponse registerResponse = sendWebSocketRegisterFrame(channelId, channel);
        assertThat(registerResponse.getStatus().getCode(), is(200));
        assertThat(registerResponse.getChannelId(), equalTo(channelId));
        channel.close();
    }

    @Test
    public void websocketRegisterDuplicateChannelId() {
        final EmbeddedChannel channel = createWebSocketChannel(factory);
        final String channelId = UUID.randomUUID().toString();
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel);

        assertThat(sendWebSocketRegisterFrame(channelId, channel).getStatus().getCode(), is(200));
        assertThat(sendWebSocketRegisterFrame(channelId, channel).getStatus().getCode(), is(409));
        channel.close();
    }

    @Test
    public void websocketUnregister() {
        final EmbeddedChannel channel = createWebSocketChannel(factory);
        final String channelId = UUID.randomUUID().toString();
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel);
        sendWebSocketRegisterFrame(channelId, channel);

        final UnregisterResponse registerResponse = websocketUnRegisterFrame(channelId, channel);
        assertThat(registerResponse.getStatus().getCode(), is(200));
        channel.close();
    }

    @Test
    public void websocketUnregisterNonRegistered() {
        final EmbeddedChannel channel = createWebSocketChannel(factory);
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel);

        final UnregisterResponse registerResponse = websocketUnRegisterFrame("notRegistered", channel);
        assertThat(registerResponse.getMessageType(), equalTo(MessageType.Type.UNREGISTER));
        assertThat(registerResponse.getChannelId(), equalTo("notRegistered"));
        assertThat(registerResponse.getStatus().getCode(), is(200));
        channel.close();
    }

    @Test
    public void websocketHandleAcknowledgement() throws Exception {
        final SimplePushServer simplePushServer = defaultPushServer();
        final SockJsServiceFactory serviceFactory = defaultFactory(simplePushServer);
        final EmbeddedChannel channel = createWebSocketChannel(serviceFactory);
        final String uaid = UUIDUtil.newUAID();
        final String channelId = UUID.randomUUID().toString();
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(uaid, channel);
        final RegisterResponse registerResponse = sendWebSocketRegisterFrame(channelId, channel);
        final String endpointToken = extractEndpointToken(registerResponse.getPushEndpoint());
        sendNotification(endpointToken, 1L, simplePushServer);

        final Set<Ack> unacked = sendAcknowledge(channel, ack(channelId, 1L));
        assertThat(unacked.isEmpty(), is(true));
        channel.close();
    }

    @Test
    public void websocketHandleAcknowledgements() throws Exception {
        final SimplePushServer simplePushServer = defaultPushServer();
        final SockJsServiceFactory serviceFactory = defaultFactory(simplePushServer);
        final EmbeddedChannel channel = createWebSocketChannel(serviceFactory);
        final String uaid = UUIDUtil.newUAID();
        final String channelId1 = UUID.randomUUID().toString();
        final String channelId2 = UUID.randomUUID().toString();
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(uaid, channel);
        final RegisterResponse registerResponse1 = sendWebSocketRegisterFrame(channelId1, channel);
        final String endpointToken1 = extractEndpointToken(registerResponse1.getPushEndpoint());
        final RegisterResponse registerResponse2 = sendWebSocketRegisterFrame(channelId2, channel);
        final String endpointToken2 = extractEndpointToken(registerResponse2.getPushEndpoint());
        sendNotification(endpointToken1, 1L, simplePushServer);
        sendNotification(endpointToken2, 1L, simplePushServer);

        final Set<Ack> unacked = sendAcknowledge(channel, ack(channelId1, 1L), ack(channelId2, 1L));
        assertThat(unacked.isEmpty(), is(true));
        channel.close();
    }

    private String extractEndpointToken(final String pushEndpoint) {
        return pushEndpoint.substring(pushEndpoint.lastIndexOf('/') + 1);
    }

    @Test
    @Ignore("Need to figure out how to run a schedules job with the new EmbeddedChannel")
    // https://groups.google.com/forum/#!topic/netty/Q-_wat_9Odo
    public void websocketHandleOneUnacknowledgement() throws Exception {
        final SimplePushServer simplePushServer = defaultPushServer();
        final SockJsServiceFactory serviceFactory = defaultFactory(simplePushServer);
        final EmbeddedChannel channel = createWebSocketChannel(serviceFactory);
        final String uaid = UUIDUtil.newUAID();
        final String channelId1 = UUID.randomUUID().toString();
        final String channelId2 = UUID.randomUUID().toString();
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(uaid, channel);
        final RegisterResponse registerResponse1 = sendWebSocketRegisterFrame(channelId1, channel);
        final String endpointToken1 = extractEndpointToken(registerResponse1.getPushEndpoint());
        sendNotification(endpointToken1, 1L, simplePushServer);

        final RegisterResponse registerResponse2 = sendWebSocketRegisterFrame(channelId2, channel);
        final String endpointToken2 = extractEndpointToken(registerResponse2.getPushEndpoint());
        sendNotification(endpointToken2, 1L, simplePushServer);

        final Set<Ack> unacked = sendAcknowledge(channel, ack(channelId1, 1L));
        assertThat(unacked.size(), is(1));
        assertThat(unacked, hasItem(new AckImpl(channelId2, 1L)));
        channel.close();
    }

    @Test
    @Ignore("Need to figure out how to run a schedules job with the new EmbeddedChannel")
    // https://groups.google.com/forum/#!topic/netty/Q-_wat_9Odo
    public void websocketHandleUnacknowledgement() throws Exception {
        final SimplePushServer simplePushServer = defaultPushServer();
        final SockJsServiceFactory serviceFactory = defaultFactory(simplePushServer);
        final EmbeddedChannel channel = createWebSocketChannel(serviceFactory);
        final String uaid = UUIDUtil.newUAID();
        final String channelId1 = UUID.randomUUID().toString();
        final String channelId2 = UUID.randomUUID().toString();
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(uaid, channel);
        final RegisterResponse registerResponse1 = sendWebSocketRegisterFrame(channelId1, channel);
        final String endpointToken1 = extractEndpointToken(registerResponse1.getPushEndpoint());
        sendNotification(endpointToken1, 1L, simplePushServer);
        final RegisterResponse registerResponse2 = sendWebSocketRegisterFrame(channelId2, channel);
        final String endpointToken2 = extractEndpointToken(registerResponse2.getPushEndpoint());
        sendNotification(endpointToken2, 1L, simplePushServer);

        final Set<Ack> unacked = sendAcknowledge(channel);
        assertThat(unacked.size(), is(1));
        assertThat(unacked, hasItems(ack(channelId1, 1L), ack(channelId2, 1L)));
        channel.close();
    }

    @Test
    public void websocketPing() {
        final EmbeddedChannel channel = createWebSocketChannel(factory);
        sendWebSocketHttpUpgradeRequest(sessionUrl, channel);
        sendWebSocketHelloFrame(UUIDUtil.newUAID(), channel);

        final PingMessage pingResponse = sendWebSocketPingFrame(channel);
        assertThat(pingResponse.getPingMessage(), equalTo(PingMessage.PING_MESSAGE));
        channel.close();
    }

    private SimplePushServer defaultPushServer() {
        final DataStore store = new InMemoryDataStore();
        final SimplePushServerConfig config = DefaultSimplePushConfig.create().password("test").build();
        final byte[] privateKey = DefaultSimplePushServer.generateAndStorePrivateKey(store, config);
        return new DefaultSimplePushServer(store, config, privateKey);
    }

    private void sendNotification(final String endpointToken, final long version,
            final SimplePushServer simplePushServer) throws ChannelNotFoundException {
        simplePushServer.handleNotification(endpointToken, "version=" + version);
    }

    private Ack ack(final String channelId, final Long version) {
        return new AckImpl(channelId, version);
    }

    private Set<Ack> sendAcknowledge(final EmbeddedChannel channel, final Ack... acks) {
        final Set<Ack> ups = new HashSet<Ack>(Arrays.asList(acks));
        final TextWebSocketFrame ackFrame = ackFrame(ups);
        channel.writeInbound(ackFrame);
        channel.runPendingTasks();

        final Object out = channel.readOutbound();
        if (out == null) {
            return Collections.emptySet();
        }

        final NotificationMessageImpl unacked = responseToType(out, NotificationMessageImpl.class);
        return unacked.getAcks();
    }

    private TextWebSocketFrame ackFrame(final Set<Ack> acks) {
        return new TextWebSocketFrame(JsonUtil.toJson(new AckMessageImpl(acks)));
    }

    private RegisterResponseImpl sendWebSocketRegisterFrame(final String channelId, final EmbeddedChannel ch) {
        ch.writeInbound(TestUtil.registerChannelIdWebSocketFrame(channelId));
        return responseToType(readOutboundDiscardEmpty(ch), RegisterResponseImpl.class);
    }

    private PingMessageImpl sendWebSocketPingFrame(final EmbeddedChannel ch) {
        ch.writeInbound(TestUtil.pingWebSocketFrame());
        return responseToType(ch.readOutbound(), PingMessageImpl.class);
    }

    private UnregisterResponse websocketUnRegisterFrame(final String channelId, final EmbeddedChannel ch) {
        ch.writeInbound(TestUtil.unregisterChannelIdWebSocketFrame(channelId));
        return responseToType(ch.readOutbound(), UnregisterResponseImpl.class);
    }

    private HttpResponse websocketHttpUpgradeRequest(final String sessionUrl, final EmbeddedChannel ch) throws Exception{
        ch.writeInbound(websocketUpgradeRequest(sessionUrl + Transports.Type.WEBSOCKET.path()));
        return decodeHttpResponse(ch);
    }

    private void sendWebSocketHttpUpgradeRequest(final String sessionUrl, final EmbeddedChannel ch) {
        ch.writeInbound(websocketUpgradeRequest(sessionUrl + Transports.Type.WEBSOCKET.path()));
        // Discarding the Http upgrade response
        ch.readOutbound();
        ch.readOutbound();
        // Discard open frame
        ch.readOutbound();
        ch.readOutbound();
        ch.pipeline().remove("wsencoder");
    }

    private HelloResponse sendWebSocketHelloFrame(final String uaid, final EmbeddedChannel ch) {
        ch.writeInbound(TestUtil.helloWebSocketFrame(uaid));
        return responseToType(ch.readOutbound(), HelloResponseImpl.class);
    }

    private Object readOutboundDiscardEmpty(final EmbeddedChannel ch) {
        final Object obj = ch.readOutbound();
        if (obj instanceof ByteBuf) {
            final ByteBuf buf = (ByteBuf) obj;
            if (buf.capacity() == 0) {
                ReferenceCountUtil.release(buf);
                return ch.readOutbound();
            }
        }
        return obj;
    }

    private <T> T responseToType(final Object response, Class<T> type) {
        if (response instanceof TextWebSocketFrame) {
            final TextWebSocketFrame frame = (TextWebSocketFrame) response;
            String content = frame.text();
            if (content.startsWith("a[")) {
                content = TestUtil.extractJsonFromSockJSMessage(content);
            }
            return JsonUtil.fromJson(content, type);
        }
        throw new IllegalArgumentException("Response is expected to be of type TextWebSocketFrame was: " + response);
    }

    private FullHttpResponse sendXhrOpenFrameRequest(final SockJsServiceFactory factory, final String sessionUrl) throws Exception {
        final EmbeddedChannel openChannel = createChannel(factory);
        openChannel.writeInbound(httpGetRequest(sessionUrl + Transports.Type.XHR.path()));
        final FullHttpResponse openFrameResponse = decodeFullHttpResponse(openChannel);
        openChannel.close();
        return openFrameResponse;
    }

    private FullHttpResponse sendXhrHelloMessageRequest(final SockJsServiceFactory factory, final String sessionUrl,
            final String uaid, final String... channelIds) throws Exception {
        return xhrSend(factory, sessionUrl, TestUtil.helloSockJSFrame(uaid, channelIds));
    }

    private HelloResponseImpl pollXhrHelloMessageResponse(final SockJsServiceFactory factory, final String sessionUrl) throws Exception {
        final FullHttpResponse pollResponse = xhrPoll(factory, sessionUrl);
        assertThat(pollResponse.getStatus(), is(HttpResponseStatus.OK));

        final String helloJson = TestUtil.extractJsonFromSockJSMessage(pollResponse.content().toString(UTF_8));
        return JsonUtil.fromJson(helloJson, HelloResponseImpl.class);
    }

    private FullHttpResponse sendXhrRegisterChannelIdRequest(final SockJsServiceFactory factory, final String sessionUrl,
            final String channelId) throws Exception {
        return xhrSend(factory, sessionUrl, TestUtil.registerChannelIdMessageSockJSFrame(channelId));
    }

    private RegisterResponseImpl pollXhrRegisterChannelIdResponse(final SockJsServiceFactory factory, final String sessionUrl) throws Exception {
        final FullHttpResponse pollResponse = xhrPoll(factory, sessionUrl);
        assertThat(pollResponse.getStatus(), is(HttpResponseStatus.OK));

        final String json = TestUtil.extractJsonFromSockJSMessage(pollResponse.content().toString(UTF_8));
        return JsonUtil.fromJson(json, RegisterResponseImpl.class);
    }

    private FullHttpResponse unregisterChannelIdRequest(final SockJsServiceFactory factory, final String sessionUrl,
            final String channelId) throws Exception {
        return xhrSend(factory, sessionUrl, TestUtil.unregisterChannelIdMessageSockJSFrame(channelId));
    }

    private UnregisterResponseImpl unregisterChannelIdResponse(final SockJsServiceFactory factory, final String sessionUrl) throws Exception {
        final FullHttpResponse pollResponse = xhrPoll(factory, sessionUrl);
        assertThat(pollResponse.getStatus(), is(HttpResponseStatus.OK));

        final String json = TestUtil.extractJsonFromSockJSMessage(pollResponse.content().toString(UTF_8));
        return JsonUtil.fromJson(json, UnregisterResponseImpl.class);
    }

    private FullHttpResponse sendXhrPingRequest(final SockJsServiceFactory factory, final String sessionUrl) throws Exception {
        return xhrSend(factory, sessionUrl, TestUtil.pingSockJSFrame());
    }

    private PingMessageImpl pollXhrPingMessageResponse(final SockJsServiceFactory factory, final String sessionUrl) throws Exception {
        final FullHttpResponse pollResponse = xhrPoll(factory, sessionUrl);
        assertThat(pollResponse.getStatus(), is(HttpResponseStatus.OK));

        final String helloJson = TestUtil.extractJsonFromSockJSMessage(pollResponse.content().toString(UTF_8));
        return JsonUtil.fromJson(helloJson, PingMessageImpl.class);
    }

    private FullHttpResponse xhrSend(final SockJsServiceFactory factory, final String sessionUrl, final String content) throws Exception {
        final EmbeddedChannel sendChannel = createChannel(factory);
        final FullHttpRequest sendRequest = httpPostRequest(sessionUrl + Transports.Type.XHR_SEND.path());
        sendRequest.content().writeBytes(Unpooled.copiedBuffer(content, UTF_8));
        sendChannel.writeInbound(sendRequest);
        final FullHttpResponse sendResponse = decodeFullHttpResponse(sendChannel);
        sendChannel.close();
        return sendResponse;

    }

    private FullHttpResponse xhrPoll(final SockJsServiceFactory factory, final String sessionUrl) throws Exception {
        final EmbeddedChannel pollChannel = createChannel(factory);
        pollChannel.writeInbound(httpGetRequest(sessionUrl + Transports.Type.XHR.path()));
        return decodeFullHttpResponse(pollChannel);
    }

    private FullHttpRequest httpGetRequest(final String path) {
        return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, path);
    }

    private FullHttpRequest websocketUpgradeRequest(final String path) {
        final FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, path);
        req.headers().set(Names.HOST, "server.test.com");
        req.headers().set(Names.UPGRADE, WEBSOCKET.toString());
        req.headers().set(Names.CONNECTION, "Upgrade");
        req.headers().set(Names.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==");
        req.headers().set(Names.SEC_WEBSOCKET_ORIGIN, "http://test.com");
        req.headers().set(Names.SEC_WEBSOCKET_VERSION, "13");
        req.headers().set(Names.CONTENT_LENGTH, "0");
        return req;
    }

    private FullHttpRequest httpPostRequest(final String path) {
        return new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, path);
    }

    private SockJsServiceFactory defaultFactory() {
        final SimplePushServerConfig simplePushConfig = DefaultSimplePushConfig.create().password("test").build();
        final SockJsConfig sockjsConf = SockJsConfig.withPrefix("/simplepush").build();
        final byte[] privateKey = CryptoUtil.secretKey(simplePushConfig.password(), "someSaltForTesting".getBytes());
        final SimplePushServer pushServer = new DefaultSimplePushServer(new InMemoryDataStore(), simplePushConfig, privateKey);
        return new SimplePushServiceFactory(sockjsConf, pushServer);
    }

    private SockJsServiceFactory defaultFactory(final SimplePushServer simplePushServer) {
        final SockJsConfig sockJSConfig = SockJsConfig.withPrefix("/simplepush").build();
        return new SockJsServiceFactory() {
            @Override
            public SockJsService create() {
                return new SimplePushSockJSService(config(), simplePushServer);
            }

            @Override
            public SockJsConfig config() {
                return sockJSConfig;
            }
        };
    }

    private String randomSessionIdUrl(final SockJsServiceFactory factory) {
        return factory.config().prefix() + "/111/" + UUID.randomUUID().toString();
    }

    private EmbeddedChannel createChannel(final SockJsServiceFactory factory) {
        final EmbeddedChannel ch = new TestEmbeddedChannel(
                new HttpRequestDecoder(),
                new HttpResponseEncoder(),
                new CorsInboundHandler(),
                new SockJsHandler(factory),
                new CorsOutboundHandler());
        ch.pipeline().remove("EmbeddedChannel$LastInboundHandler#0");
        return ch;
    }

    private EmbeddedChannel createWebSocketChannel(final SockJsServiceFactory factory) {
        final EmbeddedChannel ch = new TestEmbeddedChannel(
                new HttpRequestDecoder(),
                new HttpResponseEncoder(),
                new CorsInboundHandler(),
                new SockJsHandler(factory),
                new CorsOutboundHandler());
        ch.pipeline().remove("EmbeddedChannel$LastInboundHandler#0");
        return ch;
    }

    private static class TestEmbeddedChannel extends EmbeddedChannel {

        public TestEmbeddedChannel(final ChannelHandler... handlers) {
            super(handlers);
        }

        @Override
        public Unsafe unsafe() {
            final AbstractUnsafe delegate = super.newUnsafe();
            return new TestUnsafe(delegate, new StubEmbeddedEventLoop(super.eventLoop()));
        }

        private class TestUnsafe implements Unsafe {

            private final Unsafe delegate;
            private final ChannelHandlerInvoker invoker;

            public TestUnsafe(final Unsafe delegate, final ChannelHandlerInvoker invoker) {
                this.delegate = delegate;
                this.invoker = invoker;
            }

            @Override
            public ChannelHandlerInvoker invoker() {
                return invoker;
            }

            @Override
            public SocketAddress localAddress() {
                return delegate.localAddress();
            }

            @Override
            public SocketAddress remoteAddress() {
                return delegate.remoteAddress();
            }

            @Override
            public void register(ChannelPromise promise) {
                delegate.register(promise);
            }

            @Override
            public void bind(SocketAddress localAddress, ChannelPromise promise) {
                delegate.bind(localAddress, promise);
            }

            @Override
            public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
                delegate.connect(remoteAddress, localAddress, promise);
            }

            @Override
            public void disconnect(ChannelPromise promise) {
                delegate.disconnect(promise);
            }

            @Override
            public void close(ChannelPromise promise) {
                delegate.close(promise);
            }

            @Override
            public void closeForcibly() {
                delegate.closeForcibly();
            }

            @Override
            public void beginRead() {
                delegate.beginRead();
            }

            @Override
            public void write(Object msg, ChannelPromise promise) {
                delegate.write(msg, promise);
            }

            @Override
            public void flush() {
                delegate.flush();
            }

            @Override
            public ChannelPromise voidPromise() {
                return delegate.voidPromise();
            }

            @Override
            public ChannelOutboundBuffer outboundBuffer() {
                return delegate.outboundBuffer();
            }
        }

    }

}
TOP

Related Classes of org.jboss.aerogear.simplepush.server.netty.SimplePushSockJSServiceTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.