UserAuthImpl made plenty cleaner...

This commit is contained in:
Shikhar Bhushan
2011-05-11 00:08:28 +01:00
parent 2d49cb4d77
commit 4f7b29da0d

View File

@@ -15,7 +15,7 @@
*/
package net.schmizz.sshj.userauth;
import net.schmizz.concurrent.Promise;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.AbstractService;
import net.schmizz.sshj.Service;
import net.schmizz.sshj.common.DisconnectReason;
@@ -36,71 +36,75 @@ import java.util.concurrent.TimeUnit;
/** {@link UserAuth} implementation. */
public class UserAuthImpl
extends AbstractService
implements UserAuth, AuthParams {
implements UserAuth {
private final Set<String> allowed = new HashSet<String>();
private final Event<UserAuthException> authenticated
= new Event<UserAuthException>("authenticated", UserAuthException.chainer);
// Externally available
private final Deque<UserAuthException> savedEx = new ArrayDeque<UserAuthException>();
private final Promise<Boolean, UserAuthException> result
= new Promise<Boolean, UserAuthException>("userauth result", UserAuthException.chainer);
private String username;
private AuthMethod currentMethod;
private Service nextService;
private boolean firstAttempt = true;
private volatile String banner;
private volatile String banner = "";
private volatile boolean partialSuccess;
// Internal state
private Set<String> allowedMethods;
private AuthMethod currentMethod;
public UserAuthImpl(Transport trans) {
super("ssh-userauth", trans);
}
// synchronized for mutual exclusion; ensure one authenticate() ever in progress
// synchronized for mutual exclusion; ensure only one authenticate() ever in progress
@Override
public synchronized void authenticate(String username, Service nextService, Iterable<AuthMethod> methods)
public synchronized void authenticate(final String username,
final Service nextService,
final Iterable<AuthMethod> methods)
throws UserAuthException, TransportException {
clearState();
this.username = username;
this.nextService = nextService;
savedEx.clear();
// Request "ssh-userauth" service (if not already active)
request();
super.request();
if (firstAttempt) { // Assume all allowed
if (allowedMethods == null) { // Assume all are allowed
allowedMethods = new HashSet<String>();
for (AuthMethod meth : methods)
allowed.add(meth.getName());
firstAttempt = false;
allowedMethods.add(meth.getName());
}
try {
for (AuthMethod meth : methods)
final AuthParams authParams = makeAuthParams(username, nextService);
if (allowed.contains(meth.getName())) {
for (AuthMethod meth : methods) {
log.info("Trying `{}` auth...", meth.getName());
if (!allowedMethods.contains(meth.getName())) {
saveException(new UserAuthException(meth.getName() + " auth not allowed by server"));
continue;
}
boolean success = false;
try {
success = tryWith(meth);
} catch (UserAuthException e) {
// Give other method a shot
saveException(e);
}
log.info("Trying `{}` auth...", meth.getName());
authenticated.clear();
currentMethod = meth;
if (success) {
log.info("`{}` auth successful", meth.getName());
return;
} else
log.info("`{}` auth failed", meth.getName());
try {
} else
saveException(meth.getName() + " auth not allowed by server");
currentMethod.init(authParams);
currentMethod.request();
authenticated.await(timeout, TimeUnit.SECONDS);
} catch (UserAuthException e) {
log.info("`{}` auth failed", meth.getName());
// Give other methods a shot
saveException(e);
continue;
}
log.info("`{}` auth successful", meth.getName());
trans.setAuthenticated(); // So it can put delayed compression into force if applicable
trans.setService(nextService); // We aren't in charge anymore, next service is
return;
}
} finally {
currentMethod = null;
@@ -111,34 +115,13 @@ public class UserAuthImpl
}
@Override
public String getBanner() {
return banner;
}
@Override
public String getNextServiceName() {
return nextService.getName();
}
@Override
public Transport getTransport() {
return trans;
}
/**
* Returns the exceptions that occured during authentication process but were ignored because more method were
* available for trying.
*
* @return deque of saved exceptions
*/
@Override
public Deque<UserAuthException> getSavedExceptions() {
public synchronized Deque<UserAuthException> getSavedExceptions() {
return savedEx;
}
@Override
public String getUsername() {
return username;
public String getBanner() {
return banner;
}
@Override
@@ -153,75 +136,63 @@ public class UserAuthImpl
throw new TransportException(DisconnectReason.PROTOCOL_ERROR);
switch (msg) {
case USERAUTH_BANNER:
gotBanner(buf);
break;
case USERAUTH_SUCCESS:
gotSuccess();
break;
case USERAUTH_BANNER: {
banner = buf.readString();
} break;
case USERAUTH_FAILURE:
gotFailure(buf);
break;
case USERAUTH_SUCCESS: {
authenticated.set();
} break;
case USERAUTH_FAILURE: {
allowedMethods.clear();
allowedMethods.addAll(Arrays.<String>asList(buf.readString().split(",")));
partialSuccess |= buf.readBoolean();
if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) {
currentMethod.request();
} else {
authenticated.deliverError(new UserAuthException(currentMethod.getName() + " auth failed"));
}
} break;
default: {
log.debug("Asking `{}` method to handle {} packet", currentMethod.getName(), msg);
try {
currentMethod.handle(msg, buf);
} catch (UserAuthException e) {
authenticated.deliverError(e);
}
}
default:
gotUnknown(msg, buf);
}
}
@Override
public void notifyError(SSHException error) {
super.notifyError(error);
result.deliverError(error);
authenticated.deliverError(error);
}
private void clearState() {
allowed.clear();
savedEx.clear();
banner = null;
}
private AuthParams makeAuthParams(final String username, final Service nextService) {
return new AuthParams() {
private void gotBanner(SSHPacket buf) {
banner = buf.readString();
}
@Override
public String getNextServiceName() {
return nextService.getName();
}
private void gotFailure(SSHPacket buf)
throws UserAuthException, TransportException {
allowed.clear();
allowed.addAll(Arrays.<String>asList(buf.readString().split(",")));
partialSuccess |= buf.readBoolean();
if (allowed.contains(currentMethod.getName()) && currentMethod.shouldRetry())
currentMethod.request();
else {
saveException(currentMethod.getName() + " auth failed");
result.deliver(false);
}
}
@Override
public Transport getTransport() {
return trans;
}
private void gotSuccess() {
trans.setAuthenticated(); // So it can put delayed compression into force if applicable
trans.setService(nextService); // We aren't in charge anymore, next service is
result.deliver(true);
}
@Override
public String getUsername() {
return username;
}
private void gotUnknown(Message msg, SSHPacket buf)
throws SSHException {
if (currentMethod == null || result == null) {
trans.sendUnimplemented();
return;
}
log.debug("Asking {} method to handle {} packet", currentMethod.getName(), msg);
try {
currentMethod.handle(msg, buf);
} catch (UserAuthException e) {
result.deliverError(e);
}
}
private void saveException(String msg) {
saveException(new UserAuthException(msg));
};
}
private void saveException(UserAuthException e) {
@@ -229,13 +200,4 @@ public class UserAuthImpl
savedEx.push(e);
}
private boolean tryWith(AuthMethod meth)
throws UserAuthException, TransportException {
currentMethod = meth;
result.clear();
meth.init(this);
meth.request();
return result.retrieve(timeout, TimeUnit.SECONDS);
}
}