/*
* Copyright 2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ratpack.http.client.internal;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.*;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
import ratpack.exec.*;
import ratpack.func.Action;
import ratpack.http.Headers;
import ratpack.http.MutableHeaders;
import ratpack.http.Status;
import ratpack.http.client.HttpClient;
import ratpack.http.client.ReceivedResponse;
import ratpack.http.client.RequestSpec;
import ratpack.http.internal.*;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.net.URI;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static ratpack.util.ExceptionUtils.uncheck;
public class DefaultHttpClient implements HttpClient {
private final ExecController execController;
private final ByteBufAllocator byteBufAllocator;
private final int maxContentLengthBytes;
public DefaultHttpClient(ExecController execController, ByteBufAllocator byteBufAllocator, int maxContentLengthBytes) {
this.execController = execController;
this.byteBufAllocator = byteBufAllocator;
this.maxContentLengthBytes = maxContentLengthBytes;
}
@Override
public Promise<ReceivedResponse> get(URI uri, Action<? super RequestSpec> requestConfigurer) {
return request(uri, requestConfigurer);
}
private static class Post implements Action<RequestSpec> {
@Override
public void execute(RequestSpec requestSpec) throws Exception {
requestSpec.method("POST");
}
}
@Override
public Promise<ReceivedResponse> post(URI uri, Action<? super RequestSpec> action) {
return request(uri, Action.join(new Post(), action));
}
@Override
public Promise<ReceivedResponse> request(URI uri, final Action<? super RequestSpec> requestConfigurer) {
final ExecControl execControl = execController.getControl();
final Execution execution = execControl.getExecution();
final EventLoopGroup eventLoopGroup = execController.getEventLoopGroup();
try {
RequestAction requestAction = new RequestAction(requestConfigurer, uri, execution, eventLoopGroup, byteBufAllocator, maxContentLengthBytes);
return execController.getControl().promise(requestAction);
} catch (Exception e) {
throw uncheck(e);
}
}
private static ByteBuf initBufferReleaseOnExecutionClose(final ByteBuf responseBuffer, Execution execution) {
execution.onCleanup(responseBuffer::release);
return responseBuffer;
}
private static String getFullPath(URI uri) {
StringBuilder sb = new StringBuilder(uri.getRawPath());
String query = uri.getRawQuery();
if (query != null) {
sb.append("?").append(query);
}
return sb.toString();
}
private static class RequestAction implements Action<Fulfiller<? super ReceivedResponse>> {
final Execution execution;
final EventLoopGroup eventLoopGroup;
final Action<? super RequestSpec> requestConfigurer;
final boolean finalUseSsl;
final String host;
final int port;
final MutableHeaders headers;
final RequestSpecBacking requestSpecBacking;
final URI uri;
private final ByteBufAllocator byteBufAllocator;
private final int maxContentLengthBytes;
private final RequestParams requestParams;
private final AtomicBoolean fired = new AtomicBoolean();
public RequestAction(Action<? super RequestSpec> requestConfigurer, URI uri, Execution execution, EventLoopGroup eventLoopGroup, ByteBufAllocator byteBufAllocator, int maxContentLengthBytes) {
this.execution = execution;
this.eventLoopGroup = eventLoopGroup;
this.requestConfigurer = requestConfigurer;
this.byteBufAllocator = byteBufAllocator;
this.maxContentLengthBytes = maxContentLengthBytes;
this.uri = uri;
this.requestParams = new RequestParams();
headers = new NettyHeadersBackedMutableHeaders(new DefaultHttpHeaders());
requestSpecBacking = new RequestSpecBacking(headers, uri, byteBufAllocator, requestParams);
try {
requestConfigurer.execute(requestSpecBacking.asSpec());
} catch (Exception e) {
throw uncheck(e);
}
String scheme = this.uri.getScheme();
boolean useSsl = false;
if (scheme.equals("https")) {
useSsl = true;
} else if (!scheme.equals("http")) {
throw new IllegalArgumentException(String.format("URL '%s' is not a http url", this.uri.toString()));
}
finalUseSsl = useSsl;
host = this.uri.getHost();
port = this.uri.getPort() < 0 ? (useSsl ? 443 : 80) : this.uri.getPort();
}
public void execute(final Fulfiller<? super ReceivedResponse> fulfiller) throws Exception {
final Bootstrap b = new Bootstrap();
b.group(eventLoopGroup)
.channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
if (finalUseSsl) {
SSLEngine engine = SSLContext.getDefault().createSSLEngine();
engine.setUseClientMode(true);
p.addLast("ssl", new SslHandler(engine));
}
p.addLast("codec", new HttpClientCodec());
p.addLast("aggregator", new HttpObjectAggregator(maxContentLengthBytes));
p.addLast("readTimeout", new ReadTimeoutHandler(requestParams.readTimeoutNanos, TimeUnit.NANOSECONDS));
p.addLast("handler", new SimpleChannelInboundHandler<HttpObject>(false) {
@Override
public void channelRead0(ChannelHandlerContext ctx, HttpObject msg) throws Exception {
if (msg instanceof FullHttpResponse) {
final FullHttpResponse response = (FullHttpResponse) msg;
final Headers headers = new NettyHeadersBackedHeaders(response.headers());
String contentType = headers.get(HttpHeaderConstants.CONTENT_TYPE.toString());
ByteBuf responseBuffer = initBufferReleaseOnExecutionClose(response.content(), execution);
final ByteBufBackedTypedData typedData = new ByteBufBackedTypedData(responseBuffer, DefaultMediaType.get(contentType));
final Status status = new DefaultStatus(response.getStatus());
int maxRedirects = requestSpecBacking.getMaxRedirects();
String locationValue = headers.get("Location");
URI locationUrl = null;
if (locationValue != null) {
locationUrl = new URI(locationValue);
}
//Check for redirect and location header if it is follow redirect if we have request forwarding left
if (shouldRedirect(status) && maxRedirects > 0 && locationUrl != null) {
Action<? super RequestSpec> redirectRequestConfig = Action.join(requestConfigurer, s -> {
if (status.getCode() == 301 || status.getCode() == 302) {
s.method("GET");
}
s.redirects(maxRedirects - 1);
});
RequestAction requestAction = new RequestAction(redirectRequestConfig, locationUrl, execution, eventLoopGroup, byteBufAllocator, maxContentLengthBytes);
requestAction.execute(fulfiller);
} else {
//Just fulfill what ever we currently have
success(fulfiller, new DefaultReceivedResponse(status, headers, typedData));
}
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.close();
error(fulfiller, cause);
}
});
}
});
ChannelFuture connectFuture = b.connect(host, port);
connectFuture.addListener(f1 -> {
if (connectFuture.isSuccess()) {
String fullPath = getFullPath(uri);
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.valueOf(requestSpecBacking.getMethod()), fullPath, requestSpecBacking.getBody());
if (headers.get(HttpHeaderConstants.HOST) == null) {
headers.set(HttpHeaderConstants.HOST, host);
}
headers.set(HttpHeaderConstants.CONNECTION, HttpHeaders.Values.CLOSE);
int contentLength = request.content().readableBytes();
if (contentLength > 0) {
headers.set(HttpHeaderConstants.CONTENT_LENGTH, Integer.toString(contentLength, 10));
}
HttpHeaders requestHeaders = request.headers();
for (String name : headers.getNames()) {
requestHeaders.set(name, headers.getAll(name));
}
ChannelFuture writeFuture = connectFuture.channel().writeAndFlush(request);
writeFuture.addListener(f2 -> {
if (!writeFuture.isSuccess()) {
writeFuture.channel().close();
error(fulfiller, writeFuture.cause());
}
});
} else {
connectFuture.channel().close();
error(fulfiller, connectFuture.cause());
}
});
}
private <T> void success(Fulfiller<? super T> fulfiller, T value) {
if (fired.compareAndSet(false, true)) {
fulfiller.success(value);
}
}
private void error(Fulfiller<?> fulfiller, Throwable error) {
if (fired.compareAndSet(false, true)) {
fulfiller.error(error);
}
}
private static boolean shouldRedirect(Status status) {
int code = status.getCode();
return code == 301 || code == 302 || code == 303 || code == 307;
}
}
}