Package org.eclipse.jetty.websocket.common.test

Source Code of org.eclipse.jetty.websocket.common.test.BlockheadServer$ServerConnection

//
//  ========================================================================
//  Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
//  ------------------------------------------------------------------------
//  All rights reserved. This program and the accompanying materials
//  are made available under the terms of the Eclipse Public License v1.0
//  and Apache License v2.0 which accompanies this distribution.
//
//      The Eclipse Public License is available at
//      http://www.eclipse.org/legal/epl-v10.html
//
//      The Apache License v2.0 is available at
//      http://www.opensource.org/licenses/apache2.0.php
//
//  You may elect to redistribute this code under either of these licenses.
//  ========================================================================
//

package org.eclipse.jetty.websocket.common.test;

import static org.hamcrest.Matchers.*;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.eclipse.jetty.io.ByteBufferPool;
import org.eclipse.jetty.io.MappedByteBufferPool;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.IO;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.websocket.api.BatchMode;
import org.eclipse.jetty.websocket.api.WebSocketPolicy;
import org.eclipse.jetty.websocket.api.WriteCallback;
import org.eclipse.jetty.websocket.api.extensions.ExtensionConfig;
import org.eclipse.jetty.websocket.api.extensions.Frame;
import org.eclipse.jetty.websocket.api.extensions.Frame.Type;
import org.eclipse.jetty.websocket.api.extensions.IncomingFrames;
import org.eclipse.jetty.websocket.api.extensions.OutgoingFrames;
import org.eclipse.jetty.websocket.common.AcceptHash;
import org.eclipse.jetty.websocket.common.CloseInfo;
import org.eclipse.jetty.websocket.common.Generator;
import org.eclipse.jetty.websocket.common.OpCode;
import org.eclipse.jetty.websocket.common.Parser;
import org.eclipse.jetty.websocket.common.WebSocketFrame;
import org.eclipse.jetty.websocket.common.extensions.ExtensionStack;
import org.eclipse.jetty.websocket.common.extensions.WebSocketExtensionFactory;
import org.eclipse.jetty.websocket.common.frames.CloseFrame;
import org.junit.Assert;

/**
* A overly simplistic websocket server used during testing.
* <p>
* This is not meant to be performant or accurate. In fact, having the server misbehave is a useful trait during testing.
*/
public class BlockheadServer
{
    public static class ServerConnection implements IncomingFrames, OutgoingFrames, Runnable
    {
        private final int BUFFER_SIZE = 8192;
        private final Socket socket;
        private final ByteBufferPool bufferPool;
        private final WebSocketPolicy policy;
        private final IncomingFramesCapture incomingFrames;
        private final Parser parser;
        private final Generator generator;
        private final AtomicInteger parseCount;
        private final WebSocketExtensionFactory extensionRegistry;
        private final AtomicBoolean echoing = new AtomicBoolean(false);
        private Thread echoThread;

        /** Set to true to disable timeouts (for debugging reasons) */
        private boolean debug = false;
        private OutputStream out;
        private InputStream in;

        private Map<String, String> extraResponseHeaders = new HashMap<>();
        private OutgoingFrames outgoing = this;

        public ServerConnection(Socket socket)
        {
            this.socket = socket;
            this.incomingFrames = new IncomingFramesCapture();
            this.policy = WebSocketPolicy.newServerPolicy();
            this.policy.setMaxBinaryMessageSize(100000);
            this.policy.setMaxTextMessageSize(100000);
            // This is a blockhead server connection, no point tracking leaks on this object.
            this.bufferPool = new MappedByteBufferPool(BUFFER_SIZE);
            this.parser = new Parser(policy,bufferPool);
            this.parseCount = new AtomicInteger(0);
            this.generator = new Generator(policy,bufferPool,false);
            this.extensionRegistry = new WebSocketExtensionFactory(policy,bufferPool);
        }

        /**
         * Add an extra header for the upgrade response (from the server). No extra work is done to ensure the key and value are sane for http.
         */
        public void addResponseHeader(String rawkey, String rawvalue)
        {
            extraResponseHeaders.put(rawkey,rawvalue);
        }

        public void close() throws IOException
        {
            write(new CloseFrame());
            flush();
        }

        public void close(int statusCode) throws IOException
        {
            CloseInfo close = new CloseInfo(statusCode);
            write(close.asFrame());
            flush();
        }

        public void disconnect()
        {
            LOG.debug("disconnect");
            IO.close(in);
            IO.close(out);
            if (socket != null)
            {
                try
                {
                    socket.close();
                }
                catch (IOException ignore)
                {
                    /* ignore */
                }
            }
        }

