Fix race condition causing SSH_MSG_UNIMPLEMENTED occasionally during key exchange (#851)

* Fix race condition causing SSH_MSG_UNIMPLEMENTED occasionally during key exchange

* unit tests

* fix unit tests

---------

Co-authored-by: Jeroen van Erp <jeroen@hierynomus.com>
This commit is contained in:
Raul Santelices
2023-09-01 18:54:22 -04:00
committed by GitHub
parent a5fdb29fad
commit a186dbf0bc
5 changed files with 134 additions and 25 deletions

View File

@@ -810,12 +810,7 @@ public class SSHClient
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
keepAliveThread.start();
}
if (trans.isKeyExchangeRequired()) {
log.debug("Initiating Key Exchange for new connection");
doKex();
} else {
log.debug("Key Exchange already completed for new connection");
}
doKex();
}
/**

View File

@@ -136,13 +136,25 @@ final class KeyExchanger
void startKex(boolean waitForDone)
throws TransportException {
if (!kexOngoing.getAndSet(true)) {
done.clear();
sendKexInit();
if (isKeyExchangeAllowed()) {
log.debug("Initiating key exchange");
done.clear();
sendKexInit();
} else {
kexOngoing.set(false);
}
}
if (waitForDone)
waitForDone();
}
/**
* Key exchange can be initiated exactly once while connecting or later after authentication when re-keying.
*/
private boolean isKeyExchangeAllowed() {
return !isKexDone() || transport.isAuthenticated();
}
void waitForDone()
throws TransportException {
done.await(transport.getTimeoutMs(), TimeUnit.MILLISECONDS);

View File

@@ -71,13 +71,6 @@ public interface Transport
void doKex()
throws TransportException;
/**
* Is Key Exchange required based on current transport status
*
* @return Key Exchange required status
*/
boolean isKeyExchangeRequired();
/** @return the version string used by this client to identify itself to an SSH server, e.g. "SSHJ_3_0" */
String getClientVersion();

View File

@@ -254,16 +254,6 @@ public final class TransportImpl
kexer.startKex(true);
}
/**
* Is Key Exchange required returns true when Key Exchange is not done and when Key Exchange is not ongoing
*
* @return Key Exchange required status
*/
@Override
public boolean isKeyExchangeRequired() {
return !kexer.isKexDone() && !kexer.isKexOngoing();
}
public boolean isKexDone() {
return kexer.isKexDone();
}

View File

@@ -0,0 +1,119 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.transport;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.util.Collections;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.transport.kex.KeyExchange;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
public class KeyExchangeRepeatTest {
private TransportImpl transport;
private DefaultConfig config;
private KeyExchanger keyExchanger;
@BeforeEach
public void setup() throws GeneralSecurityException, TransportException {
KeyExchange kex = mock(KeyExchange.class, Mockito.RETURNS_DEEP_STUBS);
transport = mock(TransportImpl.class, Mockito.RETURNS_DEEP_STUBS);
config = new DefaultConfig() {
@Override
protected void initKeyExchangeFactories() {
setKeyExchangeFactories(Collections.singletonList(new Factory.Named<>() {
@Override
public KeyExchange create() {
return kex;
}
@Override
public String getName() {
return "mock-kex";
}
}));
}
};
when(transport.getConfig()).thenReturn(config);
when(transport.getServerID()).thenReturn("some server id");
when(transport.getClientID()).thenReturn("some client id");
when(kex.next(any(), any())).thenReturn(true);
when(kex.getH()).thenReturn(new byte[0]);
when(kex.getK()).thenReturn(BigInteger.ZERO);
when(kex.getHash().digest()).thenReturn(new byte[10]);
keyExchanger = new KeyExchanger(transport);
keyExchanger.addHostKeyVerifier(new PromiscuousVerifier());
assertFalse(transport.isAuthenticated()); // sanity check
assertTrue(!keyExchanger.isKexOngoing() && !keyExchanger.isKexDone()); // sanity check
}
@Test
public void allowOnlyOneKeyExchangeBeforeAuthentication() throws TransportException {
// First key exchange before authentication succeeds.
performAndCheckKeyExchange();
// Second key exchange attempt before authentication is ignored.
keyExchanger.startKex(false);
assertTrue(!keyExchanger.isKexOngoing() && keyExchanger.isKexDone());
}
@Test
public void allowExtraKeyExchangesAfterAuthentication() throws TransportException {
// Key exchange before authentication succeeds.
performAndCheckKeyExchange();
// Simulate authentication.
when(transport.isAuthenticated()).thenReturn(true);
// Key exchange after authentication succeeds too.
performAndCheckKeyExchange();
}
private void performAndCheckKeyExchange() throws TransportException {
// Start key exchange.
keyExchanger.startKex(false);
assertTrue(keyExchanger.isKexOngoing() && !keyExchanger.isKexDone());
// Simulate the arrival of the expected packets from the server while checking the state of the exchange.
keyExchanger.handle(Message.KEXINIT, getKexinitPacket());
assertTrue(keyExchanger.isKexOngoing() && !keyExchanger.isKexDone());
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
assertTrue(keyExchanger.isKexOngoing() && !keyExchanger.isKexDone());
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
assertTrue(!keyExchanger.isKexOngoing() && keyExchanger.isKexDone()); // done
}
private SSHPacket getKexinitPacket() {
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList()).getPacket();
kexinitPacket.rpos(kexinitPacket.rpos() + 1);
return kexinitPacket;
}
}