/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import com.sun.security.auth.callback.DialogCallbackHandler;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.PrivilegedAction;
import javax.security.auth.login.LoginContext;
import javax.security.auth.Subject;
import org.ietf.jgss.*;
public class JgssPerfTest {
public static void main(String[] args) throws Exception {
if (args.length != 3) {
System.out.println("Uasge: java JgssPerfTest <serverPrincipal> " +
"<initCalls> <repeatCalls>");
System.exit(1);
}
String serverPrincipal = args[0];
int initCalls = Integer.decode(args[1]).intValue();
int repeatCalls = Integer.decode(args[2]).intValue();
Server server = new Server(serverPrincipal);
Client client = new Client(serverPrincipal,
server.serverSocket.getInetAddress(),
server.serverSocket.getLocalPort());
byte[] data = new byte[10];
for (int i = 0; i < data.length; i++) {
data[i] = (byte)i;
}
// run a few times warm up
for (int i = 0; i < initCalls; i++) {
client.callServer(data);
}
// here comes the real run
long start = System.currentTimeMillis();
for (int i = 0; i < repeatCalls; i++) {
client.callServer(data);
}
long runTime = System.currentTimeMillis() - start;
System.out.println("average round trip time: " +
(float)runTime/repeatCalls + "msecs (averaged " +
"over " + repeatCalls + " invocations)\n");
System.exit(0);
}
private static Subject getLoginSubject(String loginName) throws Exception {
LoginContext ctx =
new LoginContext(loginName, new DialogCallbackHandler());
ctx.login();
return ctx.getSubject();
}
private static class Server {
String serverPrincipal;
ServerSocket serverSocket;
Server(String serverPrincipal) throws Exception {
this.serverPrincipal = serverPrincipal;
Subject serverSubject = getLoginSubject("testServer");
serverSocket = new ServerSocket(0);
Subject.doAs(serverSubject, new PrivilegedAction() {
public Object run() {
new Thread(new Runnable() {
public void run() {
acceptRequests();
}
}).start();
return null;
}
});
}
void acceptRequests() {
try {
Socket sock = serverSocket.accept();
sock.setTcpNoDelay(true);
sock.setKeepAlive(true);
handleConnection(sock);
} catch (Exception e) {
throw new RuntimeException(
"Exception thrown while server is " +
"handling an inbound request", e);
}
}
void handleConnection(Socket sock) throws Exception {
DataInputStream dis = new DataInputStream(sock.getInputStream());
DataOutputStream dos =
new DataOutputStream(sock.getOutputStream());
GSSManager gssManager = GSSManager.getInstance();
GSSName serverName =
gssManager.createName(serverPrincipal, null);
Oid krb5Oid = new Oid("1.2.840.113554.1.2.2");
GSSCredential serverCred = gssManager.createCredential(
serverName, GSSCredential.INDEFINITE_LIFETIME,
krb5Oid, GSSCredential.ACCEPT_ONLY);
GSSContext gssContext = gssManager.createContext(serverCred);
// establish gss context
byte[] token = null;
while (!gssContext.isEstablished()) {
token = new byte[dis.readInt()];
dis.readFully(token);
token = gssContext.acceptSecContext(token, 0, token.length);
// Send a token to the peer if one was generated by
// acceptSecContext
if (token != null) {
dos.writeInt(token.length);
dos.write(token);
dos.flush();
}
}
// process incoming requests
while (true) {
MessageProp prop = new MessageProp(0, true);
token = new byte[dis.readInt()];
dis.readFully(token);
byte[] bytes = gssContext.unwrap(token, 0, token.length, prop);
prop = new MessageProp(0, true);
token = gssContext.wrap(bytes, 0, bytes.length, prop);
dos.writeInt(token.length);
dos.write(token);
dos.flush();
}
}
}
private static class Client {
String serverPrincipal;
InetAddress serverHost;
int serverPort;
GSSContext gssContext;
Socket sock;
DataInputStream dis;
DataOutputStream dos;
Client(String serverPrincipal, InetAddress serverHost, int serverPort)
throws Exception
{
this.serverPrincipal = serverPrincipal;
this.serverHost = serverHost;
this.serverPort = serverPort;
Subject clientSubject = getLoginSubject("testClient");
Subject.doAs(clientSubject, new PrivilegedAction() {
public Object run() {
try {
connectToServer();
} catch (Exception e) {
throw new RuntimeException(
"Exception thrown while client connecting " +
"to server", e);
}
return null;
}
});
}
void connectToServer() throws Exception {
GSSManager gssManager = GSSManager.getInstance();
GSSName serverName =
gssManager.createName(serverPrincipal, null);
Oid krb5Oid = new Oid("1.2.840.113554.1.2.2");
gssContext = gssManager.createContext(
serverName, krb5Oid, null, GSSContext.DEFAULT_LIFETIME);
gssContext.requestMutualAuth(true); // Mutual authentication
gssContext.requestInteg(true); // Will use integrity later
gssContext.requestConf(true); // always enable encryption on ctx
gssContext.requestCredDeleg(false);
gssContext.requestReplayDet(true);
gssContext.requestSequenceDet(true);
sock = new Socket(serverHost, serverPort);
sock.setTcpNoDelay(true);
sock.setKeepAlive(true);
dis = new DataInputStream(sock.getInputStream());
dos = new DataOutputStream(sock.getOutputStream());
// Do the context eastablishment loop
byte[] token = new byte[0];
while (true) {
token = gssContext.initSecContext(token, 0, token.length);
if (token != null) {
dos.writeInt(token.length);
dos.write(token);
dos.flush();
}
if (gssContext.isEstablished()) {
break;
} else {
token = new byte[dis.readInt()];
dis.readFully(token);
}
}
}
byte[] callServer(byte[] data) throws Exception {
MessageProp prop = new MessageProp(0, true);
byte[] token = gssContext.wrap(data, 0, data.length, prop);
dos.writeInt(token.length);
dos.write(token);
dos.flush();
token = new byte[dis.readInt()];
dis.readFully(token);
byte[] bytes = gssContext.unwrap(token, 0, token.length, prop);
return bytes;
}
}
}