        public void echoMessage(int expectedFrames, int timeoutDuration, TimeUnit timeoutUnit) throws IOException, TimeoutException
        {
            LOG.debug("Echo Frames [expecting {}]",expectedFrames);
            IncomingFramesCapture cap = readFrames(expectedFrames,timeoutDuration,timeoutUnit);
            // now echo them back.
            for (Frame frame : cap.getFrames())
            {
                write(WebSocketFrame.copy(frame).setMasked(false));
            }
        }

        public void flush() throws IOException
        {
            getOutputStream().flush();
        }

        public ByteBufferPool getBufferPool()
        {
            return bufferPool;
        }

        public IncomingFramesCapture getIncomingFrames()
        {
            return incomingFrames;
        }

        public InputStream getInputStream() throws IOException
        {
            if (in == null)
            {
                in = socket.getInputStream();
            }
            return in;
        }

        private OutputStream getOutputStream() throws IOException
        {
            if (out == null)
            {
                out = socket.getOutputStream();
            }
            return out;
        }

        public Parser getParser()
        {
            return parser;
        }

        public WebSocketPolicy getPolicy()
        {
            return policy;
        }

        @Override
        public void incomingError(Throwable e)
        {
            incomingFrames.incomingError(e);
        }

        @Override
        public void incomingFrame(Frame frame)
        {
            LOG.debug("incoming({})",frame);
            int count = parseCount.incrementAndGet();
            if ((count % 10) == 0)
            {
                LOG.info("Server parsed {} frames",count);
            }
            incomingFrames.incomingFrame(WebSocketFrame.copy(frame));

            if (frame.getOpCode() == OpCode.CLOSE)
            {
                CloseInfo close = new CloseInfo(frame);
                LOG.debug("Close frame: {}",close);
            }

            Type type = frame.getType();
            if (echoing.get() && (type.isData() || type.isContinuation()))
            {
                try
                {
                    write(WebSocketFrame.copy(frame).setMasked(false));
                }
                catch (IOException e)
                {
                    LOG.warn(e);
                }
            }
        }

        @Override
        public void outgoingFrame(Frame frame, WriteCallback callback, BatchMode batchMode)
        {
            ByteBuffer headerBuf = generator.generateHeaderBytes(frame);
            if (LOG.isDebugEnabled())
            {
                LOG.debug("writing out: {}",BufferUtil.toDetailString(headerBuf));
            }

            try
            {
                BufferUtil.writeTo(headerBuf,out);
                if (frame.hasPayload())
                    BufferUtil.writeTo(frame.getPayload(),out);
                out.flush();
                if (callback != null)
                {
                    callback.writeSuccess();
                }

                if (frame.getOpCode() == OpCode.CLOSE)
                {
                    disconnect();
                }
            }
            catch (Throwable t)
            {
                if (callback != null)
                {
                    callback.writeFailed(t);
                }
            }
        }

        public List<ExtensionConfig> parseExtensions(List<String> requestLines)
        {
            List<ExtensionConfig> extensionConfigs = new ArrayList<>();
           
            List<String> hits = regexFind(requestLines, "^Sec-WebSocket-Extensions: (.*)$");

            for (String econf : hits)
            {
                // found extensions
                ExtensionConfig config = ExtensionConfig.parse(econf);
                extensionConfigs.add(config);
            }

            return extensionConfigs;
        }

        public String parseWebSocketKey(List<String> requestLines)
        {
            List<String> hits = regexFind(requestLines,"^Sec-WebSocket-Key: (.*)$");
            if (hits.size() <= 0)
            {
                return null;
            }
           
            Assert.assertThat("Number of Sec-WebSocket-Key headers", hits.size(), is(1));
           
            String key = hits.get(0);
            return key;
        }

        public int read(ByteBuffer buf) throws IOException
        {
            int len = 0;
            while ((in.available() > 0) && (buf.remaining() > 0))
            {
                buf.put((byte)in.read());
                len++;
            }
            return len;
        }

