Merge branch 'bkarge-issue-183'

This commit is contained in:
Jeroen van Erp
2015-06-17 12:37:55 +02:00
3 changed files with 209 additions and 73 deletions

View File

@@ -1,12 +1,12 @@
/** /**
* Copyright 2009 sshj contributors * Copyright 2009 sshj contributors
* * <p/>
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* * <p/>
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* * <p/>
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -19,6 +19,7 @@ import net.schmizz.concurrent.Promise;
import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.sftp.Response.StatusCode; import net.schmizz.sshj.sftp.Response.StatusCode;
import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@@ -33,37 +34,31 @@ public class RemoteFile
super(requester, path, handle); super(requester, path, handle);
} }
public FileAttributes fetchAttributes() public FileAttributes fetchAttributes() throws IOException {
throws IOException {
return requester.request(newRequest(PacketType.FSTAT)) return requester.request(newRequest(PacketType.FSTAT))
.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS) .retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS)
.ensurePacketTypeIs(PacketType.ATTRS) .ensurePacketTypeIs(PacketType.ATTRS)
.readFileAttributes(); .readFileAttributes();
} }
public long length() public long length() throws IOException {
throws IOException {
return fetchAttributes().getSize(); return fetchAttributes().getSize();
} }
public void setLength(long len) public void setLength(long len) throws IOException {
throws IOException {
setAttributes(new FileAttributes.Builder().withSize(len).build()); setAttributes(new FileAttributes.Builder().withSize(len).build());
} }
public int read(long fileOffset, byte[] to, int offset, int len) public int read(long fileOffset, byte[] to, int offset, int len) throws IOException {
throws IOException {
final Response res = asyncRead(fileOffset, len).retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS); final Response res = asyncRead(fileOffset, len).retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS);
return checkReadResponse(res, to, offset); return checkReadResponse(res, to, offset);
} }
protected Promise<Response, SFTPException> asyncRead(long fileOffset, int len) protected Promise<Response, SFTPException> asyncRead(long fileOffset, int len) throws IOException {
throws IOException {
return requester.request(newRequest(PacketType.READ).putUInt64(fileOffset).putUInt32(len)); return requester.request(newRequest(PacketType.READ).putUInt64(fileOffset).putUInt32(len));
} }
protected int checkReadResponse(Response res, byte[] to, int offset) protected int checkReadResponse(Response res, byte[] to, int offset) throws Buffer.BufferException, SFTPException {
throws Buffer.BufferException, SFTPException {
switch (res.getType()) { switch (res.getType()) {
case DATA: case DATA:
int recvLen = res.readUInt32AsInt(); int recvLen = res.readUInt32AsInt();
@@ -79,28 +74,25 @@ public class RemoteFile
} }
} }
public void write(long fileOffset, byte[] data, int off, int len) public void write(long fileOffset, byte[] data, int off, int len) throws IOException {
throws IOException {
checkWriteResponse(asyncWrite(fileOffset, data, off, len)); checkWriteResponse(asyncWrite(fileOffset, data, off, len));
} }
protected Promise<Response, SFTPException> asyncWrite(long fileOffset, byte[] data, int off, int len) protected Promise<Response, SFTPException> asyncWrite(long fileOffset, byte[] data, int off, int len)
throws IOException { throws IOException {
return requester.request(newRequest(PacketType.WRITE) return requester.request(newRequest(PacketType.WRITE)
.putUInt64(fileOffset) .putUInt64(fileOffset)
// TODO The SFTP spec claims this field is unneeded...? See #187 // TODO The SFTP spec claims this field is unneeded...? See #187
.putUInt32(len) .putUInt32(len)
.putRawBytes(data, off, len) .putRawBytes(data, off, len)
); );
} }
private void checkWriteResponse(Promise<Response, SFTPException> responsePromise) private void checkWriteResponse(Promise<Response, SFTPException> responsePromise) throws SFTPException {
throws SFTPException {
responsePromise.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS).ensureStatusPacketIsOK(); responsePromise.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS).ensureStatusPacketIsOK();
} }
public void setAttributes(FileAttributes attrs) public void setAttributes(FileAttributes attrs) throws IOException {
throws IOException {
requester.request(newRequest(PacketType.FSETSTAT).putFileAttributes(attrs)) requester.request(newRequest(PacketType.FSETSTAT).putFileAttributes(attrs))
.retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS).ensureStatusPacketIsOK(); .retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS).ensureStatusPacketIsOK();
} }
@@ -140,15 +132,13 @@ public class RemoteFile
} }
@Override @Override
public void write(int w) public void write(int w) throws IOException {
throws IOException {
b[0] = (byte) w; b[0] = (byte) w;
write(b, 0, 1); write(b, 0, 1);
} }
@Override @Override
public void write(byte[] buf, int off, int len) public void write(byte[] buf, int off, int len) throws IOException {
throws IOException {
if (unconfirmedWrites.size() > maxUnconfirmedWrites) { if (unconfirmedWrites.size() > maxUnconfirmedWrites) {
checkWriteResponse(unconfirmedWrites.remove()); checkWriteResponse(unconfirmedWrites.remove());
} }
@@ -157,23 +147,20 @@ public class RemoteFile
} }
@Override @Override
public void flush() public void flush() throws IOException {
throws IOException {
while (!unconfirmedWrites.isEmpty()) { while (!unconfirmedWrites.isEmpty()) {
checkWriteResponse(unconfirmedWrites.remove()); checkWriteResponse(unconfirmedWrites.remove());
} }
} }
@Override @Override
public void close() public void close() throws IOException {
throws IOException {
flush(); flush();
} }
} }
public class RemoteFileInputStream public class RemoteFileInputStream extends InputStream {
extends InputStream {
private final byte[] b = new byte[1]; private final byte[] b = new byte[1];
@@ -201,31 +188,29 @@ public class RemoteFile
} }
@Override @Override
public void reset() public void reset() throws IOException {
throws IOException {
fileOffset = markPos; fileOffset = markPos;
} }
@Override @Override
public long skip(long n) public long skip(long n) throws IOException {
throws IOException {
return (this.fileOffset = Math.min(fileOffset + n, length())); return (this.fileOffset = Math.min(fileOffset + n, length()));
} }
@Override @Override
public int read() public int read() throws IOException {
throws IOException {
return read(b, 0, 1) == -1 ? -1 : b[0] & 0xff; return read(b, 0, 1) == -1 ? -1 : b[0] & 0xff;
} }
@Override @Override
public int read(byte[] into, int off, int len) public int read(byte[] into, int off, int len) throws IOException {
throws IOException {
int read = RemoteFile.this.read(fileOffset, into, off, len); int read = RemoteFile.this.read(fileOffset, into, off, len);
if (read != -1) { if (read != -1) {
fileOffset += read; fileOffset += read;
if (markPos != 0 && read > readLimit) // Invalidate mark position if (markPos != 0 && read > readLimit) {
// Invalidate mark position
markPos = 0; markPos = 0;
}
} }
return read; return read;
} }
@@ -238,27 +223,56 @@ public class RemoteFile
private final byte[] b = new byte[1]; private final byte[] b = new byte[1];
private final int maxUnconfirmedReads; private final int maxUnconfirmedReads;
private final Queue<Promise<Response, SFTPException>> unconfirmedReads; private final Queue<Promise<Response, SFTPException>> unconfirmedReads = new LinkedList<Promise<Response, SFTPException>>();
private final Queue<Long> unconfirmedReadOffsets = new LinkedList<Long>();
private long fileOffset; private long requestOffset;
private long responseOffset;
private boolean eof; private boolean eof;
public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads) { public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads) {
assert 0 <= maxUnconfirmedReads;
this.maxUnconfirmedReads = maxUnconfirmedReads; this.maxUnconfirmedReads = maxUnconfirmedReads;
this.unconfirmedReads = new LinkedList<Promise<Response, SFTPException>>();
this.fileOffset = 0;
} }
public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset) { public ReadAheadRemoteFileInputStream(int maxUnconfirmedReads, long fileOffset) {
assert 0 <= maxUnconfirmedReads;
assert 0 <= fileOffset;
this.maxUnconfirmedReads = maxUnconfirmedReads; this.maxUnconfirmedReads = maxUnconfirmedReads;
this.unconfirmedReads = new LinkedList<Promise<Response, SFTPException>>(); this.requestOffset = this.responseOffset = fileOffset;
this.fileOffset = fileOffset;
} }
@Override private ByteArrayInputStream pending = new ByteArrayInputStream(new byte[0]);
public long skip(long n)
throws IOException { private boolean retrieveUnconfirmedRead(boolean blocking) throws IOException {
throw new IOException("skip is not supported by ReadAheadFileInputStream, use RemoteFileInputStream instead"); if (unconfirmedReads.size() <= 0) {
return false;
}
if (!blocking && !unconfirmedReads.peek().isDelivered()) {
return false;
}
unconfirmedReadOffsets.remove();
final Response res = unconfirmedReads.remove().retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS);
switch (res.getType()) {
case DATA:
int recvLen = res.readUInt32AsInt();
responseOffset += recvLen;
pending = new ByteArrayInputStream(res.array(), res.rpos(), recvLen);
break;
case STATUS:
res.ensureStatusIs(Response.StatusCode.EOF);
eof = true;
break;
default:
throw new SFTPException("Unexpected packet: " + res.getType());
}
return true;
} }
@Override @Override
@@ -268,26 +282,66 @@ public class RemoteFile
} }
@Override @Override
public int read(byte[] into, int off, int len) public int read(byte[] into, int off, int len) throws IOException {
throws IOException {
while (!eof && unconfirmedReads.size() <= maxUnconfirmedReads) { while (!eof && pending.available() <= 0) {
// Send read requests as long as there is no EOF and we have not reached the maximum parallelism
unconfirmedReads.add(asyncRead(fileOffset, len)); // we also need to go here for len <= 0, because pending may be at
fileOffset += len; // EOF in which case it would return -1 instead of 0
while (unconfirmedReads.size() <= maxUnconfirmedReads) {
// Send read requests as long as there is no EOF and we have not reached the maximum parallelism
int reqLen = Math.max(1024, len); // don't be shy!
unconfirmedReads.add(RemoteFile.this.asyncRead(requestOffset, reqLen));
unconfirmedReadOffsets.add(requestOffset);
requestOffset += reqLen;
}
long nextOffset = unconfirmedReadOffsets.peek();
if (responseOffset != nextOffset) {
// the server could not give us all the data we needed, so
// we try to fill the gap synchronously
assert responseOffset < nextOffset;
assert 0 < (nextOffset - responseOffset);
assert (nextOffset - responseOffset) <= Integer.MAX_VALUE;
byte[] buf = new byte[(int) (nextOffset - responseOffset)];
int recvLen = RemoteFile.this.read(responseOffset, buf, 0, buf.length);
if (recvLen < 0) {
eof = true;
return -1;
}
if (0 == recvLen) {
// avoid infinite loops
throw new SFTPException("Unexpected response size (0), bailing out");
}
responseOffset += recvLen;
pending = new ByteArrayInputStream(buf, 0, recvLen);
} else if (!retrieveUnconfirmedRead(true /*blocking*/)) {
// this may happen if we change prefetch strategy
// currently, we should never get here...
throw new IllegalStateException("Could not retrieve data for pending read request");
}
} }
if (unconfirmedReads.isEmpty()) {
assert eof; return pending.read(into, off, len);
return -1;
}
// Retrieve first in
final Response res = unconfirmedReads.remove().retrieve(requester.getTimeoutMs(), TimeUnit.MILLISECONDS);
final int recvLen = checkReadResponse(res, into, off);
if (recvLen == -1) {
eof = true;
}
return recvLen;
} }
@Override
public int available() throws IOException {
boolean lastRead = true;
while (!eof && (pending.available() <= 0) && lastRead) {
lastRead = retrieveUnconfirmedRead(false /*blocking*/);
}
return pending.available();
}
} }
} }

