Add Transport.isKeyExchangeRequired() to avoid unnecessary KEXINIT (#811)

* Added Transport.isKeyExchangeRequired() to avoid unnecessary KEXINIT

- Updated SSHClient.onConnect() to check isKeyExchangeRequired() before calling doKex()
- Added started timestamp in ThreadNameProvider for improved tracking

* Moved KeepAliveThread State check after authentication to avoid test timing issues
This commit is contained in:
exceptionfactory
2022-09-16 08:04:26 -05:00
committed by GitHub
parent 430cbfcf13
commit 2551f8e559
5 changed files with 27 additions and 3 deletions

View File

@@ -29,7 +29,8 @@ public class ThreadNameProvider {
public static void setThreadName(final Thread thread, final RemoteAddressProvider remoteAddressProvider) { public static void setThreadName(final Thread thread, final RemoteAddressProvider remoteAddressProvider) {
final InetSocketAddress remoteSocketAddress = remoteAddressProvider.getRemoteSocketAddress(); final InetSocketAddress remoteSocketAddress = remoteAddressProvider.getRemoteSocketAddress();
final String address = remoteSocketAddress == null ? DISCONNECTED : remoteSocketAddress.toString(); final String address = remoteSocketAddress == null ? DISCONNECTED : remoteSocketAddress.toString();
final String threadName = String.format("sshj-%s-%s", thread.getClass().getSimpleName(), address); final long started = System.currentTimeMillis();
final String threadName = String.format("sshj-%s-%s-%d", thread.getClass().getSimpleName(), address, started);
thread.setName(threadName); thread.setName(threadName);
} }
} }

View File

@@ -810,7 +810,12 @@ public class SSHClient
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans); ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
keepAliveThread.start(); keepAliveThread.start();
} }
doKex(); if (trans.isKeyExchangeRequired()) {
log.debug("Initiating Key Exchange for new connection");
doKex();
} else {
log.debug("Key Exchange already completed for new connection");
}
} }
/** /**

View File

@@ -71,6 +71,13 @@ public interface Transport
void doKex() void doKex()
throws TransportException; throws TransportException;
/**
* Is Key Exchange required based on current transport status
*
* @return Key Exchange required status
*/
boolean isKeyExchangeRequired();
/** @return the version string used by this client to identify itself to an SSH server, e.g. "SSHJ_3_0" */ /** @return the version string used by this client to identify itself to an SSH server, e.g. "SSHJ_3_0" */
String getClientVersion(); String getClientVersion();

View File

@@ -254,6 +254,16 @@ public final class TransportImpl
kexer.startKex(true); kexer.startKex(true);
} }
/**
* Is Key Exchange required returns true when Key Exchange is not done and when Key Exchange is not ongoing
*
* @return Key Exchange required status
*/
@Override
public boolean isKeyExchangeRequired() {
return !kexer.isKexDone() && !kexer.isKexOngoing();
}
public boolean isKexDone() { public boolean isKexDone() {
return kexer.isKexDone(); return kexer.isKexDone();
} }

View File

@@ -59,10 +59,11 @@ public class KeepAliveThreadTerminationTest {
assertEquals(Thread.State.NEW, keepAlive.getState()); assertEquals(Thread.State.NEW, keepAlive.getState());
fixture.connectClient(sshClient); fixture.connectClient(sshClient);
assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState());
assertThrows(UserAuthException.class, () -> sshClient.authPassword("bad", "credentials")); assertThrows(UserAuthException.class, () -> sshClient.authPassword("bad", "credentials"));
assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState());
fixture.stopClient(); fixture.stopClient();
Thread.sleep(STOP_SLEEP); Thread.sleep(STOP_SLEEP);