diff --git a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java
index 5bb48362..dcee1326 100644
--- a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java
+++ b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java
@@ -1,12 +1,12 @@
/**
* Copyright 2009 sshj contributors
- *
+ *
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -19,6 +19,7 @@ import net.schmizz.concurrent.Promise;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.sftp.Response.StatusCode;
+import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@@ -33,37 +34,31 @@ public class RemoteFile
super(requester, path, handle);
}
- public FileAttributes fetchAttributes()
- throws IOException {
+ public FileAttributes fetchAttributes() throws IOException {
return requester.request(newRequest(PacketType.FSTAT))
.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS)
.ensurePacketTypeIs(PacketType.ATTRS)
.readFileAttributes();
}
- public long length()
- throws IOException {
+ public long length() throws IOException {
return fetchAttributes().getSize();
}
- public void setLength(long len)
- throws IOException {
+ public void setLength(long len) throws IOException {
setAttributes(new FileAttributes.Builder().withSize(len).build());
}
- public int read(long fileOffset, byte[] to, int offset, int len)
- throws IOException {
+ public int read(long fileOffset, byte[] to, int offset, int len) throws IOException {
final Response res = asyncRead(fileOffset, len).retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS);
return checkReadResponse(res, to, offset);
}
- protected Promise asyncRead(long fileOffset, int len)
- throws IOException {
+ protected Promise asyncRead(long fileOffset, int len) throws IOException {
return requester.request(newRequest(PacketType.READ).putUInt64(fileOffset).putUInt32(len));
}
- protected int checkReadResponse(Response res, byte[] to, int offset)
- throws Buffer.BufferException, SFTPException {
+ protected int checkReadResponse(Response res, byte[] to, int offset) throws Buffer.BufferException, SFTPException {
switch (res.getType()) {
case DATA:
int recvLen = res.readUInt32AsInt();
@@ -79,28 +74,25 @@ public class RemoteFile
}
}
- public void write(long fileOffset, byte[] data, int off, int len)
- throws IOException {
+ public void write(long fileOffset, byte[] data, int off, int len) throws IOException {
checkWriteResponse(asyncWrite(fileOffset, data, off, len));
}
protected Promise asyncWrite(long fileOffset, byte[] data, int off, int len)
throws IOException {
return requester.request(newRequest(PacketType.WRITE)
- .putUInt64(fileOffset)
- // TODO The SFTP spec claims this field is unneeded...? See #187
- .putUInt32(len)
- .putRawBytes(data, off, len)
+ .putUInt64(fileOffset)
+ // TODO The SFTP spec claims this field is unneeded...? See #187
+ .putUInt32(len)
+ .putRawBytes(data, off, len)
);
}
- private void checkWriteResponse(Promise responsePromise)
- throws SFTPException {
+ private void checkWriteResponse(Promise responsePromise) throws SFTPException {
responsePromise.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS).ensureStatusPacketIsOK();
}
- public void setAttributes(FileAttributes attrs)
- throws IOException {
+ public void setAttributes(FileAttributes attrs) throws IOException {
requester.request(newRequest(PacketType.FSETSTAT).putFileAttributes(attrs))
.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS).ensureStatusPacketIsOK();
}
@@ -140,15 +132,13 @@ public class RemoteFile
}
@Override
- public void write(int w)
- throws IOException {
+ public void write(int w) throws IOException {
b[0] = (byte) w;
write(b, 0, 1);
}
@Override
- public void write(byte[] buf, int off, int len)
- throws IOException {
+ public void write(byte[] buf, int off, int len) throws IOException {
if (unconfirmedWrites.size() > maxUnconfirmedWrites) {
checkWriteResponse(unconfirmedWrites.remove());
}
@@ -157,23 +147,20 @@ public class RemoteFile
}
@Override
- public void flush()
- throws IOException {
+ public void flush() throws IOException {
while (!unconfirmedWrites.isEmpty()) {
checkWriteResponse(unconfirmedWrites.remove());
}
}
@Override
- public void close()
- throws IOException {
+ public void close() throws IOException {
flush();
}
}
- public class RemoteFileInputStream
- extends InputStream {
+ public class RemoteFileInputStream extends InputStream {
private final byte[] b = new byte[1];
@@ -201,31 +188,29 @@ public class RemoteFile
}
@Override
- public void reset()
- throws IOException {
+ public void reset() throws IOException {
fileOffset = markPos;
}
@Override
- public long skip(long n)
- throws IOException {
+ public long skip(long n) throws IOException {
return (this.fileOffset = Math.min(fileOffset + n, length()));
}
@Override
- public int read()
- throws IOException {
+ public int read() throws IOException {
return read(b, 0, 1) == -1 ? -1 : b[0] & 0xff;
}
@Override
- public int read(byte[] into, int off, int len)
- throws IOException {
+ public int read(byte[] into, int off, int len) throws IOException {
int read = RemoteFile.this.read(fileOffset, into, off, len);
if (read != -1) {
fileOffset += read;
- if (markPos != 0 && read > readLimit) // Invalidate mark position
+ if (markPos != 0 && read > readLimit) {
+ // Invalidate mark position
markPos = 0;
+ }
}
return read;
}
@@ -238,27 +223,56 @@ public class RemoteFile
private final byte[] b = new byte[1];
private final int maxUnconfirmedReads;
- private final Queue> unconfirmedReads;
+ private final Queue> unconfirmedReads = new LinkedList>();
+ private final Queue unconfirmedReadOffsets = new LinkedList();
- private long fileOffset;
+ private long requestOffset;
+ private long responseOffset;
private boolean eof;
public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads) {
+ assert 0 <= maxUnconfirmedReads;
+
this.maxUnconfirmedReads = maxUnconfirmedReads;
- this.unconfirmedReads = new LinkedList>();
- this.fileOffset = 0;
}
public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset) {
+ assert 0 <= maxUnconfirmedReads;
+ assert 0 <= fileOffset;
+
this.maxUnconfirmedReads = maxUnconfirmedReads;
- this.unconfirmedReads = new LinkedList>();
- this.fileOffset = fileOffset;
+ this.requestOffset = this.responseOffset = fileOffset;
}
- @Override
- public long skip(long n)
- throws IOException {
- throw new IOException("skip is not supported by ReadAheadFileInputStream, use RemoteFileInputStream instead");
+ private ByteArrayInputStream pending = new ByteArrayInputStream(new byte[0]);
+
+ private boolean retrieveUnconfirmedRead(boolean blocking) throws IOException {
+ if (unconfirmedReads.size() <= 0) {
+ return false;
+ }
+
+ if (!blocking && !unconfirmedReads.peek().isDelivered()) {
+ return false;
+ }
+
+ unconfirmedReadOffsets.remove();
+ final Response res = unconfirmedReads.remove().retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS);
+ switch (res.getType()) {
+ case DATA:
+ int recvLen = res.readUInt32AsInt();
+ responseOffset += recvLen;
+ pending = new ByteArrayInputStream(res.array(), res.rpos(), recvLen);
+ break;
+
+ case STATUS:
+ res.ensureStatusIs(Response.StatusCode.EOF);
+ eof = true;
+ break;
+
+ default:
+ throw new SFTPException("Unexpected packet: " + res.getType());
+ }
+ return true;
}
@Override
@@ -268,26 +282,66 @@ public class RemoteFile
}
@Override
- public int read(byte[] into, int off, int len)
- throws IOException {
- while (!eof && unconfirmedReads.size() <= maxUnconfirmedReads) {
- // Send read requests as long as there is no EOF and we have not reached the maximum parallelism
- unconfirmedReads.add(asyncRead(fileOffset, len));
- fileOffset += len;
+ public int read(byte[] into, int off, int len) throws IOException {
+
+ while (!eof && pending.available() <= 0) {
+
+ // we also need to go here for len <= 0, because pending may be at
+ // EOF in which case it would return -1 instead of 0
+
+ while (unconfirmedReads.size() <= maxUnconfirmedReads) {
+ // Send read requests as long as there is no EOF and we have not reached the maximum parallelism
+ int reqLen = Math.max(1024, len); // don't be shy!
+ unconfirmedReads.add(RemoteFile.this.asyncRead(requestOffset, reqLen));
+ unconfirmedReadOffsets.add(requestOffset);
+ requestOffset += reqLen;
+ }
+
+ long nextOffset = unconfirmedReadOffsets.peek();
+ if (responseOffset != nextOffset) {
+
+ // the server could not give us all the data we needed, so
+ // we try to fill the gap synchronously
+
+ assert responseOffset < nextOffset;
+ assert 0 < (nextOffset - responseOffset);
+ assert (nextOffset - responseOffset) <= Integer.MAX_VALUE;
+
+ byte[] buf = new byte[(int) (nextOffset - responseOffset)];
+ int recvLen = RemoteFile.this.read(responseOffset, buf, 0, buf.length);
+
+ if (recvLen < 0) {
+ eof = true;
+ return -1;
+ }
+
+ if (0 == recvLen) {
+ // avoid infinite loops
+ throw new SFTPException("Unexpected response size (0), bailing out");
+ }
+
+ responseOffset += recvLen;
+ pending = new ByteArrayInputStream(buf, 0, recvLen);
+ } else if (!retrieveUnconfirmedRead(true /*blocking*/)) {
+
+ // this may happen if we change prefetch strategy
+ // currently, we should never get here...
+
+ throw new IllegalStateException("Could not retrieve data for pending read request");
+ }
}
- if (unconfirmedReads.isEmpty()) {
- assert eof;
- return -1;
- }
- // Retrieve first in
- final Response res = unconfirmedReads.remove().retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS);
- final int recvLen = checkReadResponse(res, into, off);
- if (recvLen == -1) {
- eof = true;
- }
- return recvLen;
+
+ return pending.read(into, off, len);
}
+ @Override
+ public int available() throws IOException {
+ boolean lastRead = true;
+ while (!eof && (pending.available() <= 0) && lastRead) {
+ lastRead = retrieveUnconfirmedRead(false /*blocking*/);
+ }
+ return pending.available();
+ }
}
+}
-}
\ No newline at end of file
diff --git a/src/test/java/com/hierynomus/sshj/SshFixture.java b/src/test/java/com/hierynomus/sshj/SshFixture.java
index 39cfc5b8..417961a5 100644
--- a/src/test/java/com/hierynomus/sshj/SshFixture.java
+++ b/src/test/java/com/hierynomus/sshj/SshFixture.java
@@ -3,15 +3,21 @@ package com.hierynomus.sshj;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
+import net.schmizz.sshj.transport.TransportException;
+import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.util.gss.BogusGSSAuthenticator;
import org.apache.sshd.SshServer;
+import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.keyprovider.FileKeyPairProvider;
+import org.apache.sshd.server.Command;
import org.apache.sshd.server.PasswordAuthenticator;
import org.apache.sshd.server.session.ServerSession;
+import org.apache.sshd.server.sftp.SftpSubsystem;
import org.junit.rules.ExternalResource;
import java.io.IOException;
import java.net.ServerSocket;
+import java.util.Collections;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@@ -91,6 +97,7 @@ public class SshFixture extends ExternalResource {
}
});
sshServer.setGSSAuthenticator(new BogusGSSAuthenticator());
+ sshServer.setSubsystemFactories(Collections.>singletonList(new SftpSubsystem.Factory()));
return sshServer;
}
diff --git a/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java
new file mode 100644
index 00000000..fadb0355
--- /dev/null
+++ b/src/test/java/com/hierynomus/sshj/sftp/RemoteFileTest.java
@@ -0,0 +1,75 @@
+package com.hierynomus.sshj.sftp;
+
+import com.hierynomus.sshj.SshFixture;
+import net.schmizz.sshj.SSHClient;
+import net.schmizz.sshj.sftp.OpenMode;
+import net.schmizz.sshj.sftp.RemoteFile;
+import net.schmizz.sshj.sftp.SFTPEngine;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.Random;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+public class RemoteFileTest {
+ @Rule
+ public SshFixture fixture = new SshFixture();
+
+ @Rule
+ public TemporaryFolder temp = new TemporaryFolder();
+
+ @Test
+ public void shouldNotGoOutOfBoundsInReadAheadInputStream() throws IOException {
+ SSHClient ssh = fixture.setupConnectedDefaultClient();
+ ssh.authPassword("test", "test");
+ SFTPEngine sftp = new SFTPEngine(ssh).init();
+
+ RemoteFile rf;
+ File file = temp.newFile("SftpReadAheadTest.bin");
+ rf = sftp.open(file.getPath(), EnumSet.of(OpenMode.WRITE, OpenMode.CREAT));
+ byte[] data = new byte[8192];
+ new Random(53).nextBytes(data);
+ data[3072] = 1;
+ rf.write(0, data, 0, data.length);
+ rf.close();
+
+ assertThat("The file should exist", file.exists());
+
+ rf = sftp.open(file.getPath());
+ InputStream rs = rf.new ReadAheadRemoteFileInputStream(16 /*maxUnconfirmedReads*/);
+
+ byte[] test = new byte[4097];
+ int n = 0;
+
+ while (n < 2048) {
+ n += rs.read(test, n, 2048 - n);
+ }
+
+ while (n < 3072) {
+ n += rs.read(test, n, 3072 - n);
+ }
+
+ if (test[3072] != 0) {
+ System.err.println("buffer overrun!");
+ }
+
+ n += rs.read(test, n, test.length - n); // --> ArrayIndexOutOfBoundsException
+
+ byte[] test2 = new byte[data.length];
+ System.arraycopy(test, 0, test2, 0, test.length);
+
+ while (n < data.length) {
+ n += rs.read(test2, n, data.length - n);
+ }
+
+ assertThat("The written and received data should match", data, equalTo(test2));
+ }
+}