Package co.cask.cdap.gateway.router.handlers

Source Code of co.cask.cdap.gateway.router.handlers.SecurityAuthenticationHttpHandler

/*
* Copyright © 2014 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package co.cask.cdap.gateway.router.handlers;

import co.cask.cdap.common.conf.CConfiguration;
import co.cask.cdap.common.conf.Constants;
import co.cask.cdap.common.logging.AuditLogEntry;
import co.cask.cdap.security.auth.AccessTokenTransformer;
import co.cask.cdap.security.auth.TokenState;
import co.cask.cdap.security.auth.TokenValidator;
import co.cask.cdap.security.server.GrantAccessToken;
import com.google.common.base.Charsets;
import com.google.common.base.Stopwatch;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import com.ning.org.jboss.netty.handler.codec.http.HttpConstants;
import org.apache.twill.discovery.Discoverable;
import org.apache.twill.discovery.DiscoveryServiceClient;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBufferIndexFinder;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelFuture;
import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelHandler;
import org.jboss.netty.channel.WriteCompletionEvent;
import org.jboss.netty.handler.codec.http.DefaultHttpResponse;
import org.jboss.netty.handler.codec.http.HttpHeaders;
import org.jboss.netty.handler.codec.http.HttpRequest;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.util.concurrent.TimeUnit;

/**
* Security handler that intercept HTTP message and validates the access token in
* header Authorization field.
*/
public class SecurityAuthenticationHttpHandler extends SimpleChannelHandler {
  private static final Logger LOG = LoggerFactory.getLogger(SecurityAuthenticationHttpHandler.class);
  private static final Logger AUDIT_LOG = LoggerFactory.getLogger("http-access");

  private final TokenValidator tokenValidator;
  private final AccessTokenTransformer accessTokenTransformer;
  private final DiscoveryServiceClient discoveryServiceClient;
  private final Iterable<Discoverable> discoverables;
  private final CConfiguration configuration;
  private final String realm;


  public SecurityAuthenticationHttpHandler(String realm, TokenValidator tokenValidator,
                                           CConfiguration configuration,
                                           AccessTokenTransformer accessTokenTransformer,
                                           DiscoveryServiceClient discoveryServiceClient) {
    this.realm = realm;
    this.tokenValidator = tokenValidator;
    this.accessTokenTransformer = accessTokenTransformer;
    this.discoveryServiceClient = discoveryServiceClient;
    this.discoverables = discoveryServiceClient.discover(Constants.Service.EXTERNAL_AUTHENTICATION);
    this.configuration = configuration;
  }