View File

@@ -3,15 +3,21 @@ package com.hierynomus.sshj;
import net.schmizz.sshj.Config; import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig; import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.util.gss.BogusGSSAuthenticator; import net.schmizz.sshj.util.gss.BogusGSSAuthenticator;
import org.apache.sshd.SshServer; import org.apache.sshd.SshServer;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.keyprovider.FileKeyPairProvider; import org.apache.sshd.common.keyprovider.FileKeyPairProvider;
import org.apache.sshd.server.Command;
import org.apache.sshd.server.PasswordAuthenticator; import org.apache.sshd.server.PasswordAuthenticator;
import org.apache.sshd.server.session.ServerSession; import org.apache.sshd.server.session.ServerSession;
import org.apache.sshd.server.sftp.SftpSubsystem;
import org.junit.rules.ExternalResource; import org.junit.rules.ExternalResource;
import java.io.IOException; import java.io.IOException;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
/** /**
@@ -91,6 +97,7 @@ public class SshFixture extends ExternalResource {
} }
}); });
sshServer.setGSSAuthenticator(new BogusGSSAuthenticator()); sshServer.setGSSAuthenticator(new BogusGSSAuthenticator());
sshServer.setSubsystemFactories(Collections.<NamedFactory<Command>>singletonList(new SftpSubsystem.Factory()));
return sshServer; return sshServer;
} }

View File

@@ -0,0 +1,75 @@
package com.hierynomus.sshj.sftp;
import com.hierynomus.sshj.SshFixture;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.sftp.OpenMode;
import net.schmizz.sshj.sftp.RemoteFile;
import net.schmizz.sshj.sftp.SFTPEngine;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Random;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
public class RemoteFileTest {
@Rule
public SshFixture fixture = new SshFixture();
@Rule
public TemporaryFolder temp = new TemporaryFolder();
@Test
public void shouldNotGoOutOfBoundsInReadAheadInputStream() throws IOException {
SSHClient ssh = fixture.setupConnectedDefaultClient();
ssh.authPassword("test", "test");
SFTPEngine sftp = new SFTPEngine(ssh).init();
RemoteFile rf;
File file = temp.newFile("SftpReadAheadTest.bin");
rf = sftp.open(file.getPath(), EnumSet.of(OpenMode.WRITE, OpenMode.CREAT));
byte[] data = new byte[8192];
new Random(53).nextBytes(data);
data[3072] = 1;
rf.write(0, data, 0, data.length);
rf.close();
assertThat("The file should exist", file.exists());
rf = sftp.open(file.getPath());
InputStream rs = rf.new ReadAheadRemoteFileInputStream(16 /*maxUnconfirmedReads*/);
byte[] test = new byte[4097];
int n = 0;
while (n < 2048) {
n += rs.read(test, n, 2048 - n);
}
while (n < 3072) {
n += rs.read(test, n, 3072 - n);
}
if (test[3072] != 0) {
System.err.println("buffer overrun!");
}
n += rs.read(test, n, test.length - n); // --> ArrayIndexOutOfBoundsException
byte[] test2 = new byte[data.length];
System.arraycopy(test, 0, test2, 0, test.length);
while (n < data.length) {
n += rs.read(test2, n, data.length - n);
}
assertThat("The written and received data should match", data, equalTo(test2));
}
}