Compare commits

..

1 Commits

Author SHA1 Message Date
Jeroen van Erp
1a7255cce2 Moved host verification tests to JUnit5 2023-10-23 12:04:44 +02:00
20 changed files with 43 additions and 1205 deletions

View File

@@ -1,7 +1,7 @@
= sshj - SSHv2 library for Java
Jeroen van Erp
:sshj_groupid: com.hierynomus
:sshj_version: 0.38.0
:sshj_version: 0.37.0
:source-highlighter: pygments
image:https://github.com/hierynomus/sshj/actions/workflows/gradle.yml/badge.svg[link="https://github.com/hierynomus/sshj/actions/workflows/gradle.yml"]
@@ -10,8 +10,6 @@ image:https://codecov.io/gh/hierynomus/sshj/branch/master/graph/badge.svg["codec
image:http://www.javadoc.io/badge/com.hierynomus/sshj.svg?color=blue["JavaDocs", link="http://www.javadoc.io/doc/com.hierynomus/sshj"]
image:https://maven-badges.herokuapp.com/maven-central/com.hierynomus/sshj/badge.svg["Maven Central",link="https://maven-badges.herokuapp.com/maven-central/com.hierynomus/sshj"]
WARNING: SSHJ versions up to and including 0.37.0 are vulnerable to https://nvd.nist.gov/vuln/detail/CVE-2023-48795[CVE-2023-48795 - Terrapin]. Please upgrade to 0.38.0 or higher.
To get started, have a look at one of the examples. Hopefully you will find the API pleasant to work with :)
== Getting SSHJ
@@ -48,7 +46,7 @@ If your project is built using another build tool that uses the Maven Central re
In the `examples` directory, there is a separate Maven project that shows how the library can be used in some sample cases. If you want to run them, follow these guidelines:
. Install http://maven.apache.org/[Maven 2.2.1] or up.
. Clone the SSHJ repository.
. Clone the Overthere repository.
. Go into the `examples` directory and run the command `mvn eclipse:eclipse`.
. Import the `examples` project into Eclipse.
. Change the login details in the example classes (address, username and password) and run them!
@@ -110,14 +108,6 @@ Issue tracker: https://github.com/hierynomus/sshj/issues
Fork away!
== Release history
SSHJ 0.38.0 (2024-01-02)::
* Mitigated CVE-2023-48795 - Terrapin
* Merged https://github.com/hierynomus/sshj/pull/917[#917]: Implement OpenSSH strict key exchange extension
* Merged https://github.com/hierynomus/sshj/pull/903[#903]: Fix for writing known hosts key string
* Merged https://github.com/hierynomus/sshj/pull/913[#913]: Prevent remote port forwarding buffers to grow without bounds
* Moved tess to JUnit5
* Merged https://github.com/hierynomus/sshj/pull/827[#827]: Fallback to posix-rename@openssh.com extension if available
* Merged https://github.com/hierynomus/sshj/pull/904[#904]: Add ChaCha20-Poly1305 support for OpenSSH keys
SSHJ 0.37.0 (2023-10-11)::
* Merged https://github.com/hierynomus/sshj/pull/899[#899]: Add support for AES-GCM OpenSSH private keys
* Merged https://github.com/hierynomus/sshj/pull/901[#901]: Fix ZLib compression bug

View File

