misc cleanups

This commit is contained in:
Shikhar Bhushan
2011-04-25 15:17:54 +01:00
parent 4de741359e
commit 85abcb7aad
15 changed files with 140 additions and 107 deletions

View File

@@ -37,7 +37,7 @@ public class Exec {
final Session session = ssh.startSession(); final Session session = ssh.startSession();
try { try {
final Command cmd = session.exec("ping -c 1 google.com"); final Command cmd = session.exec("ping -c 1 google.com");
System.out.println(IOUtils.pipeStream(cmd.getInputStream()).toString()); System.out.println(IOUtils.readFully(cmd.getInputStream()).toString());
cmd.join(5, TimeUnit.SECONDS); cmd.join(5, TimeUnit.SECONDS);
System.out.println("\n** exit status: " + cmd.getExitStatus()); System.out.println("\n** exit status: " + cmd.getExitStatus());
} finally { } finally {

View File

@@ -48,22 +48,19 @@ class RudimentaryPTY {
final Shell shell = session.startShell(); final Shell shell = session.startShell();
new StreamCopier("stdout", shell.getInputStream(), System.out) new StreamCopier(shell.getInputStream(), System.out)
.bufSize(shell.getLocalMaxPacketSize()) .bufSize(shell.getLocalMaxPacketSize())
.keepFlushing(true) .spawn("stdout");
.start();
new StreamCopier("stderr", shell.getErrorStream(), System.err) new StreamCopier(shell.getErrorStream(), System.err)
.bufSize(shell.getLocalMaxPacketSize()) .bufSize(shell.getLocalMaxPacketSize())
.keepFlushing(true) .spawn("stderr");
.start();
// Now make System.in act as stdin. To exit, hit Ctrl+D (since that results in an EOF on System.in) // Now make System.in act as stdin. To exit, hit Ctrl+D (since that results in an EOF on System.in)
// This is kinda messy because java only allows console input after you hit return // This is kinda messy because java only allows console input after you hit return
// But this is just an example... a GUI app could implement a proper PTY // But this is just an example... a GUI app could implement a proper PTY
new StreamCopier("stdin", System.in, shell.getOutputStream()) new StreamCopier(System.in, shell.getOutputStream())
.bufSize(shell.getRemoteMaxPacketSize()) .bufSize(shell.getRemoteMaxPacketSize())
.keepFlushing(true)
.copy(); .copy();
} finally { } finally {

View File

@@ -57,8 +57,8 @@ public class X11 {
final Command cmd = sess.exec("/usr/X11/bin/xcalc"); final Command cmd = sess.exec("/usr/X11/bin/xcalc");
new StreamCopier("stdout", cmd.getInputStream(), System.out).start(); new StreamCopier(cmd.getInputStream(), System.out).spawn("stdout");
new StreamCopier("stderr", cmd.getErrorStream(), System.err).start(); new StreamCopier(cmd.getErrorStream(), System.err).spawn("stderr");
// Wait for session & X11 channel to get closed // Wait for session & X11 channel to get closed
ssh.getConnection().join(); ssh.getConnection().join();

View File

@@ -427,20 +427,16 @@ public class Buffer<T extends Buffer<T>> {
} }
public T putString(String string) { public T putString(String string) {
try { return putString(string.getBytes(IOUtils.UTF8));
return putString(string.getBytes("UTF-8"));
} catch (UnsupportedEncodingException e) {
throw new SSHRuntimeException(e);
}
} }
/** /**
* Writes a char-array as an SSH string and then blanks it out. * Writes a char-array as an SSH string and then blanks it out.
* <p/> * <p/>
* This is useful when a plaintext password needs to be sent. If {@code passwd} is {@code null}, an empty string is * This is useful when a plaintext password needs to be sent. If {@code str} is {@code null}, an empty string is
* written. * written.
* *
* @param str (null-ok) the password as a character array * @param str (null-ok) the string as a character array
* *
* @return this * @return this
*/ */

View File

@@ -42,11 +42,14 @@ import java.io.ByteArrayOutputStream;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.nio.charset.Charset;
public class IOUtils { public class IOUtils {
private static final Logger LOG = LoggerFactory.getLogger(IOUtils.class); private static final Logger LOG = LoggerFactory.getLogger(IOUtils.class);
public static final Charset UTF8 = Charset.forName("UTF-8");
public static void closeQuietly(Closeable... closeables) { public static void closeQuietly(Closeable... closeables) {
for (Closeable c : closeables) for (Closeable c : closeables)
try { try {
@@ -57,15 +60,11 @@ public class IOUtils {
} }
} }
public static ByteArrayOutputStream pipeStream(InputStream stream) public static ByteArrayOutputStream readFully(InputStream stream)
throws IOException { throws IOException {
final ByteArrayOutputStream bos = new ByteArrayOutputStream(); final ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buf = new byte[1024]; new StreamCopier(stream, baos).copy();
int read; return baos;
while ((read = (stream.read(buf))) != -1)
bos.write(buf, 0, read);
bos.flush();
return bos;
} }
} }

View File

@@ -15,6 +15,8 @@
*/ */
package net.schmizz.sshj.common; package net.schmizz.sshj.common;
import net.schmizz.concurrent.Event;
import net.schmizz.concurrent.ExceptionChainer;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -22,11 +24,9 @@ import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.concurrent.TimeUnit;
public class StreamCopier public class StreamCopier {
extends Thread {
private final Logger logger = LoggerFactory.getLogger(getClass());
public interface ErrorCallback { public interface ErrorCallback {
@@ -36,7 +36,7 @@ public class StreamCopier
public interface Listener { public interface Listener {
void reportProgress(long transferred); void reportProgress(long transferred) throws IOException;
} }
@@ -65,17 +65,25 @@ public class StreamCopier
private final InputStream in; private final InputStream in;
private final OutputStream out; private final OutputStream out;
private int bufSize = 1;
private boolean keepFlushing = false;
private long length = -1;
private Listener listener = NULL_LISTENER; private Listener listener = NULL_LISTENER;
private ErrorCallback errCB = NULL_CALLBACK; private ErrorCallback errCB = NULL_CALLBACK;
public StreamCopier(String name, InputStream in, OutputStream out) { private int bufSize = 1;
private boolean keepFlushing = true;
private long length = -1;
private final Event<IOException> doneEvent =
new Event<IOException>("copyDone", new ExceptionChainer<IOException>() {
@Override
public IOException chain(Throwable t) {
return (t instanceof IOException) ? (IOException) t : new IOException(t);
}
});
public StreamCopier(InputStream in, OutputStream out) {
this.in = in; this.in = in;
this.out = out; this.out = out;
setName(name);
} }
public StreamCopier bufSize(int bufSize) { public StreamCopier bufSize(int bufSize) {
@@ -105,8 +113,41 @@ public class StreamCopier
return this; return this;
} }
public StreamCopier daemon(boolean choice) { public StreamCopier spawn(String name) {
setDaemon(choice); return spawn(name, false);
}
public StreamCopier spawnDaemon(String name) {
return spawn(name, true);
}
private StreamCopier spawn(final String name, final boolean daemon) {
new Thread() {
{
setName(name);
setDaemon(daemon);
}
@Override
public void run() {
try {
log.debug("Will copy from {} to {}", in, out);
copy();
log.debug("Done copying from {}", in);
doneEvent.set();
} catch (IOException ioe) {
log.error("In pipe from {} to {}: {}" + ioe.toString(), in, out);
doneEvent.error(ioe);
errCB.onError(ioe);
}
}
}.start();
return this;
}
public StreamCopier join(int timeout, TimeUnit unit)
throws IOException {
doneEvent.await(timeout, unit);
return this; return this;
} }
@@ -131,7 +172,7 @@ public class StreamCopier
final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0; final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0;
final double sizeKiB = count / 1024.0; final double sizeKiB = count / 1024.0;
logger.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds)); log.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds));
if (length != -1 && read == -1) if (length != -1 && read == -1)
throw new IOException("Encountered EOF, could not transfer " + length + " bytes"); throw new IOException("Encountered EOF, could not transfer " + length + " bytes");
@@ -149,16 +190,4 @@ public class StreamCopier
return count; return count;
} }
@Override
public void run() {
try {
log.debug("Wil pipe from {} to {}", in, out);
copy();
log.debug("EOF on {}", in);
} catch (IOException ioe) {
log.error("In pipe from {} to {}: " + ioe.toString(), in, out);
errCB.onError(ioe);
}
}
} }

View File

@@ -55,17 +55,15 @@ public class LocalPortForwarder {
} }
}); });
new StreamCopier("chan2soc", getInputStream(), sock.getOutputStream()) new StreamCopier(getInputStream(), sock.getOutputStream())
.bufSize(getLocalMaxPacketSize()) .bufSize(getLocalMaxPacketSize())
.errorCallback(closer) .errorCallback(closer)
.daemon(true) .spawnDaemon("chan2soc");
.start();
new StreamCopier("soc2chan", sock.getInputStream(), getOutputStream()) new StreamCopier(sock.getInputStream(), getOutputStream())
.bufSize(getRemoteMaxPacketSize()) .bufSize(getRemoteMaxPacketSize())
.errorCallback(closer) .errorCallback(closer)
.daemon(true) .spawnDaemon("soc2chan");
.start();
} }
@Override @Override

