From ab705d7f2a180ded74be0cfc6aa7be675a803993 Mon Sep 17 00:00:00 2001 From: Shikhar Bhushan Date: Sun, 24 Apr 2011 17:21:15 +0100 Subject: [PATCH] Consolidated stream copying logic --- src/main/java/examples/RudimentaryPTY.java | 7 +- .../net/schmizz/sshj/common/StreamCopier.java | 136 +++++++++++------- .../schmizz/sshj/sftp/SFTPFileTransfer.java | 12 +- .../net/schmizz/sshj/xfer/scp/SCPEngine.java | 79 +++++----- 4 files changed, 130 insertions(+), 104 deletions(-) diff --git a/src/main/java/examples/RudimentaryPTY.java b/src/main/java/examples/RudimentaryPTY.java index 79c4dab1..5f5cf9a5 100644 --- a/src/main/java/examples/RudimentaryPTY.java +++ b/src/main/java/examples/RudimentaryPTY.java @@ -50,16 +50,21 @@ class RudimentaryPTY { new StreamCopier("stdout", shell.getInputStream(), System.out) .bufSize(shell.getLocalMaxPacketSize()) + .keepFlushing(true) .start(); new StreamCopier("stderr", shell.getErrorStream(), System.err) .bufSize(shell.getLocalMaxPacketSize()) + .keepFlushing(true) .start(); // Now make System.in act as stdin. To exit, hit Ctrl+D (since that results in an EOF on System.in) // This is kinda messy because java only allows console input after you hit return // But this is just an example... a GUI app could implement a proper PTY - StreamCopier.copy(System.in, shell.getOutputStream(), shell.getRemoteMaxPacketSize(), true); + new StreamCopier("stdin", System.in, shell.getOutputStream()) + .bufSize(shell.getRemoteMaxPacketSize()) + .keepFlushing(true) + .copy(); } finally { session.close(); diff --git a/src/main/java/net/schmizz/sshj/common/StreamCopier.java b/src/main/java/net/schmizz/sshj/common/StreamCopier.java index 7382118e..c1eb799e 100644 --- a/src/main/java/net/schmizz/sshj/common/StreamCopier.java +++ b/src/main/java/net/schmizz/sshj/common/StreamCopier.java @@ -26,10 +26,18 @@ import java.io.OutputStream; public class StreamCopier extends Thread { - private static final Logger LOG = LoggerFactory.getLogger(StreamCopier.class); + private final Logger logger = LoggerFactory.getLogger(getClass()); public interface ErrorCallback { + void onError(IOException ioe); + + } + + public interface Listener { + + void reportProgress(long transferred); + } public static ErrorCallback closeOnErrorCallback(final Closeable... toClose) { @@ -41,42 +49,6 @@ public class StreamCopier }; } - public interface Listener { - void reportProgress(long transferred); - } - - public static long copy(InputStream in, OutputStream out, int bufSize, boolean keepFlushing, Listener listener) - throws IOException { - long count = 0; - - final boolean reportProgress = listener != null; - final long startTime = System.currentTimeMillis(); - - final byte[] buf = new byte[bufSize]; - int read; - while ((read = in.read(buf)) != -1) { - out.write(buf, 0, read); - count += read; - if (keepFlushing) - out.flush(); - if (reportProgress) - listener.reportProgress(count); - } - if (!keepFlushing) - out.flush(); - - final double sizeKiB = count / 1024.0; - final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0; - LOG.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds)); - - return count; - } - - public static long copy(InputStream in, OutputStream out, int bufSize, boolean keepFlushing) - throws IOException { - return copy(in, out, bufSize, keepFlushing, null); - } - public static String copyStreamToString(InputStream stream) throws IOException { final StringBuilder sb = new StringBuilder(); @@ -86,19 +58,28 @@ public class StreamCopier return sb.toString(); } + private static final ErrorCallback NULL_CALLBACK = new ErrorCallback() { + @Override + public void onError(IOException ioe) { + } + }; + + private static final Listener NULL_LISTENER = new Listener() { + @Override + public void reportProgress(long transferred) { + } + }; + private final Logger log = LoggerFactory.getLogger(getClass()); private final InputStream in; private final OutputStream out; - private int bufSize = 1; - private boolean keepFlushing = true; + private boolean keepFlushing = false; + private long length = -1; - private ErrorCallback errCB = new ErrorCallback() { - @Override - public void onError(IOException ioe) { - } - }; // Default null cb + private Listener listener = NULL_LISTENER; + private ErrorCallback errCB = NULL_CALLBACK; public StreamCopier(String name, InputStream in, OutputStream out) { this.in = in; @@ -106,13 +87,30 @@ public class StreamCopier setName(name); } - public StreamCopier bufSize(int size) { - bufSize = size; + public StreamCopier bufSize(int bufSize) { + this.bufSize = bufSize; return this; } - public StreamCopier keepFlushing(boolean choice) { - keepFlushing = choice; + public StreamCopier keepFlushing(boolean keepFlushing) { + this.keepFlushing = keepFlushing; + return this; + } + + public StreamCopier listener(Listener listener) { + if (listener == null) listener = NULL_LISTENER; + this.listener = listener; + return this; + } + + public StreamCopier errorCallback(ErrorCallback errCB) { + if (errCB == null) errCB = NULL_CALLBACK; + this.errCB = errCB; + return this; + } + + public StreamCopier length(long length) { + this.length = length; return this; } @@ -121,16 +119,50 @@ public class StreamCopier return this; } - public StreamCopier errorCallback(ErrorCallback errCB) { - this.errCB = errCB; - return this; + public long copy() + throws IOException { + final byte[] buf = new byte[bufSize]; + long count = 0; + int read = 0; + + final long startTime = System.currentTimeMillis(); + + if (length == -1) { + while ((read = in.read(buf)) != -1) + count = write(buf, count, read); + } else { + while (count < length && (read = in.read(buf, 0, (int) Math.min(bufSize, length - count))) != -1) + count = write(buf, count, read); + } + + if (!keepFlushing) + out.flush(); + + final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0; + final double sizeKiB = count / 1024.0; + logger.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds)); + + if (length != -1 && read == -1) + throw new IOException("Encountered EOF, could not transfer " + length + " bytes"); + + return count; + } + + private long write(byte[] buf, long count, int read) + throws IOException { + out.write(buf, 0, read); + count += read; + if (keepFlushing) + out.flush(); + listener.reportProgress(count); + return count; } @Override public void run() { try { log.debug("Wil pipe from {} to {}", in, out); - copy(in, out, bufSize, keepFlushing); + copy(); log.debug("EOF on {}", in); } catch (IOException ioe) { log.error("In pipe from {} to {}: " + ioe.toString(), in, out); diff --git a/src/main/java/net/schmizz/sshj/sftp/SFTPFileTransfer.java b/src/main/java/net/schmizz/sshj/sftp/SFTPFileTransfer.java index 0960e1a7..33011338 100644 --- a/src/main/java/net/schmizz/sshj/sftp/SFTPFileTransfer.java +++ b/src/main/java/net/schmizz/sshj/sftp/SFTPFileTransfer.java @@ -135,8 +135,10 @@ public class SFTPFileTransfer try { final OutputStream os = adjusted.getOutputStream(); try { - StreamCopier.copy(rf.getInputStream(), os, - engine.getSubsystem().getLocalMaxPacketSize(), false, listener); + new StreamCopier("sftp download", rf.getInputStream(), os) + .bufSize(engine.getSubsystem().getLocalMaxPacketSize()) + .listener(listener) + .copy(); } finally { os.close(); } @@ -195,8 +197,10 @@ public class SFTPFileTransfer try { final InputStream fis = local.getInputStream(); try { - final int bufSize = engine.getSubsystem().getRemoteMaxPacketSize() - rf.getOutgoingPacketOverhead(); - StreamCopier.copy(fis, rf.getOutputStream(), bufSize, false, listener); + new StreamCopier("sftp upload", fis, rf.getOutputStream()) + .bufSize(engine.getSubsystem().getRemoteMaxPacketSize() - rf.getOutgoingPacketOverhead()) + .listener(listener) + .copy(); } finally { fis.close(); } diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java index 22c0e215..0499fe4d 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java @@ -15,20 +15,20 @@ */ package net.schmizz.sshj.xfer.scp; +import net.schmizz.sshj.common.IOUtils; +import net.schmizz.sshj.common.SSHException; +import net.schmizz.sshj.common.StreamCopier; +import net.schmizz.sshj.connection.channel.direct.Session.Command; +import net.schmizz.sshj.connection.channel.direct.SessionFactory; +import net.schmizz.sshj.xfer.TransferListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.List; -import net.schmizz.sshj.common.IOUtils; -import net.schmizz.sshj.common.SSHException; -import net.schmizz.sshj.connection.channel.direct.Session.Command; -import net.schmizz.sshj.connection.channel.direct.SessionFactory; -import net.schmizz.sshj.xfer.TransferListener; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - /** @see SCP Protocol */ class SCPEngine { @@ -164,53 +164,38 @@ class SCPEngine { scp.getOutputStream().flush(); } - void transferToRemote(final InputStream src, final long length) - throws IOException { - transfer(src, scp.getOutputStream(), scp.getRemoteMaxPacketSize(), length); - } - - void transferFromRemote(final OutputStream dest, final long length) - throws IOException { - transfer(scp.getInputStream(), dest, scp.getLocalMaxPacketSize(), length); - } - - private void transfer(InputStream in, OutputStream out, int bufSize, long len) + long transferToRemote(final InputStream src, final long length) throws IOException { - final byte[] buf = new byte[bufSize]; - long count = 0; - int read = 0; + return transfer(src, scp.getOutputStream(), scp.getRemoteMaxPacketSize(), length); + } - final long startTime = System.currentTimeMillis(); + long transferFromRemote(final OutputStream dest, final long length) + throws IOException { + return transfer(scp.getInputStream(), dest, scp.getLocalMaxPacketSize(), length); + } - while (count < len && (read = in.read(buf, 0, (int) Math.min(bufSize, len - count))) != -1) { - out.write(buf, 0, read); - count += read; - listener.reportProgress(count); - } - out.flush(); - - final double sizeKiB = count / 1024.0; - final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0; - log.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds)); - - if (read == -1) - throw new IOException("Had EOF before transfer completed"); + private long transfer(InputStream in, OutputStream out, int bufSize, long len) + throws IOException { + return new StreamCopier("scp engine", in, out) + .bufSize(bufSize).length(len) + .listener(listener) + .copy(); } void startedDir(final String dirname) { - listener.startedDir(dirname); + listener.startedDir(dirname); } - void finishedDir() { - listener.finishedDir(); - } + void finishedDir() { + listener.finishedDir(); + } - void startedFile(final String filename, final long length) { - listener.startedFile(filename, length); - } + void startedFile(final String filename, final long length) { + listener.startedFile(filename, length); + } - void finishedFile() { - listener.finishedFile(); - } + void finishedFile() { + listener.finishedFile(); + } }