@@ -146,9 +146,8 @@ public class SshdContainer extends GenericContainer<SshdContainer> {
.withFileFromString("sshd_config", sshdConfig.build());
}
@Override
public void accept(@NotNull DockerfileBuilder builder) {
builder.from("alpine:3.19.0");
builder.from("alpine:3.18.3");
builder.run("apk add --no-cache openssh");
builder.expose(22);
builder.copy("entrypoint.sh", "/entrypoint.sh");

View File

@@ -1,111 +0,0 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hierynomus.sshj.transport.kex;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import com.hierynomus.sshj.SshdContainer;
import net.schmizz.sshj.SSHClient;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.slf4j.LoggerFactory;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Testcontainers
class StrictKeyExchangeTest {
@Container
private static final SshdContainer sshd = new SshdContainer();
private final List<Logger> watchedLoggers = new ArrayList<>();
private final ListAppender<ILoggingEvent> logWatcher = new ListAppender<>();
@BeforeEach
void setUpLogWatcher() {
logWatcher.start();
setUpLogger("net.schmizz.sshj.transport.Decoder");
setUpLogger("net.schmizz.sshj.transport.Encoder");
setUpLogger("net.schmizz.sshj.transport.KeyExchanger");
}
@AfterEach
void tearDown() {
watchedLoggers.forEach(Logger::detachAndStopAllAppenders);
}
private void setUpLogger(String className) {
Logger logger = ((Logger) LoggerFactory.getLogger(className));
logger.addAppender(logWatcher);
watchedLoggers.add(logger);
}
@Test
void strictKeyExchange() throws Throwable {
try (SSHClient client = sshd.getConnectedClient()) {
client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1");
assertTrue(client.isAuthenticated());
}
List<String> keyExchangerLogs = getLogs("KeyExchanger");
assertThat(keyExchangerLogs).containsSequence(
"Initiating key exchange",
"Sending SSH_MSG_KEXINIT",
"Received SSH_MSG_KEXINIT",
"Enabling strict key exchange extension"
);
List<String> decoderLogs = getLogs("Decoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(decoderLogs).containsExactly(
"Received packet #0",
"Received packet #1",
"Received packet #2",
"Received packet #0",
"Received packet #1",
"Received packet #2",
"Received packet #3"
);
List<String> encoderLogs = getLogs("Encoder").stream()
.map(log -> log.split(":")[0])
.collect(Collectors.toList());
assertThat(encoderLogs).containsExactly(
"Encoding packet #0",
"Encoding packet #1",
"Encoding packet #2",
"Encoding packet #0",
"Encoding packet #1",
"Encoding packet #2",
"Encoding packet #3"
);
}
private List<String> getLogs(String className) {
return logWatcher.list.stream()
.filter(event -> event.getLoggerName().endsWith(className))
.map(ILoggingEvent::getFormattedMessage)
.collect(Collectors.toList());
}
}

View File

@@ -200,8 +200,4 @@ public interface Config {
* See {@link #isVerifyHostKeyCertificates()}.
*/
void setVerifyHostKeyCertificates(boolean value);
int getMaxCircularBufferSize();
void setMaxCircularBufferSize(int maxCircularBufferSize);
}

View File

@@ -49,8 +49,6 @@ public class ConfigImpl
private boolean waitForServerIdentBeforeSendingClientIdent = false;
private LoggerFactory loggerFactory;
private boolean verifyHostKeyCertificates = true;
// HF-982: default to 16MB buffers.
private int maxCircularBufferSize = 16 * 1024 * 1024;
@Override
public List<Factory.Named<Cipher>> getCipherFactories() {
@@ -177,16 +175,6 @@ public class ConfigImpl
return loggerFactory;
}
@Override
public int getMaxCircularBufferSize() {
return maxCircularBufferSize;
}
@Override
public void setMaxCircularBufferSize(int maxCircularBufferSize) {
this.maxCircularBufferSize = maxCircularBufferSize;
}
@Override
public void setLoggerFactory(LoggerFactory loggerFactory) {
this.loggerFactory = loggerFactory;

View File

@@ -1,194 +0,0 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.common;
public class CircularBuffer<T extends CircularBuffer<T>> {
public static class CircularBufferException
extends SSHException {
public CircularBufferException(String message) {
super(message);
}
}
public static final class PlainCircularBuffer
extends CircularBuffer<PlainCircularBuffer> {
public PlainCircularBuffer(int size, int maxSize) {
super(size, maxSize);
}
}
/**
* Maximum size of the internal array (one plus the maximum capacity of the buffer).
*/
private final int maxSize;
/**
* Internal array for the data. All bytes minus one can be used to avoid empty vs full ambiguity when rpos == wpos.
*/
private byte[] data;
/**
* Next read position. Wraps around the end of the internal array. When it reaches wpos, the buffer becomes empty.
* Can take the value data.length, which is equivalent to 0.
*/
private int rpos;
/**
* Next write position. Wraps around the end of the internal array. If it is equal to rpos, then the buffer is
* empty; the code does not allow wpos to reach rpos from the left. This implies that the buffer can store up to
* data.length - 1 bytes. Can take the value data.length, which is equivalent to 0.
*/
private int wpos;
/**
* Determines the size to which to grow the internal array.
*/
private int getNextSize(int currentSize) {
// Use next power of 2.
int nextSize = 1;
while (nextSize < currentSize) {
nextSize <<= 1;
if (nextSize <= 0) {
return maxSize;
}
}
return Math.min(nextSize, maxSize); // limit to max size
}
/**
* Creates a new circular buffer of the given size. The capacity of the buffer is one less than the size/
*/
public CircularBuffer(int size, int maxSize) {
this.maxSize = maxSize;
if (size > maxSize) {
throw new IllegalArgumentException(
String.format("Initial requested size %d larger than maximum size %d", size, maxSize));
}
int initialSize = getNextSize(size);
this.data = new byte[initialSize];
this.rpos = 0;
this.wpos = 0;
}
/**
* Data available in the buffer for reading.
*/
public int available() {
int available = wpos - rpos;
return available >= 0 ? available : available + data.length; // adjust if wpos is left of rpos
}
private void ensureAvailable(int a)
throws CircularBufferException {
if (available() < a) {
throw new CircularBufferException("Underflow");
}
}
/**
* Returns how many more bytes this buffer can receive.
*/
public int maxPossibleRemainingCapacity() {
// Remaining capacity is one less than remaining space to ensure that wpos does not reach rpos from the left.
int remaining = rpos - wpos - 1;
if (remaining < 0) {
remaining += data.length; // adjust if rpos is left of wpos
}
// Add the maximum amount the internal array can grow.
return remaining + maxSize - data.length;
}
/**
* If the internal array does not have room for "capacity" more bytes, resizes the array to make that room.
*/
void ensureCapacity(int capacity) throws CircularBufferException {
int available = available();
int remaining = data.length - available;
// If capacity fits exactly in the remaining space, expand it; otherwise, wpos would reach rpos from the left.
if (remaining <= capacity) {
int neededSize = available + capacity + 1;
int nextSize = getNextSize(neededSize);
if (nextSize < neededSize) {
throw new CircularBufferException("Attempted overflow");
}
byte[] tmp = new byte[nextSize];
// Copy data to the beginning of the new array.
if (wpos >= rpos) {
System.arraycopy(data, rpos, tmp, 0, available);
wpos -= rpos; // wpos must be relative to the new rpos, which will be 0
} else {
int tail = data.length - rpos;
System.arraycopy(data, rpos, tmp, 0, tail); // segment right of rpos
System.arraycopy(data, 0, tmp, tail, wpos); // segment left of wpos
wpos += tail; // wpos must be relative to the new rpos, which will be 0
}
rpos = 0;
data = tmp;
}
}
/**
* Reads data from this buffer into the provided array.
*/
public void readRawBytes(byte[] destination, int offset, int length) throws CircularBufferException {
ensureAvailable(length);
int rposNext = rpos + length;
if (rposNext <= data.length) {
System.arraycopy(data, rpos, destination, offset, length);
} else {
int tail = data.length - rpos;
System.arraycopy(data, rpos, destination, offset, tail); // segment right of rpos
rposNext = length - tail; // rpos wraps around the end of the buffer
System.arraycopy(data, 0, destination, offset + tail, rposNext); // remainder
}
// This can make rpos equal data.length, which has the same effect as wpos being 0.
rpos = rposNext;
}
/**
* Writes data to this buffer from the provided array.
*/
@SuppressWarnings("unchecked")
public T putRawBytes(byte[] source, int offset, int length) throws CircularBufferException {
ensureCapacity(length);
int wposNext = wpos + length;
if (wposNext <= data.length) {
System.arraycopy(source, offset, data, wpos, length);
} else {
int tail = data.length - wpos;
System.arraycopy(source, offset, data, wpos, tail); // segment right of wpos
wposNext = length - tail; // wpos wraps around the end of the buffer
System.arraycopy(source, offset + tail, data, 0, wposNext); // remainder
}
// This can make wpos equal data.length, which has the same effect as wpos being 0.
wpos = wposNext;
return (T) this;
}
// Used only for testing.
int length() {
return data.length;
}
@Override
public String toString() {
return "CircularBuffer [rpos=" + rpos + ", wpos=" + wpos + ", size=" + data.length + "]";
}
}

View File

@@ -164,7 +164,8 @@ public abstract class AbstractChannel
}
@Override
public void handle(Message msg, SSHPacket buf) throws SSHException {
public void handle(Message msg, SSHPacket buf)
throws ConnectionException, TransportException {
switch (msg) {
case CHANNEL_DATA:
@@ -353,7 +354,7 @@ public abstract class AbstractChannel
}
protected void gotExtendedData(SSHPacket buf)
throws SSHException {
throws ConnectionException, TransportException {
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
"Extended data not supported on " + type + " channel");
}
@@ -374,7 +375,7 @@ public abstract class AbstractChannel
}
protected void receiveInto(ChannelInputStream stream, SSHPacket buf)
throws SSHException {
throws ConnectionException, TransportException {
final int len;
try {
len = buf.readUInt32AsInt();

View File

@@ -38,7 +38,7 @@ public final class ChannelInputStream
private final Channel chan;
private final Transport trans;
private final Window.Local win;
private final CircularBuffer.PlainCircularBuffer buf;
private final Buffer.PlainBuffer buf;
private final byte[] b = new byte[1];
private boolean eof;
@@ -46,11 +46,10 @@ public final class ChannelInputStream
public ChannelInputStream(Channel chan, Transport trans, Window.Local win) {
this.chan = chan;
this.log = chan.getLoggerFactory().getLogger(getClass());
log = chan.getLoggerFactory().getLogger(getClass());
this.trans = trans;
this.win = win;
this.buf = new CircularBuffer.PlainCircularBuffer(
chan.getLocalMaxPacketSize(), trans.getConfig().getMaxCircularBufferSize());
buf = new Buffer.PlainBuffer(chan.getLocalMaxPacketSize());
}
@Override
@@ -114,44 +113,48 @@ public final class ChannelInputStream
len = buf.available();
}
buf.readRawBytes(b, off, len);
if (!chan.getAutoExpand()) {
checkWindow();
if (buf.rpos() > win.getMaxPacketSize() && buf.available() == 0) {
buf.clear();
}
}
if (!chan.getAutoExpand()) {
checkWindow();
}
return len;
}
public void receive(byte[] data, int offset, int len) throws SSHException {
public void receive(byte[] data, int offset, int len)
throws ConnectionException, TransportException {
if (eof) {
throw new ConnectionException("Getting data on EOF'ed stream");
}
synchronized (buf) {
buf.putRawBytes(data, offset, len);
buf.notifyAll();
// Potential fix for #203 (window consumed below 0).
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
// And the window has not expanded yet.
}
// Potential fix for #203 (window consumed below 0).
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
// And the window has not expanded yet.
synchronized (win) {
win.consume(len);
if (chan.getAutoExpand()) {
checkWindow();
}
}
if (chan.getAutoExpand()) {
checkWindow();
}
}
private void checkWindow() throws TransportException {
/*
* Window must fit in remaining buffer capacity. We already expect win.size() amount of data to arrive. The
* difference between that and the remaining capacity is the maximum adjustment we can make to the window.
*/
final long maxAdjustment = buf.maxPossibleRemainingCapacity() - win.getSize();
final long adjustment = Math.min(win.neededAdjustment(), maxAdjustment);
if (adjustment > 0) {
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
win.expand(adjustment);
private void checkWindow()
throws TransportException {
synchronized (win) {
final long adjustment = win.neededAdjustment();
if (adjustment > 0) {
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
win.expand(adjustment);
}
}
}

View File

@@ -210,7 +210,7 @@ public class SessionChannel
@Override
protected void gotExtendedData(SSHPacket buf)
throws SSHException {
throws ConnectionException, TransportException {
try {
final int dataTypeCode = buf.readUInt32AsInt();
if (dataTypeCode == 1)

View File

@@ -51,14 +51,6 @@ abstract class Converter {
return seq;
}
void resetSequenceNumber() {
seq = -1;
}
boolean isSequenceNumberAtMax() {
return seq == 0xffffffffL;
}
void setAlgorithms(Cipher cipher, MAC mac, Compression compression) {
this.cipher = cipher;
this.mac = mac;

View File

@@ -60,10 +60,6 @@ final class KeyExchanger
private final AtomicBoolean kexOngoing = new AtomicBoolean();
private final AtomicBoolean initialKex = new AtomicBoolean(true);
private final AtomicBoolean strictKex = new AtomicBoolean();
/** What we are expecting from the next packet */
private Expected expected = Expected.KEXINIT;
@@ -127,14 +123,6 @@ final class KeyExchanger
return kexOngoing.get();
}
boolean isStrictKex() {
return strictKex.get();
}
boolean isInitialKex() {
return initialKex.get();
}
/**
* Starts key exchange by sending a {@code SSH_MSG_KEXINIT} packet. Key exchange needs to be done once mandatorily
* after initializing the {@link Transport} for it to be usable and may be initiated at any later point e.g. if
@@ -195,7 +183,7 @@ final class KeyExchanger
throws TransportException {
log.debug("Sending SSH_MSG_KEXINIT");
List<String> knownHostAlgs = findKnownHostAlgs(transport.getRemoteHost(), transport.getRemotePort());
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs, initialKex.get());
clientProposal = new Proposal(transport.getConfig(), knownHostAlgs);
transport.write(clientProposal.getPacket());
kexInitSent.set();
}
@@ -214,9 +202,6 @@ final class KeyExchanger
throws TransportException {
log.debug("Sending SSH_MSG_NEWKEYS");
transport.write(new SSHPacket(Message.NEWKEYS));
if (strictKex.get()) {
transport.getEncoder().resetSequenceNumber();
}
}
/**
@@ -249,10 +234,6 @@ final class KeyExchanger
private void setKexDone() {
kexOngoing.set(false);
initialKex.set(false);
if (strictKex.get()) {
transport.getDecoder().resetSequenceNumber();
}
kexInitSent.clear();
done.set();
}
@@ -261,7 +242,6 @@ final class KeyExchanger
throws TransportException {
buf.rpos(buf.rpos() - 1);
final Proposal serverProposal = new Proposal(buf);
gotStrictKexInfo(serverProposal);
negotiatedAlgs = clientProposal.negotiate(serverProposal);
log.debug("Negotiated algorithms: {}", negotiatedAlgs);
for(AlgorithmsVerifier v: algorithmVerifiers) {
@@ -285,18 +265,6 @@ final class KeyExchanger
}
}
private void gotStrictKexInfo(Proposal serverProposal) throws TransportException {
if (initialKex.get() && serverProposal.isStrictKeyExchangeSupportedByServer()) {
strictKex.set(true);
log.debug("Enabling strict key exchange extension");
if (transport.getDecoder().getSequenceNumber() != 0) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"SSH_MSG_KEXINIT was not first package during strict key exchange"
);
}
}
}
/**
* Private method used while putting new keys into use that will resize the key used to initialize the cipher to the
* needed length.

View File

@@ -37,11 +37,8 @@ class Proposal {
private final List<String> s2cComp;
private final SSHPacket packet;
public Proposal(Config config, List<String> knownHostAlgs, boolean initialKex) {
public Proposal(Config config, List<String> knownHostAlgs) {
kex = Factory.Named.Util.getNames(config.getKeyExchangeFactories());
if (initialKex) {
kex.add("kex-strict-c-v00@openssh.com");
}
sig = filterKnownHostKeyAlgorithms(Factory.Named.Util.getNames(config.getKeyAlgorithms()), knownHostAlgs);
c2sCipher = s2cCipher = Factory.Named.Util.getNames(config.getCipherFactories());
c2sMAC = s2cMAC = Factory.Named.Util.getNames(config.getMACFactories());
@@ -94,10 +91,6 @@ class Proposal {
return kex;
}
public boolean isStrictKeyExchangeSupportedByServer() {
return kex.contains("kex-strict-s-v00@openssh.com");
}
public List<String> getHostKeyAlgorithms() {
return sig;
}

View File

@@ -426,7 +426,7 @@ public final class TransportImpl
assert m != Message.KEXINIT;
kexer.waitForDone();
}
} else if (encoder.isSequenceNumberAtMax()) // We get here every 2**32th packet
} else if (encoder.getSequenceNumber() == 0) // We get here every 2**32th packet
kexer.startKex(true);
final long seq = encoder.encode(payload);
@@ -479,20 +479,9 @@ public final class TransportImpl
log.trace("Received packet {}", msg);
if (kexer.isInitialKex()) {
if (decoder.isSequenceNumberAtMax()) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"Sequence number of decoder is about to wrap during initial key exchange");
}
if (kexer.isStrictKex() && !isKexerPacket(msg) && msg != Message.DISCONNECT) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"Unexpected packet type during initial strict key exchange");
}
}
if (msg.geq(50)) { // not a transport layer packet
service.handle(msg, buf);
} else if (isKexerPacket(msg)) {
} else if (msg.in(20, 21) || msg.in(30, 49)) { // kex packet
kexer.handle(msg, buf);
} else {
switch (msg) {
@@ -524,10 +513,6 @@ public final class TransportImpl
}
}
private static boolean isKexerPacket(Message msg) {
return msg.in(20, 21) || msg.in(30, 49);
}
private void gotDebug(SSHPacket buf)
throws TransportException {
try {

View File

@@ -41,7 +41,6 @@ import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.spec.RSAPublicKeySpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
@@ -469,8 +468,7 @@ public class OpenSSHKnownHosts
}
private String getKeyString(PublicKey pk) {
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer().putPublicKey(pk);
return Base64.getEncoder().encodeToString(Arrays.copyOfRange(buf.array(), buf.rpos(), buf.available()));
return Base64.getEncoder().encodeToString(pk.getEncoded());
}
protected String getHostPart() {

View File

@@ -1,188 +0,0 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hierynomus.sshj.connection.channel.forwarded;
import static org.junit.jupiter.api.Assertions.*;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.Forward;
import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RemotePFPerformanceTest {
private static final Logger log = LoggerFactory.getLogger(RemotePFPerformanceTest.class);
@Test
@Disabled
public void startPF() throws IOException, InterruptedException {
DefaultConfig config = new DefaultConfig();
config.setMaxCircularBufferSize(16 * 1024 * 1024);
SSHClient client = new SSHClient(config);
client.loadKnownHosts();
client.addHostKeyVerifier("5c:0c:8e:9d:1c:50:a9:ba:a7:05:f6:b1:2b:0b:5f:ba");
client.getConnection().getKeepAlive().setKeepAliveInterval(5);
client.connect("localhost");
client.getConnection().getKeepAlive().setKeepAliveInterval(5);
Object consumerReadyMonitor = new Object();
ConsumerThread consumerThread = new ConsumerThread(consumerReadyMonitor);
ProducerThread producerThread = new ProducerThread();
try {
client.authPassword(System.getenv().get("USERNAME"), System.getenv().get("PASSWORD"));
/*
* We make _server_ listen on port 8080, which forwards all connections to us as a channel, and we further
* forward all such channels to google.com:80
*/
client.getRemotePortForwarder().bind(
// where the server should listen
new Forward(8888),
// what we do with incoming connections that are forwarded to us
new SocketForwardingConnectListener(new InetSocketAddress("localhost", 12345)));
consumerThread.start();
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.wait();
}
producerThread.start();
// Wait for consumer to finish receiving data.
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.wait();
}
} finally {
producerThread.interrupt();
consumerThread.interrupt();
client.disconnect();
}
}
private static class ConsumerThread extends Thread {
private final Object consumerReadyMonitor;
private ConsumerThread(Object consumerReadyMonitor) {
super("Consumer");
this.consumerReadyMonitor = consumerReadyMonitor;
}
@Override
public void run() {
try (ServerSocket serverSocket = new ServerSocket(12345)) {
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.notifyAll();
}
try (Socket acceptedSocket = serverSocket.accept()) {
InputStream in = acceptedSocket.getInputStream();
int numRead;
byte[] buf = new byte[40000];
//byte[] buf = new byte[255 * 4 * 1000];
byte expectedNext = 1;
while ((numRead = in.read(buf)) != 0) {
if (Thread.interrupted()) {
log.info("Consumer thread interrupted");
return;
}
log.info(String.format("Read %d characters; values from %d to %d", numRead, buf[0], buf[numRead - 1]));
if (buf[numRead - 1] == 0) {
verifyData(buf, numRead - 1, expectedNext);
break;
}
expectedNext = verifyData(buf, numRead, expectedNext);
// Slow down consumer to test buffering.
Thread.sleep(Long.parseLong(System.getenv().get("DELAY_MS")));
}
log.info("Consumer read end of stream value: " + numRead);
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.notifyAll();
}
}
} catch (Exception e) {
synchronized (consumerReadyMonitor) {
consumerReadyMonitor.notifyAll();
}
e.printStackTrace();
}
}
private byte verifyData(byte[] buf, int numRead, byte expectedNext) {
for (int i = 0; i < numRead; ++i) {
if (buf[i] != expectedNext) {
fail("Expected buf[" + i + "]=" + buf[i] + " to be " + expectedNext);
}
if (++expectedNext == 0) {
expectedNext = 1;
}
}
return expectedNext;
}
}
private static class ProducerThread extends Thread {
private ProducerThread() {
super("Producer");
}
@Override
public void run() {
try (Socket clientSocket = new Socket("127.0.0.1", 8888);
OutputStream writer = clientSocket.getOutputStream()) {
byte[] buf = getData();
assertEquals(buf[0], 1);
assertEquals(buf[buf.length - 1], -1);
for (int i = 0; i < 1000; ++i) {
writer.write(buf);
if (Thread.interrupted()) {
log.info("Consumer thread interrupted");
return;
}
log.info(String.format("Wrote %d characters; values from %d to %d", buf.length, buf[0], buf[buf.length - 1]));
}
writer.write(0); // end of stream value
log.info("Producer finished sending data");
} catch (Exception e) {
e.printStackTrace();
}
}
private byte[] getData() {
byte[] buf = new byte[255 * 4 * 1000];
byte nextValue = 1;
for (int i = 0; i < buf.length; ++i) {
buf[i] = nextValue++;
// reserve 0 for end of stream
if (nextValue == 0) {
nextValue = 1;
}
}
return buf;
}
}
}

View File

@@ -63,11 +63,6 @@ public class OpenSSHKnownHostsTest {
OpenSSHKnownHosts ohk = new OpenSSHKnownHosts(knownHosts);
assertTrue(ohk.verify("192.168.1.61", 22, k));
assertFalse(ohk.verify("192.168.1.2", 22, k));
ohk.write();
for (OpenSSHKnownHosts.KnownHostEntry entry : ohk.entries()) {
assertEquals("|1|F1E1KeoE/eEWhi10WpGv4OdiO6Y=|3988QV0VE8wmZL7suNrYQLITLCg= ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEA6P9Hlwdahh250jGZYKg2snRq2j2lFJVdKSHyxqbJiVy9VX9gTkN3K2MD48qyrYLYOyGs3vTttyUk+cK++JMzURWsrP4piby7LpeOT+3Iq8CQNj4gXZdcH9w15Vuk2qS11at6IsQPVHpKD9HGg9//EFUccI/4w06k4XXLm/IxOGUwj6I2AeWmEOL3aDi+fe07TTosSdLUD6INtR0cyKsg0zC7Da24ixoShT8Oy3x2MpR7CY3PQ1pUVmvPkr79VeA+4qV9F1JM09WdboAMZgWQZ+XrbtuBlGsyhpUHSCQOya+kOJ+bYryS+U7A+6nmTW3C9FX4FgFqTF89UHOC7V0zZQ==",
entry.getLine());
}
}
@Test

View File

@@ -1,221 +0,0 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.common;
import static org.junit.jupiter.api.Assertions.*;
import net.schmizz.sshj.common.CircularBuffer.CircularBufferException;
import net.schmizz.sshj.common.CircularBuffer.PlainCircularBuffer;
import org.junit.jupiter.api.Test;
public class CircularBufferTest {
@Test
public void shouldStoreDataCorrectlyWithoutResizing() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(256, Integer.MAX_VALUE);
byte[] dataToWrite = getData(500);
buffer.putRawBytes(dataToWrite, 0, 100);
buffer.putRawBytes(dataToWrite, 100, 100);
byte[] dataToRead = new byte[500];
buffer.readRawBytes(dataToRead, 0, 80);
buffer.readRawBytes(dataToRead, 80, 80);
buffer.putRawBytes(dataToWrite, 200, 100);
buffer.readRawBytes(dataToRead, 160, 80);
buffer.putRawBytes(dataToWrite, 300, 100);
buffer.readRawBytes(dataToRead, 240, 80);
buffer.putRawBytes(dataToWrite, 400, 100);
buffer.readRawBytes(dataToRead, 320, 80);
buffer.readRawBytes(dataToRead, 400, 100);
assertEquals(256, buffer.length());
assertArrayEquals(dataToWrite, dataToRead);
}
@Test
public void shouldStoreDataCorrectlyWithResizing() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
byte[] dataToWrite = getData(500);
buffer.putRawBytes(dataToWrite, 0, 100);
buffer.putRawBytes(dataToWrite, 100, 100);
byte[] dataToRead = new byte[500];
buffer.readRawBytes(dataToRead, 0, 80);
buffer.readRawBytes(dataToRead, 80, 80);
buffer.putRawBytes(dataToWrite, 200, 100);
buffer.readRawBytes(dataToRead, 160, 80);
buffer.putRawBytes(dataToWrite, 300, 100);
buffer.readRawBytes(dataToRead, 240, 80);
buffer.putRawBytes(dataToWrite, 400, 100);
buffer.readRawBytes(dataToRead, 320, 80);
buffer.readRawBytes(dataToRead, 400, 100);
assertEquals(256, buffer.length());
assertArrayEquals(dataToWrite, dataToRead);
}
@Test
public void shouldNotOverflowWhenWritingFullLengthToTheEnd() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
byte[] dataToWrite = getData(64);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should write to the end
assertEquals(64, buffer.available());
assertEquals(64 * 2, buffer.length());
}
@Test
public void shouldNotOverflowWhenWritingFullLengthWrapsAround() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
// Move 1 byte forward.
buffer.putRawBytes(new byte[1], 0, 1);
buffer.readRawBytes(new byte[1], 0, 1);
// Force writes to wrap around.
byte[] dataToWrite = getData(64);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should wrap around the end
assertEquals(64, buffer.available());
assertEquals(64 * 2, buffer.length());
}
@Test
public void shouldAllowWritingMaxCapacityFromZero() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
// Max capacity is always one less than the buffer size.
int maxCapacity = buffer.maxPossibleRemainingCapacity();
assertEquals(buffer.length() - 1, maxCapacity);
byte[] dataToWrite = getData(maxCapacity);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
assertEquals(dataToWrite.length, buffer.available());
assertEquals(64, buffer.length());
}
@Test
public void shouldAllowWritingMaxRemainingCapacity() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
final int initiallyWritten = 10;
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
// Max remaining capacity is always one less than the remaining buffer size.
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
byte[] dataToWrite = getData(maxRemainingCapacity);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
assertEquals(dataToWrite.length + initiallyWritten, buffer.available());
assertEquals(64, buffer.length());
}
@Test
public void shouldAllowWritingMaxRemainingCapacityAfterWrappingAround() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
// Cause the internal write pointer to wrap around and be left of the read pointer.
final int initiallyWritten = 40;
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
buffer.readRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
// Max remaining capacity is always one less than the remaining buffer size.
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
byte[] dataToWrite = getData(maxRemainingCapacity);
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
assertEquals(dataToWrite.length + initiallyWritten, buffer.available());
assertEquals(64, buffer.length());
}
@Test
public void shouldOverflowWhenWritingOverMaxRemainingCapacity() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
final int initiallyWritten = 10;
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
// Max remaining capacity is always one less than the remaining buffer size.
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
byte[] dataToWrite = getData(maxRemainingCapacity + 1);
assertThrows(CircularBufferException.class, () -> buffer.putRawBytes(dataToWrite, 0, dataToWrite.length));
}
@Test
public void shouldThrowWhenReadingEmptyBuffer() {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[1], 0, 1));
}
@Test
public void shouldThrowWhenReadingMoreThanAvailable() throws CircularBufferException {
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
buffer.putRawBytes(new byte[1], 0, 1);
assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[2], 0, 2));
}
@Test
public void shouldThrowOnAboveMaximumInitialSize() {
assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(65, 64));
}
@Test
public void shouldThrowOnMaximumInitialSize() {
assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(Integer.MAX_VALUE, 64));
}
@Test
public void shouldAllowFullCapacity() throws CircularBufferException {
int maxSize = 1024;
PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize);
buffer.ensureCapacity(maxSize - 1);
assertEquals(maxSize - 1, buffer.maxPossibleRemainingCapacity());
}
@Test
public void shouldThrowOnTooLargeRequestedCapacity() {
int maxSize = 1024;
PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize);
assertThrows(CircularBufferException.class, () -> buffer.ensureCapacity(maxSize));
}
private static byte[] getData(int length) {
byte[] data = new byte[length];
byte nextValue = 0;
for (int i = 0; i < length; ++i) {
data[i] = nextValue++;
}
return data;
}
}

