mirror of
https://github.com/hierynomus/sshj.git
synced 2025-12-06 15:20:54 +03:00
* Fix for Remote port forwarding buffers can grow without limits (issue #658) * Update test classes to use JUnit 5 * Fix MB computation
This commit is contained in:
@@ -200,4 +200,8 @@ public interface Config {
|
||||
* See {@link #isVerifyHostKeyCertificates()}.
|
||||
*/
|
||||
void setVerifyHostKeyCertificates(boolean value);
|
||||
|
||||
int getMaxCircularBufferSize();
|
||||
|
||||
void setMaxCircularBufferSize(int maxCircularBufferSize);
|
||||
}
|
||||
|
||||
@@ -49,6 +49,8 @@ public class ConfigImpl
|
||||
private boolean waitForServerIdentBeforeSendingClientIdent = false;
|
||||
private LoggerFactory loggerFactory;
|
||||
private boolean verifyHostKeyCertificates = true;
|
||||
// HF-982: default to 16MB buffers.
|
||||
private int maxCircularBufferSize = 16 * 1024 * 1024;
|
||||
|
||||
@Override
|
||||
public List<Factory.Named<Cipher>> getCipherFactories() {
|
||||
@@ -175,6 +177,16 @@ public class ConfigImpl
|
||||
return loggerFactory;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxCircularBufferSize() {
|
||||
return maxCircularBufferSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setMaxCircularBufferSize(int maxCircularBufferSize) {
|
||||
this.maxCircularBufferSize = maxCircularBufferSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLoggerFactory(LoggerFactory loggerFactory) {
|
||||
this.loggerFactory = loggerFactory;
|
||||
|
||||
194
src/main/java/net/schmizz/sshj/common/CircularBuffer.java
Normal file
194
src/main/java/net/schmizz/sshj/common/CircularBuffer.java
Normal file
@@ -0,0 +1,194 @@
|
||||
/*
|
||||
* Copyright (C)2009 - SSHJ Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package net.schmizz.sshj.common;
|
||||
|
||||
public class CircularBuffer<T extends CircularBuffer<T>> {
|
||||
|
||||
public static class CircularBufferException
|
||||
extends SSHException {
|
||||
|
||||
public CircularBufferException(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
|
||||
public static final class PlainCircularBuffer
|
||||
extends CircularBuffer<PlainCircularBuffer> {
|
||||
|
||||
public PlainCircularBuffer(int size, int maxSize) {
|
||||
super(size, maxSize);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Maximum size of the internal array (one plus the maximum capacity of the buffer).
|
||||
*/
|
||||
private final int maxSize;
|
||||
/**
|
||||
* Internal array for the data. All bytes minus one can be used to avoid empty vs full ambiguity when rpos == wpos.
|
||||
*/
|
||||
private byte[] data;
|
||||
/**
|
||||
* Next read position. Wraps around the end of the internal array. When it reaches wpos, the buffer becomes empty.
|
||||
* Can take the value data.length, which is equivalent to 0.
|
||||
*/
|
||||
private int rpos;
|
||||
/**
|
||||
* Next write position. Wraps around the end of the internal array. If it is equal to rpos, then the buffer is
|
||||
* empty; the code does not allow wpos to reach rpos from the left. This implies that the buffer can store up to
|
||||
* data.length - 1 bytes. Can take the value data.length, which is equivalent to 0.
|
||||
*/
|
||||
private int wpos;
|
||||
|
||||
/**
|
||||
* Determines the size to which to grow the internal array.
|
||||
*/
|
||||
private int getNextSize(int currentSize) {
|
||||
// Use next power of 2.
|
||||
int nextSize = 1;
|
||||
while (nextSize < currentSize) {
|
||||
nextSize <<= 1;
|
||||
if (nextSize <= 0) {
|
||||
return maxSize;
|
||||
}
|
||||
}
|
||||
return Math.min(nextSize, maxSize); // limit to max size
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new circular buffer of the given size. The capacity of the buffer is one less than the size/
|
||||
*/
|
||||
public CircularBuffer(int size, int maxSize) {
|
||||
this.maxSize = maxSize;
|
||||
if (size > maxSize) {
|
||||
throw new IllegalArgumentException(
|
||||
String.format("Initial requested size %d larger than maximum size %d", size, maxSize));
|
||||
}
|
||||
int initialSize = getNextSize(size);
|
||||
this.data = new byte[initialSize];
|
||||
this.rpos = 0;
|
||||
this.wpos = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Data available in the buffer for reading.
|
||||
*/
|
||||
public int available() {
|
||||
int available = wpos - rpos;
|
||||
return available >= 0 ? available : available + data.length; // adjust if wpos is left of rpos
|
||||
}
|
||||
|
||||
private void ensureAvailable(int a)
|
||||
throws CircularBufferException {
|
||||
if (available() < a) {
|
||||
throw new CircularBufferException("Underflow");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns how many more bytes this buffer can receive.
|
||||
*/
|
||||
public int maxPossibleRemainingCapacity() {
|
||||
// Remaining capacity is one less than remaining space to ensure that wpos does not reach rpos from the left.
|
||||
int remaining = rpos - wpos - 1;
|
||||
if (remaining < 0) {
|
||||
remaining += data.length; // adjust if rpos is left of wpos
|
||||
}
|
||||
// Add the maximum amount the internal array can grow.
|
||||
return remaining + maxSize - data.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* If the internal array does not have room for "capacity" more bytes, resizes the array to make that room.
|
||||
*/
|
||||
void ensureCapacity(int capacity) throws CircularBufferException {
|
||||
int available = available();
|
||||
int remaining = data.length - available;
|
||||
// If capacity fits exactly in the remaining space, expand it; otherwise, wpos would reach rpos from the left.
|
||||
if (remaining <= capacity) {
|
||||
int neededSize = available + capacity + 1;
|
||||
int nextSize = getNextSize(neededSize);
|
||||
if (nextSize < neededSize) {
|
||||
throw new CircularBufferException("Attempted overflow");
|
||||
}
|
||||
byte[] tmp = new byte[nextSize];
|
||||
// Copy data to the beginning of the new array.
|
||||
if (wpos >= rpos) {
|
||||
System.arraycopy(data, rpos, tmp, 0, available);
|
||||
wpos -= rpos; // wpos must be relative to the new rpos, which will be 0
|
||||
} else {
|
||||
int tail = data.length - rpos;
|
||||
System.arraycopy(data, rpos, tmp, 0, tail); // segment right of rpos
|
||||
System.arraycopy(data, 0, tmp, tail, wpos); // segment left of wpos
|
||||
wpos += tail; // wpos must be relative to the new rpos, which will be 0
|
||||
}
|
||||
rpos = 0;
|
||||
data = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads data from this buffer into the provided array.
|
||||
*/
|
||||
public void readRawBytes(byte[] destination, int offset, int length) throws CircularBufferException {
|
||||
ensureAvailable(length);
|
||||
|
||||
int rposNext = rpos + length;
|
||||
if (rposNext <= data.length) {
|
||||
System.arraycopy(data, rpos, destination, offset, length);
|
||||
} else {
|
||||
int tail = data.length - rpos;
|
||||
System.arraycopy(data, rpos, destination, offset, tail); // segment right of rpos
|
||||
rposNext = length - tail; // rpos wraps around the end of the buffer
|
||||
System.arraycopy(data, 0, destination, offset + tail, rposNext); // remainder
|
||||
}
|
||||
// This can make rpos equal data.length, which has the same effect as wpos being 0.
|
||||
rpos = rposNext;
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes data to this buffer from the provided array.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public T putRawBytes(byte[] source, int offset, int length) throws CircularBufferException {
|
||||
ensureCapacity(length);
|
||||
|
||||
int wposNext = wpos + length;
|
||||
if (wposNext <= data.length) {
|
||||
System.arraycopy(source, offset, data, wpos, length);
|
||||
} else {
|
||||
int tail = data.length - wpos;
|
||||
System.arraycopy(source, offset, data, wpos, tail); // segment right of wpos
|
||||
wposNext = length - tail; // wpos wraps around the end of the buffer
|
||||
System.arraycopy(source, offset + tail, data, 0, wposNext); // remainder
|
||||
}
|
||||
// This can make wpos equal data.length, which has the same effect as wpos being 0.
|
||||
wpos = wposNext;
|
||||
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
// Used only for testing.
|
||||
int length() {
|
||||
return data.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CircularBuffer [rpos=" + rpos + ", wpos=" + wpos + ", size=" + data.length + "]";
|
||||
}
|
||||
|
||||
}
|
||||
@@ -164,8 +164,7 @@ public abstract class AbstractChannel
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handle(Message msg, SSHPacket buf)
|
||||
throws ConnectionException, TransportException {
|
||||
public void handle(Message msg, SSHPacket buf) throws SSHException {
|
||||
switch (msg) {
|
||||
|
||||
case CHANNEL_DATA:
|
||||
@@ -354,7 +353,7 @@ public abstract class AbstractChannel
|
||||
}
|
||||
|
||||
protected void gotExtendedData(SSHPacket buf)
|
||||
throws ConnectionException, TransportException {
|
||||
throws SSHException {
|
||||
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
|
||||
"Extended data not supported on " + type + " channel");
|
||||
}
|
||||
@@ -375,7 +374,7 @@ public abstract class AbstractChannel
|
||||
}
|
||||
|
||||
protected void receiveInto(ChannelInputStream stream, SSHPacket buf)
|
||||
throws ConnectionException, TransportException {
|
||||
throws SSHException {
|
||||
final int len;
|
||||
try {
|
||||
len = buf.readUInt32AsInt();
|
||||
|
||||
@@ -38,7 +38,7 @@ public final class ChannelInputStream
|
||||
private final Channel chan;
|
||||
private final Transport trans;
|
||||
private final Window.Local win;
|
||||
private final Buffer.PlainBuffer buf;
|
||||
private final CircularBuffer.PlainCircularBuffer buf;
|
||||
private final byte[] b = new byte[1];
|
||||
|
||||
private boolean eof;
|
||||
@@ -46,10 +46,11 @@ public final class ChannelInputStream
|
||||
|
||||
public ChannelInputStream(Channel chan, Transport trans, Window.Local win) {
|
||||
this.chan = chan;
|
||||
log = chan.getLoggerFactory().getLogger(getClass());
|
||||
this.log = chan.getLoggerFactory().getLogger(getClass());
|
||||
this.trans = trans;
|
||||
this.win = win;
|
||||
buf = new Buffer.PlainBuffer(chan.getLocalMaxPacketSize());
|
||||
this.buf = new CircularBuffer.PlainCircularBuffer(
|
||||
chan.getLocalMaxPacketSize(), trans.getConfig().getMaxCircularBufferSize());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -113,48 +114,44 @@ public final class ChannelInputStream
|
||||
len = buf.available();
|
||||
}
|
||||
buf.readRawBytes(b, off, len);
|
||||
if (buf.rpos() > win.getMaxPacketSize() && buf.available() == 0) {
|
||||
buf.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!chan.getAutoExpand()) {
|
||||
checkWindow();
|
||||
if (!chan.getAutoExpand()) {
|
||||
checkWindow();
|
||||
}
|
||||
}
|
||||
|
||||
return len;
|
||||
}
|
||||
|
||||
public void receive(byte[] data, int offset, int len)
|
||||
throws ConnectionException, TransportException {
|
||||
public void receive(byte[] data, int offset, int len) throws SSHException {
|
||||
if (eof) {
|
||||
throw new ConnectionException("Getting data on EOF'ed stream");
|
||||
}
|
||||
synchronized (buf) {
|
||||
buf.putRawBytes(data, offset, len);
|
||||
buf.notifyAll();
|
||||
}
|
||||
// Potential fix for #203 (window consumed below 0).
|
||||
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
|
||||
// And the window has not expanded yet.
|
||||
synchronized (win) {
|
||||
// Potential fix for #203 (window consumed below 0).
|
||||
// This seems to be a race condition if we receive more data, while we're already sending a SSH_MSG_CHANNEL_WINDOW_ADJUST
|
||||
// And the window has not expanded yet.
|
||||
win.consume(len);
|
||||
}
|
||||
if (chan.getAutoExpand()) {
|
||||
checkWindow();
|
||||
if (chan.getAutoExpand()) {
|
||||
checkWindow();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void checkWindow()
|
||||
throws TransportException {
|
||||
synchronized (win) {
|
||||
final long adjustment = win.neededAdjustment();
|
||||
if (adjustment > 0) {
|
||||
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
|
||||
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
|
||||
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
|
||||
win.expand(adjustment);
|
||||
}
|
||||
private void checkWindow() throws TransportException {
|
||||
/*
|
||||
* Window must fit in remaining buffer capacity. We already expect win.size() amount of data to arrive. The
|
||||
* difference between that and the remaining capacity is the maximum adjustment we can make to the window.
|
||||
*/
|
||||
final long maxAdjustment = buf.maxPossibleRemainingCapacity() - win.getSize();
|
||||
final long adjustment = Math.min(win.neededAdjustment(), maxAdjustment);
|
||||
if (adjustment > 0) {
|
||||
log.debug("Sending SSH_MSG_CHANNEL_WINDOW_ADJUST to #{} for {} bytes", chan.getRecipient(), adjustment);
|
||||
trans.write(new SSHPacket(Message.CHANNEL_WINDOW_ADJUST)
|
||||
.putUInt32FromInt(chan.getRecipient()).putUInt32(adjustment));
|
||||
win.expand(adjustment);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -210,7 +210,7 @@ public class SessionChannel
|
||||
|
||||
@Override
|
||||
protected void gotExtendedData(SSHPacket buf)
|
||||
throws ConnectionException, TransportException {
|
||||
throws SSHException {
|
||||
try {
|
||||
final int dataTypeCode = buf.readUInt32AsInt();
|
||||
if (dataTypeCode == 1)
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
/*
|
||||
* Copyright (C)2009 - SSHJ Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.hierynomus.sshj.connection.channel.forwarded;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.net.ServerSocket;
|
||||
import java.net.Socket;
|
||||
import net.schmizz.sshj.DefaultConfig;
|
||||
import net.schmizz.sshj.SSHClient;
|
||||
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder.Forward;
|
||||
import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class RemotePFPerformanceTest {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(RemotePFPerformanceTest.class);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void startPF() throws IOException, InterruptedException {
|
||||
DefaultConfig config = new DefaultConfig();
|
||||
config.setMaxCircularBufferSize(16 * 1024 * 1024);
|
||||
SSHClient client = new SSHClient(config);
|
||||
client.loadKnownHosts();
|
||||
client.addHostKeyVerifier("5c:0c:8e:9d:1c:50:a9:ba:a7:05:f6:b1:2b:0b:5f:ba");
|
||||
|
||||
client.getConnection().getKeepAlive().setKeepAliveInterval(5);
|
||||
client.connect("localhost");
|
||||
client.getConnection().getKeepAlive().setKeepAliveInterval(5);
|
||||
|
||||
Object consumerReadyMonitor = new Object();
|
||||
ConsumerThread consumerThread = new ConsumerThread(consumerReadyMonitor);
|
||||
ProducerThread producerThread = new ProducerThread();
|
||||
try {
|
||||
|
||||
client.authPassword(System.getenv().get("USERNAME"), System.getenv().get("PASSWORD"));
|
||||
|
||||
/*
|
||||
* We make _server_ listen on port 8080, which forwards all connections to us as a channel, and we further
|
||||
* forward all such channels to google.com:80
|
||||
*/
|
||||
client.getRemotePortForwarder().bind(
|
||||
// where the server should listen
|
||||
new Forward(8888),
|
||||
// what we do with incoming connections that are forwarded to us
|
||||
new SocketForwardingConnectListener(new InetSocketAddress("localhost", 12345)));
|
||||
|
||||
consumerThread.start();
|
||||
synchronized (consumerReadyMonitor) {
|
||||
consumerReadyMonitor.wait();
|
||||
}
|
||||
producerThread.start();
|
||||
|
||||
// Wait for consumer to finish receiving data.
|
||||
synchronized (consumerReadyMonitor) {
|
||||
consumerReadyMonitor.wait();
|
||||
}
|
||||
|
||||
} finally {
|
||||
producerThread.interrupt();
|
||||
consumerThread.interrupt();
|
||||
client.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
private static class ConsumerThread extends Thread {
|
||||
private final Object consumerReadyMonitor;
|
||||
|
||||
private ConsumerThread(Object consumerReadyMonitor) {
|
||||
super("Consumer");
|
||||
this.consumerReadyMonitor = consumerReadyMonitor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try (ServerSocket serverSocket = new ServerSocket(12345)) {
|
||||
synchronized (consumerReadyMonitor) {
|
||||
consumerReadyMonitor.notifyAll();
|
||||
}
|
||||
try (Socket acceptedSocket = serverSocket.accept()) {
|
||||
InputStream in = acceptedSocket.getInputStream();
|
||||
int numRead;
|
||||
byte[] buf = new byte[40000];
|
||||
//byte[] buf = new byte[255 * 4 * 1000];
|
||||
byte expectedNext = 1;
|
||||
while ((numRead = in.read(buf)) != 0) {
|
||||
if (Thread.interrupted()) {
|
||||
log.info("Consumer thread interrupted");
|
||||
return;
|
||||
}
|
||||
log.info(String.format("Read %d characters; values from %d to %d", numRead, buf[0], buf[numRead - 1]));
|
||||
if (buf[numRead - 1] == 0) {
|
||||
verifyData(buf, numRead - 1, expectedNext);
|
||||
break;
|
||||
}
|
||||
expectedNext = verifyData(buf, numRead, expectedNext);
|
||||
// Slow down consumer to test buffering.
|
||||
Thread.sleep(Long.parseLong(System.getenv().get("DELAY_MS")));
|
||||
}
|
||||
log.info("Consumer read end of stream value: " + numRead);
|
||||
synchronized (consumerReadyMonitor) {
|
||||
consumerReadyMonitor.notifyAll();
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
synchronized (consumerReadyMonitor) {
|
||||
consumerReadyMonitor.notifyAll();
|
||||
}
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
private byte verifyData(byte[] buf, int numRead, byte expectedNext) {
|
||||
for (int i = 0; i < numRead; ++i) {
|
||||
if (buf[i] != expectedNext) {
|
||||
fail("Expected buf[" + i + "]=" + buf[i] + " to be " + expectedNext);
|
||||
}
|
||||
if (++expectedNext == 0) {
|
||||
expectedNext = 1;
|
||||
}
|
||||
}
|
||||
return expectedNext;
|
||||
}
|
||||
}
|
||||
|
||||
private static class ProducerThread extends Thread {
|
||||
private ProducerThread() {
|
||||
super("Producer");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try (Socket clientSocket = new Socket("127.0.0.1", 8888);
|
||||
OutputStream writer = clientSocket.getOutputStream()) {
|
||||
byte[] buf = getData();
|
||||
assertEquals(buf[0], 1);
|
||||
assertEquals(buf[buf.length - 1], -1);
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
writer.write(buf);
|
||||
if (Thread.interrupted()) {
|
||||
log.info("Consumer thread interrupted");
|
||||
return;
|
||||
}
|
||||
log.info(String.format("Wrote %d characters; values from %d to %d", buf.length, buf[0], buf[buf.length - 1]));
|
||||
}
|
||||
writer.write(0); // end of stream value
|
||||
log.info("Producer finished sending data");
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] getData() {
|
||||
byte[] buf = new byte[255 * 4 * 1000];
|
||||
byte nextValue = 1;
|
||||
for (int i = 0; i < buf.length; ++i) {
|
||||
buf[i] = nextValue++;
|
||||
// reserve 0 for end of stream
|
||||
if (nextValue == 0) {
|
||||
nextValue = 1;
|
||||
}
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
221
src/test/java/net/schmizz/sshj/common/CircularBufferTest.java
Normal file
221
src/test/java/net/schmizz/sshj/common/CircularBufferTest.java
Normal file
@@ -0,0 +1,221 @@
|
||||
/*
|
||||
* Copyright (C)2009 - SSHJ Contributors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package net.schmizz.sshj.common;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import net.schmizz.sshj.common.CircularBuffer.CircularBufferException;
|
||||
import net.schmizz.sshj.common.CircularBuffer.PlainCircularBuffer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class CircularBufferTest {
|
||||
|
||||
@Test
|
||||
public void shouldStoreDataCorrectlyWithoutResizing() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(256, Integer.MAX_VALUE);
|
||||
|
||||
byte[] dataToWrite = getData(500);
|
||||
buffer.putRawBytes(dataToWrite, 0, 100);
|
||||
buffer.putRawBytes(dataToWrite, 100, 100);
|
||||
|
||||
byte[] dataToRead = new byte[500];
|
||||
buffer.readRawBytes(dataToRead, 0, 80);
|
||||
buffer.readRawBytes(dataToRead, 80, 80);
|
||||
|
||||
buffer.putRawBytes(dataToWrite, 200, 100);
|
||||
buffer.readRawBytes(dataToRead, 160, 80);
|
||||
|
||||
buffer.putRawBytes(dataToWrite, 300, 100);
|
||||
buffer.readRawBytes(dataToRead, 240, 80);
|
||||
|
||||
buffer.putRawBytes(dataToWrite, 400, 100);
|
||||
buffer.readRawBytes(dataToRead, 320, 80);
|
||||
buffer.readRawBytes(dataToRead, 400, 100);
|
||||
|
||||
assertEquals(256, buffer.length());
|
||||
assertArrayEquals(dataToWrite, dataToRead);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldStoreDataCorrectlyWithResizing() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||
|
||||
byte[] dataToWrite = getData(500);
|
||||
buffer.putRawBytes(dataToWrite, 0, 100);
|
||||
buffer.putRawBytes(dataToWrite, 100, 100);
|
||||
|
||||
byte[] dataToRead = new byte[500];
|
||||
buffer.readRawBytes(dataToRead, 0, 80);
|
||||
buffer.readRawBytes(dataToRead, 80, 80);
|
||||
|
||||
buffer.putRawBytes(dataToWrite, 200, 100);
|
||||
buffer.readRawBytes(dataToRead, 160, 80);
|
||||
|
||||
buffer.putRawBytes(dataToWrite, 300, 100);
|
||||
buffer.readRawBytes(dataToRead, 240, 80);
|
||||
|
||||
buffer.putRawBytes(dataToWrite, 400, 100);
|
||||
buffer.readRawBytes(dataToRead, 320, 80);
|
||||
|
||||
buffer.readRawBytes(dataToRead, 400, 100);
|
||||
|
||||
assertEquals(256, buffer.length());
|
||||
assertArrayEquals(dataToWrite, dataToRead);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldNotOverflowWhenWritingFullLengthToTheEnd() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||
|
||||
byte[] dataToWrite = getData(64);
|
||||
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should write to the end
|
||||
|
||||
assertEquals(64, buffer.available());
|
||||
assertEquals(64 * 2, buffer.length());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldNotOverflowWhenWritingFullLengthWrapsAround() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||
|
||||
// Move 1 byte forward.
|
||||
buffer.putRawBytes(new byte[1], 0, 1);
|
||||
buffer.readRawBytes(new byte[1], 0, 1);
|
||||
|
||||
// Force writes to wrap around.
|
||||
byte[] dataToWrite = getData(64);
|
||||
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length); // should wrap around the end
|
||||
|
||||
assertEquals(64, buffer.available());
|
||||
assertEquals(64 * 2, buffer.length());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldAllowWritingMaxCapacityFromZero() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
|
||||
|
||||
// Max capacity is always one less than the buffer size.
|
||||
int maxCapacity = buffer.maxPossibleRemainingCapacity();
|
||||
assertEquals(buffer.length() - 1, maxCapacity);
|
||||
|
||||
byte[] dataToWrite = getData(maxCapacity);
|
||||
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
|
||||
|
||||
assertEquals(dataToWrite.length, buffer.available());
|
||||
assertEquals(64, buffer.length());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldAllowWritingMaxRemainingCapacity() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
|
||||
|
||||
final int initiallyWritten = 10;
|
||||
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||
|
||||
// Max remaining capacity is always one less than the remaining buffer size.
|
||||
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
|
||||
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
|
||||
|
||||
byte[] dataToWrite = getData(maxRemainingCapacity);
|
||||
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
|
||||
|
||||
assertEquals(dataToWrite.length + initiallyWritten, buffer.available());
|
||||
assertEquals(64, buffer.length());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldAllowWritingMaxRemainingCapacityAfterWrappingAround() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
|
||||
|
||||
// Cause the internal write pointer to wrap around and be left of the read pointer.
|
||||
final int initiallyWritten = 40;
|
||||
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||
buffer.readRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||
|
||||
// Max remaining capacity is always one less than the remaining buffer size.
|
||||
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
|
||||
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
|
||||
|
||||
byte[] dataToWrite = getData(maxRemainingCapacity);
|
||||
buffer.putRawBytes(dataToWrite, 0, dataToWrite.length);
|
||||
|
||||
assertEquals(dataToWrite.length + initiallyWritten, buffer.available());
|
||||
assertEquals(64, buffer.length());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldOverflowWhenWritingOverMaxRemainingCapacity() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, 64);
|
||||
|
||||
final int initiallyWritten = 10;
|
||||
buffer.putRawBytes(new byte[initiallyWritten], 0, initiallyWritten);
|
||||
|
||||
// Max remaining capacity is always one less than the remaining buffer size.
|
||||
int maxRemainingCapacity = buffer.maxPossibleRemainingCapacity();
|
||||
assertEquals(buffer.length() - 1 - initiallyWritten, maxRemainingCapacity);
|
||||
|
||||
byte[] dataToWrite = getData(maxRemainingCapacity + 1);
|
||||
assertThrows(CircularBufferException.class, () -> buffer.putRawBytes(dataToWrite, 0, dataToWrite.length));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldThrowWhenReadingEmptyBuffer() {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||
assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[1], 0, 1));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldThrowWhenReadingMoreThanAvailable() throws CircularBufferException {
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(64, Integer.MAX_VALUE);
|
||||
buffer.putRawBytes(new byte[1], 0, 1);
|
||||
assertThrows(CircularBufferException.class, () -> buffer.readRawBytes(new byte[2], 0, 2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldThrowOnAboveMaximumInitialSize() {
|
||||
assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(65, 64));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldThrowOnMaximumInitialSize() {
|
||||
assertThrows(IllegalArgumentException.class, () -> new PlainCircularBuffer(Integer.MAX_VALUE, 64));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldAllowFullCapacity() throws CircularBufferException {
|
||||
int maxSize = 1024;
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize);
|
||||
buffer.ensureCapacity(maxSize - 1);
|
||||
assertEquals(maxSize - 1, buffer.maxPossibleRemainingCapacity());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldThrowOnTooLargeRequestedCapacity() {
|
||||
int maxSize = 1024;
|
||||
PlainCircularBuffer buffer = new PlainCircularBuffer(256, maxSize);
|
||||
assertThrows(CircularBufferException.class, () -> buffer.ensureCapacity(maxSize));
|
||||
}
|
||||
|
||||
private static byte[] getData(int length) {
|
||||
byte[] data = new byte[length];
|
||||
byte nextValue = 0;
|
||||
for (int i = 0; i < length; ++i) {
|
||||
data[i] = nextValue++;
|
||||
}
|
||||
return data;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user