mirror of
https://github.com/hierynomus/sshj.git
synced 2025-12-07 07:40:55 +03:00
Implement OpenSSH strict key exchange extension (#917)
This commit is contained in:
@@ -146,8 +146,9 @@ public class SshdContainer extends GenericContainer<SshdContainer> {
|
|||||||
.withFileFromString("sshd_config", sshdConfig.build());
|
.withFileFromString("sshd_config", sshdConfig.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public void accept(@NotNull DockerfileBuilder builder) {
|
public void accept(@NotNull DockerfileBuilder builder) {
|
||||||
builder.from("alpine:3.18.3");
|
builder.from("alpine:3.19.0");
|
||||||
builder.run("apk add --no-cache openssh");
|
builder.run("apk add --no-cache openssh");
|
||||||
builder.expose(22);
|
builder.expose(22);
|
||||||
builder.copy("entrypoint.sh", "/entrypoint.sh");
|
builder.copy("entrypoint.sh", "/entrypoint.sh");
|
||||||
|
|||||||
@@ -0,0 +1,111 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C)2009 - SSHJ Contributors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package com.hierynomus.sshj.transport.kex;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import ch.qos.logback.classic.Logger;
|
||||||
|
import ch.qos.logback.classic.spi.ILoggingEvent;
|
||||||
|
import ch.qos.logback.core.read.ListAppender;
|
||||||
|
import com.hierynomus.sshj.SshdContainer;
|
||||||
|
import net.schmizz.sshj.SSHClient;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.testcontainers.junit.jupiter.Container;
|
||||||
|
import org.testcontainers.junit.jupiter.Testcontainers;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
|
@Testcontainers
|
||||||
|
class StrictKeyExchangeTest {
|
||||||
|
|
||||||
|
@Container
|
||||||
|
private static final SshdContainer sshd = new SshdContainer();
|
||||||
|
|
||||||
|
private final List<Logger> watchedLoggers = new ArrayList<>();
|
||||||
|
private final ListAppender<ILoggingEvent> logWatcher = new ListAppender<>();
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUpLogWatcher() {
|
||||||
|
logWatcher.start();
|
||||||
|
setUpLogger("net.schmizz.sshj.transport.Decoder");
|
||||||
|
setUpLogger("net.schmizz.sshj.transport.Encoder");
|
||||||
|
setUpLogger("net.schmizz.sshj.transport.KeyExchanger");
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() {
|
||||||
|
watchedLoggers.forEach(Logger::detachAndStopAllAppenders);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setUpLogger(String className) {
|
||||||
|
Logger logger = ((Logger) LoggerFactory.getLogger(className));
|
||||||
|
logger.addAppender(logWatcher);
|
||||||
|
watchedLoggers.add(logger);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void strictKeyExchange() throws Throwable {
|
||||||
|
try (SSHClient client = sshd.getConnectedClient()) {
|
||||||
|
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
|
||||||
|
assertTrue(client.isAuthenticated());
|
||||||
|
}
|
||||||
|
List<String> keyExchangerLogs = getLogs("KeyExchanger");
|
||||||
|
assertThat(keyExchangerLogs).containsSequence(
|
||||||
|
"Initiating key exchange",
|
||||||
|
"Sending SSH_MSG_KEXINIT",
|
||||||
|
"Received SSH_MSG_KEXINIT",
|
||||||
|
"Enabling strict key exchange extension"
|
||||||
|
);
|
||||||
|
List<String> decoderLogs = getLogs("Decoder").stream()
|
||||||
|
.map(log -> log.split(":")[0])
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
assertThat(decoderLogs).containsExactly(
|
||||||
|
"Received packet #0",
|
||||||
|
"Received packet #1",
|
||||||
|
"Received packet #2",
|
||||||
|
"Received packet #0",
|
||||||
|
"Received packet #1",
|
||||||
|
"Received packet #2",
|
||||||
|
"Received packet #3"
|
||||||
|
);
|
||||||
|
List<String> encoderLogs = getLogs("Encoder").stream()
|
||||||
|
.map(log -> log.split(":")[0])
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
assertThat(encoderLogs).containsExactly(
|
||||||
|
"Encoding packet #0",
|
||||||
|
"Encoding packet #1",
|
||||||
|
"Encoding packet #2",
|
||||||
|
"Encoding packet #0",
|
||||||
|
"Encoding packet #1",
|
||||||
|
"Encoding packet #2",
|
||||||
|
"Encoding packet #3"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> getLogs(String className) {
|
||||||
|
return logWatcher.list.stream()
|
||||||
|
.filter(event -> event.getLoggerName().endsWith(className))
|
||||||
|
.map(ILoggingEvent::getFormattedMessage)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -51,6 +51,14 @@ abstract class Converter {
|
|||||||
return seq;
|
return seq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void resetSequenceNumber() {
|
||||||
|
seq = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean isSequenceNumberAtMax() {
|
||||||
|
return seq == 0xffffffffL;
|
||||||
|
}
|
||||||
|
|
||||||
void setAlgorithms(Cipher cipher, MAC mac, Compression compression) {
|
void setAlgorithms(Cipher cipher, MAC mac, Compression compression) {
|
||||||
this.cipher = cipher;
|
this.cipher = cipher;
|
||||||
this.mac = mac;
|
this.mac = mac;
|
||||||
|
|||||||
@@ -60,6 +60,10 @@ final class KeyExchanger
|
|||||||
|
|
||||||
private final AtomicBoolean kexOngoing = new AtomicBoolean();
|
private final AtomicBoolean kexOngoing = new AtomicBoolean();
|
||||||
|
|
||||||
|
private final AtomicBoolean initialKex = new AtomicBoolean(true);
|
||||||
|
|
||||||
|
private final AtomicBoolean strictKex = new AtomicBoolean();
|
||||||
|
|
||||||
/** What we are expecting from the next packet */
|
/** What we are expecting from the next packet */
|
||||||
private Expected expected = Expected.KEXINIT;
|
private Expected expected = Expected.KEXINIT;
|
||||||
|
|
||||||
@@ -123,6 +127,14 @@ final class KeyExchanger
|
|||||||
return kexOngoing.get();
|
return kexOngoing.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean isStrictKex() {
|
||||||
|
return strictKex.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean isInitialKex() {
|
||||||
|
return initialKex.get();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily
|
* Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily
|
||||||
* after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if
|
* after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if
|
||||||
@@ -183,7 +195,7 @@ final class KeyExchanger
|
|||||||
throws TransportException {
|
throws TransportException {
|
||||||
log.debug("Sending SSH_MSG_KEXINIT");
|
log.debug("Sending SSH_MSG_KEXINIT");
|
||||||
List<String> knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort());
|
List<String> knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort());
|
||||||
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs);
|
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs, initialKex.get());
|
||||||
transport.write(clientProposal.getPacket());
|
transport.write(clientProposal.getPacket());
|
||||||
kexInitSent.set();
|
kexInitSent.set();
|
||||||
}
|
}
|
||||||
@@ -202,6 +214,9 @@ final class KeyExchanger
|
|||||||
throws TransportException {
|
throws TransportException {
|
||||||
log.debug("Sending SSH_MSG_NEWKEYS");
|
log.debug("Sending SSH_MSG_NEWKEYS");
|
||||||
transport.write(new SSHPacket(Message.NEWKEYS));
|
transport.write(new SSHPacket(Message.NEWKEYS));
|
||||||
|
if (strictKex.get()) {
|
||||||
|
transport.getEncoder().resetSequenceNumber();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -234,6 +249,10 @@ final class KeyExchanger
|
|||||||
|
|
||||||
private void setKexDone() {
|
private void setKexDone() {
|
||||||
kexOngoing.set(false);
|
kexOngoing.set(false);
|
||||||
|
initialKex.set(false);
|
||||||
|
if (strictKex.get()) {
|
||||||
|
transport.getDecoder().resetSequenceNumber();
|
||||||
|
}
|
||||||
kexInitSent.clear();
|
kexInitSent.clear();
|
||||||
done.set();
|
done.set();
|
||||||
}
|
}
|
||||||
@@ -242,6 +261,7 @@ final class KeyExchanger
|
|||||||
throws TransportException {
|
throws TransportException {
|
||||||
buf.rpos(buf.rpos() - 1);
|
buf.rpos(buf.rpos() - 1);
|
||||||
final Proposal serverProposal = new Proposal(buf);
|
final Proposal serverProposal = new Proposal(buf);
|
||||||
|
gotStrictKexInfo(serverProposal);
|
||||||
negotiatedAlgs = clientProposal.negotiate(serverProposal);
|
negotiatedAlgs = clientProposal.negotiate(serverProposal);
|
||||||
log.debug("Negotiated algorithms: {}", negotiatedAlgs);
|
log.debug("Negotiated algorithms: {}", negotiatedAlgs);
|
||||||
for(AlgorithmsVerifier v: algorithmVerifiers) {
|
for(AlgorithmsVerifier v: algorithmVerifiers) {
|
||||||
@@ -265,6 +285,18 @@ final class KeyExchanger
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void gotStrictKexInfo(Proposal serverProposal) throws TransportException {
|
||||||
|
if (initialKex.get() && serverProposal.isStrictKeyExchangeSupportedByServer()) {
|
||||||
|
strictKex.set(true);
|
||||||
|
log.debug("Enabling strict key exchange extension");
|
||||||
|
if (transport.getDecoder().getSequenceNumber() != 0) {
|
||||||
|
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
|
||||||
|
"SSH_MSG_KEXINIT was not first package during strict key exchange"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Private method used while putting new keys into use that will resize the key used to initialize the cipher to the
|
* Private method used while putting new keys into use that will resize the key used to initialize the cipher to the
|
||||||
* needed length.
|
* needed length.
|
||||||
|
|||||||
@@ -37,8 +37,11 @@ class Proposal {
|
|||||||
private final List<String> s2cComp;
|
private final List<String> s2cComp;
|
||||||
private final SSHPacket packet;
|
private final SSHPacket packet;
|
||||||
|
|
||||||
public Proposal(Config config, List<String> knownHostAlgs) {
|
public Proposal(Config config, List<String> knownHostAlgs, boolean initialKex) {
|
||||||
kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories());
|
kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories());
|
||||||
|
if (initialKex) {
|
||||||
|
kex.add("kex-strict-c-v00@openssh.com");
|
||||||
|
}
|
||||||
sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs);
|
sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs);
|
||||||
c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories());
|
c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories());
|
||||||
c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories());
|
c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories());
|
||||||
@@ -91,6 +94,10 @@ class Proposal {
|
|||||||
return kex;
|
return kex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isStrictKeyExchangeSupportedByServer() {
|
||||||
|
return kex.contains("kex-strict-s-v00@openssh.com");
|
||||||
|
}
|
||||||
|
|
||||||
public List<String> getHostKeyAlgorithms() {
|
public List<String> getHostKeyAlgorithms() {
|
||||||
return sig;
|
return sig;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -426,7 +426,7 @@ public final class TransportImpl
|
|||||||
assert m != Message.KEXINIT;
|
assert m != Message.KEXINIT;
|
||||||
kexer.waitForDone();
|
kexer.waitForDone();
|
||||||
}
|
}
|
||||||
} else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet
|
} else if (encoder.isSequenceNumberAtMax()) // We get here every 2**32th packet
|
||||||
kexer.startKex(true);
|
kexer.startKex(true);
|
||||||
|
|
||||||
final long seq = encoder.encode(payload);
|
final long seq = encoder.encode(payload);
|
||||||
@@ -479,9 +479,20 @@ public final class TransportImpl
|
|||||||
|
|
||||||
log.trace("Received packet {}", msg);
|
log.trace("Received packet {}", msg);
|
||||||
|
|
||||||
|
if (kexer.isInitialKex()) {
|
||||||
|
if (decoder.isSequenceNumberAtMax()) {
|
||||||
|
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
|
||||||
|
"Sequence number of decoder is about to wrap during initial key exchange");
|
||||||
|
}
|
||||||
|
if (kexer.isStrictKex() && !isKexerPacket(msg) && msg != Message.DISCONNECT) {
|
||||||
|
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
|
||||||
|
"Unexpected packet type during initial strict key exchange");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (msg.geq(50)) { // not a transport layer packet
|
if (msg.geq(50)) { // not a transport layer packet
|
||||||
service.handle(msg, buf);
|
service.handle(msg, buf);
|
||||||
} else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet
|
} else if (isKexerPacket(msg)) {
|
||||||
kexer.handle(msg, buf);
|
kexer.handle(msg, buf);
|
||||||
} else {
|
} else {
|
||||||
switch (msg) {
|
switch (msg) {
|
||||||
@@ -513,6 +524,10 @@ public final class TransportImpl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static boolean isKexerPacket(Message msg) {
|
||||||
|
return msg.in(20, 21) || msg.in(30, 49);
|
||||||
|
}
|
||||||
|
|
||||||
private void gotDebug(SSHPacket buf)
|
private void gotDebug(SSHPacket buf)
|
||||||
throws TransportException {
|
throws TransportException {
|
||||||
try {
|
try {
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ public class KeyExchangeRepeatTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private SSHPacket getKexinitPacket() {
|
private SSHPacket getKexinitPacket() {
|
||||||
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList()).getPacket();
|
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), false).getPacket();
|
||||||
kexinitPacket.rpos(kexinitPacket.rpos() + 1);
|
kexinitPacket.rpos(kexinitPacket.rpos() + 1);
|
||||||
return kexinitPacket;
|
return kexinitPacket;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user