Package net.sf.cindy.filter

Source Code of net.sf.cindy.filter.SSLFilter$QueuedPacket

/*
* Copyright 2004-2006 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.sf.cindy.filter;

import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;

import net.sf.cindy.Buffer;
import net.sf.cindy.Packet;
import net.sf.cindy.Session;
import net.sf.cindy.SessionFilterAdapter;
import net.sf.cindy.SessionFilterChain;
import net.sf.cindy.buffer.BufferFactory;
import net.sf.cindy.packet.DefaultPacket;
import net.sf.cindy.session.SessionException;
import edu.emory.mathcs.backport.java.util.LinkedList;
import edu.emory.mathcs.backport.java.util.Queue;

/**
* SSL/TLS filter, require java 5.0.
*
* @author <a href="chenrui@gmail.com">Roger Chen</a>
* @version $id$
*/
public class SSLFilter extends SessionFilterAdapter {

    private final SSLContext sslc;

    /**
     * Hold send packet if handshake is not completed.
     */
    private final Queue tempQueue = new LinkedList();

    private static class QueuedPacket {

        private final SessionFilterChain chain;
        private final Packet packet;

        public QueuedPacket(SessionFilterChain chain, Packet packet) {
            this.chain = chain;
            this.packet = packet;
        }

    }

    private boolean clientMode;
    private boolean needClientAuth;

    private SSLEngine engine;
    private boolean handshakeCompleted = false;
    private int appBufferSize;
    private int netBufferSize;

    public SSLFilter(SSLContext context) {
        this.sslc = context;
    }

    public SSLContext getSSLContext() {
        return sslc;
    }

    public boolean isClientMode() {
        return clientMode;
    }

    public void setClientMode(boolean clientMode) {
        this.clientMode = clientMode;
    }

    public boolean isNeedClientAuth() {
        return needClientAuth;
    }

    public void setNeedClientAuth(boolean needClientAuth) {
        this.needClientAuth = needClientAuth;
    }

    /**
     * Handshake packet.
     *
     * @author <a href="chenrui@gmail.com">Roger Chen</a>
     * @version $id$
     */
    private static class HandshakePacket extends DefaultPacket {

        public HandshakePacket() {
            super(BufferFactory.allocate(0));
        }

    }

    public void sessionStarted(SessionFilterChain filterChain) throws Exception {
        // init engine
        Session session = filterChain.getSession();
        SocketAddress address = session.getRemoteAddress();
        if (clientMode && (address instanceof InetSocketAddress)) {
            InetSocketAddress netAddress = (InetSocketAddress) address;
            engine = sslc.createSSLEngine(netAddress.getAddress()
                    .getHostAddress(), netAddress.getPort());
        } else {
            engine = sslc.createSSLEngine();
        }
        engine.setUseClientMode(clientMode);
        engine.setNeedClientAuth(needClientAuth);
        initSSLEngine(engine);

        appBufferSize = engine.getSession().getApplicationBufferSize();
        netBufferSize = engine.getSession().getPacketBufferSize();

        // handshake
        engine.beginHandshake();

        super.sessionStarted(filterChain);
        if (clientMode)
            session.flush(new HandshakePacket());
    }

    public void sessionClosed(SessionFilterChain filterChain) throws Exception {
        try {
            if (engine != null) {
                engine.closeOutbound();
                if (!engine.isInboundDone()) {
                    Packet encodedPacket = encode(filterChain.getSession(),
                            new HandshakePacket());
                    if (encodedPacket != null)
                        encodedPacket.getContent().release();
                }
                engine.closeInbound();
            }
        } catch (SSLException ssle) {
            throw new SessionException(ssle);
        } finally {
            handshakeCompleted = false;
            engine = null;
            QueuedPacket packet = null;
            while ((packet = (QueuedPacket) tempQueue.poll()) != null) {
                packet.chain.packetSend(packet.packet);
            }
            super.sessionClosed(filterChain);
        }
    }

    /**
     * Templet method.
     */
    protected void initSSLEngine(SSLEngine engine) {
    }

    private synchronized void handshakeCompleted() throws Exception {
        handshakeCompleted = true;
        QueuedPacket packet = null;
        while ((packet = (QueuedPacket) tempQueue.poll()) != null) {
            this.packetSend(packet.chain, packet.packet);
        }
    }

    private void runTask() {
        Runnable runnable = null;
        while ((runnable = engine.getDelegatedTask()) != null)
            runnable.run();
    }

