Package io.undertow.servlet.spec

Source Code of io.undertow.servlet.spec.AsyncContextImpl

/*
* JBoss, Home of Professional Open Source.
* Copyright 2012 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.undertow.servlet.spec;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

import javax.servlet.AsyncContext;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.DispatcherType;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import io.undertow.server.HttpHandler;
import io.undertow.server.HttpHandlers;
import io.undertow.server.HttpServerExchange;
import io.undertow.servlet.UndertowServletLogger;
import io.undertow.servlet.UndertowServletMessages;
import io.undertow.servlet.api.Deployment;
import io.undertow.servlet.api.InstanceFactory;
import io.undertow.servlet.api.ServletDispatcher;
import io.undertow.servlet.api.ThreadSetupAction;
import io.undertow.servlet.core.CompositeThreadSetupAction;
import io.undertow.servlet.handlers.ServletAttachments;
import io.undertow.servlet.handlers.ServletPathMatch;
import io.undertow.util.AttachmentKey;
import io.undertow.util.SameThreadExecutor;
import org.xnio.XnioExecutor;

/**
* @author Stuart Douglas
*/
public class AsyncContextImpl implements AsyncContext {

    public static final AttachmentKey<Boolean> ASYNC_SUPPORTED = AttachmentKey.create(Boolean.class);
    public static final AttachmentKey<Executor> ASYNC_EXECUTOR = AttachmentKey.create(Executor.class);

    private final List<BoundAsyncListener> asyncListeners = new CopyOnWriteArrayList<BoundAsyncListener>();

    private final HttpServerExchange exchange;
    private final ServletRequest servletRequest;
    private final ServletResponse servletResponse;
    private final TimeoutTask timeoutTask = new TimeoutTask();

    private AsyncContextImpl previousAsyncContext; //the previous async context


    //todo: make default configurable
    private volatile long timeout = 120000;

    private volatile XnioExecutor.Key timeoutKey;

    private boolean dispatched;
    private boolean initialRequestDone;
    private Thread initiatingThread;

    private final Deque<Runnable> asyncTaskQueue = new ArrayDeque<>();
    private boolean processingAsyncTask = false;

    public AsyncContextImpl(final HttpServerExchange exchange, final ServletRequest servletRequest, final ServletResponse servletResponse, final AsyncContextImpl previousAsyncContext) {
        this.exchange = exchange;
        this.servletRequest = servletRequest;
        this.servletResponse = servletResponse;
        this.previousAsyncContext = previousAsyncContext;
        initiatingThread = Thread.currentThread();
        exchange.dispatch(SameThreadExecutor.INSTANCE, new Runnable() {
            @Override
            public void run() {
                exchange.setDispatchExecutor(null);
                initialRequestDone();
            }
        });
    }

    public void updateTimeout() {
        XnioExecutor.Key key = this.timeoutKey;
        if (key != null) {
            if (!key.remove()) {
                return;
            }
        }
        if (timeout > 0) {
            this.timeoutKey = exchange.getIoThread().executeAfter(timeoutTask, timeout, TimeUnit.MILLISECONDS);
        }
    }

    @Override
    public ServletRequest getRequest() {
        return servletRequest;
    }

    @Override
    public ServletResponse getResponse() {
        return servletResponse;
    }

    @Override
    public boolean hasOriginalRequestAndResponse() {
        return servletRequest instanceof HttpServletRequestImpl &&
                servletResponse instanceof HttpServletResponseImpl;
    }

    @Override
    public void dispatch() {
        final HttpServletRequestImpl requestImpl = HttpServletRequestImpl.getRequestImpl(servletRequest);
        final ServletPathMatch handler;
        Deployment deployment = requestImpl.getServletContext().getDeployment();
        if (servletRequest instanceof HttpServletRequest) {
            handler = deployment.getServletPaths().getServletHandlerByPath(((HttpServletRequest) servletRequest).getServletPath());
        } else {
            handler = deployment.getServletPaths().getServletHandlerByPath(exchange.getRelativePath());
        }

        final HttpServerExchange exchange = requestImpl.getExchange();

        exchange.putAttachment(HttpServletRequestImpl.DISPATCHER_TYPE_ATTACHMENT_KEY, DispatcherType.ASYNC);

        exchange.putAttachment(HttpServletRequestImpl.ATTACHMENT_KEY, servletRequest);
        exchange.putAttachment(HttpServletResponseImpl.ATTACHMENT_KEY, servletResponse);

        dispatchAsyncRequest(deployment.getServletDispatcher(), handler, exchange);
    }

