Package io.undertow.ajp

Source Code of io.undertow.ajp.AjpResponseConduit

/*
* JBoss, Home of Professional Open Source.
* Copyright 2012 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.undertow.ajp;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

import io.undertow.UndertowLogger;
import io.undertow.conduits.ConduitListener;
import io.undertow.server.ExchangeCookieUtils;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.HeaderMap;
import io.undertow.util.Headers;
import io.undertow.util.HttpString;
import io.undertow.util.StatusCodes;
import org.jboss.logging.Logger;
import org.xnio.IoUtils;
import org.xnio.Pool;
import org.xnio.Pooled;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.AbstractStreamSinkConduit;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.StreamSinkConduit;

import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.allAreSet;
import static org.xnio.Bits.anyAreSet;

/**
* AJP response channel. For now we are going to assume that the buffers are sized to
* fit complete packets. As AJP packets are limited to 8k this is a reasonable assumption.
*
* @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
* @author Stuart Douglas
*/
final class AjpResponseConduit extends AbstractStreamSinkConduit<StreamSinkConduit> {

    private static final Logger log = Logger.getLogger("io.undertow.server.channel.ajp.response");

    private static final int MAX_DATA_SIZE = 8186;

    private static final Map<HttpString, Integer> HEADER_MAP;

    private final Pool<ByteBuffer> pool;

    /**
     * State flags
     */
    @SuppressWarnings("unused")
    private volatile int state = FLAG_START;

    private static final AtomicIntegerFieldUpdater<AjpResponseConduit> stateUpdater = AtomicIntegerFieldUpdater.newUpdater(AjpResponseConduit.class, "state");

    /**
     * The current data buffer. This will be released once it has been written out.
     */
    private Pooled<ByteBuffer> currentDataBuffer;
    /**
     * The current packet header and data buffer combined, in a form that allows them to be written out
     * in a gathering write.
     */
    private ByteBuffer[] packetHeaderAndDataBuffer;

    private final HttpServerExchange exchange;

    private final ConduitListener<? super AjpResponseConduit> finishListener;



    /**
     * An AJP request channel that wants access to the underlying sink channel.
     * <p/>
     * When this is then then any remaining data will be written out, and then ownership of
     * the underlying channel will be transferred to the request channel.
     * <p/>
     * While this field is set attempts to write will always return 0.
     */
    private volatile ByteBuffer readBodyChunkBuffer;

    private static final int FLAG_START = 1; //indicates that the header has not been generated yet.
    private static final int FLAG_SHUTDOWN = 1 << 2;
    private static final int FLAG_DELEGATE_SHUTDOWN = 1 << 3;
    private static final int FLAG_CLOSE_QUEUED = 1 << 4;
    private static final int FLAG_WRITE_ENTERED = 1 << 5;

    static {
        final Map<HttpString, Integer> headers = new HashMap<HttpString, Integer>();
        headers.put(Headers.CONTENT_TYPE, 0xA001);
        headers.put(Headers.CONTENT_LANGUAGE, 0xA002);
        headers.put(Headers.CONTENT_LENGTH, 0xA003);
        headers.put(Headers.DATE, 0xA004);
        headers.put(Headers.LAST_MODIFIED, 0xA005);
        headers.put(Headers.LOCATION, 0xA006);
        headers.put(Headers.SET_COOKIE, 0xA007);
        headers.put(Headers.SET_COOKIE2, 0xA008);
        headers.put(Headers.SERVLET_ENGINE, 0xA009);
        headers.put(Headers.STATUS, 0xA00A);
        headers.put(Headers.WWW_AUTHENTICATE, 0xA00B);
        HEADER_MAP = Collections.unmodifiableMap(headers);
    }

    AjpResponseConduit(final StreamSinkConduit next, final Pool<ByteBuffer> pool, final HttpServerExchange exchange, ConduitListener<? super AjpResponseConduit> finishListener) {
        super(next);
        this.pool = pool;
        this.exchange = exchange;
        this.finishListener = finishListener;
        state = FLAG_START;
    }