    private Packet decode(Session session, Packet packet) throws Exception {
        Buffer src = packet.getContent();
        ByteBuffer srcBuffer = src.asByteBuffer();

        int size = appBufferSize
                * (int) Math.ceil((double) (src.remaining() + 1)
                        / netBufferSize);
        Buffer dest = BufferFactory.allocate(size);
        ByteBuffer destBuffer = dest.asByteBuffer();

        boolean hasAppData = false;

        // decode
        while (true) {
            SSLEngineResult result = engine.unwrap(srcBuffer, destBuffer);
            Status status = result.getStatus();
            if (status == Status.OK) {
                HandshakeStatus handshakeStatus = result.getHandshakeStatus();
                if (handshakeStatus == HandshakeStatus.NOT_HANDSHAKING) {
                    hasAppData = true;
                    if (!srcBuffer.hasRemaining()) // decode completed
                        break;
                } else if (handshakeStatus == HandshakeStatus.NEED_TASK) {
                    runTask();
                } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) {
                    // unwrap again
                } else if (handshakeStatus == HandshakeStatus.NEED_WRAP) {
                    // need send some handshake packet
                    session.flush(new HandshakePacket());
                    break;
                } else if (handshakeStatus == HandshakeStatus.FINISHED) {
                    handshakeCompleted();
                }
            } else if (status == Status.BUFFER_UNDERFLOW) {
                // need more content, wait next time
                break;
            } else if (status == Status.BUFFER_OVERFLOW) {
                // increase dest capacity
                destBuffer.flip();
                Buffer newDest = BufferFactory.allocate(
                        appBufferSize + destBuffer.remaining()).put(destBuffer);
                dest.release();
                dest = newDest;
                destBuffer = dest.asByteBuffer();
            } else if (status == Status.CLOSED) {
                // send close message if possible
                if (result.getHandshakeStatus() == HandshakeStatus.NEED_WRAP)
                    session.flush(new HandshakePacket());
                break;
            }
        }

        src.release();
        if (hasAppData) {
            dest.position(0).limit(destBuffer.position());
            return new DefaultPacket(dest, packet.getAddress());
        }
        return null;
    }

    public void packetReceived(SessionFilterChain filterChain, Packet packet)
            throws Exception {
        Packet decodedPacket = decode(filterChain.getSession(), packet);
        if (decodedPacket != null)
            super.packetReceived(filterChain, decodedPacket);
    }

    private Packet encode(Session session, Packet packet) throws Exception {
        Buffer src = packet.getContent();
        ByteBuffer srcBuffer = src.asByteBuffer();

        int size = netBufferSize
                * (int) Math.ceil((double) (src.remaining() + 1)
                        / appBufferSize);
        Buffer dest = BufferFactory.allocate(size);
        ByteBuffer destBuffer = dest.asByteBuffer();

        // encode
        while (true) {
            SSLEngineResult result = engine.wrap(srcBuffer, destBuffer);
            Status status = result.getStatus();
            if (status == Status.OK) {
                HandshakeStatus handshakeStatus = result.getHandshakeStatus();
                if (handshakeStatus == HandshakeStatus.NOT_HANDSHAKING) {
                    if (!srcBuffer.hasRemaining()) // encode completed
                        break;
                } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) {
                    break; // wait receive
                } else if (handshakeStatus == HandshakeStatus.NEED_WRAP) {
                    // wrap again
                } else if (handshakeStatus == HandshakeStatus.NEED_TASK) {
                    runTask();
                } else if (handshakeStatus == HandshakeStatus.FINISHED) {
                    handshakeCompleted();
                    break;
                }
            } else if (status == Status.BUFFER_OVERFLOW) {
                // increase dest capacity
                destBuffer.flip();
                Buffer newDest = BufferFactory.allocate(
                        netBufferSize + destBuffer.remaining()).put(destBuffer);
                dest.release();
                dest = newDest;
                destBuffer = dest.asByteBuffer();
            } else if (status == Status.CLOSED) {
                break;
            } else if (status == Status.BUFFER_UNDERFLOW) {
                // should never happen
                src.release();
                session.close();
                return null;
            }
        }

        src.release();
        dest.position(0).limit(destBuffer.position());
        return new DefaultPacket(dest, packet.getAddress());
    }

    public void packetSend(SessionFilterChain filterChain, Packet packet)
            throws Exception {
        boolean queued = !handshakeCompleted
                && !(packet instanceof HandshakePacket);
        if (queued) {
            tempQueue.offer(new QueuedPacket(filterChain, packet));
            return;
        }

        Packet encodedPacket = encode(filterChain.getSession(), packet);
        if (encodedPacket != null)
            super.packetSend(filterChain, encodedPacket);
    }
}
TOP

Related Classes of net.sf.cindy.filter.SSLFilter$QueuedPacket

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.