diff --git a/src/main/java/examples/Exec.java b/src/main/java/examples/Exec.java index 38676bc1..d6273357 100644 --- a/src/main/java/examples/Exec.java +++ b/src/main/java/examples/Exec.java @@ -37,7 +37,7 @@ public class Exec { final Session session = ssh.startSession(); try { final Command cmd = session.exec("ping -c 1 google.com"); - System.out.println(IOUtils.pipeStream(cmd.getInputStream()).toString()); + System.out.println(IOUtils.readFully(cmd.getInputStream()).toString()); cmd.join(5, TimeUnit.SECONDS); System.out.println("\n** exit status: " + cmd.getExitStatus()); } finally { diff --git a/src/main/java/examples/RudimentaryPTY.java b/src/main/java/examples/RudimentaryPTY.java index 5f5cf9a5..4c6bb136 100644 --- a/src/main/java/examples/RudimentaryPTY.java +++ b/src/main/java/examples/RudimentaryPTY.java @@ -48,22 +48,19 @@ class RudimentaryPTY { final Shell shell = session.startShell(); - new StreamCopier("stdout", shell.getInputStream(), System.out) + new StreamCopier(shell.getInputStream(), System.out) .bufSize(shell.getLocalMaxPacketSize()) - .keepFlushing(true) - .start(); + .spawn("stdout"); - new StreamCopier("stderr", shell.getErrorStream(), System.err) + new StreamCopier(shell.getErrorStream(), System.err) .bufSize(shell.getLocalMaxPacketSize()) - .keepFlushing(true) - .start(); + .spawn("stderr"); // Now make System.in act as stdin. To exit, hit Ctrl+D (since that results in an EOF on System.in) // This is kinda messy because java only allows console input after you hit return // But this is just an example... a GUI app could implement a proper PTY - new StreamCopier("stdin", System.in, shell.getOutputStream()) + new StreamCopier(System.in, shell.getOutputStream()) .bufSize(shell.getRemoteMaxPacketSize()) - .keepFlushing(true) .copy(); } finally { diff --git a/src/main/java/examples/X11.java b/src/main/java/examples/X11.java index 653429a9..b02dc138 100644 --- a/src/main/java/examples/X11.java +++ b/src/main/java/examples/X11.java @@ -57,8 +57,8 @@ public class X11 { final Command cmd = sess.exec("/usr/X11/bin/xcalc"); - new StreamCopier("stdout", cmd.getInputStream(), System.out).start(); - new StreamCopier("stderr", cmd.getErrorStream(), System.err).start(); + new StreamCopier(cmd.getInputStream(), System.out).spawn("stdout"); + new StreamCopier(cmd.getErrorStream(), System.err).spawn("stderr"); // Wait for session & X11 channel to get closed ssh.getConnection().join(); diff --git a/src/main/java/net/schmizz/sshj/common/Buffer.java b/src/main/java/net/schmizz/sshj/common/Buffer.java index d258590f..7bbe43eb 100644 --- a/src/main/java/net/schmizz/sshj/common/Buffer.java +++ b/src/main/java/net/schmizz/sshj/common/Buffer.java @@ -427,20 +427,16 @@ public class Buffer> { } public T putString(String string) { - try { - return putString(string.getBytes("UTF-8")); - } catch (UnsupportedEncodingException e) { - throw new SSHRuntimeException(e); - } + return putString(string.getBytes(IOUtils.UTF8)); } /** * Writes a char-array as an SSH string and then blanks it out. *

- * This is useful when a plaintext password needs to be sent. If {@code passwd} is {@code null}, an empty string is + * This is useful when a plaintext password needs to be sent. If {@code str} is {@code null}, an empty string is * written. * - * @param str (null-ok) the password as a character array + * @param str (null-ok) the string as a character array * * @return this */ diff --git a/src/main/java/net/schmizz/sshj/common/IOUtils.java b/src/main/java/net/schmizz/sshj/common/IOUtils.java index 7909c005..273d9403 100644 --- a/src/main/java/net/schmizz/sshj/common/IOUtils.java +++ b/src/main/java/net/schmizz/sshj/common/IOUtils.java @@ -42,11 +42,14 @@ import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.Charset; public class IOUtils { private static final Logger LOG = LoggerFactory.getLogger(IOUtils.class); + public static final Charset UTF8 = Charset.forName("UTF-8"); + public static void closeQuietly(Closeable... closeables) { for (Closeable c : closeables) try { @@ -57,15 +60,11 @@ public class IOUtils { } } - public static ByteArrayOutputStream pipeStream(InputStream stream) + public static ByteArrayOutputStream readFully(InputStream stream) throws IOException { - final ByteArrayOutputStream bos = new ByteArrayOutputStream(); - byte[] buf = new byte[1024]; - int read; - while ((read = (stream.read(buf))) != -1) - bos.write(buf, 0, read); - bos.flush(); - return bos; + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + new StreamCopier(stream, baos).copy(); + return baos; } -} +} \ No newline at end of file diff --git a/src/main/java/net/schmizz/sshj/common/StreamCopier.java b/src/main/java/net/schmizz/sshj/common/StreamCopier.java index c21ca378..da38c462 100644 --- a/src/main/java/net/schmizz/sshj/common/StreamCopier.java +++ b/src/main/java/net/schmizz/sshj/common/StreamCopier.java @@ -15,6 +15,8 @@ */ package net.schmizz.sshj.common; +import net.schmizz.concurrent.Event; +import net.schmizz.concurrent.ExceptionChainer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -22,11 +24,9 @@ import java.io.Closeable; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.concurrent.TimeUnit; -public class StreamCopier - extends Thread { - - private final Logger logger = LoggerFactory.getLogger(getClass()); +public class StreamCopier { public interface ErrorCallback { @@ -36,7 +36,7 @@ public class StreamCopier public interface Listener { - void reportProgress(long transferred); + void reportProgress(long transferred) throws IOException; } @@ -65,17 +65,25 @@ public class StreamCopier private final InputStream in; private final OutputStream out; - private int bufSize = 1; - private boolean keepFlushing = false; - private long length = -1; private Listener listener = NULL_LISTENER; private ErrorCallback errCB = NULL_CALLBACK; - public StreamCopier(String name, InputStream in, OutputStream out) { + private int bufSize = 1; + private boolean keepFlushing = true; + private long length = -1; + + private final Event doneEvent = + new Event("copyDone", new ExceptionChainer() { + @Override + public IOException chain(Throwable t) { + return (t instanceof IOException) ? (IOException) t : new IOException(t); + } + }); + + public StreamCopier(InputStream in, OutputStream out) { this.in = in; this.out = out; - setName(name); } public StreamCopier bufSize(int bufSize) { @@ -105,8 +113,41 @@ public class StreamCopier return this; } - public StreamCopier daemon(boolean choice) { - setDaemon(choice); + public StreamCopier spawn(String name) { + return spawn(name, false); + } + + public StreamCopier spawnDaemon(String name) { + return spawn(name, true); + } + + private StreamCopier spawn(final String name, final boolean daemon) { + new Thread() { + { + setName(name); + setDaemon(daemon); + } + + @Override + public void run() { + try { + log.debug("Will copy from {} to {}", in, out); + copy(); + log.debug("Done copying from {}", in); + doneEvent.set(); + } catch (IOException ioe) { + log.error("In pipe from {} to {}: {}" + ioe.toString(), in, out); + doneEvent.error(ioe); + errCB.onError(ioe); + } + } + }.start(); + return this; + } + + public StreamCopier join(int timeout, TimeUnit unit) + throws IOException { + doneEvent.await(timeout, unit); return this; } @@ -131,7 +172,7 @@ public class StreamCopier final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0; final double sizeKiB = count / 1024.0; - logger.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds)); + log.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds)); if (length != -1 && read == -1) throw new IOException("Encountered EOF, could not transfer " + length + " bytes"); @@ -149,16 +190,4 @@ public class StreamCopier return count; } - @Override - public void run() { - try { - log.debug("Wil pipe from {} to {}", in, out); - copy(); - log.debug("EOF on {}", in); - } catch (IOException ioe) { - log.error("In pipe from {} to {}: " + ioe.toString(), in, out); - errCB.onError(ioe); - } - } - } \ No newline at end of file diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java index 53ca6c06..9821a377 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java @@ -55,17 +55,15 @@ public class LocalPortForwarder { } }); - new StreamCopier("chan2soc", getInputStream(), sock.getOutputStream()) + new StreamCopier(getInputStream(), sock.getOutputStream()) .bufSize(getLocalMaxPacketSize()) .errorCallback(closer) - .daemon(true) - .start(); + .spawnDaemon("chan2soc"); - new StreamCopier("soc2chan", sock.getInputStream(), getOutputStream()) + new StreamCopier(sock.getInputStream(), getOutputStream()) .bufSize(getRemoteMaxPacketSize()) .errorCallback(closer) - .daemon(true) - .start(); + .spawnDaemon("soc2chan"); } @Override diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java index b25f775e..c7e1b801 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java @@ -247,13 +247,13 @@ public class SessionChannel @Override @Deprecated public String getOutputAsString() throws IOException { - return IOUtils.pipeStream(getInputStream()).toString(); + return IOUtils.readFully(getInputStream()).toString(); } @Override @Deprecated public String getErrorAsString() throws IOException { - return IOUtils.pipeStream(getErrorStream()).toString(); + return IOUtils.readFully(getErrorStream()).toString(); } } diff --git a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/SocketForwardingConnectListener.java b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/SocketForwardingConnectListener.java index 123dcb69..3bfe8476 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/SocketForwardingConnectListener.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/SocketForwardingConnectListener.java @@ -62,17 +62,15 @@ public class SocketForwardingConnectListener } }); - new StreamCopier("soc2chan", sock.getInputStream(), chan.getOutputStream()) + new StreamCopier(sock.getInputStream(), chan.getOutputStream()) .bufSize(chan.getRemoteMaxPacketSize()) .errorCallback(closer) - .daemon(true) - .start(); + .spawnDaemon("soc2chan"); - new StreamCopier("chan2soc", chan.getInputStream(), sock.getOutputStream()) + new StreamCopier(chan.getInputStream(), sock.getOutputStream()) .bufSize(chan.getLocalMaxPacketSize()) .errorCallback(closer) - .daemon(true) - .start(); + .spawnDaemon("chan2soc"); } } \ No newline at end of file diff --git a/src/main/java/net/schmizz/sshj/sftp/SFTPFileTransfer.java b/src/main/java/net/schmizz/sshj/sftp/SFTPFileTransfer.java index 33011338..0d57aea1 100644 --- a/src/main/java/net/schmizz/sshj/sftp/SFTPFileTransfer.java +++ b/src/main/java/net/schmizz/sshj/sftp/SFTPFileTransfer.java @@ -135,8 +135,9 @@ public class SFTPFileTransfer try { final OutputStream os = adjusted.getOutputStream(); try { - new StreamCopier("sftp download", rf.getInputStream(), os) + new StreamCopier(rf.getInputStream(), os) .bufSize(engine.getSubsystem().getLocalMaxPacketSize()) + .keepFlushing(false) .listener(listener) .copy(); } finally { @@ -197,8 +198,9 @@ public class SFTPFileTransfer try { final InputStream fis = local.getInputStream(); try { - new StreamCopier("sftp upload", fis, rf.getOutputStream()) + new StreamCopier(fis, rf.getOutputStream()) .bufSize(engine.getSubsystem().getRemoteMaxPacketSize() - rf.getOutgoingPacketOverhead()) + .keepFlushing(false) .listener(listener) .copy(); } finally { diff --git a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java index 3d28bdc1..48fccac0 100644 --- a/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java +++ b/src/main/java/net/schmizz/sshj/transport/KeyExchanger.java @@ -37,7 +37,17 @@ package net.schmizz.sshj.transport; import net.schmizz.concurrent.Event; import net.schmizz.concurrent.FutureUtils; -import net.schmizz.sshj.common.*; +import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.DisconnectReason; +import net.schmizz.sshj.common.ErrorNotifiable; +import net.schmizz.sshj.common.Factory; +import net.schmizz.sshj.common.IOUtils; +import net.schmizz.sshj.common.KeyType; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHException; +import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.common.SSHPacketHandler; +import net.schmizz.sshj.common.SecurityUtils; import net.schmizz.sshj.transport.cipher.Cipher; import net.schmizz.sshj.transport.compression.Compression; import net.schmizz.sshj.transport.digest.Digest; @@ -92,7 +102,8 @@ final class KeyExchanger private Proposal clientProposal; private NegotiatedAlgorithms negotiatedAlgs; - private final Event kexInitSent = new Event("kexinit sent", TransportException.chainer); + private final Event kexInitSent = + new Event("kexinit sent", TransportException.chainer); private final Event done; @@ -208,11 +219,11 @@ final class KeyExchanger return; } - throw new TransportException(DisconnectReason.HOST_KEY_NOT_VERIFIABLE, "Could not verify `" - + KeyType - .fromKey(key) + "` host key with fingerprint `" + SecurityUtils.getFingerprint(key) - + "` for `" + transport - .getRemoteHost() + "` on port " + transport.getRemotePort()); + throw new TransportException(DisconnectReason.HOST_KEY_NOT_VERIFIABLE, + "Could not verify `" + KeyType.fromKey(key) + + "` host key with fingerprint `" + SecurityUtils.getFingerprint(key) + + "` for `" + transport.getRemoteHost() + + "` on port " + transport.getRemotePort()); } private void setKexDone() { @@ -229,8 +240,11 @@ final class KeyExchanger kex = Factory.Named.Util.create(transport.getConfig().getKeyExchangeFactories(), negotiatedAlgs .getKeyExchangeAlgorithm()); try { - kex.init(transport, transport.getServerID().getBytes(), transport.getClientID().getBytes(), buf - .getCompactData(), clientProposal.getPacket().getCompactData()); + kex.init(transport, + transport.getServerID().getBytes(IOUtils.UTF8), + transport.getClientID().getBytes(IOUtils.UTF8), + buf.getCompactData(), + clientProposal.getPacket().getCompactData()); } catch (GeneralSecurityException e) { throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, e); } @@ -321,10 +335,12 @@ final class KeyExchanger negotiatedAlgs.getServer2ClientMACAlgorithm()); mac_S2C.init(integrityKey_S2C); - final Compression compression_S2C = Factory.Named.Util.create(transport.getConfig().getCompressionFactories(), - negotiatedAlgs.getServer2ClientCompressionAlgorithm()); - final Compression compression_C2S = Factory.Named.Util.create(transport.getConfig().getCompressionFactories(), - negotiatedAlgs.getClient2ServerCompressionAlgorithm()); + final Compression compression_S2C = + Factory.Named.Util.create(transport.getConfig().getCompressionFactories(), + negotiatedAlgs.getServer2ClientCompressionAlgorithm()); + final Compression compression_C2S = + Factory.Named.Util.create(transport.getConfig().getCompressionFactories(), + negotiatedAlgs.getClient2ServerCompressionAlgorithm()); transport.getEncoder().setAlgorithms(cipher_C2S, mac_C2S, compression_C2S); transport.getDecoder().setAlgorithms(cipher_S2C, mac_S2C, compression_S2C); diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index 4c43524b..9fcb3a4e 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -148,7 +148,7 @@ public final class TransportImpl try { log.info("Client identity string: {}", clientID); - connInfo.out.write((clientID + "\r\n").getBytes()); + connInfo.out.write((clientID + "\r\n").getBytes(IOUtils.UTF8)); // Read server's ID final Buffer.PlainBuffer buf = new Buffer.PlainBuffer(); diff --git a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java index 9a2d485e..1e657cb6 100644 --- a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java +++ b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java @@ -196,7 +196,7 @@ public class OpenSSHKnownHosts private String hashHost(String host) throws IOException { sha1.init(getSaltyBytes()); - return "|1|" + getSalt() + "|" + Base64.encodeBytes(sha1.doFinal(host.getBytes())); + return "|1|" + getSalt() + "|" + Base64.encodeBytes(sha1.doFinal(host.getBytes(IOUtils.UTF8))); } private byte[] getSaltyBytes() @@ -289,7 +289,7 @@ public class OpenSSHKnownHosts final BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(khFile)); try { for (Entry entry : entries) - bos.write((entry.getLine() + LS).getBytes()); + bos.write((entry.getLine() + LS).getBytes(IOUtils.UTF8)); } finally { bos.close(); } diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java index 55db6107..92255197 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java @@ -69,10 +69,10 @@ public final class SCPDownloadClient { engine.signal("Start status OK"); - String msg = engine.readMessage(true); + String msg = engine.readMessage(); do process(null, msg, targetFile); - while ((msg = engine.readMessage(false)) != null); + while (!(msg = engine.readMessage()).isEmpty()); } private long parseLong(String longString, String valType) @@ -102,7 +102,7 @@ public final class SCPDownloadClient { case 'T': engine.signal("ACK: T"); - process(msg, engine.readMessage(true), f); + process(msg, engine.readMessage(), f); break; case 'C': diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java index 15652083..b7fc94e0 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java @@ -24,6 +24,7 @@ import net.schmizz.sshj.xfer.TransferListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -77,7 +78,7 @@ class SCPEngine { int code = scp.getInputStream().read(); switch (code) { case -1: - String stderr = IOUtils.pipeStream(scp.getErrorStream()).toString(); + String stderr = IOUtils.readFully(scp.getErrorStream()).toString(); if (!stderr.isEmpty()) stderr = ". Additional info: `" + stderr + "`"; throw new SCPException("EOF while expecting response to protocol message" + stderr); @@ -98,7 +99,7 @@ class SCPEngine { void execSCPWith(List args, String path) throws SSHException { - StringBuilder cmd = new StringBuilder(SCP_COMMAND); + final StringBuilder cmd = new StringBuilder(SCP_COMMAND); for (Arg arg : args) cmd.append(" ").append(arg); cmd.append(" "); @@ -130,29 +131,25 @@ class SCPEngine { String readMessage() throws IOException { - return readMessage(true); - } - - String readMessage(boolean errOnEOF) - throws IOException { - StringBuilder sb = new StringBuilder(); + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); int x; while ((x = scp.getInputStream().read()) != LF) if (x == -1) { - if (errOnEOF) - throw new IOException("EOF while reading message"); + if (baos.size() == 0) + return ""; else - return null; + throw new IOException("EOF while reading message"); } else - sb.append((char) x); - log.debug("Read message: {}", sb); - return sb.toString(); + baos.write(x); + final String msg = baos.toString(IOUtils.UTF8.displayName()); + log.debug("Read message: `{}`", msg); + return msg; } void sendMessage(String msg) throws IOException { log.debug("Sending message: {}", msg); - scp.getOutputStream().write((msg + LF).getBytes()); + scp.getOutputStream().write((msg + LF).getBytes(IOUtils.UTF8)); scp.getOutputStream().flush(); check("Message ACK received"); } @@ -164,25 +161,26 @@ class SCPEngine { scp.getOutputStream().flush(); } - long transferToRemote(final InputStream src, final long length) + long transferToRemote(InputStream src, long length) throws IOException { return transfer(src, scp.getOutputStream(), scp.getRemoteMaxPacketSize(), length); } - long transferFromRemote(final OutputStream dest, final long length) + long transferFromRemote(OutputStream dest, long length) throws IOException { return transfer(scp.getInputStream(), dest, scp.getLocalMaxPacketSize(), length); } private long transfer(InputStream in, OutputStream out, int bufSize, long len) throws IOException { - return new StreamCopier("scp engine", in, out) + return new StreamCopier(in, out) .bufSize(bufSize).length(len) + .keepFlushing(false) .listener(listener) .copy(); } - void startedDir(final String dirname) { + void startedDir(String dirname) { listener.startedDir(dirname); } @@ -190,7 +188,7 @@ class SCPEngine { listener.finishedDir(); } - void startedFile(final String filename, final long length) { + void startedFile(String filename, long length) { listener.startedFile(filename, length); }