mirror of
https://github.com/hierynomus/sshj.git
synced 2025-12-07 15:50:57 +03:00
Simplify the UserAuth.authenticate(..) interface, move the multi-auth-method trial-and-error into SSHClient API
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
*/
|
||||
package net.schmizz.sshj.userauth;
|
||||
|
||||
import net.schmizz.concurrent.Event;
|
||||
import net.schmizz.concurrent.Promise;
|
||||
import net.schmizz.sshj.AbstractService;
|
||||
import net.schmizz.sshj.Service;
|
||||
import net.schmizz.sshj.common.DisconnectReason;
|
||||
@@ -26,11 +26,10 @@ import net.schmizz.sshj.transport.Transport;
|
||||
import net.schmizz.sshj.transport.TransportException;
|
||||
import net.schmizz.sshj.userauth.method.AuthMethod;
|
||||
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.Arrays;
|
||||
import java.util.Deque;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.Collections;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/** {@link UserAuth} implementation. */
|
||||
@@ -38,85 +37,51 @@ public class UserAuthImpl
|
||||
extends AbstractService
|
||||
implements UserAuth {
|
||||
|
||||
private final Event<UserAuthException> authenticated
|
||||
= new Event<UserAuthException>("authenticated", UserAuthException.chainer);
|
||||
private final Promise<Boolean, UserAuthException> authenticated
|
||||
= new Promise<Boolean, UserAuthException>("authenticated", UserAuthException.chainer);
|
||||
|
||||
// Externally available
|
||||
private final Deque<UserAuthException> savedEx = new ArrayDeque<UserAuthException>();
|
||||
private volatile String banner = "";
|
||||
private volatile boolean partialSuccess;
|
||||
private volatile boolean partialSuccess = false;
|
||||
private volatile List<String> allowedMethods = new LinkedList<String>();
|
||||
|
||||
// Internal state
|
||||
private Set<String> allowedMethods;
|
||||
private AuthMethod currentMethod;
|
||||
|
||||
public UserAuthImpl(Transport trans) {
|
||||
super("ssh-userauth", trans);
|
||||
}
|
||||
|
||||
// synchronized for mutual exclusion; ensure only one authenticate() ever in progress
|
||||
@Override
|
||||
public synchronized void authenticate(final String username,
|
||||
final Service nextService,
|
||||
final Iterable<AuthMethod> methods)
|
||||
public boolean authenticate(String username, Service nextService, AuthMethod method, int timeoutMs)
|
||||
throws UserAuthException, TransportException {
|
||||
savedEx.clear();
|
||||
|
||||
// Request "ssh-userauth" service (if not already active)
|
||||
super.request();
|
||||
|
||||
if (allowedMethods == null) { // Assume all are allowed
|
||||
allowedMethods = new HashSet<String>();
|
||||
for (AuthMethod meth : methods)
|
||||
allowedMethods.add(meth.getName());
|
||||
}
|
||||
final boolean outcome;
|
||||
|
||||
authenticated.lock();
|
||||
try {
|
||||
super.request(); // Request "ssh-userauth" service (if not already active)
|
||||
|
||||
final AuthParams authParams = makeAuthParams(username, nextService);
|
||||
currentMethod = method;
|
||||
currentMethod.init(makeAuthParams(username, nextService));
|
||||
authenticated.clear();
|
||||
log.debug("Trying `{}` auth...", method.getName());
|
||||
currentMethod.request();
|
||||
outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS);
|
||||
|
||||
for (AuthMethod meth : methods) {
|
||||
|
||||
if (!allowedMethods.contains(meth.getName())) {
|
||||
saveException(new UserAuthException(meth.getName() + " auth not allowed by server"));
|
||||
continue;
|
||||
}
|
||||
|
||||
log.debug("Trying `{}` auth...", meth.getName());
|
||||
authenticated.clear();
|
||||
currentMethod = meth;
|
||||
|
||||
try {
|
||||
|
||||
currentMethod.init(authParams);
|
||||
currentMethod.request();
|
||||
authenticated.await(timeout, TimeUnit.SECONDS);
|
||||
|
||||
} catch (UserAuthException e) {
|
||||
log.debug("`{}` auth failed", meth.getName());
|
||||
// Give other methods a shot
|
||||
saveException(e);
|
||||
continue;
|
||||
}
|
||||
|
||||
log.debug("`{}` auth successful", meth.getName());
|
||||
if (outcome) {
|
||||
log.debug("`{}` auth successful", method.getName());
|
||||
trans.setAuthenticated(); // So it can put delayed compression into force if applicable
|
||||
trans.setService(nextService); // We aren't in charge anymore, next service is
|
||||
return;
|
||||
|
||||
} else {
|
||||
log.debug("`{}` auth failed", method.getName());
|
||||
}
|
||||
|
||||
} finally {
|
||||
currentMethod = null;
|
||||
authenticated.unlock();
|
||||
}
|
||||
|
||||
log.debug("Had {} saved exception(s)", savedEx.size());
|
||||
throw new UserAuthException("Exhausted available authentication methods", savedEx.peek());
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized Deque<UserAuthException> getSavedExceptions() {
|
||||
return savedEx;
|
||||
return outcome;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -129,45 +94,54 @@ public class UserAuthImpl
|
||||
return partialSuccess;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterable<String> getAllowedMethods() {
|
||||
return Collections.unmodifiableList(allowedMethods);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handle(Message msg, SSHPacket buf)
|
||||
throws SSHException {
|
||||
if (!msg.in(50, 80)) // ssh-userauth packets have message numbers between 50-80
|
||||
throw new TransportException(DisconnectReason.PROTOCOL_ERROR);
|
||||
|
||||
switch (msg) {
|
||||
authenticated.lock();
|
||||
try {
|
||||
switch (msg) {
|
||||
|
||||
case USERAUTH_BANNER: {
|
||||
banner = buf.readString();
|
||||
}
|
||||
break;
|
||||
|
||||
case USERAUTH_SUCCESS: {
|
||||
authenticated.set();
|
||||
}
|
||||
break;
|
||||
|
||||
case USERAUTH_FAILURE: {
|
||||
allowedMethods.clear();
|
||||
allowedMethods.addAll(Arrays.<String>asList(buf.readString().split(",")));
|
||||
partialSuccess |= buf.readBoolean();
|
||||
if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) {
|
||||
currentMethod.request();
|
||||
} else {
|
||||
authenticated.deliverError(new UserAuthException(currentMethod.getName() + " auth failed"));
|
||||
case USERAUTH_BANNER: {
|
||||
banner = buf.readString();
|
||||
}
|
||||
}
|
||||
break;
|
||||
break;
|
||||
|
||||
default: {
|
||||
log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg);
|
||||
try {
|
||||
currentMethod.handle(msg, buf);
|
||||
} catch (UserAuthException e) {
|
||||
authenticated.deliverError(e);
|
||||
case USERAUTH_SUCCESS: {
|
||||
authenticated.deliver(true);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case USERAUTH_FAILURE: {
|
||||
allowedMethods = Arrays.asList(buf.readString().split(","));
|
||||
partialSuccess |= buf.readBoolean();
|
||||
if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) {
|
||||
currentMethod.request();
|
||||
} else {
|
||||
authenticated.deliver(false);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
default: {
|
||||
log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg);
|
||||
try {
|
||||
currentMethod.handle(msg, buf);
|
||||
} catch (UserAuthException e) {
|
||||
authenticated.deliverError(e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} finally {
|
||||
authenticated.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,9 +172,4 @@ public class UserAuthImpl
|
||||
};
|
||||
}
|
||||
|
||||
private void saveException(UserAuthException e) {
|
||||
log.debug("Saving for later - {}", e.toString());
|
||||
savedEx.push(e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user