diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index 546a92a7..980dd57e 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -68,6 +68,7 @@ import java.net.ServerSocket; import java.security.KeyPair; import java.security.PublicKey; import java.util.Arrays; +import java.util.Deque; import java.util.LinkedList; import java.util.List; @@ -175,6 +176,8 @@ public class SSHClient }); } + // FIXME: there are way too many auth... overrides. Better API needed. + /** * Authenticate {@code username} using the supplied {@code methods}. * @@ -202,7 +205,16 @@ public class SSHClient public void auth(String username, Iterable methods) throws UserAuthException, TransportException { checkConnected(); - auth.authenticate(username, (Service) conn, methods); + final Deque savedEx = new LinkedList(); + for (AuthMethod method: methods) { + try { + if (auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs())) + return; + } catch (UserAuthException e) { + savedEx.push(e); + } + } + throw new UserAuthException("Exhausted available authentication methods", savedEx.peek()); } /** @@ -390,8 +402,7 @@ public class SSHClient /** * @return the associated {@link UserAuth} instance. This allows access to information like the {@link * UserAuth#getBanner() authentication banner}, whether authentication was at least {@link - * UserAuth#hadPartialSuccess() partially successful}, and any {@link UserAuth#getSavedExceptions() saved - * exceptions} that were ignored because there were more authentication method that could be tried. + * UserAuth#hadPartialSuccess() partially successful}. */ public UserAuth getUserAuth() { return auth; diff --git a/src/main/java/net/schmizz/sshj/userauth/UserAuth.java b/src/main/java/net/schmizz/sshj/userauth/UserAuth.java index 952cd1b3..6d5d4b3a 100644 --- a/src/main/java/net/schmizz/sshj/userauth/UserAuth.java +++ b/src/main/java/net/schmizz/sshj/userauth/UserAuth.java @@ -19,8 +19,6 @@ import net.schmizz.sshj.Service; import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.userauth.method.AuthMethod; -import java.util.Deque; - /** User authentication API. See RFC 4252. */ public interface UserAuth { @@ -29,9 +27,7 @@ public interface UserAuth { * {@link Service} that will be enabled on successful authentication. *

* Authentication fails if there are no method available, i.e. if all the method failed or there were method - * available but could not be attempted because the server did not allow them. In this case, a {@code - * UserAuthException} is thrown with its cause as the last authentication failure. Other {@code UserAuthException}'s - * which may have been ignored may be accessed via {@link #getSavedExceptions()}. + * available but could not be attempted because the server did not allow them. *