    private void putInt(final ByteBuffer buf, int value) {
        buf.put((byte) ((value >> 8) & 0xFF));
        buf.put((byte) (value & 0xFF));
    }

    private void putString(final ByteBuffer buf, String value) {
        final int length = value.length();
        putInt(buf, length);
        for (int i = 0; i < length; ++i) {
            buf.put((byte) value.charAt(i));
        }
        buf.put((byte) 0);
    }

    /**
     * Handles writing out the header data, plus any current buffers. Returns true if the write can proceed,
     * false if there are still cached bufers
     *
     * @return
     * @throws java.io.IOException
     */
    private boolean processWrite() throws IOException {
        int oldState;
        int writeEnteredState;
        do {
            oldState = this.state;
            if ((oldState & FLAG_WRITE_ENTERED) != 0) {
                return false;
            }
            if (anyAreSet(state, FLAG_DELEGATE_SHUTDOWN)) {
                return true;
            }
            writeEnteredState = oldState | FLAG_WRITE_ENTERED;
        } while (!stateUpdater.compareAndSet(this, state, writeEnteredState));
        int newState = writeEnteredState;

        //if currentDataBuffer is set then we just
        if (anyAreSet(oldState, FLAG_START)) {
            if (readBodyChunkBuffer == null) {

                //merge the cookies into the header map
                ExchangeCookieUtils.flattenCookies(exchange);

                currentDataBuffer = pool.allocate();
                final ByteBuffer buffer = currentDataBuffer.getResource();
                packetHeaderAndDataBuffer = new ByteBuffer[1];
                packetHeaderAndDataBuffer[0] = buffer;
                buffer.put((byte) 'A');
                buffer.put((byte) 'B');
                buffer.put((byte) 0); //we fill the size in later
                buffer.put((byte) 0);
                buffer.put((byte) 4);
                putInt(buffer, exchange.getResponseCode());
                putString(buffer, StatusCodes.getReason(exchange.getResponseCode()));

                int headers = 0;
                //we need to count the headers
                final HeaderMap responseHeaders = exchange.getResponseHeaders();
                for (HttpString name : responseHeaders.getHeaderNames()) {
                    headers += responseHeaders.get(name).size();
                }

                putInt(buffer, headers);


                for (final HttpString header : responseHeaders.getHeaderNames()) {
                    for (String headerValue : responseHeaders.get(header)) {
                        Integer headerCode = HEADER_MAP.get(header);
                        if (headerCode != null) {
                            putInt(buffer, headerCode);
                        } else {
                            putString(buffer, header.toString());
                        }
                        putString(buffer, headerValue);
                    }
                }

                int dataLength = buffer.position() - 4;
                buffer.put(2, (byte) ((dataLength >> 8) & 0xFF));
                buffer.put(3, (byte) (dataLength & 0xFF));
                buffer.flip();
                newState = (newState & ~FLAG_START);
            } else {
                //otherwise we just write out the get request body chunk and return
                ByteBuffer readBuffer = readBodyChunkBuffer;
                do {
                    int res = next.write(readBuffer);
                    if (res == 0) {
                        stateUpdater.set(this, newState & ~FLAG_WRITE_ENTERED); //clear the write entered flag
                        return false;
                    }
                } while (readBodyChunkBuffer.hasRemaining());
                readBodyChunkBuffer = null;
                stateUpdater.set(this, newState & ~FLAG_WRITE_ENTERED); //clear the write entered flag
                return true;
            }
        }

        if (currentDataBuffer != null) {
            if (!writeCurrentBuffer()) {
                stateUpdater.set(this, newState & ~FLAG_WRITE_ENTERED); //clear the write entered flag
                return false;
            }
        }

        //now next writing to the active request channel, so it can send
        //its messages
        ByteBuffer readBuffer = readBodyChunkBuffer;
        if (readBuffer != null) {
            do {
                int res = next.write(readBuffer);
                if (res == 0) {
                    stateUpdater.set(this, newState & ~FLAG_WRITE_ENTERED); //clear the write entered flag
                    return false;
                }
            } while (readBodyChunkBuffer.hasRemaining());
            readBodyChunkBuffer = null;
        }

        if (anyAreSet(state, FLAG_SHUTDOWN) && allAreClear(state, FLAG_CLOSE_QUEUED)) {
            newState = newState | FLAG_CLOSE_QUEUED;
            currentDataBuffer = pool.allocate();
            final ByteBuffer buffer = currentDataBuffer.getResource();
            packetHeaderAndDataBuffer = new ByteBuffer[1];
            packetHeaderAndDataBuffer[0] = buffer;
            buffer.put((byte) 'A');
            buffer.put((byte) 'B');
            buffer.put((byte) 0);
            buffer.put((byte) 2);
            buffer.put((byte) 5);
            buffer.put((byte) (exchange.isPersistent() ? 1 : 0)); //reuse
            buffer.flip();
            if (!writeCurrentBuffer()) {
                stateUpdater.set(this, newState & ~FLAG_WRITE_ENTERED); //clear the write entered flag
                return false;
            }
        }
        if (newState != writeEnteredState) {
            stateUpdater.set(this, newState);
        }
        return true;
    }