        public IncomingFramesCapture readFrames(int expectedCount, int timeoutDuration, TimeUnit timeoutUnit) throws IOException, TimeoutException
        {
            LOG.debug("Read: waiting for {} frame(s) from client",expectedCount);
            int startCount = incomingFrames.size();

            ByteBuffer buf = bufferPool.acquire(BUFFER_SIZE,false);
            BufferUtil.clearToFill(buf);
            try
            {
                long msDur = TimeUnit.MILLISECONDS.convert(timeoutDuration,timeoutUnit);
                long now = System.currentTimeMillis();
                long expireOn = now + msDur;
                LOG.debug("Now: {} - expireOn: {} ({} ms)",now,expireOn,msDur);

                int len = 0;
                while (incomingFrames.size() < (startCount + expectedCount))
                {
                    BufferUtil.clearToFill(buf);
                    len = read(buf);
                    if (len > 0)
                    {
                        LOG.debug("Read {} bytes",len);
                        BufferUtil.flipToFlush(buf,0);
                        parser.parse(buf);
                    }
                    try
                    {
                        TimeUnit.MILLISECONDS.sleep(20);
                    }
                    catch (InterruptedException gnore)
                    {
                        /* ignore */
                    }
                    if (!debug && (System.currentTimeMillis() > expireOn))
                    {
                        incomingFrames.dump();
                        throw new TimeoutException(String.format("Timeout reading all %d expected frames. (managed to only read %d frame(s))",expectedCount,
                                incomingFrames.size()));
                    }
                }
            }
            finally
            {
                bufferPool.release(buf);
            }

            return incomingFrames;
        }

        public String readRequest() throws IOException
        {
            LOG.debug("Reading client request");
            StringBuilder request = new StringBuilder();
            BufferedReader in = new BufferedReader(new InputStreamReader(getInputStream()));
            for (String line = in.readLine(); line != null; line = in.readLine())
            {
                if (line.length() == 0)
                {
                    break;
                }
                request.append(line).append("\r\n");
                LOG.debug("read line: {}",line);
            }

            LOG.debug("Client Request:{}{}","\n",request);
            return request.toString();
        }

        public List<String> readRequestLines() throws IOException
        {
            LOG.debug("Reading client request header");
            List<String> lines = new ArrayList<>();

            BufferedReader in = new BufferedReader(new InputStreamReader(getInputStream()));
            for (String line = in.readLine(); line != null; line = in.readLine())
            {
                if (line.length() == 0)
                {
                    break;
                }
                lines.add(line);
            }

            return lines;
        }

        public List<String> regexFind(List<String> lines, String pattern)
        {
            List<String> hits = new ArrayList<>();

            Pattern patKey = Pattern.compile(pattern,Pattern.CASE_INSENSITIVE);

            Matcher mat;
            for (String line : lines)
            {
                mat = patKey.matcher(line);
                if (mat.matches())
                {
                    if (mat.groupCount() >= 1)
                    {
                        hits.add(mat.group(1));
                    }
                    else
                    {
                        hits.add(mat.group(0));
                    }
                }
            }

            return hits;
        }

        public void respond(String rawstr) throws IOException
        {
            LOG.debug("respond(){}{}","\n",rawstr);
            getOutputStream().write(rawstr.getBytes());
            flush();
        }

        @Override
        public void run()
        {
            LOG.debug("Entering echo thread");

            ByteBuffer buf = bufferPool.acquire(BUFFER_SIZE,false);
            BufferUtil.clearToFill(buf);
            long readBytes = 0;
            try
            {
                while (echoing.get())
                {
                    BufferUtil.clearToFill(buf);
                    long len = read(buf);
                    if (len > 0)
                    {
                        readBytes += len;
                        LOG.debug("Read {} bytes",len);
                        BufferUtil.flipToFlush(buf,0);
                        parser.parse(buf);
                    }

                    try
                    {
                        TimeUnit.MILLISECONDS.sleep(20);
                    }
                    catch (InterruptedException gnore)
                    {
                        /* ignore */
                    }
                }
            }
            catch (IOException e)
            {
                LOG.debug("Exception during echo loop",e);
            }
            finally
            {
                LOG.debug("Read {} bytes",readBytes);
                bufferPool.release(buf);
            }
        }

        public void setSoTimeout(int ms) throws SocketException
        {
            socket.setSoTimeout(ms);
        }

        public void startEcho()
        {
            if (echoThread != null)
            {
                throw new IllegalStateException("Echo thread already declared!");
            }
            echoThread = new Thread(this,"BlockheadServer/Echo");
            echoing.set(true);
            echoThread.start();
        }

        public void stopEcho()
        {
            echoing.set(false);
        }

