Package org.springframework.web.socket.messaging

Source Code of org.springframework.web.socket.messaging.StompSubProtocolHandlerTests

/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.socket.messaging;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.SimpAttributes;
import org.springframework.messaging.simp.SimpAttributesContextHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.TestPrincipal;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompDecoder;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.simp.user.DefaultUserSessionRegistry;
import org.springframework.messaging.simp.user.DestinationUserNameProvider;
import org.springframework.messaging.simp.user.UserSessionRegistry;
import org.springframework.messaging.support.ChannelInterceptorAdapter;
import org.springframework.messaging.support.ExecutorSubscribableChannel;
import org.springframework.messaging.support.ImmutableMessageChannelInterceptor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.handler.TestWebSocketSession;
import org.springframework.web.socket.sockjs.transport.SockJsSession;

import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

/**
* Test fixture for {@link StompSubProtocolHandler} tests.
*
* @author Rossen Stoyanchev
*/
public class StompSubProtocolHandlerTests {

  public static final byte[] EMPTY_PAYLOAD = new byte[0];

  private StompSubProtocolHandler protocolHandler;

  private TestWebSocketSession session;

  private MessageChannel channel;

  @SuppressWarnings("rawtypes")
  private ArgumentCaptor<Message> messageCaptor;


  @Before
  public void setup() {
    this.protocolHandler = new StompSubProtocolHandler();
    this.channel = Mockito.mock(MessageChannel.class);
    this.messageCaptor = ArgumentCaptor.forClass(Message.class);

    this.session = new TestWebSocketSession();
    this.session.setId("s1");
    this.session.setPrincipal(new TestPrincipal("joe"));
  }

  @Test
  public void handleMessageToClientConnected() {

    UserSessionRegistry registry = new DefaultUserSessionRegistry();
    this.protocolHandler.setUserSessionRegistry(registry);

    StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
    Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    this.protocolHandler.handleMessageToClient(this.session, message);

    assertEquals(1, this.session.getSentMessages().size());
    WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
    assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());