View File

@@ -112,7 +112,7 @@ public class KeyExchangeRepeatTest {
}
private SSHPacket getKexinitPacket() {
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), false).getPacket();
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList()).getPacket();
kexinitPacket.rpos(kexinitPacket.rpos() + 1);
return kexinitPacket;
}

View File

@@ -1,236 +0,0 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.transport;
import java.math.BigInteger;
import java.util.Collections;
import java.util.List;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.transport.kex.KeyExchange;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
class KeyExchangerStrictKeyExchangeTest {
private TransportImpl transport;
private DefaultConfig config;
private KeyExchanger keyExchanger;
@BeforeEach
void setUp() throws Exception {
KeyExchange kex = mock(KeyExchange.class, Mockito.RETURNS_DEEP_STUBS);
transport = mock(TransportImpl.class, Mockito.RETURNS_DEEP_STUBS);
config = new DefaultConfig() {
@Override
protected void initKeyExchangeFactories() {
setKeyExchangeFactories(Collections.singletonList(new Factory.Named<>() {
@Override
public KeyExchange create() {
return kex;
}
@Override
public String getName() {
return "mock-kex";
}
}));
}
};
when(transport.getConfig()).thenReturn(config);
when(transport.getServerID()).thenReturn("some server id");
when(transport.getClientID()).thenReturn("some client id");
when(kex.next(any(), any())).thenReturn(true);
when(kex.getH()).thenReturn(new byte[0]);
when(kex.getK()).thenReturn(BigInteger.ZERO);
when(kex.getHash().digest()).thenReturn(new byte[10]);
keyExchanger = new KeyExchanger(transport);
keyExchanger.addHostKeyVerifier(new PromiscuousVerifier());
}
@Test
void initialConditions() {
assertThat(keyExchanger.isKexDone()).isFalse();
assertThat(keyExchanger.isKexOngoing()).isFalse();
assertThat(keyExchanger.isStrictKex()).isFalse();
assertThat(keyExchanger.isInitialKex()).isTrue();
}
@Test
void startInitialKex() throws Exception {
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
keyExchanger.startKex(false);
assertThat(keyExchanger.isKexDone()).isFalse();
assertThat(keyExchanger.isKexOngoing()).isTrue();
assertThat(keyExchanger.isStrictKex()).isFalse();
assertThat(keyExchanger.isInitialKex()).isTrue();
SSHPacket sshPacket = sshPacketCaptor.getValue();
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
assertThat(kex).endsWith("kex-strict-c-v00@openssh.com");
}
@Test
void receiveKexInitWithoutServerFlag() throws Exception {
keyExchanger.startKex(false);
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
assertThat(keyExchanger.isKexDone()).isFalse();
assertThat(keyExchanger.isKexOngoing()).isTrue();
assertThat(keyExchanger.isStrictKex()).isFalse();
assertThat(keyExchanger.isInitialKex()).isTrue();
}
@Test
void finishNonStrictKex() throws Exception {
keyExchanger.startKex(false);
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
assertThat(keyExchanger.isKexDone()).isTrue();
assertThat(keyExchanger.isKexOngoing()).isFalse();
assertThat(keyExchanger.isStrictKex()).isFalse();
assertThat(keyExchanger.isInitialKex()).isFalse();
verify(transport.getEncoder(), never()).resetSequenceNumber();
verify(transport.getDecoder(), never()).resetSequenceNumber();
}
@Test
void receiveKexInitWithServerFlag() throws Exception {
keyExchanger.startKex(false);
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
assertThat(keyExchanger.isKexDone()).isFalse();
assertThat(keyExchanger.isKexOngoing()).isTrue();
assertThat(keyExchanger.isStrictKex()).isTrue();
assertThat(keyExchanger.isInitialKex()).isTrue();
}
@Test
void strictKexInitIsNotFirstPacket() throws Exception {
when(transport.getDecoder().getSequenceNumber()).thenReturn(1L);
keyExchanger.startKex(false);
assertThatExceptionOfType(TransportException.class).isThrownBy(
() -> keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true))
).satisfies(e -> {
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
assertThat(e.getMessage()).isEqualTo("SSH_MSG_KEXINIT was not first package during strict key exchange");
});
}
@Test
void finishStrictKex() throws Exception {
keyExchanger.startKex(false);
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
verify(transport.getEncoder(), never()).resetSequenceNumber();
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
verify(transport.getEncoder()).resetSequenceNumber();
verify(transport.getDecoder(), never()).resetSequenceNumber();
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
verify(transport.getDecoder()).resetSequenceNumber();
assertThat(keyExchanger.isKexDone()).isTrue();
assertThat(keyExchanger.isKexOngoing()).isFalse();
assertThat(keyExchanger.isStrictKex()).isTrue();
assertThat(keyExchanger.isInitialKex()).isFalse();
}
@Test
void noClientFlagInSecondStrictKex() throws Exception {
keyExchanger.startKex(false);
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
when(transport.isAuthenticated()).thenReturn(true);
keyExchanger.startKex(false);
assertThat(keyExchanger.isKexDone()).isFalse();
assertThat(keyExchanger.isKexOngoing()).isTrue();
assertThat(keyExchanger.isStrictKex()).isTrue();
assertThat(keyExchanger.isInitialKex()).isFalse();
SSHPacket sshPacket = sshPacketCaptor.getValue();
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
assertThat(kex).doesNotContain("kex-strict-c-v00@openssh.com");
}
@Test
void serverFlagIsIgnoredInSecondKex() throws Exception {
keyExchanger.startKex(false);
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(false));
keyExchanger.handle(Message.KEXDH_31, new SSHPacket(Message.KEXDH_31));
keyExchanger.handle(Message.NEWKEYS, new SSHPacket(Message.NEWKEYS));
ArgumentCaptor<SSHPacket> sshPacketCaptor = ArgumentCaptor.forClass(SSHPacket.class);
when(transport.write(sshPacketCaptor.capture())).thenReturn(0L);
when(transport.isAuthenticated()).thenReturn(true);
keyExchanger.startKex(false);
keyExchanger.handle(Message.KEXINIT, getKexInitPacket(true));
assertThat(keyExchanger.isKexDone()).isFalse();
assertThat(keyExchanger.isKexOngoing()).isTrue();
assertThat(keyExchanger.isStrictKex()).isFalse();
assertThat(keyExchanger.isInitialKex()).isFalse();
SSHPacket sshPacket = sshPacketCaptor.getValue();
List<String> kex = new Proposal(sshPacket).getKeyExchangeAlgorithms();
assertThat(kex).doesNotContain("kex-strict-c-v00@openssh.com");
}
private SSHPacket getKexInitPacket(boolean withServerFlag) {
SSHPacket kexinitPacket = new Proposal(config, Collections.emptyList(), true).getPacket();
if (withServerFlag) {
int finalWpos = kexinitPacket.wpos();
kexinitPacket.wpos(22);
kexinitPacket.putString("mock-kex,kex-strict-s-v00@openssh.com");
kexinitPacket.wpos(finalWpos);
}
kexinitPacket.rpos(kexinitPacket.rpos() + 1);
return kexinitPacket;
}
}

