Package com.alibaba.citrus.service.requestcontext.impl

Source Code of com.alibaba.citrus.service.requestcontext.impl.RequestContextChainingServiceImpl

/*
* Copyright (c) 2002-2012 Alibaba Group Holding Limited.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*    http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.alibaba.citrus.service.requestcontext.impl;

import static com.alibaba.citrus.service.requestcontext.util.RequestContextUtil.*;
import static com.alibaba.citrus.util.ArrayUtil.*;
import static com.alibaba.citrus.util.Assert.*;
import static com.alibaba.citrus.util.CollectionUtil.*;
import static com.alibaba.citrus.util.ObjectUtil.*;
import static com.alibaba.citrus.util.internal.Servlet3Util.*;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.alibaba.citrus.service.AbstractService;
import com.alibaba.citrus.service.requestcontext.RequestContext;
import com.alibaba.citrus.service.requestcontext.RequestContextChainingService;
import com.alibaba.citrus.service.requestcontext.RequestContextException;
import com.alibaba.citrus.service.requestcontext.RequestContextFactory;
import com.alibaba.citrus.service.requestcontext.RequestContextInfo;
import com.alibaba.citrus.service.requestcontext.RequestContextInfo.AfterFeature;
import com.alibaba.citrus.service.requestcontext.RequestContextInfo.BeforeFeature;
import com.alibaba.citrus.service.requestcontext.RequestContextInfo.FeatureOrder;
import com.alibaba.citrus.service.requestcontext.RequestContextInfo.RequiresFeature;
import com.alibaba.citrus.service.requestcontext.TwoPhaseCommitRequestContext;
import com.alibaba.citrus.service.requestcontext.util.RequestContextUtil;
import com.alibaba.citrus.util.ToStringBuilder;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

/**
* 将<code>RequestContext</code>对象串联起来的service。
* <p>
* 通过它可以实现多重包装的HTTP request和response。
* </p>
*/
public class RequestContextChainingServiceImpl extends AbstractService<RequestContextChainingService>
        implements RequestContextChainingService {
    private List<RequestContextFactory<?>> factories;
    private boolean                        sort;
    private boolean                        threadContextInheritable;

    public void setFactories(List<RequestContextFactory<?>> factories) {
        this.factories = createArrayList(assertNotNull(factories, "factories"));
    }

    public void setSort(boolean sort) {
        this.sort = sort;
    }

    public void setThreadContextInheritable(boolean threadContextInheritable) {
        this.threadContextInheritable = threadContextInheritable;
    }

    @Override
    protected void init() {
        if (factories == null) {
            getLogger().warn("no request context factory defined");
            factories = createArrayList();
        }

        if (sort) {
            // 排序
            Map<Integer, RequestContextFactory<?>> result = createLinkedHashMap();
            Set<Integer> processing = createHashSet();

            for (int i = 0; i < factories.size(); i++) {
                if (!result.containsKey(i)) {
                    sort(i, factories.get(i), result, processing);
                }
            }

            factories = createArrayList(result.values());

            // 检查缺失的features
            Set<String> usableFeatures = createHashSet();

            for (RequestContextFactory<?> f : factories) {
                usableFeatures.addAll(asList(f.getFeatures()));

                FeatureOrder[] featureOrders = f.featureOrders();

                if (featureOrders != null) {
                    for (FeatureOrder requiresFeature : featureOrders) {
                        if (requiresFeature instanceof RequiresFeature) {
                            if (!usableFeatures.contains(requiresFeature.feature)) {
                                throw new IllegalArgumentException(String.format(
                                        "Missing feature of %s, which is required by %s", requiresFeature.feature, f));
                            }
                        }
                    }
                }
            }
        }

        getLogger().debug("Initialized {}", this);
    }

    /** 对factories进行排序,以便符合各自的限定。 */
    private void sort(int index, RequestContextFactory<?> f, Map<Integer, RequestContextFactory<?>> result,
                      Set<Integer> processing) {
        if (!processing.contains(index)) {
            processing.add(index);

            for (Map.Entry<Integer, RequestContextFactory<?>> entry : getFactoriesBefore(f).entrySet()) {
                sort(entry.getKey(), entry.getValue(), result, processing);
            }

            result.put(index, f);
            processing.remove(index);
        }
    }

    /** 取得所有在f之前的factories。 */
    private Map<Integer, RequestContextFactory<?>> getFactoriesBefore(RequestContextFactory<?> f) {
        Map<Integer, RequestContextFactory<?>> allBefore = createLinkedHashMap();

        for (int i = 0; i < factories.size(); i++) {
            RequestContextFactory<?> test = factories.get(i);

            if (f != test && compare(test, f) < 0) {
                allBefore.put(i, test);
            }
        }

        return allBefore;
    }

    /**
     * 比较f1和f2,如f1在前,返回-1,如f1在后,则返回1,不确定则返回0。
     * <p>
     * 明确指定的胜于用*指定的。
     * </p>
     */
    private int compare(RequestContextFactory<?> f1, RequestContextFactory<?> f2) {
        boolean f1BeforeF2 = compare(f1, BeforeFeature.class, f2);
        boolean f2AfterF1 = compare(f2, AfterFeature.class, f1);

        if (f1BeforeF2 || f2AfterF1) {
            return -1;
        }

        boolean f2BeforeF1 = compare(f2, BeforeFeature.class, f1);
        boolean f1AfterF2 = compare(f1, AfterFeature.class, f2);

        if (f2BeforeF1 || f1AfterF2) {
            return 1;
        }

        boolean f1BeforeAll = compareWithAll(f1, BeforeFeature.class);
        boolean f2AfterAll = compareWithAll(f2, AfterFeature.class);

        if (f1BeforeAll || f2AfterAll) {
            return -1;
        }

        boolean f2BeforeAll = compareWithAll(f2, BeforeFeature.class);
        boolean f1AfterAll = compareWithAll(f1, AfterFeature.class);

        if (f2BeforeAll || f1AfterAll) {
            return 1;
        }

        return 0;
    }

    private boolean compare(RequestContextFactory<?> f1, Class<? extends FeatureOrder> type, RequestContextFactory<?> f2) {
        FeatureOrder[] featureOrders = f1.featureOrders();
        String[] features = f2.getFeatures();

        if (isEmptyArray(features) || isEmptyArray(featureOrders)) {
            return false;
        }

        for (String feature : features) {
            for (FeatureOrder featureOrder : featureOrders) {
                if (type.isInstance(featureOrder)) {
                    if (isEquals(featureOrder.feature, feature)) {
                        return true;
                    }
                }
            }
        }

        return false;
    }

    private boolean compareWithAll(RequestContextFactory<?> f1, Class<? extends FeatureOrder> type) {
        FeatureOrder[] featureOrders = f1.featureOrders();

        if (isEmptyArray(featureOrders)) {
            return false;
        }

        for (FeatureOrder featureOrder : featureOrders) {
            if (type.isInstance(featureOrder)) {
                if (isEquals(featureOrder.feature, "*")) {
                    return true;
                }
            }
        }

        return false;
    }

    /** 取得所有的request context的信息。 */
    public RequestContextInfo<?>[] getRequestContextInfos() {
        return factories.toArray(new RequestContextInfo<?>[factories.size()]);
    }

    /**
     * 取得<code>RequestContext</code>串。
     *
     * @param servletContext <code>ServletContext</code>对象
     * @param request        <code>HttpServletRequest</code>对象
     * @param response       <code>HttpServletResponse</code>对象
     * @return request context
     */
    public RequestContext getRequestContext(ServletContext servletContext, HttpServletRequest request,
                                            HttpServletResponse response) {
        assertInitialized();

        // 异步请求(dispatcherType == ASYNC)开始时,如果已经存在request context,则直接取得并返回之。
        boolean asyncDispatcher = request_isDispatcherType(request, DISPATCHER_TYPE_ASYNC);
        RequestContext requestContext = null;

        if (asyncDispatcher) {
            requestContext = RequestContextUtil.getRequestContext(request);
        }

        if (requestContext == null) {
            SimpleRequestContext innerReuestContext = new SimpleRequestContext(servletContext, request, response, this);

            requestContext = innerReuestContext;

            // 将requestContext放入request中,以便今后只需要用request就可以取得requestContext。
            // 及早设置setRequestContext,以便随后的prepareRequestContext就能使用。
            RequestContextUtil.setRequestContext(requestContext);

            for (RequestContextFactory<?> factory : factories) {
                requestContext = factory.getRequestContextWrapper(requestContext);

                // 调用<code>requestContext.prepare()</code>方法
                prepareRequestContext(requestContext);

                // 将requestContext放入request中,以便今后只需要用request就可以取得requestContext。
                RequestContextUtil.setRequestContext(requestContext);
            }

            innerReuestContext.setTopRequestContext(requestContext);

            getLogger().debug("Created a new request context: {}", requestContext);
        }

        // 无论是否是从async恢复,都需要重新设置thread的上下文环境。
        bind(requestContext.getRequest());

        return requestContext;
    }

    public void bind(HttpServletRequest request) {
        setupSpringWebEnvironment(request);
    }

    /**
     * 调用<code>requestContext.prepare()</code>方法。
     *
     * @param requestContext 要初始化的request context
     */
    private void prepareRequestContext(RequestContext requestContext) {
        if (getLogger().isTraceEnabled()) {
            getLogger().trace("Preparing request context: {}", requestContext.getClass().getSimpleName());
        }

        requestContext.prepare();
    }

    /**
     * 由外到内地调用<code>afterRequest()</code>方法。
     *
     * @param requestContext 要初始化的request context
     * @throws RequestContextException 如果失败
     */
    public void commitRequestContext(final RequestContext requestContext) throws RequestContextException {
        assertInitialized();

        // 异步处理状态:
        // - request.isAsyncStarted()    == true    - 当request.startAsync()被调用以后,且asyncContext.dispatch()及complete()未被调用前。
        // - request.getDispatcherType() == ASYNC   - 当asyncContext.dispatch()被调用以后,从dispatcher线程返回以前。
        //
        // 有四种情况会执行到这里:
        // 1. 普通请求(dispatcherType == REQUEST)结束时,request.startAsync()已被调用(isAsyncStarted == true)
        //    此时,需要注册AsyncListener.onComplete,使request context延迟到async结束时才被提交。
        // 2. 普通请求(dispatcherType == REQUEST)结束时,request.startAsync()未被调用或者异步已经结束(isAsyncStarted == false)
        //    此时,立即提交request context。
        // 3. 异步请求(dispatcherType == ASYNC)结束时,request.startAsync()再次被调用(isAsyncStarted == true)
        //    此时,不需要在这里注册AsyncListener,而是在前一个AsyncListener.onAsyncStarted()中重置listener,使request context延迟到async结束时才被提交。
        // 4. 异步请求(dispatcherType == ASYNC)结束时,request.startAsync()未被调用或者异步已经结束(isAsyncStarted == false)
        //    此时不需要做什么,因为request context已经被提交。

        HttpServletRequest request = requestContext.getRequest();
        boolean asyncDispatcher = request_isDispatcherType(request, DISPATCHER_TYPE_ASYNC);
        boolean asyncStarted = request_isAsyncStarted(request);

        if (!asyncDispatcher) {
            if (asyncStarted) {
                // 情况1
                if (getLogger().isDebugEnabled()) {
                    getLogger().debug("Keep request context open for asynchronous process");
                }

                request_registerAsyncListener(request, new Object() {
                    private Object thisListener;

                    public void setThisProxy(Object listener) {
                        this.thisListener = listener;
                    }

                    public void onComplete(Object /* AsyncEvent */ event) throws IOException {
                        getLogger().debug("Async task completed.");
                        doCommit(requestContext);
                    }

                    public void onTimeout(Object /* AsyncEvent */ event) throws IOException {
                    }

                    public void onError(Object /* AsyncEvent */ event) throws IOException {
                    }

                    public void onStartAsync(Object /* AsyncEvent */ event) throws IOException {
                        // 在情况3时自动调用
                        Object /* AsyncContext */ asyncContext = asyncEvent_getAsyncContext(event);
                        asyncContext_addAsyncListener(asyncContext, thisListener);
                    }
                });
            } else {
                // 情况2
                doCommit(requestContext);
            }
        }

        unbind(request);
    }

    public void unbind(HttpServletRequest request) {
        cleanupSpringWebEnvironment(request);
    }

    private void doCommit(RequestContext requestContext) {
        CommitMonitor monitor = getCommitMonitor(requestContext);

        synchronized (monitor) {
            if (!monitor.isCommitted()) {
                boolean doCommitHeaders = !monitor.isHeadersCommitted();
                monitor.setCommitted(true);

                HttpServletRequest request = requestContext.getRequest();

                for (RequestContext rc = requestContext; rc != null; rc = rc.getWrappedRequestContext()) {
                    if (getLogger().isTraceEnabled()) {
                        getLogger().trace("Committing request context: {}", rc.getClass().getSimpleName());
                    }

                    if (rc instanceof TwoPhaseCommitRequestContext && doCommitHeaders) {
                        ((TwoPhaseCommitRequestContext) rc).commitHeaders();
                    }

                    rc.commit();
                }

                // 将request和requestContext断开
                RequestContextUtil.removeRequestContext(request);

                getLogger().debug("Committed request: {}", request);
            }
        }
    }

    public void commitHeaders(RequestContext requestContext) {
        CommitMonitor monitor = getCommitMonitor(requestContext);

        synchronized (monitor) {
            if (!monitor.isHeadersCommitted()) {
                monitor.setHeadersCommitted(true);

                for (RequestContext rc = requestContext; rc != null; rc = rc.getWrappedRequestContext()) {
                    if (rc instanceof TwoPhaseCommitRequestContext) {
                        TwoPhaseCommitRequestContext tpc = (TwoPhaseCommitRequestContext) rc;

                        if (getLogger().isTraceEnabled()) {
                            getLogger().trace("Committing headers: {}", tpc.getClass().getSimpleName());
                        }

                        tpc.commitHeaders();
                    }
                }
            }
        }
    }

    private CommitMonitor getCommitMonitor(RequestContext requestContext) {
        CommitMonitor monitor = findRequestContext(requestContext, SimpleRequestContext.class);
        return assertNotNull(monitor, "no monitor");
    }

    private void setupSpringWebEnvironment(HttpServletRequest request) {
        ServletRequestAttributes attributes = new ServletRequestAttributes(request);
        LocaleContextHolder.setLocale(request.getLocale(), threadContextInheritable);
        RequestContextHolder.setRequestAttributes(attributes, threadContextInheritable);

        getLogger().debug("Bound request context to thread: {}", request);
    }

    private void cleanupSpringWebEnvironment(HttpServletRequest request) {
        RequestAttributes attributes = RequestContextHolder.getRequestAttributes();

        RequestContextHolder.resetRequestAttributes();
        LocaleContextHolder.resetLocaleContext();

        if (attributes instanceof ServletRequestAttributes) {
            ((ServletRequestAttributes) attributes).requestCompleted();
        }

        getLogger().debug("Cleared thread-bound request context: {}", request);
    }

    @Override
    public String toString() {
        ToStringBuilder buf = new ToStringBuilder();

        buf.append(getBeanDescription()).start();

        for (RequestContextFactory<?> factory : factories) {
            buf.format("  %s with features %s, ordering %s\n", factory.getRequestContextInterface().getSimpleName(),
                       asList(factory.getFeatures()), asList(factory.featureOrders()));
        }

        return buf.end().toString();
    }
}
TOP

Related Classes of com.alibaba.citrus.service.requestcontext.impl.RequestContextChainingServiceImpl

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.