misc cleanups

This commit is contained in:
Shikhar Bhushan
2011-04-25 15:17:54 +01:00
parent 4de741359e
commit 85abcb7aad
15 changed files with 140 additions and 107 deletions

View File

@@ -37,7 +37,7 @@ public class Exec {
final Session session = ssh.startSession();
try {
final Command cmd = session.exec("ping -c 1 google.com");
System.out.println(IOUtils.pipeStream(cmd.getInputStream()).toString());
System.out.println(IOUtils.readFully(cmd.getInputStream()).toString());
cmd.join(5, TimeUnit.SECONDS);
System.out.println("\n** exit status: " + cmd.getExitStatus());
} finally {

View File

@@ -48,22 +48,19 @@ class RudimentaryPTY {
final Shell shell = session.startShell();
new StreamCopier("stdout", shell.getInputStream(), System.out)
new StreamCopier(shell.getInputStream(), System.out)
.bufSize(shell.getLocalMaxPacketSize())
.keepFlushing(true)
.start();
.spawn("stdout");
new StreamCopier("stderr", shell.getErrorStream(), System.err)
new StreamCopier(shell.getErrorStream(), System.err)
.bufSize(shell.getLocalMaxPacketSize())
.keepFlushing(true)
.start();
.spawn("stderr");
// Now make System.in act as stdin. To exit, hit Ctrl+D (since that results in an EOF on System.in)
// This is kinda messy because java only allows console input after you hit return
// But this is just an example... a GUI app could implement a proper PTY
new StreamCopier("stdin", System.in, shell.getOutputStream())
new StreamCopier(System.in, shell.getOutputStream())
.bufSize(shell.getRemoteMaxPacketSize())
.keepFlushing(true)
.copy();
} finally {

View File

@@ -57,8 +57,8 @@ public class X11 {
final Command cmd = sess.exec("/usr/X11/bin/xcalc");
new StreamCopier("stdout", cmd.getInputStream(), System.out).start();
new StreamCopier("stderr", cmd.getErrorStream(), System.err).start();
new StreamCopier(cmd.getInputStream(), System.out).spawn("stdout");
new StreamCopier(cmd.getErrorStream(), System.err).spawn("stderr");
// Wait for session & X11 channel to get closed
ssh.getConnection().join();

View File

@@ -427,20 +427,16 @@ public class Buffer<T extends Buffer<T>> {
}
public T putString(String string) {
try {
return putString(string.getBytes("UTF-8"));
} catch (UnsupportedEncodingException e) {
throw new SSHRuntimeException(e);
}
return putString(string.getBytes(IOUtils.UTF8));
}
/**
* Writes a char-array as an SSH string and then blanks it out.
* <p/>
* This is useful when a plaintext password needs to be sent. If {@code passwd} is {@code null}, an empty string is
* This is useful when a plaintext password needs to be sent. If {@code str} is {@code null}, an empty string is
* written.
*
* @param str (null-ok) the password as a character array
* @param str (null-ok) the string as a character array
*
* @return this
*/

View File

@@ -42,11 +42,14 @@ import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
public class IOUtils {
private static final Logger LOG = LoggerFactory.getLogger(IOUtils.class);
public static final Charset UTF8 = Charset.forName("UTF-8");
public static void closeQuietly(Closeable... closeables) {
for (Closeable c : closeables)
try {
@@ -57,15 +60,11 @@ public class IOUtils {
}
}
public static ByteArrayOutputStream pipeStream(InputStream stream)
public static ByteArrayOutputStream readFully(InputStream stream)
throws IOException {
final ByteArrayOutputStream bos = new ByteArrayOutputStream();
byte[] buf = new byte[1024];
int read;
while ((read = (stream.read(buf))) != -1)
bos.write(buf, 0, read);
bos.flush();
return bos;
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
new StreamCopier(stream, baos).copy();
return baos;
}
}
}

View File

@@ -15,6 +15,8 @@
*/
package net.schmizz.sshj.common;
import net.schmizz.concurrent.Event;
import net.schmizz.concurrent.ExceptionChainer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -22,11 +24,9 @@ import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.TimeUnit;
public class StreamCopier
extends Thread {
private final Logger logger = LoggerFactory.getLogger(getClass());
public class StreamCopier {
public interface ErrorCallback {
@@ -36,7 +36,7 @@ public class StreamCopier
public interface Listener {
void reportProgress(long transferred);
void reportProgress(long transferred) throws IOException;
}
@@ -65,17 +65,25 @@ public class StreamCopier
private final InputStream in;
private final OutputStream out;
private int bufSize = 1;
private boolean keepFlushing = false;
private long length = -1;
private Listener listener = NULL_LISTENER;
private ErrorCallback errCB = NULL_CALLBACK;
public StreamCopier(String name, InputStream in, OutputStream out) {
private int bufSize = 1;
private boolean keepFlushing = true;
private long length = -1;
private final Event<IOException> doneEvent =
new Event<IOException>("copyDone", new ExceptionChainer<IOException>() {
@Override
public IOException chain(Throwable t) {
return (t instanceof IOException) ? (IOException) t : new IOException(t);
}
});
public StreamCopier(InputStream in, OutputStream out) {
this.in = in;
this.out = out;
setName(name);
}
public StreamCopier bufSize(int bufSize) {
@@ -105,8 +113,41 @@ public class StreamCopier
return this;
}
public StreamCopier daemon(boolean choice) {
setDaemon(choice);
public StreamCopier spawn(String name) {
return spawn(name, false);
}
public StreamCopier spawnDaemon(String name) {
return spawn(name, true);
}
private StreamCopier spawn(final String name, final boolean daemon) {
new Thread() {
{
setName(name);
setDaemon(daemon);
}
@Override
public void run() {
try {
log.debug("Will copy from {} to {}", in, out);
copy();
log.debug("Done copying from {}", in);
doneEvent.set();
} catch (IOException ioe) {
log.error("In pipe from {} to {}: {}" + ioe.toString(), in, out);
doneEvent.error(ioe);
errCB.onError(ioe);
}
}
}.start();
return this;
}
public StreamCopier join(int timeout, TimeUnit unit)
throws IOException {
doneEvent.await(timeout, unit);
return this;
}
@@ -131,7 +172,7 @@ public class StreamCopier
final double timeSeconds = (System.currentTimeMillis() - startTime) / 1000.0;
final double sizeKiB = count / 1024.0;
logger.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds));
log.info(sizeKiB + " KiB transferred in {} seconds ({} KiB/s)", timeSeconds, (sizeKiB / timeSeconds));
if (length != -1 && read == -1)
throw new IOException("Encountered EOF, could not transfer " + length + " bytes");
@@ -149,16 +190,4 @@ public class StreamCopier
return count;
}
@Override
public void run() {
try {
log.debug("Wil pipe from {} to {}", in, out);
copy();
log.debug("EOF on {}", in);
} catch (IOException ioe) {
log.error("In pipe from {} to {}: " + ioe.toString(), in, out);
errCB.onError(ioe);
}
}
}

View File

@@ -55,17 +55,15 @@ public class LocalPortForwarder {
}
});
new StreamCopier("chan2soc", getInputStream(), sock.getOutputStream())
new StreamCopier(getInputStream(), sock.getOutputStream())
.bufSize(getLocalMaxPacketSize())
.errorCallback(closer)
.daemon(true)
.start();
.spawnDaemon("chan2soc");
new StreamCopier("soc2chan", sock.getInputStream(), getOutputStream())
new StreamCopier(sock.getInputStream(), getOutputStream())
.bufSize(getRemoteMaxPacketSize())
.errorCallback(closer)
.daemon(true)
.start();
.spawnDaemon("soc2chan");
}
@Override

View File

@@ -247,13 +247,13 @@ public class SessionChannel
@Override
@Deprecated
public String getOutputAsString() throws IOException {
return IOUtils.pipeStream(getInputStream()).toString();
return IOUtils.readFully(getInputStream()).toString();
}
@Override
@Deprecated
public String getErrorAsString() throws IOException {
return IOUtils.pipeStream(getErrorStream()).toString();
return IOUtils.readFully(getErrorStream()).toString();
}
}

View File

@@ -62,17 +62,15 @@ public class SocketForwardingConnectListener
}
});
new StreamCopier("soc2chan", sock.getInputStream(), chan.getOutputStream())
new StreamCopier(sock.getInputStream(), chan.getOutputStream())
.bufSize(chan.getRemoteMaxPacketSize())
.errorCallback(closer)
.daemon(true)
.start();
.spawnDaemon("soc2chan");
new StreamCopier("chan2soc", chan.getInputStream(), sock.getOutputStream())
new StreamCopier(chan.getInputStream(), sock.getOutputStream())
.bufSize(chan.getLocalMaxPacketSize())
.errorCallback(closer)
.daemon(true)
.start();
.spawnDaemon("chan2soc");
}
}