* Further attempts may also be made by catching {@code UserAuthException} and retrying with this method. * @@ -39,10 +35,12 @@ public interface UserAuth { * @param nextService the service to set on successful authentication * @param methods the {@link AuthMethod}'s to try * + * @return whether authentication was successful + * * @throws UserAuthException in case of authentication failure * @throws TransportException if there was a transport-layer error */ - void authenticate(String username, Service nextService, Iterable methods) + boolean authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs) throws UserAuthException, TransportException; /** @@ -53,23 +51,13 @@ public interface UserAuth { */ String getBanner(); - /** @return saved exceptions that might have been ignored because there were more authentication method available. */ - Deque getSavedExceptions(); - - /** @return the {@code timeout} for a method to successfully authenticate before it is abandoned. */ - int getTimeout(); - /** * @return whether authentication was partially successful. Some server's may be configured to require multiple * authentications; and this value will be {@code true} if at least one of the method supplied succeeded. */ boolean hadPartialSuccess(); - /** - * Set the {@code timeout} for any method to successfully authenticate before it is abandoned. - * - * @param timeout the timeout in seconds - */ - void setTimeout(int timeout); + /** The available authentication methods. This is only defined once an unsuccessful authentication has taken place. */ + Iterable getAllowedMethods(); } diff --git a/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java b/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java index 944ca998..3247e263 100644 --- a/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java +++ b/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java @@ -15,7 +15,7 @@ */ package net.schmizz.sshj.userauth; -import net.schmizz.concurrent.Event; +import net.schmizz.concurrent.Promise; import net.schmizz.sshj.AbstractService; import net.schmizz.sshj.Service; import net.schmizz.sshj.common.DisconnectReason; @@ -26,11 +26,10 @@ import net.schmizz.sshj.transport.Transport; import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.userauth.method.AuthMethod; -import java.util.ArrayDeque; import java.util.Arrays; -import java.util.Deque; -import java.util.HashSet; -import java.util.Set; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; import java.util.concurrent.TimeUnit; /** {@link UserAuth} implementation. */ @@ -38,85 +37,51 @@ public class UserAuthImpl extends AbstractService implements UserAuth { - private final Event authenticated - = new Event("authenticated", UserAuthException.chainer); + private final Promise authenticated + = new Promise("authenticated", UserAuthException.chainer); // Externally available - private final Deque savedEx = new ArrayDeque(); private volatile String banner = ""; - private volatile boolean partialSuccess; + private volatile boolean partialSuccess = false; + private volatile List allowedMethods = new LinkedList(); // Internal state - private Set allowedMethods; private AuthMethod currentMethod; public UserAuthImpl(Transport trans) { super("ssh-userauth", trans); } - // synchronized for mutual exclusion; ensure only one authenticate() ever in progress @Override - public synchronized void authenticate(final String username, - final Service nextService, - final Iterable methods) + public boolean authenticate(String username, Service nextService, AuthMethod method, int timeoutMs) throws UserAuthException, TransportException { - savedEx.clear(); - - // Request "ssh-userauth" service (if not already active) - super.request(); - - if (allowedMethods == null) { // Assume all are allowed - allowedMethods = new HashSet(); - for (AuthMethod meth : methods) - allowedMethods.add(meth.getName()); - } + final boolean outcome; + authenticated.lock(); try { + super.request(); // Request "ssh-userauth" service (if not already active) - final AuthParams authParams = makeAuthParams(username, nextService); + currentMethod = method; + currentMethod.init(makeAuthParams(username, nextService)); + authenticated.clear(); + log.debug("Trying `{}` auth...", method.getName()); + currentMethod.request(); + outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS); - for (AuthMethod meth : methods) { - - if (!allowedMethods.contains(meth.getName())) { - saveException(new UserAuthException(meth.getName() + " auth not allowed by server")); - continue; - } - - log.debug("Trying `{}` auth...", meth.getName()); - authenticated.clear(); - currentMethod = meth; - - try { - - currentMethod.init(authParams); - currentMethod.request(); - authenticated.await(timeout, TimeUnit.SECONDS); - - } catch (UserAuthException e) { - log.debug("`{}` auth failed", meth.getName()); - // Give other methods a shot - saveException(e); - continue; - } - - log.debug("`{}` auth successful", meth.getName()); + if (outcome) { + log.debug("`{}` auth successful", method.getName()); trans.setAuthenticated(); // So it can put delayed compression into force if applicable trans.setService(nextService); // We aren't in charge anymore, next service is - return; - + } else { + log.debug("`{}` auth failed", method.getName()); } } finally { currentMethod = null; + authenticated.unlock(); } - log.debug("Had {} saved exception(s)", savedEx.size()); - throw new UserAuthException("Exhausted available authentication methods", savedEx.peek()); - } - - @Override - public synchronized Deque getSavedExceptions() { - return savedEx; + return outcome; } @Override @@ -129,45 +94,54 @@ public class UserAuthImpl return partialSuccess; } + @Override + public Iterable getAllowedMethods() { + return Collections.unmodifiableList(allowedMethods); + } + @Override public void handle(Message msg, SSHPacket buf) throws SSHException { if (!msg.in(50, 80)) // ssh-userauth packets have message numbers between 50-80 throw new TransportException(DisconnectReason.PROTOCOL_ERROR); - switch (msg) { + authenticated.lock(); + try { + switch (msg) { - case USERAUTH_BANNER: { - banner = buf.readString(); - } - break; - - case USERAUTH_SUCCESS: { - authenticated.set(); - } - break; - - case USERAUTH_FAILURE: { - allowedMethods.clear(); - allowedMethods.addAll(Arrays.asList(buf.readString().split(","))); - partialSuccess |= buf.readBoolean(); - if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) { - currentMethod.request(); - } else { - authenticated.deliverError(new UserAuthException(currentMethod.getName() + " auth failed")); + case USERAUTH_BANNER: { + banner = buf.readString(); } - } - break; + break; - default: { - log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg); - try { - currentMethod.handle(msg, buf); - } catch (UserAuthException e) { - authenticated.deliverError(e); + case USERAUTH_SUCCESS: { + authenticated.deliver(true); } - } + break; + case USERAUTH_FAILURE: { + allowedMethods = Arrays.asList(buf.readString().split(",")); + partialSuccess |= buf.readBoolean(); + if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) { + currentMethod.request(); + } else { + authenticated.deliver(false); + } + } + break; + + default: { + log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg); + try { + currentMethod.handle(msg, buf); + } catch (UserAuthException e) { + authenticated.deliverError(e); + } + } + + } + } finally { + authenticated.unlock(); } } @@ -198,9 +172,4 @@ public class UserAuthImpl }; } - private void saveException(UserAuthException e) { - log.debug("Saving for later - {}", e.toString()); - savedEx.push(e); - } - }