/*
* Copyright (c) 2002-2012 Alibaba Group Holding Limited.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.citrus.service.requestcontext;
import static com.alibaba.citrus.generictype.TypeInfoUtil.*;
import static com.alibaba.citrus.service.requestcontext.util.RequestContextUtil.*;
import static com.alibaba.citrus.test.TestEnvStatic.*;
import static com.alibaba.citrus.util.StringUtil.*;
import static org.junit.Assert.*;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession;
import com.alibaba.citrus.springext.support.context.XmlApplicationContext;
import com.alibaba.citrus.test.runner.TestNameAware;
import com.alibaba.citrus.util.internal.Servlet3Util;
import com.alibaba.citrus.util.io.StreamUtil;
import com.meterware.httpunit.GetMethodWebRequest;
import com.meterware.httpunit.IllegalRequestParameterException;
import com.meterware.httpunit.WebForm;
import com.meterware.httpunit.WebRequest;
import com.meterware.httpunit.WebResponse;
import com.meterware.httpunit.protocol.UploadFileSpec;
import com.meterware.servletunit.InvocationContext;
import com.meterware.servletunit.PatchedServletRunner;
import com.meterware.servletunit.ServletRunner;
import com.meterware.servletunit.ServletUnitClient;
import org.junit.After;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.io.FileSystemResource;
import org.springframework.web.context.request.RequestContextHolder;
/**
* 用来测试RequestContextChainingService及相关类的基类。
*
* @author Michael Zhou
*/
@RunWith(TestNameAware.class)
public abstract class AbstractRequestContextsTests<RC extends RequestContext> {
@SuppressWarnings("unchecked")
protected final Class<RC> requestContextInterface = (Class<RC>) resolveParameter(getClass(),
AbstractRequestContextsTests.class, 0).getRawType();
// container
protected static BeanFactory factory;
// web client
protected ServletUnitClient client;
protected WebResponse clientResponse;
// servlet request/response
protected InvocationContext invocationContext;
protected HttpServletRequest request;
protected HttpServletResponse response;
protected ServletConfig config;
// request contexts
protected RequestContextChainingService requestContexts;
protected RC requestContext;
protected HttpServletRequest newRequest;
protected HttpServletResponse newResponse;
static {
Servlet3Util.setDisableServlet3Features(true); // 禁用servlet3,因为httpunit还不支持
}
/** 创建beanFactory。 */
protected static void createBeanFactory(String configLocation) {
factory = new XmlApplicationContext(new FileSystemResource(new File(srcdir, configLocation)));
}
/** 创建web client,注册servlets。 */
@Before
public final void prepareWebClient() throws Exception {
// Servlet container
ServletRunner servletRunner = new PatchedServletRunner();
registerServlets(servletRunner);
servletRunner.registerServlet("/readfile/*", ReadFileServlet.class.getName());
servletRunner.registerServlet("/servlet/*", NoopServlet.class.getName());
servletRunner.registerServlet("*.do", NoopServlet.class.getName());
// Servlet client
client = servletRunner.newClient();
}
protected void registerServlets(ServletRunner runner) {
}
@After
public final void clearWebEnv() {
RequestContextHolder.resetRequestAttributes();
}
/** 调用noop servlet,取得request/response。 */
protected final void invokeNoopServlet(String uri) throws Exception {
if (uri != null && uri.startsWith("http")) {
uri = URI.create(uri).normalize().toString(); // full uri
} else {
uri = URI.create("http://www.taobao.com/" + trimToEmpty(uri)).normalize().toString(); // partial uri
}
invokeNoopServlet(new GetMethodWebRequest(uri));
}
/** 调用noop servlet,取得request/response。 */
protected final void invokeNoopServlet(WebRequest req) throws Exception {
invocationContext = client.newInvocation(req);
request = new MyHttpRequest(invocationContext.getRequest(), req.getURL().toExternalForm());
response = new MyHttpResponse(invocationContext.getResponse());
config = invocationContext.getServlet().getServletConfig();
}
/** 调用readfile servlet,取得request/response。 */
protected final void invokeReadFileServlet(String htmlfile) throws Exception {
String uri = URI.create("http://www.taobao.com/readfile" + "?file=" + htmlfile).normalize().toString();
// 取得初始页面和form
WebResponse response = client.getResponse(new GetMethodWebRequest(uri));
WebForm form = response.getFormWithName("myform");
// 取得提交form的request
WebRequest request = form.getRequest();
request.setParameter("myparam", new String[] { "hello",
"中华人民共和国" });
try {
request.setParameter("myfile", new UploadFileSpec[] { //
new UploadFileSpec(new File(srcdir, "smallfile.txt")), //
new UploadFileSpec(new File(srcdir, "smallfile_.JPG")), //
new UploadFileSpec(new File(srcdir, "smallfile.gif")), //
new UploadFileSpec(new File(srcdir, "smallfile")), //
});
} catch (IllegalRequestParameterException e) {
}
client.putCookie("mycookie", "mycookievalue");
invocationContext = client.newInvocation(request);
this.request = new MyHttpRequest(invocationContext.getRequest(), uri);
// 因为页面的content type是text/html; charset=UTF-8,
// 所以应该以UTF-8方式解析request。
this.request.setCharacterEncoding("UTF-8");
this.response = new MyHttpResponse(invocationContext.getResponse());
this.config = invocationContext.getServlet().getServletConfig();
}
/** 取得request context。 */
protected final void initRequestContext() throws Exception {
initRequestContext(null);
}
/** 取得request context。 */
protected final void initRequestContext(String beanName) throws Exception {
if (beanName == null) {
beanName = getDefaultBeanName();
}
requestContexts = (RequestContextChainingService) factory.getBean(beanName);
RequestContext topRC = requestContexts.getRequestContext(config.getServletContext(), request, response);
assertNotNull(topRC);
requestContext = findRequestContext(topRC, requestContextInterface);
assertNotNull(requestContextInterface.getName(), requestContext);
newRequest = requestContext.getRequest();
newResponse = requestContext.getResponse();
afterInitRequestContext();
}
protected void afterInitRequestContext() throws Exception {
}
/** 将服务端response提交到client。 */
protected final void commitToClient() throws Exception {
clientResponse = client.getResponse(invocationContext);
}
/** 从request context interface中取得默认bean名称。 */
protected String getDefaultBeanName() {
String name = requestContextInterface.getSimpleName();
Matcher matcher = Pattern.compile("(\\w+)RequestContext").matcher(name);
assertTrue(name, matcher.find());
return com.alibaba.citrus.util.StringUtil.toCamelCase(matcher.group(1));
}
/** 不做任何事的servlet。 */
public static class NoopServlet extends HttpServlet {
private static final long serialVersionUID = 3034658026956449398L;
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException,
IOException {
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException,
IOException {
doGet(request, response);
}
}
/** 返回文件内容的servlet。 */
public static class ReadFileServlet extends HttpServlet {
private static final long serialVersionUID = 3689913963685360948L;
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException,
IOException {
response.setContentType("text/html; charset=UTF-8");
PrintWriter out = response.getWriter();
String html = StreamUtil.readText(new FileInputStream(new File(srcdir, request.getParameter("file"))),
"GBK", true);
out.println(html);
}
@Override
protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException,
IOException {
doGet(request, response);
}
}
/**
* 由于httpunit目前未实现setCharacterEncoding方法,getQueryString()也实现得有问题,
* 所以只能将request包装一下。
*/
public static class MyHttpRequest extends HttpServletRequestWrapper {
private String charset;
private String overrideQueryString;
private String server = "www.taobao.com";
private int port = 80;
private boolean sessionCreated;
public MyHttpRequest(HttpServletRequest request, String uri) {
super(request);
if (uri != null) {
int index = uri.indexOf("?");
if (index >= 0) {
this.overrideQueryString = uri.substring(index + 1);
}
}
}
@Override
public String getQueryString() {
if (overrideQueryString == null) {
return super.getQueryString();
} else {
return overrideQueryString;
}
}
@Override
public String getCharacterEncoding() {
return charset;
}
@Override
public void setCharacterEncoding(String charset) throws UnsupportedEncodingException {
this.charset = charset;
}
/** 默认实现总是返回localhost,只好覆盖此方法。 */
@Override
public String getServerName() {
return server;
}
public void setServerName(String server) {
this.server = server;
}
/** 默认实现总是返回0,只好覆盖此方法。 */
@Override
public int getServerPort() {
return port;
}
public void setServerPort(int port) {
this.port = port;
}
/** 监视getSession方法的调用。 */
public boolean isSessionCreated() {
return sessionCreated;
}
@Override
public HttpSession getSession() {
sessionCreated = true;
return super.getSession();
}
@Override
public HttpSession getSession(boolean create) {
if (create) {
sessionCreated = true;
}
return super.getSession(create);
}
}
protected void resetCalled() {
}
/** 由于httpunit目前未实现commit以后抛IllegalStateException,所以只能将response包装一下。 */
public class MyHttpResponse extends HttpServletResponseWrapper {
private boolean committed;
public MyHttpResponse(HttpServletResponse response) {
super(response);
}
@Override
public boolean isCommitted() {
return super.isCommitted() || committed;
}
@Override
public void sendError(int sc, String msg) throws IOException {
ensureNotCommited();
super.sendError(sc, msg);
committed = true;
}
@Override
public void sendError(int sc) throws IOException {
ensureNotCommited();
super.sendError(sc);
committed = true;
}
@Override
public void sendRedirect(String location) throws IOException {
ensureNotCommited();
super.sendRedirect(location);
committed = true;
}
@Override
public ServletOutputStream getOutputStream() throws IOException {
ensureNotCommited();
return super.getOutputStream();
}
@Override
public PrintWriter getWriter() throws IOException {
ensureNotCommited();
return super.getWriter();
}
@Override
public void reset() {
ensureNotCommited();
super.reset();
resetCalled();
}
@Override
public void resetBuffer() {
ensureNotCommited();
super.resetBuffer();
}
@Override
public void setLocale(Locale locale) {
// 防止unsupported operation exception
}
@Override
public void setContentType(String type) {
if (type == null || type.indexOf("charset=") == -1) {
setCharacterEncoding(null);
}
super.setContentType(type);
}
@Override
public void setBufferSize(int size) {
ensureNotCommited();
super.setBufferSize(size);
}
private void ensureNotCommited() {
if (isCommitted()) {
throw new IllegalStateException();
}
}
}
}