From a186dbf0bc90e47a969ffe1f79ae9e932a83fb1c Mon Sep 17 00:00:00 2001 From: Raul Santelices Date: Fri, 1 Sep 2023 18:54:22 -0400 Subject: [PATCH] Fix race condition causing SSH_MSG_UNIMPLEMENTED occasionally during key exchange (#851) * Fix race condition causing SSH_MSG_UNIMPLEMENTED occasionally during key exchange * unit tests * fix unit tests --------- Co-authored-by: Jeroen van Erp --- src/main/java/net/schmizz/sshj/SSHClient.java | 7 +- .../schmizz/sshj/transport/KeyExchanger.java | 16 ++- .../net/schmizz/sshj/transport/Transport.java | 7 -- .../schmizz/sshj/transport/TransportImpl.java | 10 -- .../sshj/transport/KeyExchangeRepeatTest.java | 119 ++++++++++++++++++ 5 files changed, 134 insertions(+), 25 deletions(-) create mode 100644 src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index aae948bc..5c99d800 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -810,12 +810,7 @@ public class SSHClient ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans); keepAliveThread.start(); } - if (trans.isKeyExchangeRequired()) { - log.debug("Initiating Key Exchange for new connection"); - doKex(); - } else { - log.debug("Key Exchange already completed for new connection"); - } + doKex(); } /** diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index 6705519f..b8979f7b 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -136,13 +136,25 @@ final class KeyExchanger void startKex(boolean waitForDone) throws TransportException { if (!kexOngoing.getAndSet(true)) { - done.clear(); - sendKexInit(); + if (isKeyExchangeAllowed()) { + log.debug("Initiating key exchange"); + done.clear(); + sendKexInit(); + } else { + kexOngoing.set(false); + } } if (waitForDone) waitForDone(); } + /** + * Key exchange can be initiated exactly once while connecting or later after authentication when re-keying. + */ + private boolean isKeyExchangeAllowed() { + return !isKexDone() || transport.isAuthenticated(); + } + void waitForDone() throws TransportException { done.await(transport.getTimeoutMs(), TimeUnit.MILLISECONDS); diff --git a/src/main/java/net/schmizz/sshj/transport/Transport.java b/src/main/java/net/schmizz/sshj/transport/Transport.java index d8175698..5ae55968 100644 --- a/src/main/java/net/schmizz/sshj/transport/Transport.java +++ b/src/main/java/net/schmizz/sshj/transport/Transport.java @@ -71,13 +71,6 @@ public interface Transport void doKex() throws TransportException; - /** - * Is Key Exchange required based on current transport status - * - * @return Key Exchange required status - */ - boolean isKeyExchangeRequired(); - /** @return the version string used by this client to identify itself to an SSH server, e.g. "SSHJ_3_0" */ String getClientVersion(); diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index edff191c..58107c5b 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -254,16 +254,6 @@ public final class TransportImpl kexer.startKex(true); } - /** - * Is Key Exchange required returns true when Key Exchange is not done and when Key Exchange is not ongoing - * - * @return Key Exchange required status - */ - @Override - public boolean isKeyExchangeRequired() { - return !kexer.isKexDone() && !kexer.isKexOngoing(); - } - public boolean isKexDone() { return kexer.isKexDone(); } diff --git a/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java b/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java new file mode 100644 index 00000000..c1f8655a --- /dev/null +++ b/src/test/java/net/schmizz/sshj/transport/KeyExchangeRepeatTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.util.Collections; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.common.Factory; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.transport.kex.KeyExchange; +import net.schmizz.sshj.transport.verification.PromiscuousVerifier; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +public class KeyExchangeRepeatTest { + + private TransportImpl transport; + private DefaultConfig config; + private KeyExchanger keyExchanger; + + @BeforeEach + public void setup() throws GeneralSecurityException, TransportException { + KeyExchange kex = mock(KeyExchange.class, Mockito.RETURNS_DEEP_STUBS); + transport = mock(TransportImpl.class, Mockito.RETURNS_DEEP_STUBS); + config = new DefaultConfig() { + @Override + protected void initKeyExchangeFactories() { + setKeyExchangeFactories(Collections.singletonList(new Factory.Named<>() { + @Override + public KeyExchange create() { + return kex; + } + + @Override + public String getName() { + return "mock-kex"; + } + })); + } + }; + when(transport.getConfig()).thenReturn(config); + when(transport.getServerID()).thenReturn("some server id"); + when(transport.getClientID()).thenReturn("some client id"); + when(kex.next(any(), any())).thenReturn(true); + when(kex.getH()).thenReturn(new byte[0]); + when(kex.getK()).thenReturn(BigInteger.ZERO); + when(kex.getHash().digest()).thenReturn(new byte[10]); + + keyExchanger = new KeyExchanger(transport); + keyExchanger.addHostKeyVerifier(new PromiscuousVerifier()); + + assertFalse(transport.isAuthenticated()); // sanity check + assertTrue(!keyExchanger.isKexOngoing() && !keyExchanger.isKexDone()); // sanity check + } + + @Test + public void allowOnlyOneKeyExchangeBeforeAuthentication() throws TransportException { + // First key exchange before authentication succeeds. + performAndCheckKeyExchange(); + + // Second key exchange attempt before authentication is ignored. + keyExchanger.startKex(false); + assertTrue(!keyExchanger.isKexOngoing() && keyExchanger.isKexDone()); + } + + @Test + public void allowExtraKeyExchangesAfterAuthentication() throws TransportException { + // Key exchange before authentication succeeds. + performAndCheckKeyExchange(); + + // Simulate authentication. + when(transport.isAuthenticated()).thenReturn(true); + + // Key exchange after authentication succeeds too. + performAndCheckKeyExchange(); + } + + private void performAndCheckKeyExchange() throws TransportException { + // Start key exchange. + keyExchanger.startKex(false); + assertTrue(keyExchanger.isKexOngoing() && !keyExchanger.isKexDone()); + + // Simulate the arrival of the expected packets from the server while checking the state of the exchange. + keyExchanger.handle(Message.KEXINIT, getKexinitPacket()); + assertTrue(keyExchanger.isKexOngoing() && !keyExchanger.isKexDone()); + keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31)); + assertTrue(keyExchanger.isKexOngoing() && !keyExchanger.isKexDone()); + keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS)); + assertTrue(!keyExchanger.isKexOngoing() && keyExchanger.isKexDone()); // done + } + + private SSHPacket getKexinitPacket() { + SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList()).getPacket(); + kexinitPacket.rpos(kexinitPacket.rpos() + 1); + return kexinitPacket; + } +}