mirror of
https://github.com/hierynomus/sshj.git
synced 2025-12-06 07:10:53 +03:00
Don't send keep alive signals before kex is done (#934)
Otherwise, they could interfere with strict key exchange. Co-authored-by: Jeroen van Erp <jeroen@hierynomus.com>
This commit is contained in:
@@ -18,15 +18,26 @@ package com.hierynomus.sshj.transport.kex;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import ch.qos.logback.classic.Logger;
|
||||
import ch.qos.logback.classic.spi.ILoggingEvent;
|
||||
import ch.qos.logback.core.read.ListAppender;
|
||||
import com.hierynomus.sshj.SshdContainer;
|
||||
import net.schmizz.keepalive.KeepAlive;
|
||||
import net.schmizz.keepalive.KeepAliveProvider;
|
||||
import net.schmizz.sshj.Config;
|
||||
import net.schmizz.sshj.DefaultConfig;
|
||||
import net.schmizz.sshj.SSHClient;
|
||||
import net.schmizz.sshj.common.Message;
|
||||
import net.schmizz.sshj.common.SSHPacket;
|
||||
import net.schmizz.sshj.connection.ConnectionImpl;
|
||||
import net.schmizz.sshj.transport.TransportException;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testcontainers.junit.jupiter.Container;
|
||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||
@@ -62,14 +73,27 @@ class StrictKeyExchangeTest {
|
||||
watchedLoggers.add(logger);
|
||||
}
|
||||
|
||||
@Test
|
||||
void strictKeyExchange() throws Throwable {
|
||||
try (SSHClient client = sshd.getConnectedClient()) {
|
||||
private static Stream<Arguments> strictKeyExchange() {
|
||||
Config defaultConfig = new DefaultConfig();
|
||||
Config heartbeaterConfig = new DefaultConfig();
|
||||
heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() {
|
||||
@Override
|
||||
public KeepAlive provide(ConnectionImpl connection) {
|
||||
return new HotLoopHeartbeater(connection);
|
||||
}
|
||||
});
|
||||
return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of);
|
||||
}
|
||||
|
||||
@MethodSource
|
||||
@ParameterizedTest
|
||||
void strictKeyExchange(Config config) throws Throwable {
|
||||
try (SSHClient client = sshd.getConnectedClient(config)) {
|
||||
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
|
||||
assertTrue(client.isAuthenticated());
|
||||
}
|
||||
List<String> keyExchangerLogs = getLogs("KeyExchanger");
|
||||
assertThat(keyExchangerLogs).containsSequence(
|
||||
assertThat(keyExchangerLogs).contains(
|
||||
"Initiating key exchange",
|
||||
"Sending SSH_MSG_KEXINIT",
|
||||
"Received SSH_MSG_KEXINIT",
|
||||
@@ -78,7 +102,7 @@ class StrictKeyExchangeTest {
|
||||
List<String> decoderLogs = getLogs("Decoder").stream()
|
||||
.map(log -> log.split(":")[0])
|
||||
.collect(Collectors.toList());
|
||||
assertThat(decoderLogs).containsExactly(
|
||||
assertThat(decoderLogs).startsWith(
|
||||
"Received packet #0",
|
||||
"Received packet #1",
|
||||
"Received packet #2",
|
||||
@@ -90,7 +114,7 @@ class StrictKeyExchangeTest {
|
||||
List<String> encoderLogs = getLogs("Encoder").stream()
|
||||
.map(log -> log.split(":")[0])
|
||||
.collect(Collectors.toList());
|
||||
assertThat(encoderLogs).containsExactly(
|
||||
assertThat(encoderLogs).startsWith(
|
||||
"Encoding packet #0",
|
||||
"Encoding packet #1",
|
||||
"Encoding packet #2",
|
||||
@@ -108,4 +132,22 @@ class StrictKeyExchangeTest {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static class HotLoopHeartbeater extends KeepAlive {
|
||||
|
||||
HotLoopHeartbeater(ConnectionImpl conn) {
|
||||
super(conn, "sshj-Heartbeater");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isEnabled() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doKeepAlive() throws TransportException {
|
||||
conn.getTransport().write(new SSHPacket(Message.IGNORE));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -804,12 +804,12 @@ public class SSHClient
|
||||
throws IOException {
|
||||
super.onConnect();
|
||||
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
|
||||
doKex();
|
||||
final KeepAlive keepAliveThread = conn.getKeepAlive();
|
||||
if (keepAliveThread.isEnabled()) {
|
||||
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
|
||||
keepAliveThread.start();
|
||||
}
|
||||
doKex();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user