    private boolean writeCurrentBuffer() throws IOException {
        long toWrite = 0;
        for (ByteBuffer b : this.packetHeaderAndDataBuffer) {
            toWrite += b.remaining();
        }
        long r = 0;
        do {
            r = next.write(this.packetHeaderAndDataBuffer, 0, this.packetHeaderAndDataBuffer.length);
            if (r == -1) {
                throw new ClosedChannelException();
            } else if (r == 0) {
                return false;
            }
            toWrite -= r;
        } while (toWrite > 0);
        currentDataBuffer.free();
        this.currentDataBuffer = null;
        return true;
    }


    public int write(final ByteBuffer src) throws IOException {
        if (!processWrite()) {
            return 0;
        }
        try {
            int limit = src.limit();
            try {
                if (src.remaining() > MAX_DATA_SIZE) {
                    src.limit(src.position() + MAX_DATA_SIZE);
                }
                final int writeSize = src.remaining();
                final ByteBuffer[] buffers = createHeader(src);
                int toWrite = 0;
                for (ByteBuffer buffer : buffers) {
                    toWrite += buffer.remaining();
                }
                int total = 0;
                long r = 0;
                do {
                    r = next.write(buffers, 0, buffers.length);
                    total += r;
                    toWrite -= r;
                    if (r == -1) {
                        throw new ClosedChannelException();
                    } else if (r == 0) {
                        //we need to copy all the remaining bytes
                        Pooled<ByteBuffer> newPooledBuffer = pool.allocate();
                        while (src.hasRemaining()) {
                            newPooledBuffer.getResource().put(src);
                        }
                        newPooledBuffer.getResource().flip();
                        ByteBuffer[] savedBuffers = new ByteBuffer[3];
                        savedBuffers[0] = buffers[0];
                        savedBuffers[1] = newPooledBuffer.getResource();
                        savedBuffers[2] = buffers[2];
                        this.packetHeaderAndDataBuffer = savedBuffers;
                        this.currentDataBuffer = newPooledBuffer;

                        return writeSize;
                    }
                } while (toWrite > 0);
                return total;
            } finally {
                src.limit(limit);
            }
        } finally {
            exitWrite();
        }

    }

    private ByteBuffer[] createHeader(final ByteBuffer src) {
        int remaining = src.remaining();
        int chunkSize = remaining + 4;
        byte[] header = new byte[7];
        header[0] = (byte) 'A';
        header[1] = (byte) 'B';
        header[2] = (byte) ((chunkSize >> 8) & 0xFF);
        header[3] = (byte) (chunkSize & 0xFF);
        header[4] = (byte) (3 & 0xFF);
        header[5] = (byte) ((remaining >> 8) & 0xFF);
        header[6] = (byte) (remaining & 0xFF);

        byte[] footer = new byte[1];
        footer[0] = 0;

        final ByteBuffer[] buffers = new ByteBuffer[3];
        buffers[0] = ByteBuffer.wrap(header);
        buffers[1] = src;
        buffers[2] = ByteBuffer.wrap(footer);
        return buffers;
    }