View File

@@ -247,13 +247,13 @@ public class SessionChannel
@Override @Override
@Deprecated @Deprecated
public String getOutputAsString() throws IOException { public String getOutputAsString() throws IOException {
return IOUtils.pipeStream(getInputStream()).toString(); return IOUtils.readFully(getInputStream()).toString();
} }
@Override @Override
@Deprecated @Deprecated
public String getErrorAsString() throws IOException { public String getErrorAsString() throws IOException {
return IOUtils.pipeStream(getErrorStream()).toString(); return IOUtils.readFully(getErrorStream()).toString();
} }
} }

View File

@@ -62,17 +62,15 @@ public class SocketForwardingConnectListener
} }
}); });
new StreamCopier("soc2chan", sock.getInputStream(), chan.getOutputStream()) new StreamCopier(sock.getInputStream(), chan.getOutputStream())
.bufSize(chan.getRemoteMaxPacketSize()) .bufSize(chan.getRemoteMaxPacketSize())
.errorCallback(closer) .errorCallback(closer)
.daemon(true) .spawnDaemon("soc2chan");
.start();
new StreamCopier("chan2soc", chan.getInputStream(), sock.getOutputStream()) new StreamCopier(chan.getInputStream(), sock.getOutputStream())
.bufSize(chan.getLocalMaxPacketSize()) .bufSize(chan.getLocalMaxPacketSize())
.errorCallback(closer) .errorCallback(closer)
.daemon(true) .spawnDaemon("chan2soc");
.start();
} }
} }

