From 8cf63a96a9dc3e42343dab4a3f0e40c35f535dfb Mon Sep 17 00:00:00 2001 From: David Kocher Date: Thu, 23 Dec 2021 22:24:52 +0100 Subject: [PATCH] =?UTF-8?q?Add=20parameter=20to=20limit=20read=20ahead=20t?= =?UTF-8?q?o=20maximum=20length.=20Allows=20to=20use=20mu=E2=80=A6=20(#724?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add parameter to limit read ahead to maximum length. Allows to use multiple concurrent threads reading from the same file with an offset without reading too much ahead for a single segment. * Review and add tests. Signed-off-by: David Kocher Co-authored-by: Yves Langisch --- .../net/schmizz/sshj/sftp/RemoteFile.java | 23 ++++- .../hierynomus/sshj/sftp/RemoteFileTest.java | 90 +++++++++++++++++++ 2 files changed, 109 insertions(+), 4 deletions(-) diff --git a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java index 25463bfb..a5558030 100644 --- a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java +++ b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java @@ -224,6 +224,7 @@ public class RemoteFile private final byte[] b = new byte[1]; private final int maxUnconfirmedReads; + private final long readAheadLimit; private final Queue> unconfirmedReads = new LinkedList>(); private final Queue unconfirmedReadOffsets = new LinkedList(); @@ -232,17 +233,22 @@ public class RemoteFile private boolean eof; public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads) { - assert 0 <= maxUnconfirmedReads; - - this.maxUnconfirmedReads = maxUnconfirmedReads; + this(maxUnconfirmedReads, 0L, -1L); } - public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset) { + /** + * + * @param maxUnconfirmedReads Maximum number of unconfirmed requests to send + * @param fileOffset Initial offset in file to read from + * @param readAheadLimit Read ahead is disabled after this limit has been reached + */ + public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset, long readAheadLimit) { assert 0 <= maxUnconfirmedReads; assert 0 <= fileOffset; this.maxUnconfirmedReads = maxUnconfirmedReads; this.requestOffset = this.responseOffset = fileOffset; + this.readAheadLimit = readAheadLimit > 0 ? fileOffset + readAheadLimit : Long.MAX_VALUE; } private ByteArrayInputStream pending = new ByteArrayInputStream(new byte[0]); @@ -293,9 +299,18 @@ public class RemoteFile while (unconfirmedReads.size() <= maxUnconfirmedReads) { // Send read requests as long as there is no EOF and we have not reached the maximum parallelism int reqLen = Math.max(1024, len); // don't be shy! + if (readAheadLimit > requestOffset) { + long remaining = readAheadLimit - requestOffset; + if (reqLen > remaining) { + reqLen = (int) remaining; + } + } unconfirmedReads.add(RemoteFile.this.asyncRead(requestOffset, reqLen)); unconfirmedReadOffsets.add(requestOffset); requestOffset += reqLen; + if (requestOffset >= readAheadLimit) { + break; + } } long nextOffset = unconfirmedReadOffsets.peek(); diff --git a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java index 3436af42..949a917c 100644 --- a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java +++ b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java @@ -20,6 +20,7 @@ import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.sftp.OpenMode; import net.schmizz.sshj.sftp.RemoteFile; import net.schmizz.sshj.sftp.SFTPEngine; +import net.schmizz.sshj.sftp.SFTPException; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -32,6 +33,7 @@ import java.util.Random; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; public class RemoteFileTest { @Rule @@ -84,4 +86,92 @@ public class RemoteFileTest { assertThat("The written and received data should match", data, equalTo(test2)); } + + @Test + public void shouldNotReadAheadAfterLimitInputStream() throws IOException { + SSHClient ssh = fixture.setupConnectedDefaultClient(); + ssh.authPassword("test", "test"); + SFTPEngine sftp = new SFTPEngine(ssh).init(); + + RemoteFile rf; + File file = temp.newFile("SftpReadAheadLimitTest.bin"); + rf = sftp.open(file.getPath(), EnumSet.of(OpenMode.WRITE, OpenMode.CREAT)); + byte[] data = new byte[8192]; + new Random(53).nextBytes(data); + data[3072] = 1; + rf.write(0, data, 0, data.length); + rf.close(); + + assertThat("The file should exist", file.exists()); + + rf = sftp.open(file.getPath()); + InputStream rs = rf.new ReadAheadRemoteFileInputStream(16 /*maxUnconfirmedReads*/,0, 3072); + + byte[] test = new byte[4097]; + int n = 0; + + while (n < 2048) { + n += rs.read(test, n, 2048 - n); + } + + rf.close(); + + while (n < 3072) { + n += rs.read(test, n, 3072 - n); + } + + assertThat("buffer overrun", test[3072] == 0); + + try { + rs.read(test, n, test.length - n); + fail("Content must not be buffered"); + } catch (SFTPException e){ + // expected + } + } + + @Test + public void limitedReadAheadInputStream() throws IOException { + SSHClient ssh = fixture.setupConnectedDefaultClient(); + ssh.authPassword("test", "test"); + SFTPEngine sftp = new SFTPEngine(ssh).init(); + + RemoteFile rf; + File file = temp.newFile("SftpReadAheadLimitedTest.bin"); + rf = sftp.open(file.getPath(), EnumSet.of(OpenMode.WRITE, OpenMode.CREAT)); + byte[] data = new byte[8192]; + new Random(53).nextBytes(data); + data[3072] = 1; + rf.write(0, data, 0, data.length); + rf.close(); + + assertThat("The file should exist", file.exists()); + + rf = sftp.open(file.getPath()); + InputStream rs = rf.new ReadAheadRemoteFileInputStream(16 /*maxUnconfirmedReads*/,0, 3072); + + byte[] test = new byte[4097]; + int n = 0; + + while (n < 2048) { + n += rs.read(test, n, 2048 - n); + } + + while (n < 3072) { + n += rs.read(test, n, 3072 - n); + } + + assertThat("buffer overrun", test[3072] == 0); + + n += rs.read(test, n, test.length - n); // --> ArrayIndexOutOfBoundsException + + byte[] test2 = new byte[data.length]; + System.arraycopy(test, 0, test2, 0, test.length); + + while (n < data.length) { + n += rs.read(test2, n, data.length - n); + } + + assertThat("The written and received data should match", data, equalTo(test2)); + } }