    private void exitWrite() {
        stateUpdater.set(this, state & ~FLAG_WRITE_ENTERED);
    }

    public long write(final ByteBuffer[] srcs) throws IOException {
        return write(srcs, 0, srcs.length);
    }

    public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
        long total = 0;
        for (int i = offset; i < offset + length; ++i) {
            while (srcs[i].hasRemaining()) {
                int written = write(srcs[i]);
                if (written <= 0 && total == 0) {
                    return written;
                } else if (written <= 0) {
                    return total;
                }
                total += written;
            }
        }
        return total;
    }

    public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
        return src.transferTo(position, count, new ConduitWritableByteChannel(this));
    }

    public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
        return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
    }

    public boolean flush() throws IOException {
        if (!processWrite()) {
            return false;
        }
        try {
            int state = this.state;
            if (allAreSet(state, FLAG_SHUTDOWN) && allAreClear(state, FLAG_DELEGATE_SHUTDOWN)) {
                if(!exchange.isPersistent()) {
                    next.terminateWrites();
                }
                if(finishListener != null) {
                    finishListener.handleEvent(this);
                }
                stateUpdater.set(this, state | FLAG_DELEGATE_SHUTDOWN);
            }
            return next.flush();
        } finally {
            exitWrite();
        }
    }

    public void suspendWrites() {
        log.trace("suspend");
        next.suspendWrites();
    }

    public void resumeWrites() {
        log.trace("resume");
        next.resumeWrites();
    }

    public boolean isWriteResumed() {
        return next.isWriteResumed();
    }

    public void wakeupWrites() {
        log.trace("wakeup");
        next.wakeupWrites();
    }

    public void terminateWrites() throws IOException {
        int oldState = 0, newState = 0;
        do {
            oldState = this.state;
            if (anyAreSet(oldState, FLAG_SHUTDOWN)) {
                return;
            }
            newState = oldState | FLAG_SHUTDOWN;
        } while (!stateUpdater.compareAndSet(this, oldState, newState));
        if (allAreClear(oldState, FLAG_START) &&
                readBodyChunkBuffer == null &&
                packetHeaderAndDataBuffer == null) {
            if(!exchange.isPersistent()) {
                next.terminateWrites();
            }
            if(finishListener != null) {
                finishListener.handleEvent(this);
            }
            newState |= FLAG_DELEGATE_SHUTDOWN;
            while (stateUpdater.compareAndSet(this, oldState, newState)) {
                oldState = state;
                newState = oldState | FLAG_DELEGATE_SHUTDOWN;
            }
        }
    }

    public void awaitWritable() throws IOException {
        next.awaitWritable();
    }

    public void awaitWritable(final long time, final TimeUnit timeUnit) throws IOException {
        next.awaitWritable(time, timeUnit);
    }

    public boolean doGetRequestBodyChunk(ByteBuffer buffer, final AjpRequestConduit requestChannel) throws IOException {
        this.readBodyChunkBuffer = buffer;
        boolean result = processWrite();
        if (result) {
            exitWrite();
        } else {
            //if this write does not work we spawn a thread to force it out.
            //this is not great, but there is not really a great deal we can do here
            //there is probably a better way to deal with this, but I am not really sure what it is
            this.exchange.getConnection().getWorker().submit(new Runnable() {
                @Override
                public void run() {
                    try {
                        while (AjpResponseConduit.this.readBodyChunkBuffer != null) {
                            next.awaitWritable();
                            boolean result = processWrite();
                            if (result) {
                                exitWrite();
                            }
                        }
                    } catch (IOException e) {
                        if (requestChannel.isReadResumed()) {
                            requestChannel.wakeupReads();
                        }
                        if (isWriteResumed()) {
                            next.wakeupWrites();
                        }
                        UndertowLogger.REQUEST_IO_LOGGER.ioException(e);
                    }
                }
            });
        }

        return result;
    }
}
TOP

Related Classes of io.undertow.ajp.AjpResponseConduit

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.