diff --git a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java index 4d78084e..8b208faf 100644 --- a/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java +++ b/src/main/java/net/schmizz/sshj/sftp/RemoteFile.java @@ -15,15 +15,24 @@ */ package net.schmizz.sshj.sftp; +import net.schmizz.concurrent.Promise; import net.schmizz.sshj.sftp.Response.StatusCode; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; public class RemoteFile extends RemoteResource { + public static final int DEFAULT_CONCURRENT_REQUESTS = 10; + + protected volatile int concurrentRequests = DEFAULT_CONCURRENT_REQUESTS; + + private Queue> writeRequestsQueue = new LinkedBlockingQueue>(); + public RemoteFile(Requester requester, String path, String handle) { super(requester, path, handle); } @@ -73,11 +82,14 @@ public class RemoteFile public void write(long fileOffset, byte[] data, int off, int len) throws IOException { - requester.doRequest(newRequest(PacketType.WRITE) + Request request = newRequest(PacketType.WRITE) .putUInt64(fileOffset) .putUInt32(len - off) - .putRawBytes(data, off, len) - ).ensureStatusPacketIsOK(); + .putRawBytes(data, off, len); + writeRequestsQueue.add(requester.request(request)); + while (writeRequestsQueue.size() >= getConcurrentRequests()) { + requester.retrieve(writeRequestsQueue.remove()); + } } public void setAttributes(FileAttributes attrs) @@ -186,4 +198,23 @@ public class RemoteFile } } + + @Override + public void close() throws IOException { + try { + while(!writeRequestsQueue.isEmpty()) { + requester.retrieve(writeRequestsQueue.remove()); + } + } finally { + super.close(); + } + } + + public void setConcurrentRequests(int concurrentRequests) { + this.concurrentRequests = concurrentRequests; + } + + public int getConcurrentRequests() { + return concurrentRequests; + } } diff --git a/src/main/java/net/schmizz/sshj/sftp/Requester.java b/src/main/java/net/schmizz/sshj/sftp/Requester.java index 2cb4cba7..077f303f 100644 --- a/src/main/java/net/schmizz/sshj/sftp/Requester.java +++ b/src/main/java/net/schmizz/sshj/sftp/Requester.java @@ -15,6 +15,8 @@ */ package net.schmizz.sshj.sftp; +import net.schmizz.concurrent.Promise; + import java.io.IOException; public interface Requester { @@ -26,4 +28,10 @@ public interface Requester { Response doRequest(Request req) throws IOException; + Promise request(Request request) + throws IOException; + + void retrieve(Promise response) + throws IOException; + } diff --git a/src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java b/src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java index 2814f973..0d4759c0 100644 --- a/src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java +++ b/src/main/java/net/schmizz/sshj/sftp/SFTPEngine.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.sftp; +import net.schmizz.concurrent.Promise; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.connection.channel.direct.Session.Subsystem; import net.schmizz.sshj.connection.channel.direct.SessionFactory; @@ -257,7 +258,7 @@ public class SFTPEngine throw new SFTPException("Unexpected data in " + res.getType() + " packet"); } - protected synchronized void transmit(SFTPPacket payload) + private synchronized void transmit(SFTPPacket payload) throws IOException { final int len = payload.available(); out.write((len >>> 24) & 0xff); @@ -268,4 +269,17 @@ public class SFTPEngine out.flush(); } + @Override + public Promise request(Request req) throws IOException { + reader.expectResponseTo(req); + log.debug("Sending {}", req); + transmit(req); + return req.getResponsePromise(); + } + + @Override + public void retrieve(Promise request) + throws IOException { + request.retrieve(timeout, TimeUnit.SECONDS).ensureStatusPacketIsOK(); + } }