Don't send keep alive signals before kex is done (#934)

Otherwise, they could interfere with strict key exchange.

Co-authored-by: Jeroen van Erp <jeroen@hierynomus.com>
This commit is contained in:
Henning Pöttker
2024-04-15 09:29:06 +02:00
committed by GitHub
parent 70af58d199
commit 81d77d277c
2 changed files with 50 additions and 8 deletions

View File

@@ -18,15 +18,26 @@ package com.hierynomus.sshj.transport.kex;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream;
import ch.qos.logback.classic.Logger; import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent; import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender; import ch.qos.logback.core.read.ListAppender;
import com.hierynomus.sshj.SshdContainer; import com.hierynomus.sshj.SshdContainer;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.junit.jupiter.Testcontainers;
@@ -62,14 +73,27 @@ class StrictKeyExchangeTest {
watchedLoggers.add(logger); watchedLoggers.add(logger);
} }
@Test private static Stream<Arguments> strictKeyExchange() {
void strictKeyExchange() throws Throwable { Config defaultConfig = new DefaultConfig();
try (SSHClient client = sshd.getConnectedClient()) { Config heartbeaterConfig = new DefaultConfig();
heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() {
@Override
public KeepAlive provide(ConnectionImpl connection) {
return new HotLoopHeartbeater(connection);
}
});
return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of);
}
@MethodSource
@ParameterizedTest
void strictKeyExchange(Config config) throws Throwable {
try (SSHClient client = sshd.getConnectedClient(config)) {
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1"); client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
assertTrue(client.isAuthenticated()); assertTrue(client.isAuthenticated());
} }
List<String> keyExchangerLogs = getLogs("KeyExchanger"); List<String> keyExchangerLogs = getLogs("KeyExchanger");
assertThat(keyExchangerLogs).containsSequence( assertThat(keyExchangerLogs).contains(
"Initiating key exchange", "Initiating key exchange",
"Sending SSH_MSG_KEXINIT", "Sending SSH_MSG_KEXINIT",
"Received SSH_MSG_KEXINIT", "Received SSH_MSG_KEXINIT",
@@ -78,7 +102,7 @@ class StrictKeyExchangeTest {
List<String> decoderLogs = getLogs("Decoder").stream() List<String> decoderLogs = getLogs("Decoder").stream()
.map(log -> log.split(":")[0]) .map(log -> log.split(":")[0])
.collect(Collectors.toList()); .collect(Collectors.toList());
assertThat(decoderLogs).containsExactly( assertThat(decoderLogs).startsWith(
"Received packet #0", "Received packet #0",
"Received packet #1", "Received packet #1",
"Received packet #2", "Received packet #2",
@@ -90,7 +114,7 @@ class StrictKeyExchangeTest {
List<String> encoderLogs = getLogs("Encoder").stream() List<String> encoderLogs = getLogs("Encoder").stream()
.map(log -> log.split(":")[0]) .map(log -> log.split(":")[0])
.collect(Collectors.toList()); .collect(Collectors.toList());
assertThat(encoderLogs).containsExactly( assertThat(encoderLogs).startsWith(
"Encoding packet #0", "Encoding packet #0",
"Encoding packet #1", "Encoding packet #1",
"Encoding packet #2", "Encoding packet #2",
@@ -108,4 +132,22 @@ class StrictKeyExchangeTest {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private static class HotLoopHeartbeater extends KeepAlive {
HotLoopHeartbeater(ConnectionImpl conn) {
super(conn, "sshj-Heartbeater");
}
@Override
public boolean isEnabled() {
return true;
}
@Override
protected void doKeepAlive() throws TransportException {
conn.getTransport().write(new SSHPacket(Message.IGNORE));
}
}
} }

View File

@@ -804,12 +804,12 @@ public class SSHClient
throws IOException { throws IOException {
super.onConnect(); super.onConnect();
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream()); trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
doKex();
final KeepAlive keepAliveThread = conn.getKeepAlive(); final KeepAlive keepAliveThread = conn.getKeepAlive();
if (keepAliveThread.isEnabled()) { if (keepAliveThread.isEnabled()) {
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans); ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
keepAliveThread.start(); keepAliveThread.start();
} }
doKex();
} }
/** /**