        public List<String> upgrade() throws IOException
        {
            List<String> requestLines = readRequestLines();
            List<ExtensionConfig> extensionConfigs = parseExtensions(requestLines);
            String key = parseWebSocketKey(requestLines);

            LOG.debug("Client Request Extensions: {}",extensionConfigs);
            LOG.debug("Client Request Key: {}",key);

            Assert.assertThat("Request: Sec-WebSocket-Key",key,notNullValue());

            // collect extensions configured in response header
            ExtensionStack extensionStack = new ExtensionStack(extensionRegistry);
            extensionStack.negotiate(extensionConfigs);

            // Start with default routing
            extensionStack.setNextIncoming(this);
            extensionStack.setNextOutgoing(this);

            // Configure Parser / Generator
            extensionStack.configure(parser);
            extensionStack.configure(generator);

            // Start Stack
            try
            {
                extensionStack.start();
            }
            catch (Exception e)
            {
                throw new IOException("Unable to start Extension Stack");
            }

            // Configure Parser
            parser.setIncomingFramesHandler(extensionStack);

            // Setup Response
            StringBuilder resp = new StringBuilder();
            resp.append("HTTP/1.1 101 Upgrade\r\n");
            resp.append("Connection: upgrade\r\n");
            resp.append("Sec-WebSocket-Accept: ");
            resp.append(AcceptHash.hashKey(key)).append("\r\n");
            if (extensionStack.hasNegotiatedExtensions())
            {
                // Respond to used extensions
                resp.append("Sec-WebSocket-Extensions: ");
                boolean delim = false;
                for (ExtensionConfig ext : extensionStack.getNegotiatedExtensions())
                {
                    if (delim)
                    {
                        resp.append(", ");
                    }
                    resp.append(ext.getParameterizedName());
                    delim = true;
                }
                resp.append("\r\n");
            }
            if (extraResponseHeaders.size() > 0)
            {
                for (Map.Entry<String, String> xheader : extraResponseHeaders.entrySet())
                {
                    resp.append(xheader.getKey());
                    resp.append(": ");
                    resp.append(xheader.getValue());
                    resp.append("\r\n");
                }
            }
            resp.append("\r\n");

            // Write Response
            LOG.debug("Response: {}",resp.toString());
            write(resp.toString().getBytes());
            return requestLines;
        }

        private void write(byte[] bytes) throws IOException
        {
            getOutputStream().write(bytes);
        }

        public void write(byte[] buf, int offset, int length) throws IOException
        {
            getOutputStream().write(buf,offset,length);
        }

        public void write(Frame frame) throws IOException
        {
            LOG.debug("write(Frame->{}) to {}",frame,outgoing);
            outgoing.outgoingFrame(frame,null,BatchMode.OFF);
        }

        public void write(int b) throws IOException
        {
            getOutputStream().write(b);
        }

        public void write(ByteBuffer buf) throws IOException
        {
            byte arr[] = BufferUtil.toArray(buf);
            if ((arr != null) && (arr.length > 0))
            {
                getOutputStream().write(arr);
            }
        }
    }

    private static final Logger LOG = Log.getLogger(BlockheadServer.class);
    private ServerSocket serverSocket;
    private URI wsUri;

    public ServerConnection accept() throws IOException
    {
        LOG.debug(".accept()");
        assertIsStarted();
        Socket socket = serverSocket.accept();
        return new ServerConnection(socket);
    }

    private void assertIsStarted()
    {
        Assert.assertThat("ServerSocket",serverSocket,notNullValue());
        Assert.assertThat("ServerSocket.isBound",serverSocket.isBound(),is(true));
        Assert.assertThat("ServerSocket.isClosed",serverSocket.isClosed(),is(false));

        Assert.assertThat("WsUri",wsUri,notNullValue());
    }

    public URI getWsUri()
    {
        return wsUri;
    }

    public void respondToClient(Socket connection, String serverResponse) throws IOException
    {
        InputStream in = null;
        InputStreamReader isr = null;
        BufferedReader buf = null;
        OutputStream out = null;
        try
        {
            in = connection.getInputStream();
            isr = new InputStreamReader(in);
            buf = new BufferedReader(isr);
            String line;
            while ((line = buf.readLine()) != null)
            {
                // System.err.println(line);
                if (line.length() == 0)
                {
                    // Got the "\r\n" line.
                    break;
                }
            }

            // System.out.println("[Server-Out] " + serverResponse);
            out = connection.getOutputStream();
            out.write(serverResponse.getBytes());
            out.flush();
        }
        finally
        {
            IO.close(buf);
            IO.close(isr);
            IO.close(in);
            IO.close(out);
        }
    }

    public void start() throws IOException
    {
        InetAddress addr = InetAddress.getByName("localhost");
        serverSocket = new ServerSocket();
        InetSocketAddress endpoint = new InetSocketAddress(addr,0);
        serverSocket.bind(endpoint,1);
        int port = serverSocket.getLocalPort();
        String uri = String.format("ws://%s:%d/",addr.getHostAddress(),port);
        wsUri = URI.create(uri);
        LOG.debug("Server Started on {} -> {}",endpoint,wsUri);
    }

    public void stop()
    {
        LOG.debug("Stopping Server");
        try
        {
            serverSocket.close();
        }
        catch (IOException ignore)
        {
            /* ignore */
        }
    }
}
TOP

Related Classes of org.eclipse.jetty.websocket.common.test.BlockheadServer$ServerConnection

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.