Package io.undertow.websockets.core

Source Code of io.undertow.websockets.core.StreamSourceFrameChannel$Bounds

/*
* JBoss, Home of Professional Open Source.
* Copyright 2014 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
*  Unless required by applicable law or agreed to in writing, software
*  distributed under the License is distributed on an "AS IS" BASIS,
*  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*  See the License for the specific language governing permissions and
*  limitations under the License.
*/

package io.undertow.websockets.core;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.List;

import io.undertow.websockets.core.function.ChannelFunction;
import io.undertow.websockets.core.function.ChannelFunctionFileChannel;
import io.undertow.websockets.core.protocol.version07.Masker;
import io.undertow.websockets.core.protocol.version07.UTF8Checker;
import io.undertow.websockets.extensions.ExtensionByteBuffer;
import io.undertow.websockets.extensions.ExtensionFunction;
import org.xnio.Pooled;
import org.xnio.channels.StreamSinkChannel;

import io.undertow.server.protocol.framed.AbstractFramedStreamSourceChannel;
import io.undertow.server.protocol.framed.FrameHeaderData;

/**
* Base class for processes Frame bases StreamSourceChannels.
*
* @author <a href="mailto:nmaurer@redhat.com">Norman Maurer</a>
*/
public abstract class StreamSourceFrameChannel extends AbstractFramedStreamSourceChannel<WebSocketChannel, StreamSourceFrameChannel, StreamSinkFrameChannel> {

    protected final WebSocketFrameType type;

    private boolean finalFragment;
    private final int rsv;
    private final ChannelFunction[] functions;
    private final List<ExtensionFunction> extensions;
    private ExtensionByteBuffer extensionResult;
    private Masker masker;
    private UTF8Checker checker;

    protected StreamSourceFrameChannel(WebSocketChannel wsChannel, WebSocketFrameType type, Pooled<ByteBuffer> pooled, long frameLength) {
        this(wsChannel, type, 0, true, pooled, frameLength);
    }

    protected StreamSourceFrameChannel(WebSocketChannel wsChannel, WebSocketFrameType type, int rsv, boolean finalFragment, Pooled<ByteBuffer> pooled, long frameLength, ChannelFunction... functions) {
        super(wsChannel, pooled, frameLength);
        this.type = type;
        this.finalFragment = finalFragment;
        this.rsv = rsv;

        this.functions = functions;
        masker = null;
        checker = null;
        for (ChannelFunction func : functions) {
            if (func instanceof Masker) {
                masker = (Masker) func;
            }
            if (func instanceof UTF8Checker) {
                checker = (UTF8Checker) func;
            }
        }
        if (wsChannel.areExtensionsSupported() && wsChannel.getExtensions() != null && !wsChannel.getExtensions().isEmpty()) {
            extensions = wsChannel.getExtensions();
        } else {
            extensions = null;
        }
        this.extensionResult = null;
    }



    /**
     * Return the {@link WebSocketFrameType} or {@code null} if its not known at the calling time.
     */
    public WebSocketFrameType getType() {
        return type;
    }

    /**
     * Flag to indicate if this frame is the final fragment in a message. The first fragment (frame) may also be the
     * final fragment.
     */
    public boolean isFinalFragment() {
        return finalFragment;
    }

    /**
     * Return the rsv which is used for extensions.
     */
    public int getRsv() {
        return rsv;
    }

    int getWebSocketFrameCount() {
        return getReadFrameCount();
    }

    @Override
    protected WebSocketChannel getFramedChannel() {
        return (WebSocketChannel) super.getFramedChannel();
    }

    public WebSocketChannel getWebSocketChannel() {
        return getFramedChannel();
    }

    public void finalFrame() {
        this.lastFrame();
        this.finalFragment = true;
    }

    @Override
    protected void handleHeaderData(FrameHeaderData headerData) {
        super.handleHeaderData(headerData);
        if (((WebSocketFrame) headerData).isFinalFragment()) {
            finalFrame();
        }
        if(functions != null) {
            for(ChannelFunction func : functions) {
                func.newFrame(headerData);
            }
        }
    }


    @Override
    public final long transferTo(long position, long count, FileChannel target) throws IOException {
        long r;
        if (functions != null && functions.length > 0) {
            r = super.transferTo(position, count, new ChannelFunctionFileChannel(target, functions));
        } else {
            r = super.transferTo(position, count, target);
        }
        return r;
    }

    @Override
    public final long transferTo(long count, ByteBuffer throughBuffer, StreamSinkChannel target) throws IOException {
        // use this because of XNIO bug
        // See https://issues.jboss.org/browse/XNIO-185
        return WebSocketUtils.transfer(this, count, throughBuffer, target);
    }

