mirror of
https://github.com/hierynomus/sshj.git
synced 2025-12-06 15:20:54 +03:00
Don't send keep alive signals before kex is done (#934)
Otherwise, they could interfere with strict key exchange. Co-authored-by: Jeroen van Erp <jeroen@hierynomus.com>
This commit is contained in:
@@ -18,15 +18,26 @@ package com.hierynomus.sshj.transport.kex;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import ch.qos.logback.classic.Logger;
|
import ch.qos.logback.classic.Logger;
|
||||||
import ch.qos.logback.classic.spi.ILoggingEvent;
|
import ch.qos.logback.classic.spi.ILoggingEvent;
|
||||||
import ch.qos.logback.core.read.ListAppender;
|
import ch.qos.logback.core.read.ListAppender;
|
||||||
import com.hierynomus.sshj.SshdContainer;
|
import com.hierynomus.sshj.SshdContainer;
|
||||||
|
import net.schmizz.keepalive.KeepAlive;
|
||||||
|
import net.schmizz.keepalive.KeepAliveProvider;
|
||||||
|
import net.schmizz.sshj.Config;
|
||||||
|
import net.schmizz.sshj.DefaultConfig;
|
||||||
import net.schmizz.sshj.SSHClient;
|
import net.schmizz.sshj.SSHClient;
|
||||||
|
import net.schmizz.sshj.common.Message;
|
||||||
|
import net.schmizz.sshj.common.SSHPacket;
|
||||||
|
import net.schmizz.sshj.connection.ConnectionImpl;
|
||||||
|
import net.schmizz.sshj.transport.TransportException;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.params.ParameterizedTest;
|
||||||
|
import org.junit.jupiter.params.provider.Arguments;
|
||||||
|
import org.junit.jupiter.params.provider.MethodSource;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.testcontainers.junit.jupiter.Container;
|
import org.testcontainers.junit.jupiter.Container;
|
||||||
import org.testcontainers.junit.jupiter.Testcontainers;
|
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||||
@@ -62,14 +73,27 @@ class StrictKeyExchangeTest {
|
|||||||
watchedLoggers.add(logger);
|
watchedLoggers.add(logger);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
private static Stream<Arguments> strictKeyExchange() {
|
||||||
void strictKeyExchange() throws Throwable {
|
Config defaultConfig = new DefaultConfig();
|
||||||
try (SSHClient client = sshd.getConnectedClient()) {
|
Config heartbeaterConfig = new DefaultConfig();
|
||||||
|
heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() {
|
||||||
|
@Override
|
||||||
|
public KeepAlive provide(ConnectionImpl connection) {
|
||||||
|
return new HotLoopHeartbeater(connection);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of);
|
||||||
|
}
|
||||||
|
|
||||||
|
@MethodSource
|
||||||
|
@ParameterizedTest
|
||||||
|
void strictKeyExchange(Config config) throws Throwable {
|
||||||
|
try (SSHClient client = sshd.getConnectedClient(config)) {
|
||||||
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
|
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
|
||||||
assertTrue(client.isAuthenticated());
|
assertTrue(client.isAuthenticated());
|
||||||
}
|
}
|
||||||
List<String> keyExchangerLogs = getLogs("KeyExchanger");
|
List<String> keyExchangerLogs = getLogs("KeyExchanger");
|
||||||
assertThat(keyExchangerLogs).containsSequence(
|
assertThat(keyExchangerLogs).contains(
|
||||||
"Initiating key exchange",
|
"Initiating key exchange",
|
||||||
"Sending SSH_MSG_KEXINIT",
|
"Sending SSH_MSG_KEXINIT",
|
||||||
"Received SSH_MSG_KEXINIT",
|
"Received SSH_MSG_KEXINIT",
|
||||||
@@ -78,7 +102,7 @@ class StrictKeyExchangeTest {
|
|||||||
List<String> decoderLogs = getLogs("Decoder").stream()
|
List<String> decoderLogs = getLogs("Decoder").stream()
|
||||||
.map(log -> log.split(":")[0])
|
.map(log -> log.split(":")[0])
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
assertThat(decoderLogs).containsExactly(
|
assertThat(decoderLogs).startsWith(
|
||||||
"Received packet #0",
|
"Received packet #0",
|
||||||
"Received packet #1",
|
"Received packet #1",
|
||||||
"Received packet #2",
|
"Received packet #2",
|
||||||
@@ -90,7 +114,7 @@ class StrictKeyExchangeTest {
|
|||||||
List<String> encoderLogs = getLogs("Encoder").stream()
|
List<String> encoderLogs = getLogs("Encoder").stream()
|
||||||
.map(log -> log.split(":")[0])
|
.map(log -> log.split(":")[0])
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
assertThat(encoderLogs).containsExactly(
|
assertThat(encoderLogs).startsWith(
|
||||||
"Encoding packet #0",
|
"Encoding packet #0",
|
||||||
"Encoding packet #1",
|
"Encoding packet #1",
|
||||||
"Encoding packet #2",
|
"Encoding packet #2",
|
||||||
@@ -108,4 +132,22 @@ class StrictKeyExchangeTest {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class HotLoopHeartbeater extends KeepAlive {
|
||||||
|
|
||||||
|
HotLoopHeartbeater(ConnectionImpl conn) {
|
||||||
|
super(conn, "sshj-Heartbeater");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isEnabled() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void doKeepAlive() throws TransportException {
|
||||||
|
conn.getTransport().write(new SSHPacket(Message.IGNORE));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -804,12 +804,12 @@ public class SSHClient
|
|||||||
throws IOException {
|
throws IOException {
|
||||||
super.onConnect();
|
super.onConnect();
|
||||||
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
|
trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream());
|
||||||
|
doKex();
|
||||||
final KeepAlive keepAliveThread = conn.getKeepAlive();
|
final KeepAlive keepAliveThread = conn.getKeepAlive();
|
||||||
if (keepAliveThread.isEnabled()) {
|
if (keepAliveThread.isEnabled()) {
|
||||||
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
|
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
|
||||||
keepAliveThread.start();
|
keepAliveThread.start();
|
||||||
}
|
}
|
||||||
doKex();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Reference in New Issue
Block a user