    private void dispatchAsyncRequest(final ServletDispatcher servletDispatcher, final ServletPathMatch pathInfo, final HttpServerExchange exchange) {
        doDispatch(new Runnable() {
            @Override
            public void run() {
                HttpHandlers.executeRootHandler(new HttpHandler() {
                    @Override
                    public void handleRequest(final HttpServerExchange exchange) throws Exception {
                        servletDispatcher.dispatchToPath(exchange, pathInfo, DispatcherType.ASYNC);
                    }
                }, exchange, false);
            }
        });
    }

    @Override
    public void dispatch(final String path) {
        dispatch(servletRequest.getServletContext(), path);
    }

    @Override
    public void dispatch(final ServletContext context, final String path) {

        HttpServletRequestImpl requestImpl = HttpServletRequestImpl.getRequestImpl(servletRequest);
        HttpServletResponseImpl responseImpl = HttpServletResponseImpl.getResponseImpl(servletResponse);
        final HttpServerExchange exchange = requestImpl.getExchange();

        exchange.putAttachment(HttpServletRequestImpl.DISPATCHER_TYPE_ATTACHMENT_KEY, DispatcherType.ASYNC);

        requestImpl.setAttribute(ASYNC_REQUEST_URI, requestImpl.getRequestURI());
        requestImpl.setAttribute(ASYNC_CONTEXT_PATH, requestImpl.getContextPath());
        requestImpl.setAttribute(ASYNC_SERVLET_PATH, requestImpl.getServletPath());
        requestImpl.setAttribute(ASYNC_QUERY_STRING, requestImpl.getQueryString());

        String newQueryString = "";
        int qsPos = path.indexOf("?");
        String newServletPath = path;
        if (qsPos != -1) {
            newQueryString = newServletPath.substring(qsPos + 1);
            newServletPath = newServletPath.substring(0, qsPos);
        }
        String newRequestUri = context.getContextPath() + newServletPath;

        //todo: a more efficent impl
        Map<String, Deque<String>> newQueryParameters = new HashMap<String, Deque<String>>();
        for (String part : newQueryString.split("&")) {
            String name = part;
            String value = "";
            int equals = part.indexOf('=');
            if (equals != -1) {
                name = part.substring(0, equals);
                value = part.substring(equals + 1);
            }
            Deque<String> queue = newQueryParameters.get(name);
            if (queue == null) {
                newQueryParameters.put(name, queue = new ArrayDeque<String>(1));
            }
            queue.add(value);
        }
        requestImpl.setQueryParameters(newQueryParameters);

        requestImpl.getExchange().setRelativePath(newServletPath);
        requestImpl.getExchange().setQueryString(newQueryString);
        requestImpl.getExchange().setRequestPath(newRequestUri);
        requestImpl.getExchange().setRequestURI(newRequestUri);
        requestImpl.setServletContext((ServletContextImpl) context);
        responseImpl.setServletContext((ServletContextImpl) context);

        Deployment deployment = requestImpl.getServletContext().getDeployment();
        ServletPathMatch info = deployment.getServletPaths().getServletHandlerByPath(newServletPath);
        requestImpl.getExchange().putAttachment(ServletAttachments.SERVLET_PATH_MATCH, info);

        dispatchAsyncRequest(deployment.getServletDispatcher(), info, exchange);
    }

    @Override
    public synchronized void complete() {
        onAsyncComplete();
        completeInternal();
    }

