/*
* Copyright (c) 2002-2012 Alibaba Group Holding Limited.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.citrus.test.util;
import static com.alibaba.citrus.springext.util.DomUtil.*;
import static com.alibaba.citrus.springext.util.SpringExtUtil.*;
import static com.alibaba.citrus.util.Assert.*;
import static com.alibaba.citrus.util.StringUtil.*;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.util.Locale;
import java.util.Map;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession;
import com.alibaba.citrus.service.requestcontext.RequestContext;
import com.alibaba.citrus.service.requestcontext.RequestContextChainingService;
import com.alibaba.citrus.springext.support.parser.AbstractNamedBeanDefinitionParser;
import com.meterware.httpunit.GetMethodWebRequest;
import com.meterware.httpunit.WebRequest;
import com.meterware.httpunit.WebResponse;
import com.meterware.servletunit.InvocationContext;
import com.meterware.servletunit.PatchedServletRunner;
import com.meterware.servletunit.ServletRunner;
import com.meterware.servletunit.ServletUnitClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.web.context.request.RequestContextHolder;
import org.w3c.dom.Element;
/**
* 创建用于测试的servlet容器,支持request contexts机制。
*
* @author Michael Zhou
*/
public class ServletTestContainer {
protected final ServletRunner servletRunner;
protected final ServletUnitClient client;
private InvocationContext invocationContext;
private HttpServletRequest rawRequest;
private HttpServletResponse rawResponse;
private ServletContext servletContext;
private RequestContextChainingService requestContexts;
private RequestContext requestContext;
private WebResponse clientResponse;
public ServletTestContainer() {
servletRunner = new PatchedServletRunner();
servletRunner.registerServlet("/noop/*", NoopServlet.class.getName());
init(servletRunner);
client = servletRunner.newClient();
}
protected void init(ServletRunner runner) {
}
public void setMappings(Map<String, String> mappings) {
for (Map.Entry<String, String> entry : mappings.entrySet()) {
String pattern = entry.getKey();
String servletClass = defaultIfEmpty(entry.getValue(), NoopServlet.class.getName());
servletRunner.registerServlet(pattern, servletClass);
}
}
public InvocationContext getInvocationContext() {
return invocationContext;
}
public ServletContext getServletContext() {
return servletContext;
}
public HttpServletRequest getRawRequest() {
return rawRequest;
}
public HttpServletResponse getRawResponse() {
return rawResponse;
}
public RequestContext getRequestContext() {
return requestContext;
}
public HttpServletRequest getRequest() {
return requestContext == null ? rawRequest : requestContext.getRequest();
}
public HttpServletResponse getResponse() {
return requestContext == null ? rawResponse : requestContext.getResponse();
}
public WebResponse getClientResponse() {
return clientResponse;
}
@Autowired(required = false)
public void setRequestContexts(RequestContextChainingService requestContexts) {
this.requestContexts = requestContexts;
}
public void request(String uri) {
if (uri != null && uri.startsWith("http")) {
uri = URI.create(uri).normalize().toString(); // full uri
} else {
uri = URI.create("http://www.test.com/" + trimToEmpty(uri)).normalize().toString(); // partial uri
}
request(new GetMethodWebRequest(uri));
}
public void request(WebRequest webRequest) {
try {
invocationContext = client.newInvocation(webRequest);
rawRequest = new MyHttpRequest(invocationContext.getRequest(), webRequest.getURL().toExternalForm());
rawResponse = new MyHttpResponse(invocationContext.getResponse());
servletContext = invocationContext.getServlet().getServletConfig().getServletContext();
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException(e);
}
if (requestContexts != null) {
requestContext = requestContexts.getRequestContext(servletContext, rawRequest, rawResponse);
}
}
/** 将服务端response提交到client。 */
public void commit() {
try {
if (requestContexts != null && requestContext != null) {
requestContexts.commitRequestContext(requestContext);
}
clientResponse = client.getResponse(invocationContext);
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public void cleanup() {
RequestContextHolder.resetRequestAttributes();
invocationContext = null;
rawRequest = null;
rawResponse = null;
servletContext = null;
requestContext = null;
}
/** 不做任何事的servlet。 */
public static class NoopServlet extends HttpServlet {
private static final long serialVersionUID = 3034658026956449398L;
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException,
IOException {
}
}
/**
* 由于httpunit目前未实现setCharacterEncoding方法,getQueryString()也实现得有问题,
* 所以只能将request包装一下。
*/
public static class MyHttpRequest extends HttpServletRequestWrapper {
private String charset;
private String overrideQueryString;
private String server = "www.test.com";
private int port = 80;
private boolean sessionCreated;
public MyHttpRequest(HttpServletRequest request, String uri) {
super(request);
if (uri != null) {
int index = uri.indexOf("?");
if (index >= 0) {
this.overrideQueryString = uri.substring(index + 1);
}
}
}
@Override
public String getQueryString() {
if (overrideQueryString == null) {
return super.getQueryString();
} else {
return overrideQueryString;
}
}
@Override
public String getCharacterEncoding() {
return charset;
}
@Override
public void setCharacterEncoding(String charset) throws UnsupportedEncodingException {
this.charset = charset;
}
/** 默认实现总是返回localhost,只好覆盖此方法。 */
@Override
public String getServerName() {
return server;
}
public void setServerName(String server) {
this.server = server;
}
/** 默认实现总是返回0,只好覆盖此方法。 */
@Override
public int getServerPort() {
return port;
}
public void setServerPort(int port) {
this.port = port;
}
/** 监视getSession方法的调用。 */
public boolean isSessionCreated() {
return sessionCreated;
}
@Override
public HttpSession getSession() {
sessionCreated = true;
return super.getSession();
}
@Override
public HttpSession getSession(boolean create) {
if (create) {
sessionCreated = true;
}
return super.getSession(create);
}
}
/** 由于httpunit目前未实现commit以后抛IllegalStateException,所以只能将response包装一下。 */
public static class MyHttpResponse extends HttpServletResponseWrapper {
private boolean committed;
public MyHttpResponse(HttpServletResponse response) {
super(response);
}
@Override
public boolean isCommitted() {
return super.isCommitted() || committed;
}
@Override
public void sendError(int sc, String msg) throws IOException {
ensureNotCommited();
super.sendError(sc, msg);
committed = true;
}
@Override
public void sendError(int sc) throws IOException {
ensureNotCommited();
super.sendError(sc);
committed = true;
}
@Override
public void sendRedirect(String location) throws IOException {
ensureNotCommited();
super.sendRedirect(location);
committed = true;
}
@Override
public ServletOutputStream getOutputStream() throws IOException {
ensureNotCommited();
return super.getOutputStream();
}
@Override
public PrintWriter getWriter() throws IOException {
ensureNotCommited();
return super.getWriter();
}
@Override
public void reset() {
ensureNotCommited();
super.reset();
}
@Override
public void resetBuffer() {
ensureNotCommited();
super.resetBuffer();
}
@Override
public void setLocale(Locale locale) {
// 防止unsupported operation exception
}
@Override
public void setContentType(String type) {
if (type == null || type.indexOf("charset=") == -1) {
setCharacterEncoding(null);
}
super.setContentType(type);
}
@Override
public void setBufferSize(int size) {
ensureNotCommited();
super.setBufferSize(size);
}
private void ensureNotCommited() {
if (isCommitted()) {
throw new IllegalStateException();
}
}
}
public static class DefinitionParser extends AbstractNamedBeanDefinitionParser<ServletTestContainer> {
@Override
protected void doParse(Element element, ParserContext parserContext, BeanDefinitionBuilder builder) {
Map<Object, Object> mappings = createManagedMap(element, parserContext);
for (Element mappingElement : subElements(element, and(sameNs(element), name("mapping")))) {
String pattern = assertNotNull(trimToNull(mappingElement.getAttribute("pattern")), "pattern");
String servletClass = trimToNull(mappingElement.getAttribute("servletClass"));
mappings.put(pattern, servletClass);
}
if (!mappings.isEmpty()) {
builder.addPropertyValue("mappings", mappings);
}
}
@Override
protected String getDefaultName() {
return "servletTestContainer";
}
}
}