mirror of
https://github.com/hierynomus/sshj.git
synced 2025-12-06 15:20:54 +03:00
* Fix for Remote port forwarding buffers can grow without limits (issue #658) * Update test classes to use JUnit 5 * Fix MB computation
This commit is contained in:
@@ -200,4 +200,8 @@ public interface Config {
|
|||||||
* See {@link #isVerifyHostKeyCertificates()}.
|
* See {@link #isVerifyHostKeyCertificates()}.
|
||||||
*/
|
*/
|
||||||
void setVerifyHostKeyCertificates(boolean value);
|
void setVerifyHostKeyCertificates(boolean value);
|
||||||
|
|
||||||
|
int getMaxCircularBufferSize();
|
||||||
|
|
||||||
|
void setMaxCircularBufferSize(int maxCircularBufferSize);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ public class ConfigImpl
|
|||||||
private boolean waitForServerIdentBeforeSendingClientIdent = false;
|
private boolean waitForServerIdentBeforeSendingClientIdent = false;
|
||||||
private LoggerFactory loggerFactory;
|
private LoggerFactory loggerFactory;
|
||||||
private boolean verifyHostKeyCertificates = true;
|
private boolean verifyHostKeyCertificates = true;
|
||||||
|
// HF-982: default to 16MB buffers.
|
||||||
|
private int maxCircularBufferSize = 16 * 1024 * 1024;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Factory.Named<Cipher>> getCipherFactories() {
|
public List<Factory.Named<Cipher>> getCipherFactories() {
|
||||||
@@ -175,6 +177,16 @@ public class ConfigImpl
|
|||||||
return loggerFactory;
|
return loggerFactory;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxCircularBufferSize() {
|
||||||
|
return maxCircularBufferSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setMaxCircularBufferSize(int maxCircularBufferSize) {
|
||||||
|
this.maxCircularBufferSize = maxCircularBufferSize;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setLoggerFactory(LoggerFactory loggerFactory) {
|
public void setLoggerFactory(LoggerFactory loggerFactory) {
|
||||||
this.loggerFactory = loggerFactory;
|
this.loggerFactory = loggerFactory;
|
||||||
|
|||||||
194
src/main/java/net/schmizz/sshj/common/CircularBuffer.java
Normal file
194
src/main/java/net/schmizz/sshj/common/CircularBuffer.java
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C)2009 - SSHJ Contributors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package net.schmizz.sshj.common;
|
||||||
|
|
||||||
|
public class CircularBuffer<T extends CircularBuffer<T>> {
|
||||||
|
|
||||||
|
public static class CircularBufferException
|
||||||
|
extends SSHException {
|
||||||
|
|
||||||
|
public CircularBufferException(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final class PlainCircularBuffer
|
||||||
|
extends CircularBuffer<PlainCircularBuffer> {
|
||||||
|
|
||||||
|
public PlainCircularBuffer(int size, int maxSize) {
|
||||||
|
super(size, maxSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maximum size of the internal array (one plus the maximum capacity of the buffer).
|
||||||
|
*/
|
||||||
|
private final int maxSize;
|
||||||
|
/**
|
||||||
|
* Internal array for the data. All bytes minus one can be used to avoid empty vs full ambiguity when rpos == wpos.
|
||||||
|
*/
|
||||||
|
private byte[] data;
|
||||||
|
/**
|
||||||
|
* Next read position. Wraps around the end of the internal array. When it reaches wpos, the buffer becomes empty.
|
||||||
|
* Can take the value data.length, which is equivalent to 0.
|
||||||
|
*/
|
||||||
|
private int rpos;
|
||||||
|
/**
|
||||||
|
* Next write position. Wraps around the end of the internal array. If it is equal to rpos, then the buffer is
|
||||||
|
* empty; the code does not allow wpos to reach rpos from the left. This implies that the buffer can store up to
|
||||||
|
* data.length - 1 bytes. Can take the value data.length, which is equivalent to 0.
|
||||||
|
*/
|
||||||
|
private int wpos;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Determines the size to which to grow the internal array.
|
||||||
|
*/
|
||||||
|
private int getNextSize(int currentSize) {
|
||||||
|
// Use next power of 2.
|
||||||
|
int nextSize = 1;
|
||||||
|
while (nextSize < currentSize) {
|
||||||
|
nextSize <<= 1;
|
||||||
|
if (nextSize <= 0) {
|
||||||
|
return maxSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Math.min(nextSize, maxSize); // limit to max size
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new circular buffer of the given size. The capacity of the buffer is one less than the size/
|
||||||
|
*/
|
||||||
|
public CircularBuffer(int size, int maxSize) {
|
||||||
|
this.maxSize = maxSize;
|
||||||
|
if (size > maxSize) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
String.format("Initial requested size %d larger than maximum size %d", size, maxSize));
|
||||||
|
}
|
||||||
|
int initialSize = getNextSize(size);
|
||||||
|
this.data = new byte[initialSize];
|
||||||
|
this.rpos = 0;
|
||||||
|
this.wpos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Data available in the buffer for reading.
|
||||||
|
*/
|
||||||
|
public int available() {
|
||||||
|
int available = wpos - rpos;
|
||||||
|
return available >= 0 ? available : available + data.length; // adjust if wpos is left of rpos
|
||||||
|
}
|
||||||
|
|
||||||
|
private void ensureAvailable(int a)
|
||||||
|
throws CircularBufferException {
|
||||||
|
if (available() < a) {
|
||||||
|
throw new CircularBufferException("Underflow");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns how many more bytes this buffer can receive.
|
||||||
|
*/
|
||||||
|
public int maxPossibleRemainingCapacity() {
|
||||||
|
// Remaining capacity is one less than remaining space to ensure that wpos does not reach rpos from the left.
|
||||||
|
int remaining = rpos - wpos - 1;
|
||||||
|
if (remaining < 0) {
|
||||||
|
remaining += data.length; // adjust if rpos is left of wpos
|
||||||
|
}
|
||||||
|
// Add the maximum amount the internal array can grow.
|
||||||
|
return remaining + maxSize - data.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If the internal array does not have room for "capacity" more bytes, resizes the array to make that room.
|
||||||
|
*/
|
||||||
|
void ensureCapacity(int capacity) throws CircularBufferException {
|
||||||
|
int available = available();
|
||||||
|
int remaining = data.length - available;
|
||||||
|
// If capacity fits exactly in the remaining space, expand it; otherwise, wpos would reach rpos from the left.
|
||||||
|
if (remaining <= capacity) {
|
||||||
|
int neededSize = available + capacity + 1;
|
||||||
|
int nextSize = getNextSize(neededSize);
|
||||||
|
if (nextSize < neededSize) {
|
||||||
|
throw new CircularBufferException("Attempted overflow");
|
||||||
|
}
|
||||||
|
byte[] tmp = new byte[nextSize];
|
||||||
|
// Copy data to the beginning of the new array.
|
||||||
|
if (wpos >= rpos) {
|
||||||
|
System.arraycopy(data, rpos, tmp, 0, available);
|
||||||
|
wpos -= rpos; // wpos must be relative to the new rpos, which will be 0
|
||||||
|
} else {
|
||||||
|
int tail = data.length - rpos;
|
||||||
|
System.arraycopy(data, rpos, tmp, 0, tail); // segment right of rpos
|
||||||
|
System.arraycopy(data, 0, tmp, tail, wpos); // segment left of wpos
|
||||||
|
wpos += tail; // wpos must be relative to the new rpos, which will be 0
|
||||||
|
}
|
||||||
|
rpos = 0;
|
||||||
|
data = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads data from this buffer into the provided array.
|
||||||
|
*/
|
||||||
|
public void readRawBytes(byte[] destination, int offset, int length) throws CircularBufferException {
|
||||||
|
ensureAvailable(length);
|
||||||
|
|
||||||
|
int rposNext = rpos + length;
|
||||||
|
if (rposNext <= data.length) {
|
||||||
|
System.arraycopy(data, rpos, destination, offset, length);
|
||||||
|
} else {
|
||||||
|
int tail = data.length - rpos;
|
||||||
|
System.arraycopy(data, rpos, destination, offset, tail); // segment right of rpos
|
||||||
|
rposNext = length - tail; // rpos wraps around the end of the buffer
|
||||||
|
System.arraycopy(data, 0, destination, offset + tail, rposNext); // remainder
|
||||||
|
}
|
||||||
|
// This can make rpos equal data.length, which has the same effect as wpos being 0.
|
||||||
|
rpos = rposNext;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes data to this buffer from the provided array.
|
||||||
|
*/
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public T putRawBytes(byte[] source, int offset, int length) throws CircularBufferException {
|
||||||
|
ensureCapacity(length);
|
||||||
|
|
||||||
|
int wposNext = wpos + length;
|
||||||
|
if (wposNext <= data.length) {
|
||||||
|
System.arraycopy(source, offset, data, wpos, length);
|
||||||
|
} else {
|
||||||
|
int tail = data.length - wpos;
|
||||||
|
System.arraycopy(source, offset, data, wpos, tail); // segment right of wpos
|
||||||
|
wposNext = length - tail; // wpos wraps around the end of the buffer
|
||||||
|
System.arraycopy(source, offset + tail, data, 0, wposNext); // remainder
|
||||||
|
}
|
||||||
|
// This can make wpos equal data.length, which has the same effect as wpos being 0.
|
||||||
|
wpos = wposNext;
|
||||||
|
|
||||||
|
return (T) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used only for testing.
|
||||||
|
int length() {
|
||||||
|
return data.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "CircularBuffer [rpos=" + rpos + ", wpos=" + wpos + ", size=" + data.length + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -164,8 +164,7 @@ public abstract class AbstractChannel
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void handle(Message msg, SSHPacket buf)
|
public void handle(Message msg, SSHPacket buf) throws SSHException {
|
||||||
throws ConnectionException, TransportException {
|
|
||||||
switch (msg) {
|
switch (msg) {
|
||||||
|
|
||||||
case CHANNEL_DATA:
|
case CHANNEL_DATA:
|
||||||
@@ -354,7 +353,7 @@ public abstract class AbstractChannel
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected void gotExtendedData(SSHPacket buf)
|
protected void gotExtendedData(SSHPacket buf)
|
||||||
throws ConnectionException, TransportException {
|
throws SSHException {
|
||||||
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
|
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
|
||||||
"Extended data not supported on " + type + " channel");
|
"Extended data not supported on " + type + " channel");
|
||||||
}
|
}
|
||||||
@@ -375,7 +374,7 @@ public abstract class AbstractChannel
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected void receiveInto(ChannelInputStream stream, SSHPacket buf)
|
protected void receiveInto(ChannelInputStream stream, SSHPacket buf)
|
||||||
throws ConnectionException, TransportException {
|
throws SSHException {
|
||||||
final int len;
|
final int len;
|
||||||
try {
|
try {
|
||||||
len = buf.readUInt32AsInt();
|
len = buf.readUInt32AsInt();
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ public final class ChannelInputStream
|
|||||||
private final Channel chan;
|
private final Channel chan;
|
||||||
private final Transport trans;
|
private final Transport trans;
|
||||||
private final Window.Local win;
|
private final Window.Local win;
|
||||||
private final Buffer.PlainBuffer buf;
|
private final CircularBuffer.PlainCircularBuffer buf;
|
||||||
private final byte[] b = new byte[1];
|
private final byte[] b = new byte[1];
|
||||||
|
|
||||||
private boolean eof;
|
private boolean eof;
|
||||||
@@ -46,10 +46,11 @@ public final class ChannelInputStream
|
|||||||
|
|
||||||
public ChannelInputStream(Channel chan, Transport trans, Window.Local win) {
|
public ChannelInputStream(Channel chan, Transport trans, Window.Local win) {
|
||||||
this.chan = chan;
|
this.chan = chan;
|
||||||
log = chan.getLoggerFactory().getLogger(getClass());
|
this.log = chan.getLoggerFactory().getLogger(getClass());
|
||||||
this.trans = trans;
|
this.trans = trans;
|
||||||
this.win = win;
|
this.win = win;
|
||||||
buf = new Buffer.PlainBuffer(chan.getLocalMaxPacketSize());
|
this.buf = new CircularBuffer.PlainCircularBuffer(
|
||||||
|
chan.getLocalMaxPacketSize(), trans.getConfig().getMaxCircularBufferSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -113,48 +114,44 @@ public final class ChannelInputStream
|
|||||||
len = buf.available();
|
len = buf.available();
|
||||||
}
|
}
|
||||||
buf.readRawBytes(b, off, len);
|
buf.readRawBytes(b, off, len);
|
||||||
if (buf.rpos() > win.getMaxPacketSize() && buf.available() == 0) {
|
|
||||||
buf.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!chan.getAutoExpand()) {
|
if (!chan.getAutoExpand()) {
|
||||||
checkWindow();
|
checkWindow();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return len;
|
return len;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void receive(byte[] data, int offset, int len)
|
public void receive(byte[] data, int offset, int len) throws SSHException {
|
||||||
throws ConnectionException, TransportException {
|
|
||||||
if (eof) {
|
if (eof) {
|
||||||
throw new ConnectionException("Getting data on EOF'ed stream");
|
throw new ConnectionException("Getting data on EOF'ed stream");
|
||||||
}
|
}
|
||||||
synchronized (buf) {
|
synchronized (buf) {
|
||||||
buf.putRawBytes(data, offset, len);
|
buf.putRawBytes(data, offset, len);
|
||||||
buf.notifyAll();
|
buf.notifyAll();
|
||||||
}
|
// Potential fix for #203 (window consumed below 0).
|
||||||
// Potential fix for #203 (window consumed below 0).
|
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
|
||||||
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
|
// And the window has not expanded yet.
|
||||||
// And the window has not expanded yet.
|
|
||||||
synchronized (win) {
|
|
||||||
win.consume(len);
|
win.consume(len);
|
||||||
}
|
if (chan.getAutoExpand()) {
|
||||||
if (chan.getAutoExpand()) {
|
checkWindow();
|
||||||
checkWindow();
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void checkWindow()
|
private void checkWindow() throws TransportException {
|
||||||
throws TransportException {
|
/*
|
||||||
synchronized (win) {
|
* Window must fit in remaining buffer capacity. We already expect win.size() amount of data to arrive. The
|
||||||
final long adjustment = win.neededAdjustment();
|
* difference between that and the remaining capacity is the maximum adjustment we can make to the window.
|
||||||
if (adjustment > 0) {
|
*/
|
||||||
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
|
final long maxAdjustment = buf.maxPossibleRemainingCapacity() - win.getSize();
|
||||||
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
|
final long adjustment = Math.min(win.neededAdjustment(), maxAdjustment);
|
||||||
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
|
if (adjustment > 0) {
|
||||||
win.expand(adjustment);
|
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
|
||||||
}
|
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
|
||||||
|
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
|
||||||
|
win.expand(adjustment);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -210,7 +210,7 @@ public class SessionChannel
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void gotExtendedData(SSHPacket buf)
|
protected void gotExtendedData(SSHPacket buf)
|
||||||
throws ConnectionException, TransportException {
|
throws SSHException {
|
||||||
try {
|
try {
|
||||||
final int dataTypeCode = buf.readUInt32AsInt();
|
final int dataTypeCode = buf.readUInt32AsInt();
|
||||||
if (dataTypeCode == 1)
|
if (dataTypeCode == 1)
|
||||||
|
|||||||
@@ -0,0 +1,188 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C)2009 - SSHJ Contributors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package com.hierynomus.sshj.connection.channel.forwarded;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
import java.net.ServerSocket;
|
||||||
|
import java.net.Socket;
|
||||||
|
import net.schmizz.sshj.DefaultConfig;
|
||||||
|
import net.schmizz.sshj.SSHClient;
|
||||||
|
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.Forward;
|
||||||
|
import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener;
|
||||||
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
public class RemotePFPerformanceTest {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(RemotePFPerformanceTest.class);
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Disabled
|
||||||
|
public void startPF() throws IOException, InterruptedException {
|
||||||
|
DefaultConfig config = new DefaultConfig();
|
||||||
|
config.setMaxCircularBufferSize(16 * 1024 * 1024);
|
||||||
|
SSHClient client = new SSHClient(config);
|
||||||
|
client.loadKnownHosts();
|
||||||
|
client.addHostKeyVerifier("5c:0c:8e:9d:1c:50:a9:ba:a7:05:f6:b1:2b:0b:5f:ba");
|
||||||
|
|
||||||
|
client.getConnection().getKeepAlive().setKeepAliveInterval(5);
|
||||||
|
client.connect("localhost");
|
||||||
|
client.getConnection().getKeepAlive().setKeepAliveInterval(5);
|
||||||
|
|
||||||
|
Object consumerReadyMonitor = new Object();
|
||||||
|
ConsumerThread consumerThread = new ConsumerThread(consumerReadyMonitor);
|
||||||
|
ProducerThread producerThread = new ProducerThread();
|
||||||
|
try {
|
||||||
|
|
||||||
|
client.authPassword(System.getenv().get("USERNAME"), System.getenv().get("PASSWORD"));
|
||||||
|
|
||||||
|
/*
|
||||||
|
* We make _server_ listen on port 8080, which forwards all connections to us as a channel, and we further
|
||||||
|
* forward all such channels to google.com:80
|
||||||
|
*/
|
||||||
|
client.getRemotePortForwarder().bind(
|
||||||
|
// where the server should listen
|
||||||
|
new Forward(8888),
|
||||||
|
// what we do with incoming connections that are forwarded to us
|
||||||
|
new SocketForwardingConnectListener(new InetSocketAddress("localhost", 12345)));
|
||||||
|
|
||||||
|
consumerThread.start();
|
||||||
|
synchronized (consumerReadyMonitor) {
|
||||||
|
consumerReadyMonitor.wait();
|
||||||
|
}
|
||||||
|
producerThread.start();
|
||||||
|
|
||||||
|
// Wait for consumer to finish receiving data.
|
||||||
|
synchronized (consumerReadyMonitor) {
|
||||||
|
consumerReadyMonitor.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
} finally {
|
||||||
|
producerThread.interrupt();
|
||||||
|
consumerThread.interrupt();
|
||||||
|
client.disconnect();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class ConsumerThread extends Thread {
|
||||||
|
private final Object consumerReadyMonitor;
|
||||||
|
|
||||||
|
private ConsumerThread(Object consumerReadyMonitor) {
|
||||||
|
super("Consumer");
|
||||||
|
this.consumerReadyMonitor = consumerReadyMonitor;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try (ServerSocket serverSocket = new ServerSocket(12345)) {
|
||||||
|
synchronized (consumerReadyMonitor) {
|
||||||
|
consumerReadyMonitor.notifyAll();
|
||||||
|
}
|
||||||
|
try (Socket acceptedSocket = serverSocket.accept()) {
|
||||||
|
InputStream in = acceptedSocket.getInputStream();
|
||||||
|
int numRead;
|
||||||
|
byte[] buf = new byte[40000];
|
||||||
|
//byte[] buf = new byte[255 * 4 * 1000];
|
||||||
|
byte expectedNext = 1;
|
||||||
|
while ((numRead = in.read(buf)) != 0) {
|
||||||
|
if (Thread.interrupted()) {
|
||||||
|
log.info("Consumer thread interrupted");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
log.info(String.format("Read %d characters; values from %d to %d", numRead, buf[0], buf[numRead - 1]));
|
||||||
|
if (buf[numRead - 1] == 0) {
|
||||||
|
verifyData(buf, numRead - 1, expectedNext);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
expectedNext = verifyData(buf, numRead, expectedNext);
|
||||||
|
// Slow down consumer to test buffering.
|
||||||
|
Thread.sleep(Long.parseLong(System.getenv().get("DELAY_MS")));
|
||||||
|
}
|
||||||
|
log.info("Consumer read end of stream value: " + numRead);
|
||||||
|
synchronized (consumerReadyMonitor) {
|
||||||
|
consumerReadyMonitor.notifyAll();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
synchronized (consumerReadyMonitor) {
|
||||||
|
consumerReadyMonitor.notifyAll();
|
||||||
|
}
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private byte verifyData(byte[] buf, int numRead, byte expectedNext) {
|
||||||
|
for (int i = 0; i < numRead; ++i) {
|
||||||
|
if (buf[i] != expectedNext) {
|
||||||
|
fail("Expected buf[" + i + "]=" + buf[i] + " to be " + expectedNext);
|
||||||
|
}
|
||||||
|
if (++expectedNext == 0) {
|
||||||
|
expectedNext = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return expectedNext;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class ProducerThread extends Thread {
|
||||||
|
private ProducerThread() {
|
||||||
|
super("Producer");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try (Socket clientSocket = new Socket("127.0.0.1", 8888);
|
||||||
|
OutputStream writer = clientSocket.getOutputStream()) {
|
||||||
|
byte[] buf = getData();
|
||||||
|
assertEquals(buf[0], 1);
|
||||||
|
assertEquals(buf[buf.length - 1], -1);
|
||||||
|
for (int i = 0; i < 1000; ++i) {
|
||||||
|
writer.write(buf);
|
||||||
|
if (Thread.interrupted()) {
|
||||||
|
log.info("Consumer thread interrupted");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
log.info(String.format("Wrote %d characters; values from %d to %d", buf.length, buf[0], buf[buf.length - 1]));
|
||||||
|
}
|
||||||
|
writer.write(0); // end of stream value
|
||||||
|
log.info("Producer finished sending data");
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private byte[] getData() {
|
||||||
|
byte[] buf = new byte[255 * 4 * 1000];
|
||||||
|
byte nextValue = 1;
|
||||||
|
for (int i = 0; i < buf.length; ++i) {
|
||||||
|
buf[i] = nextValue++;
|
||||||
|
// reserve 0 for end of stream
|
||||||
|
if (nextValue == 0) {
|
||||||
|
nextValue = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
221
src/test/java/net/schmizz/sshj/common/CircularBufferTest.java
Normal file
221
src/test/java/net/schmizz/sshj/common/CircularBufferTest.java
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C)2009 - SSHJ Contributors
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package net.schmizz.sshj.common;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
|
import net.schmizz.sshj.common.CircularBuffer.CircularBufferException;
|
||||||
|
import net.schmizz.sshj.common.CircularBuffer.PlainCircularBuffer;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
public class CircularBufferTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldStoreDataCorrectlyWithoutResizing() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(256, Integer.MAX_VALUE);
|
||||||
|
|
||||||
|
byte[] dataToWrite = getData(500);
|
||||||
|
buffer.putRawBytes(dataToWrite, 0, 100);
|
||||||
|
buffer.putRawBytes(dataToWrite, 100, 100);
|
||||||
|
|
||||||
|
byte[] dataToRead = new byte[500];
|
||||||
|
buffer.readRawBytes(dataToRead, 0, 80);
|
||||||
|
buffer.readRawBytes(dataToRead, 80, 80);
|
||||||
|
|
||||||
|
buffer.putRawBytes(dataToWrite, 200, 100);
|
||||||
|
buffer.readRawBytes(dataToRead, 160, 80);
|
||||||
|
|
||||||
|
buffer.putRawBytes(dataToWrite, 300, 100);
|
||||||
|
buffer.readRawBytes(dataToRead, 240, 80);
|
||||||
|
|
||||||
|
buffer.putRawBytes(dataToWrite, 400, 100);
|
||||||
|
buffer.readRawBytes(dataToRead, 320, 80);
|
||||||
|
buffer.readRawBytes(dataToRead, 400, 100);
|
||||||
|
|
||||||
|
assertEquals(256, buffer.length());
|
||||||
|
assertArrayEquals(dataToWrite, dataToRead);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldStoreDataCorrectlyWithResizing() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||||
|
|
||||||
|
byte[] dataToWrite = getData(500);
|
||||||
|
buffer.putRawBytes(dataToWrite, 0, 100);
|
||||||
|
buffer.putRawBytes(dataToWrite, 100, 100);
|
||||||
|
|
||||||
|
byte[] dataToRead = new byte[500];
|
||||||
|
buffer.readRawBytes(dataToRead, 0, 80);
|
||||||
|
buffer.readRawBytes(dataToRead, 80, 80);
|
||||||
|
|
||||||
|
buffer.putRawBytes(dataToWrite, 200, 100);
|
||||||
|
buffer.readRawBytes(dataToRead, 160, 80);
|
||||||
|
|
||||||
|
buffer.putRawBytes(dataToWrite, 300, 100);
|
||||||
|
buffer.readRawBytes(dataToRead, 240, 80);
|
||||||
|
|
||||||
|
buffer.putRawBytes(dataToWrite, 400, 100);
|
||||||
|
buffer.readRawBytes(dataToRead, 320, 80);
|
||||||
|
|
||||||
|
buffer.readRawBytes(dataToRead, 400, 100);
|
||||||
|
|
||||||
|
assertEquals(256, buffer.length());
|
||||||
|
assertArrayEquals(dataToWrite, dataToRead);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldNotOverflowWhenWritingFullLengthToTheEnd() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||||
|
|
||||||
|
byte[] dataToWrite = getData(64);
|
||||||
|
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should write to the end
|
||||||
|
|
||||||
|
assertEquals(64, buffer.available());
|
||||||
|
assertEquals(64 * 2, buffer.length());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldNotOverflowWhenWritingFullLengthWrapsAround() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||||
|
|
||||||
|
// Move 1 byte forward.
|
||||||
|
buffer.putRawBytes(new byte[1], 0, 1);
|
||||||
|
buffer.readRawBytes(new byte[1], 0, 1);
|
||||||
|
|
||||||
|
// Force writes to wrap around.
|
||||||
|
byte[] dataToWrite = getData(64);
|
||||||
|
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should wrap around the end
|
||||||
|
|
||||||
|
assertEquals(64, buffer.available());
|
||||||
|
assertEquals(64 * 2, buffer.length());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldAllowWritingMaxCapacityFromZero() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
|
||||||
|
|
||||||
|
// Max capacity is always one less than the buffer size.
|
||||||
|
int maxCapacity = buffer.maxPossibleRemainingCapacity();
|
||||||
|
assertEquals(buffer.length() - 1, maxCapacity);
|
||||||
|
|
||||||
|
byte[] dataToWrite = getData(maxCapacity);
|
||||||
|
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
|
||||||
|
|
||||||
|
assertEquals(dataToWrite.length, buffer.available());
|
||||||
|
assertEquals(64, buffer.length());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldAllowWritingMaxRemainingCapacity() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
|
||||||
|
|
||||||
|
final int initiallyWritten = 10;
|
||||||
|
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||||
|
|
||||||
|
// Max remaining capacity is always one less than the remaining buffer size.
|
||||||
|
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
|
||||||
|
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
|
||||||
|
|
||||||
|
byte[] dataToWrite = getData(maxRemainingCapacity);
|
||||||
|
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
|
||||||
|
|
||||||
|
assertEquals(dataToWrite.length + initiallyWritten, buffer.available());
|
||||||
|
assertEquals(64, buffer.length());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldAllowWritingMaxRemainingCapacityAfterWrappingAround() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
|
||||||
|
|
||||||
|
// Cause the internal write pointer to wrap around and be left of the read pointer.
|
||||||
|
final int initiallyWritten = 40;
|
||||||
|
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||||
|
buffer.readRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||||
|
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||||
|
|
||||||
|
// Max remaining capacity is always one less than the remaining buffer size.
|
||||||
|
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
|
||||||
|
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
|
||||||
|
|
||||||
|
byte[] dataToWrite = getData(maxRemainingCapacity);
|
||||||
|
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
|
||||||
|
|
||||||
|
assertEquals(dataToWrite.length + initiallyWritten, buffer.available());
|
||||||
|
assertEquals(64, buffer.length());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldOverflowWhenWritingOverMaxRemainingCapacity() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
|
||||||
|
|
||||||
|
final int initiallyWritten = 10;
|
||||||
|
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||||
|
|
||||||
|
// Max remaining capacity is always one less than the remaining buffer size.
|
||||||
|
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
|
||||||
|
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
|
||||||
|
|
||||||
|
byte[] dataToWrite = getData(maxRemainingCapacity + 1);
|
||||||
|
assertThrows(CircularBufferException.class, () -> buffer.putRawBytes(dataToWrite, 0, dataToWrite.length));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldThrowWhenReadingEmptyBuffer() {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||||
|
assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[1], 0, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldThrowWhenReadingMoreThanAvailable() throws CircularBufferException {
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||||
|
buffer.putRawBytes(new byte[1], 0, 1);
|
||||||
|
assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[2], 0, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldThrowOnAboveMaximumInitialSize() {
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(65, 64));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldThrowOnMaximumInitialSize() {
|
||||||
|
assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(Integer.MAX_VALUE, 64));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldAllowFullCapacity() throws CircularBufferException {
|
||||||
|
int maxSize = 1024;
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize);
|
||||||
|
buffer.ensureCapacity(maxSize - 1);
|
||||||
|
assertEquals(maxSize - 1, buffer.maxPossibleRemainingCapacity());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void shouldThrowOnTooLargeRequestedCapacity() {
|
||||||
|
int maxSize = 1024;
|
||||||
|
PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize);
|
||||||
|
assertThrows(CircularBufferException.class, () -> buffer.ensureCapacity(maxSize));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static byte[] getData(int length) {
|
||||||
|
byte[] data = new byte[length];
|
||||||
|
byte nextValue = 0;
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
data[i] = nextValue++;
|
||||||
|
}
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user