View File

@@ -1,120 +0,0 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.transport;
import java.lang.reflect.Field;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.EnumSource.Mode;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
class TransportImplStrictKeyExchangeTest {
private final Config config = new DefaultConfig();
private final Transport transport = new TransportImpl(config);
private final KeyExchanger kexer = mock(KeyExchanger.class);
private final Decoder decoder = mock(Decoder.class);
@BeforeEach
void setUp() throws Exception {
Field kexerField = TransportImpl.class.getDeclaredField("kexer");
kexerField.setAccessible(true);
kexerField.set(transport, kexer);
Field decoderField = TransportImpl.class.getDeclaredField("decoder");
decoderField.setAccessible(true);
decoderField.set(transport, decoder);
}
@Test
void throwExceptionOnWrapDuringInitialKex() {
when(kexer.isInitialKex()).thenReturn(true);
when(decoder.isSequenceNumberAtMax()).thenReturn(true);
assertThatExceptionOfType(TransportException.class).isThrownBy(
() -> transport.handle(Message.KEXINIT, new SSHPacket(Message.KEXINIT))
).satisfies(e -> {
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
assertThat(e.getMessage()).isEqualTo("Sequence number of decoder is about to wrap during initial key exchange");
});
}
@ParameterizedTest
@EnumSource(value = Message.class, mode = Mode.EXCLUDE, names = {
"DISCONNECT", "KEXINIT", "NEWKEYS", "KEXDH_INIT", "KEXDH_31", "KEX_DH_GEX_INIT", "KEX_DH_GEX_REPLY", "KEX_DH_GEX_REQUEST"
})
void forbidUnexpectedPacketsDuringStrictKeyExchange(Message message) {
when(kexer.isInitialKex()).thenReturn(true);
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
when(kexer.isStrictKex()).thenReturn(true);
assertThatExceptionOfType(TransportException.class).isThrownBy(
() -> transport.handle(message, new SSHPacket(message))
).satisfies(e -> {
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.KEY_EXCHANGE_FAILED);
assertThat(e.getMessage()).isEqualTo("Unexpected packet type during initial strict key exchange");
});
}
@ParameterizedTest
@EnumSource(value = Message.class, mode = Mode.INCLUDE, names = {
"KEXINIT", "NEWKEYS", "KEXDH_INIT", "KEXDH_31", "KEX_DH_GEX_INIT", "KEX_DH_GEX_REPLY", "KEX_DH_GEX_REQUEST"
})
void expectedPacketsDuringStrictKeyExchangeAreHandled(Message message) throws Exception {
when(kexer.isInitialKex()).thenReturn(true);
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
when(kexer.isStrictKex()).thenReturn(true);
SSHPacket sshPacket = new SSHPacket(message);
assertThatCode(
() -> transport.handle(message, sshPacket)
).doesNotThrowAnyException();
verify(kexer).handle(message, sshPacket);
}
@Test
void disconnectIsAllowedDuringStrictKeyExchange() {
when(kexer.isInitialKex()).thenReturn(true);
when(decoder.isSequenceNumberAtMax()).thenReturn(false);
when(kexer.isStrictKex()).thenReturn(true);
SSHPacket sshPacket = new SSHPacket();
sshPacket.putUInt32(DisconnectReason.SERVICE_NOT_AVAILABLE.toInt());
sshPacket.putString("service is down for maintenance");
assertThatExceptionOfType(TransportException.class).isThrownBy(
() -> transport.handle(Message.DISCONNECT, sshPacket)
).satisfies(e -> {
assertThat(e.getDisconnectReason()).isEqualTo(DisconnectReason.SERVICE_NOT_AVAILABLE);
assertThat(e.getMessage()).isEqualTo("service is down for maintenance");
});
}
}