    public synchronized void completeInternal() {

        if (!initialRequestDone && Thread.currentThread() == initiatingThread) {
            //the context was stopped in the same request context it was started, we don't do anything
            if (dispatched) {
                throw UndertowServletMessages.MESSAGES.asyncRequestAlreadyDispatched();
            }
            exchange.unDispatch();
            dispatched = true;
            HttpServletRequestImpl request = HttpServletRequestImpl.getRequestImpl(servletRequest);
            initialRequestDone();
            request.asyncRequestDispatched();
        } else {
            doDispatch(new Runnable() {
                @Override
                public void run() {
                    //we do not run the ServletRequestListeners here, as the request does not come into the scope
                    //of a web application, as defined by the javadoc on ServletRequestListener
                    HttpServletResponseImpl response = HttpServletResponseImpl.getResponseImpl(servletResponse);
                    response.responseDone();
                }
            });
        }
    }

    @Override
    public void start(final Runnable run) {
        Executor executor = asyncExecutor();
        final CompositeThreadSetupAction setup = HttpServletRequestImpl.getRequestImpl(servletRequest).getServletContext().getDeployment().getThreadSetupAction();
        executor.execute(new Runnable() {
            @Override
            public void run() {
                ThreadSetupAction.Handle handle = setup.setup(null);
                try {
                    run.run();
                } finally {
                    handle.tearDown();
                }
            }
        });

    }

    private Executor asyncExecutor() {
        Executor executor = exchange.getAttachment(ASYNC_EXECUTOR);
        if (executor == null) {
            executor = exchange.getDispatchExecutor();
        }
        if (executor == null) {
            executor = exchange.getConnection().getWorker();
        }
        return executor;
    }


    @Override
    public void addListener(final AsyncListener listener) {
        asyncListeners.add(new BoundAsyncListener(listener, servletRequest, servletResponse));
    }

    @Override
    public void addListener(final AsyncListener listener, final ServletRequest servletRequest, final ServletResponse servletResponse) {
        asyncListeners.add(new BoundAsyncListener(listener, servletRequest, servletResponse));
    }

    public boolean isDispatched() {
        return dispatched;
    }

    @Override
    public <T extends AsyncListener> T createListener(final Class<T> clazz) throws ServletException {
        try {
            InstanceFactory<T> factory = ((ServletContextImpl) this.servletRequest.getServletContext()).getDeployment().getDeploymentInfo().getClassIntrospecter().createInstanceFactory(clazz);
            return factory.createInstance().getInstance();
        } catch (NoSuchMethodException e) {
            throw new ServletException(e);
        } catch (InstantiationException e) {
            throw new ServletException(e);
        }
    }

    @Override
    public void setTimeout(final long timeout) {
        this.timeout = timeout;
    }

    @Override
    public long getTimeout() {
        return timeout;
    }

    public void handleError(final Throwable error) {
        dispatched = false; //we reset the dispatched state
        onAsyncError(error);
        if(!dispatched) {
            servletRequest.setAttribute(RequestDispatcher.ERROR_EXCEPTION, error);
            try {
                ((HttpServletResponse)servletResponse).sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
            } catch (IOException e) {
                //ignore, not much we can do here
            }
            if(!dispatched) {
                complete();
            }
        }
    }

    /**
     * Called by the container when the initial request is finished.
     * If this request has a dispatch or complete call pending then
     * this will be started.
     */
    public synchronized void initialRequestDone() {
        initialRequestDone = true;
        if(previousAsyncContext != null) {
            previousAsyncContext.onAsyncStart(this);
            previousAsyncContext = null;
        }
        if (!processingAsyncTask) {
            processAsyncTask();
        }
        initiatingThread = null;
    }


    private synchronized void doDispatch(final Runnable runnable) {
        if (dispatched) {
            throw UndertowServletMessages.MESSAGES.asyncRequestAlreadyDispatched();
        }
        dispatched = true;
        final HttpServletRequestImpl request = HttpServletRequestImpl.getRequestImpl(servletRequest);
        addAsyncTask(new Runnable() {
            @Override
            public void run() {
                request.asyncRequestDispatched();
                runnable.run();
            }
        });
        if (timeoutKey != null) {
            timeoutKey.remove();
        }
    }


