Package org.apache.thrift.transport

Source Code of org.apache.thrift.transport.TestTSaslTransports

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.thrift.transport;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import javax.security.sasl.SaslServerFactory;

import junit.framework.TestCase;

import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.server.ServerTestBase;
import org.apache.thrift.server.TServer;
import org.apache.thrift.server.TSimpleServer;
import org.apache.thrift.server.TServer.Args;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TestTSaslTransports extends TestCase {

  private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);

  private static final String HOST = "localhost";
  private static final String SERVICE = "thrift-test";
  private static final String PRINCIPAL = "thrift-test-principal";
  private static final String PASSWORD = "super secret password";
  private static final String REALM = "thrift-test-realm";

  private static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
  private static final Map<String, String> UNWRAPPED_PROPS = null;

  private static final String WRAPPED_MECHANISM = "DIGEST-MD5";
  private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();

  static {
    WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
    WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM);
  }

  private static final String testMessage1 = "Hello, world! Also, four "
      + "score and seven years ago our fathers brought forth on this "
      + "continent a new nation, conceived in liberty, and dedicated to the "
      + "proposition that all men are created equal.";
 
  private static final String testMessage2 = "I have a dream that one day "
      + "this nation will rise up and live out the true meaning of its creed: "
      + "'We hold these truths to be self-evident, that all men are created equal.'";


  private static class TestSaslCallbackHandler implements CallbackHandler {
    private final String password;

    public TestSaslCallbackHandler(String password) {
      this.password = password;
    }

    @Override
    public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
      for (Callback c : callbacks) {
        if (c instanceof NameCallback) {
          ((NameCallback) c).setName(PRINCIPAL);
        } else if (c instanceof PasswordCallback) {
          ((PasswordCallback) c).setPassword(password.toCharArray());
        } else if (c instanceof AuthorizeCallback) {
          ((AuthorizeCallback) c).setAuthorized(true);
        } else if (c instanceof RealmCallback) {
          ((RealmCallback) c).setText(REALM);
        } else {
          throw new UnsupportedCallbackException(c);
        }
      }
    }
  }

  private class ServerThread extends Thread {
    final String mechanism;
    final Map<String, String> props;
    volatile Throwable thrown;

    public ServerThread(String mechanism, Map<String, String> props) {
      this.mechanism = mechanism;
      this.props = props;
    }

    public void run() {
      try {
        internalRun();
      } catch (Throwable t) {
        thrown = t;
      }
    }

    private void internalRun() throws Exception {
      TServerSocket serverSocket = new TServerSocket(ServerTestBase.PORT);
      try {
        acceptAndWrite(serverSocket);
      } finally {
        serverSocket.close();
      }
    }

    private void acceptAndWrite(TServerSocket serverSocket)
      throws Exception {
      TTransport serverTransport = serverSocket.accept();
      TTransport saslServerTransport = new TSaslServerTransport(
        mechanism, SERVICE, HOST,
        props, new TestSaslCallbackHandler(PASSWORD), serverTransport);

      saslServerTransport.open();

      byte[] inBuf = new byte[testMessage1.getBytes().length];
      // Deliberately read less than the full buffer to ensure
      // that TSaslTransport is correctly buffering reads. This
      // will fail for the WRAPPED test, if it doesn't work.
      saslServerTransport.readAll(inBuf, 0, 5);
      saslServerTransport.readAll(inBuf, 5, 10);
      saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
      LOGGER.debug("server got: {}", new String(inBuf));
      assertEquals(new String(inBuf), testMessage1);

      LOGGER.debug("server writing: {}", testMessage2);
      saslServerTransport.write(testMessage2.getBytes());
      saslServerTransport.flush();

      saslServerTransport.close();
    }
  }

  private void testSaslOpen(final String mechanism, final Map<String, String> props)
      throws Exception {
    ServerThread serverThread = new ServerThread(mechanism, props);
    serverThread.start();

    try {
      Thread.sleep(1000);
    } catch (InterruptedException e) {
      // Ah well.
    }

    try {
      TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
      TTransport saslClientTransport = new TSaslClientTransport(mechanism,
                                                                PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket);
      saslClientTransport.open();
      LOGGER.debug("client writing: {}", testMessage1);
      saslClientTransport.write(testMessage1.getBytes());
      saslClientTransport.flush();

      byte[] inBuf = new byte[testMessage2.getBytes().length];
      saslClientTransport.readAll(inBuf, 0, inBuf.length);
      LOGGER.debug("client got: {}", new String(inBuf));
      assertEquals(new String(inBuf), testMessage2);

      TTransportException expectedException = null;
      try {
        saslClientTransport.open();
      } catch (TTransportException e) {
        expectedException = e;
      }
      assertNotNull(expectedException);

      saslClientTransport.close();
    } catch (Exception e) {
      LOGGER.warn("Exception caught", e);
      throw e;
    } finally {
      serverThread.interrupt();
      try {
        serverThread.join();
      } catch (InterruptedException e) {
        // Ah well.
      }
      assertNull(serverThread.thrown);
    }
  }

  public void testUnwrappedOpen() throws Exception {
    testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
  }

  public void testWrappedOpen() throws Exception {
    testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
  }

  public void testAnonymousOpen() throws Exception {
    testSaslOpen("ANONYMOUS", null);
  }

  /**
   * Test that we get the proper exceptions thrown back the server when
   * the client provides invalid password.
   */
  public void testBadPassword() throws Exception {
    ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
    serverThread.start();

    try {
      Thread.sleep(1000);
    } catch (InterruptedException e) {
      // Ah well.
    }

    boolean clientSidePassed = true;

    try {
      TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
      TTransport saslClientTransport = new TSaslClientTransport(
        UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS,
        new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket);
      saslClientTransport.open();
      clientSidePassed = false;
      fail("Was able to open transport with bad password");
    } catch (TTransportException tte) {
      LOGGER.error("Exception for bad password", tte);
      assertNotNull(tte.getMessage());
      assertTrue(tte.getMessage().contains("Invalid response"));

    } finally {
      serverThread.interrupt();
      serverThread.join();

      if (clientSidePassed) {
        assertNotNull(serverThread.thrown);
        assertTrue(serverThread.thrown.getMessage().contains("Invalid response"));
      }
    }
  }

  public void testWithServer() throws Exception {
    new TestTSaslTransportsWithServer().testIt();
  }

  private static class TestTSaslTransportsWithServer extends ServerTestBase {

    private Thread serverThread;
    private TServer server;

    @Override
    public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
      return new TSaslClientTransport(
        WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS,
        new TestSaslCallbackHandler(PASSWORD), underlyingTransport);
    }

    @Override
    public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception {
      serverThread = new Thread() {
        public void run() {
          try {
            // Transport
            TServerSocket socket = new TServerSocket(PORT);

            TTransportFactory factory = new TSaslServerTransport.Factory(
              WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS,
              new TestSaslCallbackHandler(PASSWORD));
            server = new TSimpleServer(new Args(socket).processor(processor).transportFactory(factory).protocolFactory(protoFactory));

            // Run it
            LOGGER.debug("Starting the server on port {}", PORT);
            server.serve();
          } catch (Exception e) {
            e.printStackTrace();
            fail();
          }
        }
      };
      serverThread.start();
      Thread.sleep(1000);
    }

    @Override
    public void stopServer() throws Exception {
      server.stop();
      try {
        serverThread.join();
      } catch (InterruptedException e) {}
    }

  }


  /**
   * Implementation of SASL ANONYMOUS, used for testing client-side
   * initial responses.
   */
  private static class AnonymousClient implements SaslClient {
    private final String username;
    private boolean hasProvidedInitialResponse;

    public AnonymousClient(String username) {
      this.username = username;
    }

    public String getMechanismName() { return "ANONYMOUS"; }
    public boolean hasInitialResponse() { return true; }
    public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
      if (hasProvidedInitialResponse) {
        throw new SaslException("Already complete!");
      }

      try {
        hasProvidedInitialResponse = true;
        return username.getBytes("UTF-8");
      } catch (IOException e) {
        throw new SaslException(e.toString());
      }
    }
    public boolean isComplete() { return hasProvidedInitialResponse; }
    public byte[] unwrap(byte[] incoming, int offset, int len) {
      throw new UnsupportedOperationException();
    }
    public byte[] wrap(byte[] outgoing, int offset, int len) {
      throw new UnsupportedOperationException();
    }
    public Object getNegotiatedProperty(String propName) { return null; }
    public void dispose() {}
  }

  private static class AnonymousServer implements SaslServer {
    private String user;
    public String getMechanismName() { return "ANONYMOUS"; }
    public byte[] evaluateResponse(byte[] response) throws SaslException {
      try {
        this.user = new String(response, "UTF-8");
      } catch (IOException e) {
        throw new SaslException(e.toString());
      }
      return null;
    }
    public boolean isComplete() { return user != null; }
    public String getAuthorizationID() { return user; }
    public byte[] unwrap(byte[] incoming, int offset, int len) {
      throw new UnsupportedOperationException();
    }
    public byte[] wrap(byte[] outgoing, int offset, int len) {
      throw new UnsupportedOperationException();
    }
    public Object getNegotiatedProperty(String propName) { return null; }
    public void dispose() {}

  }

  public static class SaslAnonymousFactory
    implements SaslClientFactory, SaslServerFactory {

    public SaslClient createSaslClient(
      String[] mechanisms, String authorizationId, String protocol,
      String serverName, Map<String,?> props, CallbackHandler cbh)
    {
      for (String mech : mechanisms) {
        if ("ANONYMOUS".equals(mech)) {
          return new AnonymousClient(authorizationId);
        }
      }
      return null;
    }

    public SaslServer createSaslServer(
      String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)
    {
      if ("ANONYMOUS".equals(mechanism)) {
        return new AnonymousServer();
      }
      return null;
    }
    public String[] getMechanismNames(Map<String, ?> props) {
      return new String[] { "ANONYMOUS" };
    }
  }

  static {
    java.security.Security.addProvider(new SaslAnonymousProvider());
  }
  public static class SaslAnonymousProvider extends java.security.Provider {
    public SaslAnonymousProvider() {
      super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider");
      put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
      put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
    }
  }
}
TOP

Related Classes of org.apache.thrift.transport.TestTSaslTransports

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.