Retry authentication with all remaining auth methods after partial success

Signed-off-by: Jeroen van Erp <jeroen@hierynomus.com>
This commit is contained in:
Jeroen van Erp
2022-09-23 22:42:57 +02:00
parent d628c47bae
commit d7dd73b9c8
4 changed files with 53 additions and 11 deletions

View File

@@ -40,6 +40,7 @@ import net.schmizz.sshj.transport.verification.AlgorithmsVerifier;
import net.schmizz.sshj.transport.verification.FingerprintVerifier; import net.schmizz.sshj.transport.verification.FingerprintVerifier;
import net.schmizz.sshj.transport.verification.HostKeyVerifier; import net.schmizz.sshj.transport.verification.HostKeyVerifier;
import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts; import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts;
import net.schmizz.sshj.userauth.AuthResult;
import net.schmizz.sshj.userauth.UserAuth; import net.schmizz.sshj.userauth.UserAuth;
import net.schmizz.sshj.userauth.UserAuthException; import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.userauth.UserAuthImpl; import net.schmizz.sshj.userauth.UserAuthImpl;
@@ -218,13 +219,30 @@ public class SSHClient
throws UserAuthException, TransportException { throws UserAuthException, TransportException {
checkConnected(); checkConnected();
final Deque<UserAuthException> savedEx = new LinkedList<UserAuthException>(); final Deque<UserAuthException> savedEx = new LinkedList<UserAuthException>();
for (AuthMethod method: methods) { final List<AuthMethod> tried = new LinkedList<AuthMethod>();
for (Iterator<AuthMethod> it = methods.iterator(); it.hasNext();) {
AuthMethod method = it.next();
method.setLoggerFactory(loggerFactory); method.setLoggerFactory(loggerFactory);
try { try {
if (auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs())) AuthResult result = auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs());
if (result == AuthResult.SUCCESS) {
return; return;
} else if (result == AuthResult.PARTIAL) {
// Put all remaining methods in the tried list, so that we can try them for the second round of authentication
while (it.hasNext()) {
tried.add(it.next());
}
auth(username, tried);
return;
}
tried.add(method);
} catch (UserAuthException e) { } catch (UserAuthException e) {
savedEx.push(e); savedEx.push(e);
tried.add(method);
} }
} }
throw new UserAuthException("Exhausted available authentication methods", savedEx.peek()); throw new UserAuthException("Exhausted available authentication methods", savedEx.peek());

View File

@@ -0,0 +1,22 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.userauth;
public enum AuthResult {
SUCCESS,
FAILURE,
PARTIAL
}

View File

@@ -37,12 +37,12 @@ public interface UserAuth {
* @param nextService the service to set on successful authentication * @param nextService the service to set on successful authentication
* @param methods the {@link AuthMethod}'s to try * @param methods the {@link AuthMethod}'s to try
* *
* @return whether authentication was successful * @return whether authentication was successful, failed, or partially successful
* *
* @throws UserAuthException in case of authentication failure * @throws UserAuthException in case of authentication failure
* @throws TransportException if there was a transport-layer error * @throws TransportException if there was a transport-layer error
*/ */
boolean authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs) AuthResult authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs)
throws UserAuthException, TransportException; throws UserAuthException, TransportException;
/** /**

View File

@@ -40,7 +40,7 @@ public class UserAuthImpl
extends AbstractService extends AbstractService
implements UserAuth { implements UserAuth {
private final Promise<Boolean, UserAuthException> authenticated; private final Promise<AuthResult, UserAuthException> authenticated;
// Externally available // Externally available
private volatile String banner = ""; private volatile String banner = "";
@@ -53,13 +53,13 @@ public class UserAuthImpl
public UserAuthImpl(Transport trans) { public UserAuthImpl(Transport trans) {
super("ssh-userauth", trans); super("ssh-userauth", trans);
authenticated = new Promise<Boolean, UserAuthException>("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory()); authenticated = new Promise<AuthResult, UserAuthException>("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory());
} }
@Override @Override
public boolean authenticate(String username, Service nextService, AuthMethod method, int timeoutMs) public AuthResult authenticate(String username, Service nextService, AuthMethod method, int timeoutMs)
throws UserAuthException, TransportException { throws UserAuthException, TransportException {
final boolean outcome; final AuthResult outcome;
authenticated.lock(); authenticated.lock();
try { try {
@@ -73,8 +73,10 @@ public class UserAuthImpl
currentMethod.request(); currentMethod.request();
outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS); outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS);
if (outcome) { if (outcome == AuthResult.SUCCESS) {
log.debug("`{}` auth successful", method.getName()); log.debug("`{}` auth successful", method.getName());
} else if (outcome == AuthResult.PARTIAL) {
log.debug("`{}` auth partially successful", method.getName());
} else { } else {
log.debug("`{}` auth failed", method.getName()); log.debug("`{}` auth failed", method.getName());
} }
@@ -124,7 +126,7 @@ public class UserAuthImpl
// Should fix https://github.com/hierynomus/sshj/issues/237 // Should fix https://github.com/hierynomus/sshj/issues/237
trans.setAuthenticated(); // So it can put delayed compression into force if applicable trans.setAuthenticated(); // So it can put delayed compression into force if applicable
trans.setService(nextService); // We aren't in charge anymore, next service is trans.setService(nextService); // We aren't in charge anymore, next service is
authenticated.deliver(true); authenticated.deliver(AuthResult.SUCCESS);
break; break;
case USERAUTH_FAILURE: case USERAUTH_FAILURE:
@@ -133,7 +135,7 @@ public class UserAuthImpl
if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) { if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) {
currentMethod.request(); currentMethod.request();
} else { } else {
authenticated.deliver(false); authenticated.deliver(partialSuccess ? AuthResult.PARTIAL : AuthResult.FAILURE);
} }
break; break;