Consolidated stream copying logic

This commit is contained in:
Shikhar Bhushan
2011-04-24 17:21:15 +01:00
parent f89c0cc2f0
commit ab705d7f2a
4 changed files with 130 additions and 104 deletions

View File

@@ -50,16 +50,21 @@ class RudimentaryPTY {
new StreamCopier("stdout", shell.getInputStream(), System.out)
.bufSize(shell.getLocalMaxPacketSize())
.keepFlushing(true)
.start();
new StreamCopier("stderr", shell.getErrorStream(), System.err)
.bufSize(shell.getLocalMaxPacketSize())
.keepFlushing(true)
.start();
// Now make System.in act as stdin. To exit, hit Ctrl+D (since that results in an EOF on System.in)
// This is kinda messy because java only allows console input after you hit return
// But this is just an example... a GUI app could implement a proper PTY
StreamCopier.copy(System.in, shell.getOutputStream(), shell.getRemoteMaxPacketSize(), true);
new StreamCopier("stdin", System.in, shell.getOutputStream())
.bufSize(shell.getRemoteMaxPacketSize())
.keepFlushing(true)
.copy();
} finally {
session.close();

View File

@@ -26,10 +26,18 @@ import java.io.OutputStream;
public class StreamCopier
extends Thread {
private static final Logger LOG = LoggerFactory.getLogger(StreamCopier.class);
private final Logger logger = LoggerFactory.getLogger(getClass());
public interface ErrorCallback {
void onError(IOException ioe);
}
public interface Listener {
void reportProgress(long transferred);
}
public static ErrorCallback closeOnErrorCallback(final Closeable... toClose) {
@@ -41,42 +49,6 @@ public class StreamCopier
};
}
public interface Listener {
void reportProgress(long transferred);
}
public static long copy(InputStream in, OutputStream out, int bufSize, boolean keepFlushing, Listener listener)
throws IOException {
long count = 0;
final boolean reportProgress = listener != null;
final long startTime = System.currentTimeMillis();
final byte[] buf = new byte[bufSize];
int read;
while ((read = in.read(buf)) != -1) {
out.write(buf, 0, read);
count += read;
if (keepFlushing)
out.flush();
if (reportProgress)
listener.reportProgress(count);
}
if (!keepFlushing)
out.flush();
final double sizeKiB = count / 1024.0;
final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0;
LOG.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds));
return count;
}
public static long copy(InputStream in, OutputStream out, int bufSize, boolean keepFlushing)
throws IOException {
return copy(in, out, bufSize, keepFlushing, null);
}
public static String copyStreamToString(InputStream stream)
throws IOException {
final StringBuilder sb = new StringBuilder();
@@ -86,19 +58,28 @@ public class StreamCopier
return sb.toString();
}
private static final ErrorCallback NULL_CALLBACK = new ErrorCallback() {
@Override
public void onError(IOException ioe) {
}
};
private static final Listener NULL_LISTENER = new Listener() {
@Override
public void reportProgress(long transferred) {
}
};
private final Logger log = LoggerFactory.getLogger(getClass());
private final InputStream in;
private final OutputStream out;
private int bufSize = 1;
private boolean keepFlushing = true;
private boolean keepFlushing = false;
private long length = -1;
private ErrorCallback errCB = new ErrorCallback() {
@Override
public void onError(IOException ioe) {
}
}; // Default null cb
private Listener listener = NULL_LISTENER;
private ErrorCallback errCB = NULL_CALLBACK;
public StreamCopier(String name, InputStream in, OutputStream out) {
this.in = in;
@@ -106,13 +87,30 @@ public class StreamCopier
setName(name);
}
public StreamCopier bufSize(int size) {
bufSize = size;
public StreamCopier bufSize(int bufSize) {
this.bufSize = bufSize;
return this;
}
public StreamCopier keepFlushing(boolean choice) {
keepFlushing = choice;
public StreamCopier keepFlushing(boolean keepFlushing) {
this.keepFlushing = keepFlushing;
return this;
}
public StreamCopier listener(Listener listener) {
if (listener == null) listener = NULL_LISTENER;
this.listener = listener;
return this;
}
public StreamCopier errorCallback(ErrorCallback errCB) {
if (errCB == null) errCB = NULL_CALLBACK;
this.errCB = errCB;
return this;
}
public StreamCopier length(long length) {
this.length = length;
return this;
}
@@ -121,16 +119,50 @@ public class StreamCopier
return this;
}
public StreamCopier errorCallback(ErrorCallback errCB) {
this.errCB = errCB;
return this;
public long copy()
throws IOException {
final byte[] buf = new byte[bufSize];
long count = 0;
int read = 0;
final long startTime = System.currentTimeMillis();
if (length == -1) {
while ((read = in.read(buf)) != -1)
count = write(buf, count, read);
} else {
while (count < length && (read = in.read(buf, 0, (int) Math.min(bufSize, length - count))) != -1)
count = write(buf, count, read);
}
if (!keepFlushing)
out.flush();
final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0;
final double sizeKiB = count / 1024.0;
logger.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds));
if (length != -1 && read == -1)
throw new IOException("Encountered EOF, could not transfer " + length + " bytes");
return count;
}
private long write(byte[] buf, long count, int read)
throws IOException {
out.write(buf, 0, read);
count += read;
if (keepFlushing)
out.flush();
listener.reportProgress(count);
return count;
}
@Override
public void run() {
try {
log.debug("Wil pipe from {} to {}", in, out);
copy(in, out, bufSize, keepFlushing);
copy();
log.debug("EOF on {}", in);
} catch (IOException ioe) {
log.error("In pipe from {} to {}: " + ioe.toString(), in, out);

View File

@@ -135,8 +135,10 @@ public class SFTPFileTransfer
try {
final OutputStream os = adjusted.getOutputStream();
try {
StreamCopier.copy(rf.getInputStream(), os,
engine.getSubsystem().getLocalMaxPacketSize(), false, listener);
new StreamCopier("sftp download", rf.getInputStream(), os)
.bufSize(engine.getSubsystem().getLocalMaxPacketSize())
.listener(listener)
.copy();
} finally {
os.close();
}
@@ -195,8 +197,10 @@ public class SFTPFileTransfer
try {
final InputStream fis = local.getInputStream();
try {
final int bufSize = engine.getSubsystem().getRemoteMaxPacketSize() - rf.getOutgoingPacketOverhead();
StreamCopier.copy(fis, rf.getOutputStream(), bufSize, false, listener);
new StreamCopier("sftp upload", fis, rf.getOutputStream())
.bufSize(engine.getSubsystem().getRemoteMaxPacketSize() - rf.getOutgoingPacketOverhead())
.listener(listener)
.copy();
} finally {
fis.close();
}

View File

@@ -15,20 +15,20 @@
*/
package net.schmizz.sshj.xfer.scp;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.connection.channel.direct.Session.Command;
import net.schmizz.sshj.connection.channel.direct.SessionFactory;
import net.schmizz.sshj.xfer.TransferListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.List;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.connection.channel.direct.Session.Command;
import net.schmizz.sshj.connection.channel.direct.SessionFactory;
import net.schmizz.sshj.xfer.TransferListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** @see <a href="http://blogs.sun.com/janp/entry/how_the_scp_protocol_works">SCP Protocol</a> */
class SCPEngine {
@@ -164,53 +164,38 @@ class SCPEngine {
scp.getOutputStream().flush();
}
void transferToRemote(final InputStream src, final long length)
throws IOException {
transfer(src, scp.getOutputStream(), scp.getRemoteMaxPacketSize(), length);
}
void transferFromRemote(final OutputStream dest, final long length)
throws IOException {
transfer(scp.getInputStream(), dest, scp.getLocalMaxPacketSize(), length);
}
private void transfer(InputStream in, OutputStream out, int bufSize, long len)
long transferToRemote(final InputStream src, final long length)
throws IOException {
final byte[] buf = new byte[bufSize];
long count = 0;
int read = 0;
return transfer(src, scp.getOutputStream(), scp.getRemoteMaxPacketSize(), length);
}
final long startTime = System.currentTimeMillis();
long transferFromRemote(final OutputStream dest, final long length)
throws IOException {
return transfer(scp.getInputStream(), dest, scp.getLocalMaxPacketSize(), length);
}
while (count < len && (read = in.read(buf, 0, (int) Math.min(bufSize, len - count))) != -1) {
out.write(buf, 0, read);
count += read;
listener.reportProgress(count);
}
out.flush();
final double sizeKiB = count / 1024.0;
final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0;
log.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds));
if (read == -1)
throw new IOException("Had EOF before transfer completed");
private long transfer(InputStream in, OutputStream out, int bufSize, long len)
throws IOException {
return new StreamCopier("scp engine", in, out)
.bufSize(bufSize).length(len)
.listener(listener)
.copy();
}
void startedDir(final String dirname) {
listener.startedDir(dirname);
listener.startedDir(dirname);
}
void finishedDir() {
listener.finishedDir();
}
void finishedDir() {
listener.finishedDir();
}
void startedFile(final String filename, final long length) {
listener.startedFile(filename, length);
}
void startedFile(final String filename, final long length) {
listener.startedFile(filename, length);
}
void finishedFile() {
listener.finishedFile();
}
void finishedFile() {
listener.finishedFile();
}
}