  /**
   * Intercepts the HttpMessage for getting the access token in authorization header
   * @param ctx channel handler context delegated from MessageReceived callback
   * @param msg intercepted HTTP message
   * @param inboundChannel
   * @return {@code true} if the HTTP message has valid Access token
   * @throws Exception
   */
  private boolean validateSecuredInterception(ChannelHandlerContext ctx, HttpRequest msg,
                                      Channel inboundChannel, AuditLogEntry logEntry) throws Exception {
    String auth = msg.getHeader(HttpHeaders.Names.AUTHORIZATION);
    String accessToken = null;

    /*
     * Parse the access token from authorization header.  The header will be in the form:
     *     Authorization: Bearer ACCESSTOKEN
     *
     * where ACCESSTOKEN is the base64 encoded serialized AccessToken instance.
     */
    if (auth != null) {
      int spIndex = auth.trim().indexOf(' ');
      if (spIndex != -1) {
        accessToken = auth.substring(spIndex + 1).trim();
      }
    }

    logEntry.setClientIP(((InetSocketAddress) ctx.getChannel().getRemoteAddress()).getAddress());
    logEntry.setRequestLine(msg.getMethod(), msg.getUri(), msg.getProtocolVersion());

    TokenState tokenState = tokenValidator.validate(accessToken);
    if (!tokenState.isValid()) {
      HttpResponse httpResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED);
      logEntry.setResponseCode(HttpResponseStatus.UNAUTHORIZED.getCode());

      JsonObject jsonObject = new JsonObject();
      if (tokenState == TokenState.MISSING) {
        httpResponse.addHeader(HttpHeaders.Names.WWW_AUTHENTICATE,
                                 String.format("Bearer realm=\"%s\"", realm));
        LOG.debug("Authentication failed due to missing token");

      } else {
        httpResponse.addHeader(HttpHeaders.Names.WWW_AUTHENTICATE,
                                 String.format("Bearer realm=\"%s\" error=\"invalid_token\"" +
                                                 " error_description=\"%s\"", realm, tokenState.getMsg()));
        jsonObject.addProperty("error", "invalid_token");
        jsonObject.addProperty("error_description", tokenState.getMsg());
        LOG.debug("Authentication failed due to invalid token, reason={};", tokenState);
      }
      JsonArray externalAuthenticationURIs = new JsonArray();

      //Waiting for service to get discovered
      stopWatchWait(externalAuthenticationURIs);

      jsonObject.add("auth_uri", externalAuthenticationURIs);

      ChannelBuffer content = ChannelBuffers.wrappedBuffer(jsonObject.toString().getBytes(Charsets.UTF_8));
      httpResponse.setContent(content);
      int contentLength = content.readableBytes();
      httpResponse.setHeader(HttpHeaders.Names.CONTENT_LENGTH, contentLength);
      httpResponse.setHeader(HttpHeaders.Names.CONTENT_TYPE, "application/json;charset=UTF-8");
      logEntry.setResponseContentLength(new Long(contentLength));
      ChannelFuture writeFuture = Channels.future(inboundChannel);
      Channels.write(ctx, writeFuture, httpResponse);
      writeFuture.addListener(ChannelFutureListener.CLOSE);
      return false;
    } else {
      AccessTokenTransformer.AccessTokenIdentifierPair accessTokenIdentifierPair =
        accessTokenTransformer.transform(accessToken);
      logEntry.setUserName(accessTokenIdentifierPair.getAccessTokenIdentifierObj().getUsername());
      msg.setHeader(HttpHeaders.Names.AUTHORIZATION,
                    "CDAP-verified " + accessTokenIdentifierPair.getAccessTokenIdentifierStr());
      return true;
    }
  }

  /**
   *
   * @param externalAuthenticationURIs the list that should be populated with discovered with
   *                                   external auth servers URIs
   * @throws Exception
   */
  private void stopWatchWait(JsonArray externalAuthenticationURIs) throws Exception {
    boolean done = false;
    Stopwatch stopwatch = new Stopwatch();
    stopwatch.start();
    String protocol;
    int port;
    if (configuration.getBoolean(Constants.Security.SSL_ENABLED)) {
      protocol = "https";
      port = configuration.getInt(Constants.Security.AuthenticationServer.SSL_PORT);
    } else {
      protocol = "http";
      port = configuration.getInt(Constants.Security.AUTH_SERVER_PORT);
    }

    do {
      for (Discoverable d : discoverables)  {
        String url = String.format("%s://%s:%d/%s", protocol, d.getSocketAddress().getHostName(), port,
                                                                        GrantAccessToken.Paths.GET_TOKEN);
        externalAuthenticationURIs.add(new JsonPrimitive(url));
        done = true;
      }
      if (!done) {
        TimeUnit.MILLISECONDS.sleep(200);
      }
    } while (!done && stopwatch.elapsedTime(TimeUnit.SECONDS) < 2L);
  }


  @Override
  public void messageReceived(ChannelHandlerContext ctx, final MessageEvent event) throws Exception {
    Object msg = event.getMessage();
    if (!(msg instanceof HttpRequest)) {
      super.messageReceived(ctx, event);
    } else {
      AuditLogEntry logEntry = new AuditLogEntry();
      ctx.setAttachment(logEntry);
      if (validateSecuredInterception(ctx, (HttpRequest) msg, event.getChannel(), logEntry)) {
        Channels.fireMessageReceived(ctx, msg, event.getRemoteAddress());
      } else {
        // we write the response directly for authentication failure, so nothing to do
        return;
      }
    }
  }

  @Override
  public void writeRequested(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
    AuditLogEntry logEntry = getLogEntry(ctx);
    Object message = e.getMessage();
    if (message instanceof HttpResponse) {
      HttpResponse response = (HttpResponse) message;
      logEntry.setResponseCode(response.getStatus().getCode());
      if (response.containsHeader(HttpHeaders.Names.CONTENT_LENGTH)) {
        String lengthString = response.getHeader(HttpHeaders.Names.CONTENT_LENGTH);
        try {
          logEntry.setResponseContentLength(Long.valueOf(lengthString));
        } catch (NumberFormatException nfe) {
          LOG.warn("Invalid value for content length in HTTP response message: {}", lengthString, nfe);
        }
      }
    } else if (message instanceof ChannelBuffer) {
      // for chunked responses the response code will only be present on the first chunk
      // so we only look for it the first time around
      if (logEntry.getResponseCode() == null) {
        ChannelBuffer channelBuffer = (ChannelBuffer) message;
        logEntry.setResponseCode(findResponseCode(channelBuffer));
        if (logEntry.getResponseCode() != null) {
          // we currently only look for a Content-Length header in the first buffer on an HTTP response
          // this is a limitation of the implementation that simplifies header parsing
          logEntry.setResponseContentLength(findContentLength(channelBuffer));
        }
      }
    } else {
      LOG.debug("Unhandled response message type: {}", message.getClass());
    }
    super.writeRequested(ctx, e);
  }

  @Override
  public void writeComplete(ChannelHandlerContext ctx, WriteCompletionEvent e) throws Exception {
    AuditLogEntry logEntry = getLogEntry(ctx);
    if (!logEntry.isLogged()) {
      AUDIT_LOG.trace(logEntry.toString());
      logEntry.setLogged(true);
    }
  }

  private Integer findResponseCode(ChannelBuffer buffer) {
    Integer responseCode = null;

    // we assume that the response code should follow the first space in the first line of the response
    int indx = buffer.indexOf(buffer.readerIndex(), buffer.writerIndex(), ChannelBufferIndexFinder.LINEAR_WHITESPACE);
    if (indx >= 0 && indx < buffer.writerIndex() - 4) {
      String codeString = buffer.toString(indx, 4, Charsets.UTF_8).trim();
      try {
        responseCode = Integer.valueOf(codeString);
      } catch (NumberFormatException nfe) {
        LOG.warn("Invalid value for HTTP response code: {}", codeString, nfe);
      }
    } else {
      LOG.debug("Invalid index for space in response: index={}, buffer size={}", indx, buffer.readableBytes());
    }
    return responseCode;
  }

  private Long findContentLength(ChannelBuffer buffer) {
    Long contentLength = null;
    int bufferEnd = buffer.writerIndex();
    int index = buffer.indexOf(buffer.readerIndex(), bufferEnd, CONTENT_LENGTH_FINDER);
    if (index >= 0) {
      // find the following ':'
      int colonIndex = buffer.indexOf(index, bufferEnd, HttpConstants.COLON);
      int eolIndex = buffer.indexOf(index, bufferEnd, ChannelBufferIndexFinder.CRLF);
      if (colonIndex > 0 && colonIndex < eolIndex) {
        String lengthString = buffer.toString(colonIndex + 1, eolIndex - (colonIndex + 1), Charsets.UTF_8).trim();
        try {
          contentLength = Long.valueOf(lengthString);
        } catch (NumberFormatException nfe) {
          LOG.warn("Invalid value for content length in HTTP response message: {}", lengthString, nfe);
        }
      }
    }
    return contentLength;
  }

  private AuditLogEntry getLogEntry(ChannelHandlerContext ctx) {
    Object entryObject = ctx.getAttachment();
    AuditLogEntry logEntry;
    if (entryObject != null && entryObject instanceof AuditLogEntry) {
      logEntry = (AuditLogEntry) entryObject;
    } else {
      logEntry = new AuditLogEntry();
      ctx.setAttachment(logEntry);
    }
    return logEntry;
  }

  private static final ChannelBufferIndexFinder CONTENT_LENGTH_FINDER = new ChannelBufferIndexFinder() {
    private byte[] headerName = HttpHeaders.Names.CONTENT_LENGTH.getBytes(Charsets.UTF_8);

    @Override
    public boolean find(ChannelBuffer buffer, int guessedIndex) {
      if (buffer.capacity() - guessedIndex < headerName.length) {
        return false;
      }

      for (int i = 0; i < headerName.length; i++) {
        if (headerName[i] != buffer.getByte(guessedIndex + i)) {
          return false;
        }
      }
      return true;
    }
  };
}
TOP

Related Classes of co.cask.cdap.gateway.router.handlers.SecurityAuthenticationHttpHandler

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.