    assertEquals(Collections.singleton("s1"), registry.getSessionIds("joe"));
  }

  @Test
  public void handleMessageToClientConnectedUniqueUserName() {

    this.session.setPrincipal(new UniqueUser("joe"));

    UserSessionRegistry registry = new DefaultUserSessionRegistry();
    this.protocolHandler.setUserSessionRegistry(registry);

    StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
    Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    this.protocolHandler.handleMessageToClient(this.session, message);

    assertEquals(1, this.session.getSentMessages().size());
    WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
    assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());

    assertEquals(Collections.<String>emptySet(), registry.getSessionIds("joe"));
    assertEquals(Collections.singleton("s1"), registry.getSessionIds("Me myself and I"));
  }

  @Test
  public void handleMessageToClientConnectedWithHeartbeats() {

    SockJsSession sockJsSession = Mockito.mock(SockJsSession.class);

    StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
    headers.setHeartbeat(0,10);
    Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    this.protocolHandler.handleMessageToClient(sockJsSession, message);

    verify(sockJsSession).disableHeartbeat();
  }

  @Test
  public void handleMessageToClientConnectAck() {

    StompHeaderAccessor connectHeaders = StompHeaderAccessor.create(StompCommand.CONNECT);
    connectHeaders.setHeartbeat(10000, 10000);
    connectHeaders.setAcceptVersion("1.0,1.1");
    Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectHeaders.getMessageHeaders());

    SimpMessageHeaderAccessor connectAckHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
    connectAckHeaders.setHeader(SimpMessageHeaderAccessor.CONNECT_MESSAGE_HEADER, connectMessage);
    Message<byte[]> connectAckMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, connectAckHeaders.getMessageHeaders());

    this.protocolHandler.handleMessageToClient(this.session, connectAckMessage);

    verifyNoMoreInteractions(this.channel);

    // Check CONNECTED reply

    assertEquals(1, this.session.getSentMessages().size());
    TextMessage textMessage = (TextMessage) this.session.getSentMessages().get(0);

    List<Message<byte[]>> messages = new StompDecoder().decode(ByteBuffer.wrap(textMessage.getPayload().getBytes()));
    assertEquals(1, messages.size());
    StompHeaderAccessor replyHeaders = StompHeaderAccessor.wrap(messages.get(0));

    assertEquals(StompCommand.CONNECTED, replyHeaders.getCommand());
    assertEquals("1.1", replyHeaders.getVersion());
    assertArrayEquals(new long[] {0, 0}, replyHeaders.getHeartbeat());
    assertEquals("joe", replyHeaders.getNativeHeader("user-name").get(0));
  }

  @Test
  public void eventPublication() {

    TestPublisher publisher = new TestPublisher();

    UserSessionRegistry registry = new DefaultUserSessionRegistry();
    this.protocolHandler.setUserSessionRegistry(registry);
    this.protocolHandler.setApplicationEventPublisher(publisher);
    this.protocolHandler.afterSessionStarted(this.session, this.channel);

    StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
    Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    TextMessage textMessage = new TextMessage(new StompEncoder().encode(message));
    this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);

    headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
    message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    this.protocolHandler.handleMessageToClient(this.session, message);

    headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
    message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    textMessage = new TextMessage(new StompEncoder().encode(message));
    this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);

    headers = StompHeaderAccessor.create(StompCommand.UNSUBSCRIBE);
    message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    textMessage = new TextMessage(new StompEncoder().encode(message));
    this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);

    this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel);

    assertEquals("Unexpected events " + publisher.events, 5, publisher.events.size());
    assertEquals(SessionConnectEvent.class, publisher.events.get(0).getClass());
    assertEquals(SessionConnectedEvent.class, publisher.events.get(1).getClass());
    assertEquals(SessionSubscribeEvent.class, publisher.events.get(2).getClass());
    assertEquals(SessionUnsubscribeEvent.class, publisher.events.get(3).getClass());
    assertEquals(SessionDisconnectEvent.class, publisher.events.get(4).getClass());
  }

  @Test
  public void eventPublicationWithExceptions() {

    ApplicationEventPublisher publisher = mock(ApplicationEventPublisher.class);

    UserSessionRegistry registry = new DefaultUserSessionRegistry();
    this.protocolHandler.setUserSessionRegistry(registry);
    this.protocolHandler.setApplicationEventPublisher(publisher);
    this.protocolHandler.afterSessionStarted(this.session, this.channel);

    StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
    Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    TextMessage textMessage = new TextMessage(new StompEncoder().encode(message));
    this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);

    verify(this.channel).send(this.messageCaptor.capture());
    Message<?> actual = this.messageCaptor.getValue();
    assertNotNull(actual);
    assertEquals(StompCommand.CONNECT, StompHeaderAccessor.wrap(actual).getCommand());
    reset(this.channel);

    headers = StompHeaderAccessor.create(StompCommand.CONNECTED);
    message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    this.protocolHandler.handleMessageToClient(this.session, message);

    assertEquals(1, this.session.getSentMessages().size());
    textMessage = (TextMessage) this.session.getSentMessages().get(0);
    assertEquals("CONNECTED\n" + "user-name:joe\n" + "\n" + "\u0000", textMessage.getPayload());

    this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, this.channel);

    verify(this.channel).send(this.messageCaptor.capture());
    actual = this.messageCaptor.getValue();
    assertNotNull(actual);
    StompHeaderAccessor accessor = StompHeaderAccessor.wrap(actual);
    assertEquals(StompCommand.DISCONNECT, accessor.getCommand());
    assertEquals("s1", accessor.getSessionId());
    assertEquals("joe", accessor.getUser().getName());
  }

  @Test
  public void handleMessageToClientUserDestination() {

    StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.MESSAGE);
    headers.setMessageId("mess0");
    headers.setSubscriptionId("sub0");
    headers.setDestination("/queue/foo-user123");
    headers.setNativeHeader(StompHeaderAccessor.ORIGINAL_DESTINATION, "/user/queue/foo");
    Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    this.protocolHandler.handleMessageToClient(this.session, message);

    assertEquals(1, this.session.getSentMessages().size());
    WebSocketMessage<?> textMessage = this.session.getSentMessages().get(0);
    assertTrue(((String) textMessage.getPayload()).contains("destination:/user/queue/foo\n"));
    assertFalse(((String) textMessage.getPayload()).contains(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION));
  }

  @Test
  public void handleMessageFromClient() {

    TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).headers(
        "login:guest", "passcode:guest", "accept-version:1.1,1.0", "heart-beat:10000,10000").build();

    this.protocolHandler.afterSessionStarted(this.session, this.channel);
    this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);

    verify(this.channel).send(this.messageCaptor.capture());
    Message<?> actual = this.messageCaptor.getValue();
    assertNotNull(actual);

    StompHeaderAccessor headers = StompHeaderAccessor.wrap(actual);
    assertEquals(StompCommand.CONNECT, headers.getCommand());
    assertEquals("s1", headers.getSessionId());
    assertNotNull(headers.getSessionAttributes());
    assertEquals("joe", headers.getUser().getName());
    assertEquals("guest", headers.getLogin());
    assertEquals("guest", headers.getPasscode());
    assertArrayEquals(new long[] {10000, 10000}, headers.getHeartbeat());
    assertEquals(new HashSet<>(Arrays.asList("1.1","1.0")), headers.getAcceptVersion());

    assertEquals(0, this.session.getSentMessages().size());
  }

  @Test
  public void handleMessageFromClientWithImmutableMessageInterceptor() {
    AtomicReference<Boolean> mutable = new AtomicReference<>();
    ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
    channel.addInterceptor(new ChannelInterceptorAdapter() {
      @Override
      public Message<?> preSend(Message<?> message, MessageChannel channel) {
        mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable());
        return message;
      }
    });
    channel.addInterceptor(new ImmutableMessageChannelInterceptor());

    StompSubProtocolHandler handler = new StompSubProtocolHandler();
    handler.afterSessionStarted(this.session, channel);

    TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
    handler.handleMessageFromClient(this.session, message, channel);
    assertNotNull(mutable.get());
    assertTrue(mutable.get());
  }

  @Test
  public void handleMessageFromClientWithoutImmutableMessageInterceptor() {
    AtomicReference<Boolean> mutable = new AtomicReference<>();
    ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
    channel.addInterceptor(new ChannelInterceptorAdapter() {
      @Override
      public Message<?> preSend(Message<?> message, MessageChannel channel) {
        mutable.set(MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class).isMutable());
        return message;
      }
    });

    StompSubProtocolHandler handler = new StompSubProtocolHandler();
    handler.afterSessionStarted(this.session, channel);

    TextMessage message = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
    handler.handleMessageFromClient(this.session, message, channel);
    assertNotNull(mutable.get());
    assertFalse(mutable.get());
  }
  @Test
  public void handleMessageFromClientInvalidStompCommand() {

    TextMessage textMessage = new TextMessage("FOO\n\n\0");

    this.protocolHandler.afterSessionStarted(this.session, this.channel);
    this.protocolHandler.handleMessageFromClient(this.session, textMessage, this.channel);

    verifyZeroInteractions(this.channel);
    assertEquals(1, this.session.getSentMessages().size());
    TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
    assertTrue(actual.getPayload().startsWith("ERROR"));
  }

  @Test
  public void webSocketScope() {

    Runnable runnable = Mockito.mock(Runnable.class);
    SimpAttributes simpAttributes = new SimpAttributes(this.session.getId(), this.session.getAttributes());
    simpAttributes.setAttribute("name", "value");
    simpAttributes.registerDestructionCallback("name", runnable);

    MessageChannel testChannel = new MessageChannel() {
      @Override
      public boolean send(Message<?> message) {
        SimpAttributes simpAttributes = SimpAttributesContextHolder.currentAttributes();
        assertThat(simpAttributes.getAttribute("name"), is("value"));
        return true;
      }
      @Override
      public boolean send(Message<?> message, long timeout) {
        return false;
      }
    };

    this.protocolHandler.afterSessionStarted(this.session, this.channel);

    StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.CONNECT);
    Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, headers.getMessageHeaders());
    TextMessage textMessage = new TextMessage(new StompEncoder().encode(message));

    this.protocolHandler.handleMessageFromClient(this.session, textMessage, testChannel);
    assertEquals(Collections.emptyList(), session.getSentMessages());

    this.protocolHandler.afterSessionEnded(this.session, CloseStatus.BAD_DATA, testChannel);
    assertEquals(Collections.emptyList(), session.getSentMessages());
    verify(runnable, times(1)).run();
  }


  private static class UniqueUser extends TestPrincipal implements DestinationUserNameProvider {

    private UniqueUser(String name) {
      super(name);
    }

    @Override
    public String getDestinationUserName() {
      return "Me myself and I";
    }
  }

  private static class TestPublisher implements ApplicationEventPublisher {

    private final List<ApplicationEvent> events = new ArrayList<ApplicationEvent>();

    @Override
    public void publishEvent(ApplicationEvent event) {
      events.add(event);
    }
  }

}
TOP

Related Classes of org.springframework.web.socket.messaging.StompSubProtocolHandlerTests

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.