View File

@@ -135,8 +135,9 @@ public class SFTPFileTransfer
try { try {
final OutputStream os = adjusted.getOutputStream(); final OutputStream os = adjusted.getOutputStream();
try { try {
new StreamCopier("sftp download", rf.getInputStream(), os) new StreamCopier(rf.getInputStream(), os)
.bufSize(engine.getSubsystem().getLocalMaxPacketSize()) .bufSize(engine.getSubsystem().getLocalMaxPacketSize())
.keepFlushing(false)
.listener(listener) .listener(listener)
.copy(); .copy();
} finally { } finally {
@@ -197,8 +198,9 @@ public class SFTPFileTransfer
try { try {
final InputStream fis = local.getInputStream(); final InputStream fis = local.getInputStream();
try { try {
new StreamCopier("sftp upload", fis, rf.getOutputStream()) new StreamCopier(fis, rf.getOutputStream())
.bufSize(engine.getSubsystem().getRemoteMaxPacketSize() - rf.getOutgoingPacketOverhead()) .bufSize(engine.getSubsystem().getRemoteMaxPacketSize() - rf.getOutgoingPacketOverhead())
.keepFlushing(false)
.listener(listener) .listener(listener)
.copy(); .copy();
} finally { } finally {

View File

@@ -37,7 +37,17 @@ package net.schmizz.sshj.transport;
import net.schmizz.concurrent.Event; import net.schmizz.concurrent.Event;
import net.schmizz.concurrent.FutureUtils; import net.schmizz.concurrent.FutureUtils;
import net.schmizz.sshj.common.*; import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.ErrorNotifiable;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.common.SSHPacketHandler;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.cipher.Cipher; import net.schmizz.sshj.transport.cipher.Cipher;
import net.schmizz.sshj.transport.compression.Compression; import net.schmizz.sshj.transport.compression.Compression;
import net.schmizz.sshj.transport.digest.Digest; import net.schmizz.sshj.transport.digest.Digest;
@@ -92,7 +102,8 @@ final class KeyExchanger
private Proposal clientProposal; private Proposal clientProposal;
private NegotiatedAlgorithms negotiatedAlgs; private NegotiatedAlgorithms negotiatedAlgs;
private final Event<TransportException> kexInitSent = new Event<TransportException>("kexinit sent", TransportException.chainer); private final Event<TransportException> kexInitSent =
new Event<TransportException>("kexinit sent", TransportException.chainer);
private final Event<TransportException> done; private final Event<TransportException> done;
@@ -208,11 +219,11 @@ final class KeyExchanger
return; return;
} }
throw new TransportException(DisconnectReason.HOST_KEY_NOT_VERIFIABLE, "Could not verify `" throw new TransportException(DisconnectReason.HOST_KEY_NOT_VERIFIABLE,
+ KeyType "Could not verify `" + KeyType.fromKey(key)
.fromKey(key) + "` host key with fingerprint `" + SecurityUtils.getFingerprint(key) + "` host key with fingerprint `" + SecurityUtils.getFingerprint(key)
+ "` for `" + transport + "` for `" + transport.getRemoteHost()
.getRemoteHost() + "` on port " + transport.getRemotePort()); + "` on port " + transport.getRemotePort());
} }
private void setKexDone() { private void setKexDone() {
@@ -229,8 +240,11 @@ final class KeyExchanger
kex = Factory.Named.Util.create(transport.getConfig().getKeyExchangeFactories(), negotiatedAlgs kex = Factory.Named.Util.create(transport.getConfig().getKeyExchangeFactories(), negotiatedAlgs
.getKeyExchangeAlgorithm()); .getKeyExchangeAlgorithm());
try { try {
kex.init(transport, transport.getServerID().getBytes(), transport.getClientID().getBytes(), buf kex.init(transport,
.getCompactData(), clientProposal.getPacket().getCompactData()); transport.getServerID().getBytes(IOUtils.UTF8),
transport.getClientID().getBytes(IOUtils.UTF8),
buf.getCompactData(),
clientProposal.getPacket().getCompactData());
} catch (GeneralSecurityException e) { } catch (GeneralSecurityException e) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, e); throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, e);
} }
@@ -321,10 +335,12 @@ final class KeyExchanger
negotiatedAlgs.getServer2ClientMACAlgorithm()); negotiatedAlgs.getServer2ClientMACAlgorithm());
mac_S2C.init(integrityKey_S2C); mac_S2C.init(integrityKey_S2C);
final Compression compression_S2C = Factory.Named.Util.create(transport.getConfig().getCompressionFactories(), final Compression compression_S2C =
negotiatedAlgs.getServer2ClientCompressionAlgorithm()); Factory.Named.Util.create(transport.getConfig().getCompressionFactories(),
final Compression compression_C2S = Factory.Named.Util.create(transport.getConfig().getCompressionFactories(), negotiatedAlgs.getServer2ClientCompressionAlgorithm());
negotiatedAlgs.getClient2ServerCompressionAlgorithm()); final Compression compression_C2S =
Factory.Named.Util.create(transport.getConfig().getCompressionFactories(),
negotiatedAlgs.getClient2ServerCompressionAlgorithm());
transport.getEncoder().setAlgorithms(cipher_C2S, mac_C2S, compression_C2S); transport.getEncoder().setAlgorithms(cipher_C2S, mac_C2S, compression_C2S);
transport.getDecoder().setAlgorithms(cipher_S2C, mac_S2C, compression_S2C); transport.getDecoder().setAlgorithms(cipher_S2C, mac_S2C, compression_S2C);

View File

@@ -148,7 +148,7 @@ public final class TransportImpl
try { try {
log.info("Client identity string: {}", clientID); log.info("Client identity string: {}", clientID);
connInfo.out.write((clientID + "\r\n").getBytes()); connInfo.out.write((clientID + "\r\n").getBytes(IOUtils.UTF8));
// Read server's ID // Read server's ID
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer(); final Buffer.PlainBuffer buf = new Buffer.PlainBuffer();

View File

@@ -196,7 +196,7 @@ public class OpenSSHKnownHosts
private String hashHost(String host) private String hashHost(String host)
throws IOException { throws IOException {
sha1.init(getSaltyBytes()); sha1.init(getSaltyBytes());
return "|1|" + getSalt() + "|" + Base64.encodeBytes(sha1.doFinal(host.getBytes())); return "|1|" + getSalt() + "|" + Base64.encodeBytes(sha1.doFinal(host.getBytes(IOUtils.UTF8)));
} }
private byte[] getSaltyBytes() private byte[] getSaltyBytes()
@@ -289,7 +289,7 @@ public class OpenSSHKnownHosts
final BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(khFile)); final BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(khFile));
try { try {
for (Entry entry : entries) for (Entry entry : entries)
bos.write((entry.getLine() + LS).getBytes()); bos.write((entry.getLine() + LS).getBytes(IOUtils.UTF8));
} finally { } finally {
bos.close(); bos.close();
} }

View File

@@ -69,10 +69,10 @@ public final class SCPDownloadClient {
engine.signal("Start status OK"); engine.signal("Start status OK");
String msg = engine.readMessage(true); String msg = engine.readMessage();
do do
process(null, msg, targetFile); process(null, msg, targetFile);
while ((msg = engine.readMessage(false)) != null); while (!(msg = engine.readMessage()).isEmpty());
} }
private long parseLong(String longString, String valType) private long parseLong(String longString, String valType)
@@ -102,7 +102,7 @@ public final class SCPDownloadClient {
case 'T': case 'T':
engine.signal("ACK: T"); engine.signal("ACK: T");
process(msg, engine.readMessage(true), f); process(msg, engine.readMessage(), f);
break; break;
case 'C': case 'C':

View File

@@ -24,6 +24,7 @@ import net.schmizz.sshj.xfer.TransferListener;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@@ -77,7 +78,7 @@ class SCPEngine {
int code = scp.getInputStream().read(); int code = scp.getInputStream().read();
switch (code) { switch (code) {
case -1: case -1:
String stderr = IOUtils.pipeStream(scp.getErrorStream()).toString(); String stderr = IOUtils.readFully(scp.getErrorStream()).toString();
if (!stderr.isEmpty()) if (!stderr.isEmpty())
stderr = ". Additional info: `" + stderr + "`"; stderr = ". Additional info: `" + stderr + "`";
throw new SCPException("EOF while expecting response to protocol message" + stderr); throw new SCPException("EOF while expecting response to protocol message" + stderr);
@@ -98,7 +99,7 @@ class SCPEngine {
void execSCPWith(List<Arg> args, String path) void execSCPWith(List<Arg> args, String path)
throws SSHException { throws SSHException {
StringBuilder cmd = new StringBuilder(SCP_COMMAND); final StringBuilder cmd = new StringBuilder(SCP_COMMAND);
for (Arg arg : args) for (Arg arg : args)
cmd.append(" ").append(arg); cmd.append(" ").append(arg);
cmd.append(" "); cmd.append(" ");
@@ -130,29 +131,25 @@ class SCPEngine {
String readMessage() String readMessage()
throws IOException { throws IOException {
return readMessage(true); final ByteArrayOutputStream baos = new ByteArrayOutputStream();
}
String readMessage(boolean errOnEOF)
throws IOException {
StringBuilder sb = new StringBuilder();
int x; int x;
while ((x = scp.getInputStream().read()) != LF) while ((x = scp.getInputStream().read()) != LF)
if (x == -1) { if (x == -1) {
if (errOnEOF) if (baos.size() == 0)
throw new IOException("EOF while reading message"); return "";
else else
return null; throw new IOException("EOF while reading message");
} else } else
sb.append((char) x); baos.write(x);
log.debug("Read message: {}", sb); final String msg = baos.toString(IOUtils.UTF8.displayName());
return sb.toString(); log.debug("Read message: `{}`", msg);
return msg;
} }
void sendMessage(String msg) void sendMessage(String msg)
throws IOException { throws IOException {
log.debug("Sending message: {}", msg); log.debug("Sending message: {}", msg);
scp.getOutputStream().write((msg + LF).getBytes()); scp.getOutputStream().write((msg + LF).getBytes(IOUtils.UTF8));
scp.getOutputStream().flush(); scp.getOutputStream().flush();
check("Message ACK received"); check("Message ACK received");
} }
@@ -164,25 +161,26 @@ class SCPEngine {
scp.getOutputStream().flush(); scp.getOutputStream().flush();
} }
long transferToRemote(final InputStream src, final long length) long transferToRemote(InputStream src, long length)
throws IOException { throws IOException {
return transfer(src, scp.getOutputStream(), scp.getRemoteMaxPacketSize(), length); return transfer(src, scp.getOutputStream(), scp.getRemoteMaxPacketSize(), length);
} }
long transferFromRemote(final OutputStream dest, final long length) long transferFromRemote(OutputStream dest, long length)
throws IOException { throws IOException {
return transfer(scp.getInputStream(), dest, scp.getLocalMaxPacketSize(), length); return transfer(scp.getInputStream(), dest, scp.getLocalMaxPacketSize(), length);
} }
private long transfer(InputStream in, OutputStream out, int bufSize, long len) private long transfer(InputStream in, OutputStream out, int bufSize, long len)
throws IOException { throws IOException {
return new StreamCopier("scp engine", in, out) return new StreamCopier(in, out)
.bufSize(bufSize).length(len) .bufSize(bufSize).length(len)
.keepFlushing(false)
.listener(listener) .listener(listener)
.copy(); .copy();
} }
void startedDir(final String dirname) { void startedDir(String dirname) {
listener.startedDir(dirname); listener.startedDir(dirname);
} }
@@ -190,7 +188,7 @@ class SCPEngine {
listener.finishedDir(); listener.finishedDir();
} }
void startedFile(final String filename, final long length) { void startedFile(String filename, long length) {
listener.startedFile(filename, length); listener.startedFile(filename, length);
} }