View File

@@ -135,8 +135,9 @@ public class SFTPFileTransfer
try {
final OutputStream os = adjusted.getOutputStream();
try {
new StreamCopier("sftp download", rf.getInputStream(), os)
new StreamCopier(rf.getInputStream(), os)
.bufSize(engine.getSubsystem().getLocalMaxPacketSize())
.keepFlushing(false)
.listener(listener)
.copy();
} finally {
@@ -197,8 +198,9 @@ public class SFTPFileTransfer
try {
final InputStream fis = local.getInputStream();
try {
new StreamCopier("sftp upload", fis, rf.getOutputStream())
new StreamCopier(fis, rf.getOutputStream())
.bufSize(engine.getSubsystem().getRemoteMaxPacketSize() - rf.getOutgoingPacketOverhead())
.keepFlushing(false)
.listener(listener)
.copy();
} finally {

View File

@@ -37,7 +37,17 @@ package net.schmizz.sshj.transport;
import net.schmizz.concurrent.Event;
import net.schmizz.concurrent.FutureUtils;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.ErrorNotifiable;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.common.SSHPacketHandler;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.transport.cipher.Cipher;
import net.schmizz.sshj.transport.compression.Compression;
import net.schmizz.sshj.transport.digest.Digest;
@@ -92,7 +102,8 @@ final class KeyExchanger
private Proposal clientProposal;
private NegotiatedAlgorithms negotiatedAlgs;
private final Event<TransportException> kexInitSent = new Event<TransportException>("kexinit sent", TransportException.chainer);
private final Event<TransportException> kexInitSent =
new Event<TransportException>("kexinit sent", TransportException.chainer);
private final Event<TransportException> done;
@@ -208,11 +219,11 @@ final class KeyExchanger
return;
}
throw new TransportException(DisconnectReason.HOST_KEY_NOT_VERIFIABLE, "Could not verify `"
+ KeyType
.fromKey(key) + "` host key with fingerprint `" + SecurityUtils.getFingerprint(key)
+ "` for `" + transport
.getRemoteHost() + "` on port " + transport.getRemotePort());
throw new TransportException(DisconnectReason.HOST_KEY_NOT_VERIFIABLE,
"Could not verify `" + KeyType.fromKey(key)
+ "` host key with fingerprint `" + SecurityUtils.getFingerprint(key)
+ "` for `" + transport.getRemoteHost()
+ "` on port " + transport.getRemotePort());
}
private void setKexDone() {
@@ -229,8 +240,11 @@ final class KeyExchanger
kex = Factory.Named.Util.create(transport.getConfig().getKeyExchangeFactories(), negotiatedAlgs
.getKeyExchangeAlgorithm());
try {
kex.init(transport, transport.getServerID().getBytes(), transport.getClientID().getBytes(), buf
.getCompactData(), clientProposal.getPacket().getCompactData());
kex.init(transport,
transport.getServerID().getBytes(IOUtils.UTF8),
transport.getClientID().getBytes(IOUtils.UTF8),
buf.getCompactData(),
clientProposal.getPacket().getCompactData());
} catch (GeneralSecurityException e) {
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, e);
}
@@ -321,10 +335,12 @@ final class KeyExchanger
negotiatedAlgs.getServer2ClientMACAlgorithm());
mac_S2C.init(integrityKey_S2C);
final Compression compression_S2C = Factory.Named.Util.create(transport.getConfig().getCompressionFactories(),
negotiatedAlgs.getServer2ClientCompressionAlgorithm());
final Compression compression_C2S = Factory.Named.Util.create(transport.getConfig().getCompressionFactories(),
negotiatedAlgs.getClient2ServerCompressionAlgorithm());
final Compression compression_S2C =
Factory.Named.Util.create(transport.getConfig().getCompressionFactories(),
negotiatedAlgs.getServer2ClientCompressionAlgorithm());
final Compression compression_C2S =
Factory.Named.Util.create(transport.getConfig().getCompressionFactories(),
negotiatedAlgs.getClient2ServerCompressionAlgorithm());
transport.getEncoder().setAlgorithms(cipher_C2S, mac_C2S, compression_C2S);
transport.getDecoder().setAlgorithms(cipher_S2C, mac_S2C, compression_S2C);

View File

@@ -148,7 +148,7 @@ public final class TransportImpl
try {
log.info("Client identity string: {}", clientID);
connInfo.out.write((clientID + "\r\n").getBytes());
connInfo.out.write((clientID + "\r\n").getBytes(IOUtils.UTF8));
// Read server's ID
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer();

View File

@@ -196,7 +196,7 @@ public class OpenSSHKnownHosts
private String hashHost(String host)
throws IOException {
sha1.init(getSaltyBytes());
return "|1|" + getSalt() + "|" + Base64.encodeBytes(sha1.doFinal(host.getBytes()));
return "|1|" + getSalt() + "|" + Base64.encodeBytes(sha1.doFinal(host.getBytes(IOUtils.UTF8)));
}
private byte[] getSaltyBytes()
@@ -289,7 +289,7 @@ public class OpenSSHKnownHosts
final BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(khFile));
try {
for (Entry entry : entries)
bos.write((entry.getLine() + LS).getBytes());
bos.write((entry.getLine() + LS).getBytes(IOUtils.UTF8));
} finally {
bos.close();
}

View File

@@ -69,10 +69,10 @@ public final class SCPDownloadClient {
engine.signal("Start status OK");
String msg = engine.readMessage(true);
String msg = engine.readMessage();
do
process(null, msg, targetFile);
while ((msg = engine.readMessage(false)) != null);
while (!(msg = engine.readMessage()).isEmpty());
}
private long parseLong(String longString, String valType)
@@ -102,7 +102,7 @@ public final class SCPDownloadClient {
case 'T':
engine.signal("ACK: T");
process(msg, engine.readMessage(true), f);
process(msg, engine.readMessage(), f);
break;
case 'C':

View File

@@ -24,6 +24,7 @@ import net.schmizz.sshj.xfer.TransferListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@@ -77,7 +78,7 @@ class SCPEngine {
int code = scp.getInputStream().read();
switch (code) {
case -1:
String stderr = IOUtils.pipeStream(scp.getErrorStream()).toString();
String stderr = IOUtils.readFully(scp.getErrorStream()).toString();
if (!stderr.isEmpty())
stderr = ". Additional info: `" + stderr + "`";
throw new SCPException("EOF while expecting response to protocol message" + stderr);
@@ -98,7 +99,7 @@ class SCPEngine {
void execSCPWith(List<Arg> args, String path)
throws SSHException {
StringBuilder cmd = new StringBuilder(SCP_COMMAND);
final StringBuilder cmd = new StringBuilder(SCP_COMMAND);
for (Arg arg : args)
cmd.append(" ").append(arg);
cmd.append(" ");
@@ -130,29 +131,25 @@ class SCPEngine {
String readMessage()
throws IOException {
return readMessage(true);
}
String readMessage(boolean errOnEOF)
throws IOException {
StringBuilder sb = new StringBuilder();
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
int x;
while ((x = scp.getInputStream().read()) != LF)
if (x == -1) {
if (errOnEOF)
throw new IOException("EOF while reading message");
if (baos.size() == 0)
return "";
else
return null;
throw new IOException("EOF while reading message");
} else
sb.append((char) x);
log.debug("Read message: {}", sb);
return sb.toString();
baos.write(x);
final String msg = baos.toString(IOUtils.UTF8.displayName());
log.debug("Read message: `{}`", msg);
return msg;
}
void sendMessage(String msg)
throws IOException {
log.debug("Sending message: {}", msg);
scp.getOutputStream().write((msg + LF).getBytes());
scp.getOutputStream().write((msg + LF).getBytes(IOUtils.UTF8));
scp.getOutputStream().flush();
check("Message ACK received");
}
@@ -164,25 +161,26 @@ class SCPEngine {
scp.getOutputStream().flush();
}
long transferToRemote(final InputStream src, final long length)
long transferToRemote(InputStream src, long length)
throws IOException {
return transfer(src, scp.getOutputStream(), scp.getRemoteMaxPacketSize(), length);
}
long transferFromRemote(final OutputStream dest, final long length)
long transferFromRemote(OutputStream dest, long length)
throws IOException {
return transfer(scp.getInputStream(), dest, scp.getLocalMaxPacketSize(), length);
}
private long transfer(InputStream in, OutputStream out, int bufSize, long len)
throws IOException {
return new StreamCopier("scp engine", in, out)
return new StreamCopier(in, out)
.bufSize(bufSize).length(len)
.keepFlushing(false)
.listener(listener)
.copy();
}
void startedDir(final String dirname) {
void startedDir(String dirname) {
listener.startedDir(dirname);
}
@@ -190,7 +188,7 @@ class SCPEngine {
listener.finishedDir();
}
void startedFile(final String filename, final long length) {
void startedFile(String filename, long length) {
listener.startedFile(filename, length);
}