mirror of
https://github.com/hierynomus/sshj.git
synced 2025-12-06 07:10:53 +03:00
Add unit tests of strict key exchange extension (#918)
This commit is contained in:
@@ -0,0 +1,236 @@
|
||||
/*
|
||||
* Copyright (C)2009 - SSHJ Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package net.schmizz.sshj.transport;
|
||||
|
||||
import java.math.BigInteger;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import net.schmizz.sshj.DefaultConfig;
|
||||
import net.schmizz.sshj.common.DisconnectReason;
|
||||
import net.schmizz.sshj.common.Factory;
|
||||
import net.schmizz.sshj.common.Message;
|
||||
import net.schmizz.sshj.common.SSHPacket;
|
||||
import net.schmizz.sshj.transport.kex.KeyExchange;
|
||||
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class KeyExchangerStrictKeyExchangeTest {
|
||||
|
||||
private TransportImpl transport;
|
||||
private DefaultConfig config;
|
||||
private KeyExchanger keyExchanger;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
KeyExchange kex = mock(KeyExchange.class, Mockito.RETURNS_DEEP_STUBS);
|
||||
transport = mock(TransportImpl.class, Mockito.RETURNS_DEEP_STUBS);
|
||||
config = new DefaultConfig() {
|
||||
@Override
|
||||
protected void initKeyExchangeFactories() {
|
||||
setKeyExchangeFactories(Collections.singletonList(new Factory.Named<>() {
|
||||
@Override
|
||||
public KeyExchange create() {
|
||||
return kex;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return "mock-kex";
|
||||
}
|
||||
}));
|
||||
}
|
||||
};
|
||||
when(transport.getConfig()).thenReturn(config);
|
||||
when(transport.getServerID()).thenReturn("some server id");
|
||||
when(transport.getClientID()).thenReturn("some client id");
|
||||
when(kex.next(any(), any())).thenReturn(true);
|
||||
when(kex.getH()).thenReturn(new byte[0]);
|
||||
when(kex.getK()).thenReturn(BigInteger.ZERO);
|
||||
when(kex.getHash().digest()).thenReturn(new byte[10]);
|
||||
|
||||
keyExchanger = new KeyExchanger(transport);
|
||||
keyExchanger.addHostKeyVerifier(new PromiscuousVerifier());
|
||||
}
|
||||
|
||||
@Test
|
||||
void initialConditions() {
|
||||
assertThat(keyExchanger.isKexDone()).isFalse();
|
||||
assertThat(keyExchanger.isKexOngoing()).isFalse();
|
||||
assertThat(keyExchanger.isStrictKex()).isFalse();
|
||||
assertThat(keyExchanger.isInitialKex()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void startInitialKex() throws Exception {
|
||||
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
|
||||
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
|
||||
|
||||
keyExchanger.startKex(false);
|
||||
|
||||
assertThat(keyExchanger.isKexDone()).isFalse();
|
||||
assertThat(keyExchanger.isKexOngoing()).isTrue();
|
||||
assertThat(keyExchanger.isStrictKex()).isFalse();
|
||||
assertThat(keyExchanger.isInitialKex()).isTrue();
|
||||
|
||||
SSHPacket sshPacket = sshPacketCaptor.getValue();
|
||||
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
|
||||
assertThat(kex).endsWith("kex-strict-c-v00@openssh.com");
|
||||
}
|
||||
|
||||
@Test
|
||||
void receiveKexInitWithoutServerFlag() throws Exception {
|
||||
keyExchanger.startKex(false);
|
||||
|
||||
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
|
||||
|
||||
assertThat(keyExchanger.isKexDone()).isFalse();
|
||||
assertThat(keyExchanger.isKexOngoing()).isTrue();
|
||||
assertThat(keyExchanger.isStrictKex()).isFalse();
|
||||
assertThat(keyExchanger.isInitialKex()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void finishNonStrictKex() throws Exception {
|
||||
keyExchanger.startKex(false);
|
||||
|
||||
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
|
||||
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
|
||||
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
|
||||
|
||||
assertThat(keyExchanger.isKexDone()).isTrue();
|
||||
assertThat(keyExchanger.isKexOngoing()).isFalse();
|
||||
assertThat(keyExchanger.isStrictKex()).isFalse();
|
||||
assertThat(keyExchanger.isInitialKex()).isFalse();
|
||||
|
||||
verify(transport.getEncoder(), never()).resetSequenceNumber();
|
||||
verify(transport.getDecoder(), never()).resetSequenceNumber();
|
||||
}
|
||||
|
||||
@Test
|
||||
void receiveKexInitWithServerFlag() throws Exception {
|
||||
keyExchanger.startKex(false);
|
||||
|
||||
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
|
||||
|
||||
assertThat(keyExchanger.isKexDone()).isFalse();
|
||||
assertThat(keyExchanger.isKexOngoing()).isTrue();
|
||||
assertThat(keyExchanger.isStrictKex()).isTrue();
|
||||
assertThat(keyExchanger.isInitialKex()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void strictKexInitIsNotFirstPacket() throws Exception {
|
||||
when(transport.getDecoder().getSequenceNumber()).thenReturn(1L);
|
||||
keyExchanger.startKex(false);
|
||||
|
||||
assertThatExceptionOfType(TransportException.class).isThrownBy(
|
||||
() -> keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true))
|
||||
).satisfies(e -> {
|
||||
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
|
||||
assertThat(e.getMessage()).isEqualTo("SSH_MSG_KEXINIT was not first package during strict key exchange");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
void finishStrictKex() throws Exception {
|
||||
keyExchanger.startKex(false);
|
||||
|
||||
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
|
||||
verify(transport.getEncoder(), never()).resetSequenceNumber();
|
||||
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
|
||||
verify(transport.getEncoder()).resetSequenceNumber();
|
||||
verify(transport.getDecoder(), never()).resetSequenceNumber();
|
||||
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
|
||||
verify(transport.getDecoder()).resetSequenceNumber();
|
||||
|
||||
assertThat(keyExchanger.isKexDone()).isTrue();
|
||||
assertThat(keyExchanger.isKexOngoing()).isFalse();
|
||||
assertThat(keyExchanger.isStrictKex()).isTrue();
|
||||
assertThat(keyExchanger.isInitialKex()).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void noClientFlagInSecondStrictKex() throws Exception {
|
||||
keyExchanger.startKex(false);
|
||||
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
|
||||
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
|
||||
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
|
||||
|
||||
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
|
||||
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
|
||||
when(transport.isAuthenticated()).thenReturn(true);
|
||||
|
||||
keyExchanger.startKex(false);
|
||||
|
||||
assertThat(keyExchanger.isKexDone()).isFalse();
|
||||
assertThat(keyExchanger.isKexOngoing()).isTrue();
|
||||
assertThat(keyExchanger.isStrictKex()).isTrue();
|
||||
assertThat(keyExchanger.isInitialKex()).isFalse();
|
||||
|
||||
SSHPacket sshPacket = sshPacketCaptor.getValue();
|
||||
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
|
||||
assertThat(kex).doesNotContain("kex-strict-c-v00@openssh.com");
|
||||
}
|
||||
|
||||
@Test
|
||||
void serverFlagIsIgnoredInSecondKex() throws Exception {
|
||||
keyExchanger.startKex(false);
|
||||
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
|
||||
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
|
||||
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
|
||||
|
||||
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
|
||||
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
|
||||
when(transport.isAuthenticated()).thenReturn(true);
|
||||
|
||||
keyExchanger.startKex(false);
|
||||
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
|
||||
|
||||
assertThat(keyExchanger.isKexDone()).isFalse();
|
||||
assertThat(keyExchanger.isKexOngoing()).isTrue();
|
||||
assertThat(keyExchanger.isStrictKex()).isFalse();
|
||||
assertThat(keyExchanger.isInitialKex()).isFalse();
|
||||
|
||||
SSHPacket sshPacket = sshPacketCaptor.getValue();
|
||||
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
|
||||
assertThat(kex).doesNotContain("kex-strict-c-v00@openssh.com");
|
||||
}
|
||||
|
||||
private SSHPacket getKexInitPacket(boolean withServerFlag) {
|
||||
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), true).getPacket();
|
||||
if (withServerFlag) {
|
||||
int finalWpos = kexinitPacket.wpos();
|
||||
kexinitPacket.wpos(22);
|
||||
kexinitPacket.putString("mock-kex,kex-strict-s-v00@openssh.com");
|
||||
kexinitPacket.wpos(finalWpos);
|
||||
}
|
||||
kexinitPacket.rpos(kexinitPacket.rpos() + 1);
|
||||
return kexinitPacket;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
/*
|
||||
* Copyright (C)2009 - SSHJ Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package net.schmizz.sshj.transport;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
|
||||
import net.schmizz.sshj.Config;
|
||||
import net.schmizz.sshj.DefaultConfig;
|
||||
import net.schmizz.sshj.common.DisconnectReason;
|
||||
import net.schmizz.sshj.common.Message;
|
||||
import net.schmizz.sshj.common.SSHPacket;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.EnumSource;
|
||||
import org.junit.jupiter.params.provider.EnumSource.Mode;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatCode;
|
||||
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class TransportImplStrictKeyExchangeTest {
|
||||
|
||||
private final Config config = new DefaultConfig();
|
||||
private final Transport transport = new TransportImpl(config);
|
||||
private final KeyExchanger kexer = mock(KeyExchanger.class);
|
||||
private final Decoder decoder = mock(Decoder.class);
|
||||
|
||||
@BeforeEach
|
||||
void setUp() throws Exception {
|
||||
Field kexerField = TransportImpl.class.getDeclaredField("kexer");
|
||||
kexerField.setAccessible(true);
|
||||
kexerField.set(transport, kexer);
|
||||
Field decoderField = TransportImpl.class.getDeclaredField("decoder");
|
||||
decoderField.setAccessible(true);
|
||||
decoderField.set(transport, decoder);
|
||||
}
|
||||
|
||||
@Test
|
||||
void throwExceptionOnWrapDuringInitialKex() {
|
||||
when(kexer.isInitialKex()).thenReturn(true);
|
||||
when(decoder.isSequenceNumberAtMax()).thenReturn(true);
|
||||
|
||||
assertThatExceptionOfType(TransportException.class).isThrownBy(
|
||||
() -> transport.handle(Message.KEXINIT, new SSHPacket(Message.KEXINIT))
|
||||
).satisfies(e -> {
|
||||
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
|
||||
assertThat(e.getMessage()).isEqualTo("Sequence number of decoder is about to wrap during initial key exchange");
|
||||
});
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = Message.class, mode = Mode.EXCLUDE, names = {
|
||||
"DISCONNECT", "KEXINIT", "NEWKEYS", "KEXDH_INIT", "KEXDH_31", "KEX_DH_GEX_INIT", "KEX_DH_GEX_REPLY", "KEX_DH_GEX_REQUEST"
|
||||
})
|
||||
void forbidUnexpectedPacketsDuringStrictKeyExchange(Message message) {
|
||||
when(kexer.isInitialKex()).thenReturn(true);
|
||||
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
|
||||
when(kexer.isStrictKex()).thenReturn(true);
|
||||
|
||||
assertThatExceptionOfType(TransportException.class).isThrownBy(
|
||||
() -> transport.handle(message, new SSHPacket(message))
|
||||
).satisfies(e -> {
|
||||
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
|
||||
assertThat(e.getMessage()).isEqualTo("Unexpected packet type during initial strict key exchange");
|
||||
});
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@EnumSource(value = Message.class, mode = Mode.INCLUDE, names = {
|
||||
"KEXINIT", "NEWKEYS", "KEXDH_INIT", "KEXDH_31", "KEX_DH_GEX_INIT", "KEX_DH_GEX_REPLY", "KEX_DH_GEX_REQUEST"
|
||||
})
|
||||
void expectedPacketsDuringStrictKeyExchangeAreHandled(Message message) throws Exception {
|
||||
when(kexer.isInitialKex()).thenReturn(true);
|
||||
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
|
||||
when(kexer.isStrictKex()).thenReturn(true);
|
||||
SSHPacket sshPacket = new SSHPacket(message);
|
||||
|
||||
assertThatCode(
|
||||
() -> transport.handle(message, sshPacket)
|
||||
).doesNotThrowAnyException();
|
||||
|
||||
verify(kexer).handle(message, sshPacket);
|
||||
}
|
||||
|
||||
@Test
|
||||
void disconnectIsAllowedDuringStrictKeyExchange() {
|
||||
when(kexer.isInitialKex()).thenReturn(true);
|
||||
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
|
||||
when(kexer.isStrictKex()).thenReturn(true);
|
||||
|
||||
SSHPacket sshPacket = new SSHPacket();
|
||||
sshPacket.putUInt32(DisconnectReason.SERVICE_NOT_AVAILABLE.toInt());
|
||||
sshPacket.putString("service is down for maintenance");
|
||||
|
||||
assertThatExceptionOfType(TransportException.class).isThrownBy(
|
||||
() -> transport.handle(Message.DISCONNECT, sshPacket)
|
||||
).satisfies(e -> {
|
||||
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.SERVICE_NOT_AVAILABLE);
|
||||
assertThat(e.getMessage()).isEqualTo("service is down for maintenance");
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user