package org.jgroups.auth.sasl;
import java.security.Provider;
import java.security.Security;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslServerFactory;
/**
* Utility methods for handling SASL authentication
*
* @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
* @author Tristan Tarrant
*/
public final class SaslUtils {
private SaslUtils() {
}
/**
* Returns an iterator of all of the registered {@code SaslServerFactory}s where the order is
* based on the order of the Provider registration and/or class path order. Class path providers
* are listed before global providers; in the event of a name conflict, the class path provider
* is preferred.
*
* @param classLoader
* the class loader to use
* @param includeGlobal
* {@code true} to include globally registered providers, {@code false} to exclude
* them
* @return the {@code Iterator} of {@code SaslServerFactory}s
*/
public static Iterator<SaslServerFactory> getSaslServerFactories(ClassLoader classLoader, boolean includeGlobal) {
return getFactories(SaslServerFactory.class, classLoader, includeGlobal);
}
/**
* Returns an iterator of all of the registered {@code SaslServerFactory}s where the order is
* based on the order of the Provider registration and/or class path order.
*
* @return the {@code Iterator} of {@code SaslServerFactory}s
*/
public static Iterator<SaslServerFactory> getSaslServerFactories() {
return getFactories(SaslServerFactory.class, null, true);
}
/**
* Returns an iterator of all of the registered {@code SaslClientFactory}s where the order is
* based on the order of the Provider registration and/or class path order. Class path providers
* are listed before global providers; in the event of a name conflict, the class path provider
* is preferred.
*
* @param classLoader
* the class loader to use
* @param includeGlobal
* {@code true} to include globally registered providers, {@code false} to exclude
* them
* @return the {@code Iterator} of {@code SaslClientFactory}s
*/
public static Iterator<SaslClientFactory> getSaslClientFactories(ClassLoader classLoader, boolean includeGlobal) {
return getFactories(SaslClientFactory.class, classLoader, includeGlobal);
}
/**
* Returns an iterator of all of the registered {@code SaslClientFactory}s where the order is
* based on the order of the Provider registration and/or class path order.
*
* @return the {@code Iterator} of {@code SaslClientFactory}s
*/
public static Iterator<SaslClientFactory> getSaslClientFactories() {
return getFactories(SaslClientFactory.class, null, true);
}
private static <T> Iterator<T> getFactories(Class<T> type, ClassLoader classLoader, boolean includeGlobal) {
Set<T> factories = new LinkedHashSet<T>();
final ServiceLoader<T> loader = ServiceLoader.load(type, classLoader);
for (T factory : loader) {
factories.add(factory);
}
if (includeGlobal) {
Set<String> loadedClasses = new HashSet<String>();
final String filter = type.getSimpleName() + ".";
Provider[] providers = Security.getProviders();
for (Provider currentProvider : providers) {
final ClassLoader cl = currentProvider.getClass().getClassLoader();
for (Object currentKey : currentProvider.keySet()) {
if (currentKey instanceof String && ((String) currentKey).startsWith(filter)
&& ((String) currentKey).indexOf(' ') < 0) {
String className = currentProvider.getProperty((String) currentKey);
if (className != null && loadedClasses.add(className)) {
try {
factories.add(Class.forName(className, true, cl).asSubclass(type).newInstance());
} catch (ClassNotFoundException e) {
} catch (ClassCastException e) {
} catch (InstantiationException e) {
} catch (IllegalAccessException e) {
}
}
}
}
}
}
return factories.iterator();
}
public static SaslServerFactory getSaslServerFactory(String mech, Map<String, ?> props) {
Iterator<SaslServerFactory> saslFactories = SaslUtils.getSaslServerFactories(SaslUtils.class.getClassLoader(), true);
while (saslFactories.hasNext()) {
SaslServerFactory saslFactory = saslFactories.next();
for (String supportedMech : saslFactory.getMechanismNames(props)) {
if (supportedMech.equals(mech)) {
return saslFactory;
}
}
}
throw new IllegalArgumentException("No SASL server factory for mech " + mech);
}
public static SaslClientFactory getSaslClientFactory(String mech, Map<String, ?> props) {
Iterator<SaslClientFactory> saslFactories = SaslUtils.getSaslClientFactories(SaslUtils.class.getClassLoader(), true);
while (saslFactories.hasNext()) {
SaslClientFactory saslFactory = saslFactories.next();
for (String supportedMech : saslFactory.getMechanismNames(props)) {
if (mech.equals(supportedMech)) {
return saslFactory;
}
}
}
throw new IllegalArgumentException("No SASL client factory for mech " + mech);
}
}