diff --git a/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java b/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java index 577802db..2dd21bb3 100644 --- a/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java +++ b/src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java @@ -15,7 +15,7 @@ */ package net.schmizz.sshj.userauth; -import net.schmizz.concurrent.Promise; +import net.schmizz.concurrent.Event; import net.schmizz.sshj.AbstractService; import net.schmizz.sshj.Service; import net.schmizz.sshj.common.DisconnectReason; @@ -36,71 +36,75 @@ import java.util.concurrent.TimeUnit; /** {@link UserAuth} implementation. */ public class UserAuthImpl extends AbstractService - implements UserAuth, AuthParams { + implements UserAuth { - private final Set allowed = new HashSet(); + private final Event authenticated + = new Event("authenticated", UserAuthException.chainer); + // Externally available private final Deque savedEx = new ArrayDeque(); - - private final Promise result - = new Promise("userauth result", UserAuthException.chainer); - - private String username; - private AuthMethod currentMethod; - private Service nextService; - - private boolean firstAttempt = true; - - private volatile String banner; + private volatile String banner = ""; private volatile boolean partialSuccess; + // Internal state + private Set allowedMethods; + private AuthMethod currentMethod; + public UserAuthImpl(Transport trans) { super("ssh-userauth", trans); } - // synchronized for mutual exclusion; ensure one authenticate() ever in progress - + // synchronized for mutual exclusion; ensure only one authenticate() ever in progress @Override - public synchronized void authenticate(String username, Service nextService, Iterable methods) + public synchronized void authenticate(final String username, + final Service nextService, + final Iterable methods) throws UserAuthException, TransportException { - clearState(); - - this.username = username; - this.nextService = nextService; + savedEx.clear(); // Request "ssh-userauth" service (if not already active) - request(); + super.request(); - if (firstAttempt) { // Assume all allowed + if (allowedMethods == null) { // Assume all are allowed + allowedMethods = new HashSet(); for (AuthMethod meth : methods) - allowed.add(meth.getName()); - firstAttempt = false; + allowedMethods.add(meth.getName()); } try { - for (AuthMethod meth : methods) + final AuthParams authParams = makeAuthParams(username, nextService); - if (allowed.contains(meth.getName())) { + for (AuthMethod meth : methods) { - log.info("Trying `{}` auth...", meth.getName()); + if (!allowedMethods.contains(meth.getName())) { + saveException(new UserAuthException(meth.getName() + " auth not allowed by server")); + continue; + } - boolean success = false; - try { - success = tryWith(meth); - } catch (UserAuthException e) { - // Give other method a shot - saveException(e); - } + log.info("Trying `{}` auth...", meth.getName()); + authenticated.clear(); + currentMethod = meth; - if (success) { - log.info("`{}` auth successful", meth.getName()); - return; - } else - log.info("`{}` auth failed", meth.getName()); + try { - } else - saveException(meth.getName() + " auth not allowed by server"); + currentMethod.init(authParams); + currentMethod.request(); + authenticated.await(timeout, TimeUnit.SECONDS); + + } catch (UserAuthException e) { + log.info("`{}` auth failed", meth.getName()); + // Give other methods a shot + saveException(e); + continue; + } + + log.info("`{}` auth successful", meth.getName()); + trans.setAuthenticated(); // So it can put delayed compression into force if applicable + trans.setService(nextService); // We aren't in charge anymore, next service is + return; + + } } finally { currentMethod = null; @@ -111,34 +115,13 @@ public class UserAuthImpl } @Override - public String getBanner() { - return banner; - } - - @Override - public String getNextServiceName() { - return nextService.getName(); - } - - @Override - public Transport getTransport() { - return trans; - } - - /** - * Returns the exceptions that occured during authentication process but were ignored because more method were - * available for trying. - * - * @return deque of saved exceptions - */ - @Override - public Deque getSavedExceptions() { + public synchronized Deque getSavedExceptions() { return savedEx; } @Override - public String getUsername() { - return username; + public String getBanner() { + return banner; } @Override @@ -153,75 +136,63 @@ public class UserAuthImpl throw new TransportException(DisconnectReason.PROTOCOL_ERROR); switch (msg) { - case USERAUTH_BANNER: - gotBanner(buf); - break; - case USERAUTH_SUCCESS: - gotSuccess(); - break; + case USERAUTH_BANNER: { + banner = buf.readString(); + } break; - case USERAUTH_FAILURE: - gotFailure(buf); - break; + case USERAUTH_SUCCESS: { + authenticated.set(); + } break; + + case USERAUTH_FAILURE: { + allowedMethods.clear(); + allowedMethods.addAll(Arrays.asList(buf.readString().split(","))); + partialSuccess |= buf.readBoolean(); + if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) { + currentMethod.request(); + } else { + authenticated.deliverError(new UserAuthException(currentMethod.getName() + " auth failed")); + } + } break; + + default: { + log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg); + try { + currentMethod.handle(msg, buf); + } catch (UserAuthException e) { + authenticated.deliverError(e); + } + } - default: - gotUnknown(msg, buf); } } @Override public void notifyError(SSHException error) { super.notifyError(error); - result.deliverError(error); + authenticated.deliverError(error); } - private void clearState() { - allowed.clear(); - savedEx.clear(); - banner = null; - } + private AuthParams makeAuthParams(final String username, final Service nextService) { + return new AuthParams() { - private void gotBanner(SSHPacket buf) { - banner = buf.readString(); - } + @Override + public String getNextServiceName() { + return nextService.getName(); + } - private void gotFailure(SSHPacket buf) - throws UserAuthException, TransportException { - allowed.clear(); - allowed.addAll(Arrays.asList(buf.readString().split(","))); - partialSuccess |= buf.readBoolean(); - if (allowed.contains(currentMethod.getName()) && currentMethod.shouldRetry()) - currentMethod.request(); - else { - saveException(currentMethod.getName() + " auth failed"); - result.deliver(false); - } - } + @Override + public Transport getTransport() { + return trans; + } - private void gotSuccess() { - trans.setAuthenticated(); // So it can put delayed compression into force if applicable - trans.setService(nextService); // We aren't in charge anymore, next service is - result.deliver(true); - } + @Override + public String getUsername() { + return username; + } - private void gotUnknown(Message msg, SSHPacket buf) - throws SSHException { - if (currentMethod == null || result == null) { - trans.sendUnimplemented(); - return; - } - - log.debug("Asking {} method to handle {} packet", currentMethod.getName(), msg); - try { - currentMethod.handle(msg, buf); - } catch (UserAuthException e) { - result.deliverError(e); - } - } - - private void saveException(String msg) { - saveException(new UserAuthException(msg)); + }; } private void saveException(UserAuthException e) { @@ -229,13 +200,4 @@ public class UserAuthImpl savedEx.push(e); } - private boolean tryWith(AuthMethod meth) - throws UserAuthException, TransportException { - currentMethod = meth; - result.clear(); - meth.init(this); - meth.request(); - return result.retrieve(timeout, TimeUnit.SECONDS); - } - }