Simplify the UserAuth.authenticate(..) interface, move the multi-auth-method trial-and-error into SSHClient API

This commit is contained in:
shikhar
2013-04-15 22:56:24 -04:00
parent 0ec6918d7a
commit 0ddd1f38c5
3 changed files with 81 additions and 113 deletions

View File

@@ -68,6 +68,7 @@ import java.net.ServerSocket;
import java.security.KeyPair; import java.security.KeyPair;
import java.security.PublicKey; import java.security.PublicKey;
import java.util.Arrays; import java.util.Arrays;
import java.util.Deque;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
@@ -175,6 +176,8 @@ public class SSHClient
}); });
} }
// FIXME: there are way too many auth... overrides. Better API needed.
/** /**
* Authenticate {@code username} using the supplied {@code methods}. * Authenticate {@code username} using the supplied {@code methods}.
* *
@@ -202,7 +205,16 @@ public class SSHClient
public void auth(String username, Iterable<AuthMethod> methods) public void auth(String username, Iterable<AuthMethod> methods)
throws UserAuthException, TransportException { throws UserAuthException, TransportException {
checkConnected(); checkConnected();
auth.authenticate(username, (Service) conn, methods); final Deque<UserAuthException> savedEx = new LinkedList<UserAuthException>();
for (AuthMethod method: methods) {
try {
if (auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs()))
return;
} catch (UserAuthException e) {
savedEx.push(e);
}
}
throw new UserAuthException("Exhausted available authentication methods", savedEx.peek());
} }
/** /**
@@ -390,8 +402,7 @@ public class SSHClient
/** /**
* @return the associated {@link UserAuth} instance. This allows access to information like the {@link * @return the associated {@link UserAuth} instance. This allows access to information like the {@link
* UserAuth#getBanner() authentication banner}, whether authentication was at least {@link * UserAuth#getBanner() authentication banner}, whether authentication was at least {@link
* UserAuth#hadPartialSuccess() partially successful}, and any {@link UserAuth#getSavedExceptions() saved * UserAuth#hadPartialSuccess() partially successful}.
* exceptions} that were ignored because there were more authentication method that could be tried.
*/ */
public UserAuth getUserAuth() { public UserAuth getUserAuth() {
return auth; return auth;

View File

@@ -19,8 +19,6 @@ import net.schmizz.sshj.Service;
import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.userauth.method.AuthMethod; import net.schmizz.sshj.userauth.method.AuthMethod;
import java.util.Deque;
/** User authentication API. See RFC 4252. */ /** User authentication API. See RFC 4252. */
public interface UserAuth { public interface UserAuth {
@@ -29,9 +27,7 @@ public interface UserAuth {
* {@link Service} that will be enabled on successful authentication. * {@link Service} that will be enabled on successful authentication.
* <p/> * <p/>
* Authentication fails if there are no method available, i.e. if all the method failed or there were method * Authentication fails if there are no method available, i.e. if all the method failed or there were method
* available but could not be attempted because the server did not allow them. In this case, a {@code * available but could not be attempted because the server did not allow them.
* UserAuthException} is thrown with its cause as the last authentication failure. Other {@code UserAuthException}'s
* which may have been ignored may be accessed via {@link #getSavedExceptions()}.
* <p/> * <p/>
* Further attempts may also be made by catching {@code UserAuthException} and retrying with this method. * Further attempts may also be made by catching {@code UserAuthException} and retrying with this method.
* *
@@ -39,10 +35,12 @@ public interface UserAuth {
* @param nextService the service to set on successful authentication * @param nextService the service to set on successful authentication
* @param methods the {@link AuthMethod}'s to try * @param methods the {@link AuthMethod}'s to try
* *
* @return whether authentication was successful
*
* @throws UserAuthException in case of authentication failure * @throws UserAuthException in case of authentication failure
* @throws TransportException if there was a transport-layer error * @throws TransportException if there was a transport-layer error
*/ */
void authenticate(String username, Service nextService, Iterable<AuthMethod> methods) boolean authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs)
throws UserAuthException, TransportException; throws UserAuthException, TransportException;
/** /**
@@ -53,23 +51,13 @@ public interface UserAuth {
*/ */
String getBanner(); String getBanner();
/** @return saved exceptions that might have been ignored because there were more authentication method available. */
Deque<UserAuthException> getSavedExceptions();
/** @return the {@code timeout} for a method to successfully authenticate before it is abandoned. */
int getTimeout();
/** /**
* @return whether authentication was partially successful. Some server's may be configured to require multiple * @return whether authentication was partially successful. Some server's may be configured to require multiple
* authentications; and this value will be {@code true} if at least one of the method supplied succeeded. * authentications; and this value will be {@code true} if at least one of the method supplied succeeded.
*/ */
boolean hadPartialSuccess(); boolean hadPartialSuccess();
/** /** The available authentication methods. This is only defined once an unsuccessful authentication has taken place. */
* Set the {@code timeout} for any method to successfully authenticate before it is abandoned. Iterable<String> getAllowedMethods();
*
* @param timeout the timeout in seconds
*/
void setTimeout(int timeout);
} }

View File

@@ -15,7 +15,7 @@
*/ */
package net.schmizz.sshj.userauth; package net.schmizz.sshj.userauth;
import net.schmizz.concurrent.Event; import net.schmizz.concurrent.Promise;
import net.schmizz.sshj.AbstractService; import net.schmizz.sshj.AbstractService;
import net.schmizz.sshj.Service; import net.schmizz.sshj.Service;
import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.DisconnectReason;
@@ -26,11 +26,10 @@ import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.userauth.method.AuthMethod; import net.schmizz.sshj.userauth.method.AuthMethod;
import java.util.ArrayDeque;
import java.util.Arrays; import java.util.Arrays;
import java.util.Deque; import java.util.Collections;
import java.util.HashSet; import java.util.LinkedList;
import java.util.Set; import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
/** {@link UserAuth} implementation. */ /** {@link UserAuth} implementation. */
@@ -38,85 +37,51 @@ public class UserAuthImpl
extends AbstractService extends AbstractService
implements UserAuth { implements UserAuth {
private final Event<UserAuthException> authenticated private final Promise<Boolean, UserAuthException> authenticated
= new Event<UserAuthException>("authenticated", UserAuthException.chainer); = new Promise<Boolean, UserAuthException>("authenticated", UserAuthException.chainer);
// Externally available // Externally available
private final Deque<UserAuthException> savedEx = new ArrayDeque<UserAuthException>();
private volatile String banner = ""; private volatile String banner = "";
private volatile boolean partialSuccess; private volatile boolean partialSuccess = false;
private volatile List<String> allowedMethods = new LinkedList<String>();
// Internal state // Internal state
private Set<String> allowedMethods;
private AuthMethod currentMethod; private AuthMethod currentMethod;
public UserAuthImpl(Transport trans) { public UserAuthImpl(Transport trans) {
super("ssh-userauth", trans); super("ssh-userauth", trans);
} }
// synchronized for mutual exclusion; ensure only one authenticate() ever in progress
@Override @Override
public synchronized void authenticate(final String username, public boolean authenticate(String username, Service nextService, AuthMethod method, int timeoutMs)
final Service nextService,
final Iterable<AuthMethod> methods)
throws UserAuthException, TransportException { throws UserAuthException, TransportException {
savedEx.clear(); final boolean outcome;
// Request "ssh-userauth" service (if not already active)
super.request();
if (allowedMethods == null) { // Assume all are allowed
allowedMethods = new HashSet<String>();
for (AuthMethod meth : methods)
allowedMethods.add(meth.getName());
}
authenticated.lock();
try { try {
super.request(); // Request "ssh-userauth" service (if not already active)
final AuthParams authParams = makeAuthParams(username, nextService); currentMethod = method;
currentMethod.init(makeAuthParams(username, nextService));
authenticated.clear();
log.debug("Trying `{}` auth...", method.getName());
currentMethod.request();
outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS);
for (AuthMethod meth : methods) { if (outcome) {
log.debug("`{}` auth successful", method.getName());
if (!allowedMethods.contains(meth.getName())) {
saveException(new UserAuthException(meth.getName() + " auth not allowed by server"));
continue;
}
log.debug("Trying `{}` auth...", meth.getName());
authenticated.clear();
currentMethod = meth;
try {
currentMethod.init(authParams);
currentMethod.request();
authenticated.await(timeout, TimeUnit.SECONDS);
} catch (UserAuthException e) {
log.debug("`{}` auth failed", meth.getName());
// Give other methods a shot
saveException(e);
continue;
}
log.debug("`{}` auth successful", meth.getName());
trans.setAuthenticated(); // So it can put delayed compression into force if applicable trans.setAuthenticated(); // So it can put delayed compression into force if applicable
trans.setService(nextService); // We aren't in charge anymore, next service is trans.setService(nextService); // We aren't in charge anymore, next service is
return; } else {
log.debug("`{}` auth failed", method.getName());
} }
} finally { } finally {
currentMethod = null; currentMethod = null;
authenticated.unlock();
} }
log.debug("Had {} saved exception(s)", savedEx.size()); return outcome;
throw new UserAuthException("Exhausted available authentication methods", savedEx.peek());
}
@Override
public synchronized Deque<UserAuthException> getSavedExceptions() {
return savedEx;
} }
@Override @Override
@@ -129,45 +94,54 @@ public class UserAuthImpl
return partialSuccess; return partialSuccess;
} }
@Override
public Iterable<String> getAllowedMethods() {
return Collections.unmodifiableList(allowedMethods);
}
@Override @Override
public void handle(Message msg, SSHPacket buf) public void handle(Message msg, SSHPacket buf)
throws SSHException { throws SSHException {
if (!msg.in(50, 80)) // ssh-userauth packets have message numbers between 50-80 if (!msg.in(50, 80)) // ssh-userauth packets have message numbers between 50-80
throw new TransportException(DisconnectReason.PROTOCOL_ERROR); throw new TransportException(DisconnectReason.PROTOCOL_ERROR);
switch (msg) { authenticated.lock();
try {
switch (msg) {
case USERAUTH_BANNER: { case USERAUTH_BANNER: {
banner = buf.readString(); banner = buf.readString();
}
break;
case USERAUTH_SUCCESS: {
authenticated.set();
}
break;
case USERAUTH_FAILURE: {
allowedMethods.clear();
allowedMethods.addAll(Arrays.<String>asList(buf.readString().split(",")));
partialSuccess |= buf.readBoolean();
if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) {
currentMethod.request();
} else {
authenticated.deliverError(new UserAuthException(currentMethod.getName() + " auth failed"));
} }
} break;
break;
default: { case USERAUTH_SUCCESS: {
log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg); authenticated.deliver(true);
try {
currentMethod.handle(msg, buf);
} catch (UserAuthException e) {
authenticated.deliverError(e);
} }
} break;
case USERAUTH_FAILURE: {
allowedMethods = Arrays.asList(buf.readString().split(","));
partialSuccess |= buf.readBoolean();
if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) {
currentMethod.request();
} else {
authenticated.deliver(false);
}
}
break;
default: {
log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg);
try {
currentMethod.handle(msg, buf);
} catch (UserAuthException e) {
authenticated.deliverError(e);
}
}
}
} finally {
authenticated.unlock();
} }
} }
@@ -198,9 +172,4 @@ public class UserAuthImpl
}; };
} }
private void saveException(UserAuthException e) {
log.debug("Saving for later - {}", e.toString());
savedEx.push(e);
}
} }