    private final class TimeoutTask implements Runnable {

        @Override
        public void run() {
            synchronized (AsyncContextImpl.this) {
                if (!dispatched) {
                    UndertowServletLogger.REQUEST_LOGGER.debug("Async request timed out");
                    onAsyncTimeout();
                    if(!dispatched) {
                        //servlet
                        try {
                            ((HttpServletResponse)servletResponse).sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
                        } catch (IOException e) {
                            //ignore
                        }
                        if(!dispatched) {
                            complete();
                        }
                    }
                }
            }
        }
    }

    private synchronized void processAsyncTask() {
        if (!initialRequestDone) {
            return;
        }
        updateTimeout();
        final Runnable task = asyncTaskQueue.poll();
        if (task != null) {
            processingAsyncTask = true;
            asyncExecutor().execute(new TaskDispatchRunnable(task));
        } else {
            processingAsyncTask = false;
        }
    }

    /**
     * Adds a task to be run to the async context. These tasks are run one at a time,
     * after the initial request is finished. If the request is dispatched before the initial
     * request is complete then these tasks will not be run
     * <p/>
     * <p/>
     * This method is intended to be used to queue read and write tasks for async streams,
     * to make sure that multiple threads do not end up working on the same exchange at once
     *
     * @param runnable The runnable
     */
    public synchronized void addAsyncTask(final Runnable runnable) {
        asyncTaskQueue.add(runnable);
        if (!processingAsyncTask) {
            processAsyncTask();
        }
    }

    private class TaskDispatchRunnable implements Runnable {

        private final Runnable task;

        private TaskDispatchRunnable(final Runnable task) {
            this.task = task;
        }

        @Override
        public void run() {
            try {
                task.run();
            } finally {
                processAsyncTask();
            }
        }
    }


    private void onAsyncComplete() {
        for (final BoundAsyncListener listener : asyncListeners) {
            AsyncEvent event = new AsyncEvent(this, listener.servletRequest, listener.servletResponse);
            try {
                listener.asyncListener.onComplete(event);
            } catch (IOException e) {
                UndertowServletLogger.REQUEST_LOGGER.ioExceptionDispatchingAsyncEvent(e);
            }
        }
    }

    private void onAsyncTimeout() {
        for (final BoundAsyncListener listener : asyncListeners) {
            AsyncEvent event = new AsyncEvent(this, listener.servletRequest, listener.servletResponse);
            try {
                listener.asyncListener.onTimeout(event);
            } catch (IOException e) {
                UndertowServletLogger.REQUEST_LOGGER.ioExceptionDispatchingAsyncEvent(e);
            }
        }
    }

    private void onAsyncStart(AsyncContext newAsyncContext) {
        for (final BoundAsyncListener listener : asyncListeners) {
            //make sure we use the new async context
            AsyncEvent event = new AsyncEvent(newAsyncContext, listener.servletRequest, listener.servletResponse);
            try {
                listener.asyncListener.onStartAsync(event);
            } catch (IOException e) {
                UndertowServletLogger.REQUEST_LOGGER.ioExceptionDispatchingAsyncEvent(e);
            }
        }
    }

    private void onAsyncError(Throwable t) {
        for (final BoundAsyncListener listener : asyncListeners) {
            AsyncEvent event = new AsyncEvent(this, listener.servletRequest, listener.servletResponse, t);
            try {
                listener.asyncListener.onError(event);
            } catch (IOException e) {
                UndertowServletLogger.REQUEST_LOGGER.ioExceptionDispatchingAsyncEvent(e);
            }
        }
    }

    private final class BoundAsyncListener {
        final AsyncListener asyncListener;
        final ServletRequest servletRequest;
        final ServletResponse servletResponse;

        private BoundAsyncListener(final AsyncListener asyncListener, final ServletRequest servletRequest, final ServletResponse servletResponse) {
            this.asyncListener = asyncListener;
            this.servletRequest = servletRequest;
            this.servletResponse = servletResponse;
        }
    }
}
TOP

Related Classes of io.undertow.servlet.spec.AsyncContextImpl

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.