    @Override
    public int read(ByteBuffer dst) throws IOException {
        int r;
        int position = dst.position();
        if (extensionResult == null) {
            r = super.read(dst);
            if (r > 0) {
                masker(dst, position, r);
            }
            if (getRsv() > 0) {
                extensionResult = applyExtensions(dst, position, r);
            }
            if (r > 0) {
                boolean complete = isComplete() && extensionResult == null;
                checker(dst, position, dst.position() - position, complete);
            }
            return r;
        } else {
            r = extensionResult.flushExtra(dst);
            if (!extensionResult.hasExtra()) {
                extensionResult.free();
                extensionResult = null;
            }
            if (r > 0) {
                checker(dst, position, dst.position() - position, isComplete() && extensionResult == null);
            }
            return r;
        }
    }

    @Override
    public final long read(ByteBuffer[] dsts) throws IOException {
        return read(dsts, 0, dsts.length);
    }

    @Override
    public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
        Bounds[] old = new Bounds[length];
        for (int i = offset; i < length; i++) {
            ByteBuffer dst = dsts[i];
            old[i - offset] = new Bounds(dst.position(), dst.limit());
        }
        long b = super.read(dsts, offset, length);
        if (b > 0) {
            for (int i = offset; i < length; i++) {
                ByteBuffer dst = dsts[i];
                int oldPos = old[i - offset].position;
                afterRead(dst, oldPos, dst.position() - oldPos);
            }
        }
        return b;
    }

    /**
     * Called after data was read into the {@link ByteBuffer}
     *
     * @param buffer   the {@link ByteBuffer} into which the data was read
     * @param position the position it was written to
     * @param length   the number of bytes there were written
     * @throws IOException thrown if an error occurs
     */
    protected void afterRead(ByteBuffer buffer, int position, int length) throws IOException {
        try {
            for (ChannelFunction func : functions) {
                func.afterRead(buffer, position, length);
            }
            if (isComplete()) {
                try {
                    for (ChannelFunction func : functions) {
                        func.complete();
                    }
                } catch (UnsupportedEncodingException e) {
                    getFramedChannel().markReadsBroken(e);
                    throw e;
                }
            }
        } catch (UnsupportedEncodingException e) {
            getFramedChannel().markReadsBroken(e);
            throw e;
        }
    }

    protected void masker(ByteBuffer buffer, int position, int length) throws IOException {
        if (masker == null) {
            return;
        }
        masker.afterRead(buffer, position, length);
    }

    protected void checker(ByteBuffer buffer, int position, int length, boolean complete) throws IOException {
        if (checker == null) {
            return;
        }
        try {
            checker.afterRead(buffer, position, length);
            if (complete) {
                try {
                    checker.complete();
                } catch (UnsupportedEncodingException e) {
                    getFramedChannel().markReadsBroken(e);
                    throw e;
                }
            }
        } catch (UnsupportedEncodingException e) {
            getFramedChannel().markReadsBroken(e);
            throw e;
        }
    }

    /**
     * Process Extensions chain after a read operation.
     * <p>
     * An extension can modify original content beyond {@code ByteBuffer} capacity,then original buffer is wrapped with
     * {@link ExtensionByteBuffer} class. {@code ExtensionByteBuffer} stores extra buffer to manage overflow of original
     * {@code ByteBuffer} .
     *
     * @param buffer    the buffer to operate on
     * @param position  the index in the buffer to start from
     * @param length    the number of bytes to operate on
     * @return          a {@link ExtensionByteBuffer} instance as a wrapper of original buffer with extra buffers;
     *                  {@code null} if no extra buffers needed
     * @throws IOException
     */
    protected ExtensionByteBuffer applyExtensions(final ByteBuffer buffer, final int position, final int length) throws IOException {
        ExtensionByteBuffer extBuffer = new ExtensionByteBuffer(getWebSocketChannel(), buffer, position);
        int newLength = length;
        if (extensions != null) {
            for (ExtensionFunction ext : extensions) {
                ext.afterRead(this, extBuffer, position, newLength);
                if (extBuffer.getFilled() == 0) {
                    buffer.position(position);
                    newLength = 0;
                } else if (extBuffer.getFilled() != newLength) {
                    newLength = extBuffer.getFilled();
                }
            }
        }
        if (!extBuffer.hasExtra()) {
            return null;
        }
        return extBuffer;
    }

    private static class Bounds {
        final int position;
        final int limit;

        Bounds(int position, int limit) {
            this.position = position;
            this.limit = limit;
        }
    }
}
TOP

Related Classes of io.undertow.websockets.core.StreamSourceFrameChannel$Bounds

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.