From 430ebe27eafa083067260a5fd0b1e899569ebcb7 Mon Sep 17 00:00:00 2001 From: Shikhar Bhushan Date: Sat, 30 Apr 2011 22:35:55 +0100 Subject: [PATCH] Future gets tryGet(), Event gets tryAwait(). ErrorListener disappears from StreamCopier. Socket/channel cleanups for local & remote port forwarding done more consistently with a separate monitoring thread. --- src/main/java/examples/RemotePF.java | 6 +- .../java/net/schmizz/concurrent/Event.java | 15 +++++ .../java/net/schmizz/concurrent/Future.java | 30 ++++++++- .../net/schmizz/sshj/common/StreamCopier.java | 66 +++++-------------- .../channel/SocketStreamCopyMonitor.java | 63 ++++++++++++++++++ .../channel/direct/LocalPortForwarder.java | 38 ++++------- .../AbstractForwardedChannelOpener.java | 2 +- .../SocketForwardingConnectListener.java | 21 ++---- 8 files changed, 145 insertions(+), 96 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java diff --git a/src/main/java/examples/RemotePF.java b/src/main/java/examples/RemotePF.java index 4a2c5172..9eeb4634 100644 --- a/src/main/java/examples/RemotePF.java +++ b/src/main/java/examples/RemotePF.java @@ -46,11 +46,9 @@ public class RemotePF { // where the server should listen new Forward(8080), // what we do with incoming connections that are forwarded to us - new SocketForwardingConnectListener(new InetSocketAddress("google.com", 80) - )); + new SocketForwardingConnectListener(new InetSocketAddress("google.com", 80))); - client.getTransport() - .setHeartbeatInterval(30); + client.getTransport().setHeartbeatInterval(30); // Something to hang on to so that the forwarding stays client.getTransport().join(); diff --git a/src/main/java/net/schmizz/concurrent/Event.java b/src/main/java/net/schmizz/concurrent/Event.java index f6856594..2b70bf24 100644 --- a/src/main/java/net/schmizz/concurrent/Event.java +++ b/src/main/java/net/schmizz/concurrent/Event.java @@ -84,4 +84,19 @@ public class Event super.get(timeout, unit); } + /** + * Await this event to have a definite {@code true} or {@code false} value, for {@code timeout} duration. + * + * If the definite value is not available by the time timeout expires, returns {@code null}. + * + * @param timeout timeout + * @param unit the time unit for the timeout + * + * @throws T if another thread meanwhile informs this event of an error + */ + public boolean tryAwait(long timeout, TimeUnit unit) + throws T { + return super.tryGet(timeout, unit) != null; + } + } \ No newline at end of file diff --git a/src/main/java/net/schmizz/concurrent/Future.java b/src/main/java/net/schmizz/concurrent/Future.java index cc0b9db5..851698a0 100644 --- a/src/main/java/net/schmizz/concurrent/Future.java +++ b/src/main/java/net/schmizz/concurrent/Future.java @@ -119,7 +119,7 @@ public class Future { */ public V get() throws T { - return get(0, TimeUnit.SECONDS); + return tryGet(0, TimeUnit.SECONDS); } /** @@ -134,6 +134,27 @@ public class Future { */ public V get(long timeout, TimeUnit unit) throws T { + final V value = tryGet(timeout, unit); + if (value == null) + throw chainer.chain(new TimeoutException("Timeout expired")); + else + return value; + } + + /** + * Wait for {@code timeout} duration for this future's value to be set. + * + * If the value is not set by the time the timeout expires, returns {@code null}. + * + * @param timeout the timeout + * @param unit time unit for the timeout + * + * @return the value or {@code null} + * + * @throws T in case another thread informs the future of an error meanwhile + */ + public V tryGet(long timeout, TimeUnit unit) + throws T { lock(); try { if (pendingEx != null) @@ -145,7 +166,7 @@ public class Future { if (timeout == 0) cond.await(); else if (!cond.await(timeout, unit)) - throw chainer.chain(new TimeoutException("Timeout expired")); + return null; if (pendingEx != null) { log.error("<<{}>> woke to: {}", name, pendingEx.toString()); throw pendingEx; @@ -204,4 +225,9 @@ public class Future { lock.unlock(); } + @Override + public String toString() { + return name; + } + } diff --git a/src/main/java/net/schmizz/sshj/common/StreamCopier.java b/src/main/java/net/schmizz/sshj/common/StreamCopier.java index da38c462..e5786e1c 100644 --- a/src/main/java/net/schmizz/sshj/common/StreamCopier.java +++ b/src/main/java/net/schmizz/sshj/common/StreamCopier.java @@ -20,41 +20,19 @@ import net.schmizz.concurrent.ExceptionChainer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.Closeable; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.util.concurrent.TimeUnit; public class StreamCopier { - public interface ErrorCallback { - - void onError(IOException ioe); - - } - public interface Listener { - void reportProgress(long transferred) throws IOException; + void reportProgress(long transferred) + throws IOException; } - public static ErrorCallback closeOnErrorCallback(final Closeable... toClose) { - return new ErrorCallback() { - @Override - public void onError(IOException ioe) { - IOUtils.closeQuietly(toClose); - } - }; - } - - private static final ErrorCallback NULL_CALLBACK = new ErrorCallback() { - @Override - public void onError(IOException ioe) { - } - }; - private static final Listener NULL_LISTENER = new Listener() { @Override public void reportProgress(long transferred) { @@ -67,20 +45,11 @@ public class StreamCopier { private final OutputStream out; private Listener listener = NULL_LISTENER; - private ErrorCallback errCB = NULL_CALLBACK; private int bufSize = 1; private boolean keepFlushing = true; private long length = -1; - private final Event doneEvent = - new Event("copyDone", new ExceptionChainer() { - @Override - public IOException chain(Throwable t) { - return (t instanceof IOException) ? (IOException) t : new IOException(t); - } - }); - public StreamCopier(InputStream in, OutputStream out) { this.in = in; this.out = out; @@ -102,26 +71,28 @@ public class StreamCopier { return this; } - public StreamCopier errorCallback(ErrorCallback errCB) { - if (errCB == null) errCB = NULL_CALLBACK; - this.errCB = errCB; - return this; - } - public StreamCopier length(long length) { this.length = length; return this; } - public StreamCopier spawn(String name) { + public Event spawn(String name) { return spawn(name, false); } - public StreamCopier spawnDaemon(String name) { + public Event spawnDaemon(String name) { return spawn(name, true); } - private StreamCopier spawn(final String name, final boolean daemon) { + private Event spawn(final String name, final boolean daemon) { + final Event doneEvent = + new Event("copyDone", new ExceptionChainer() { + @Override + public IOException chain(Throwable t) { + return (t instanceof IOException) ? (IOException) t : new IOException(t); + } + }); + new Thread() { { setName(name); @@ -136,19 +107,12 @@ public class StreamCopier { log.debug("Done copying from {}", in); doneEvent.set(); } catch (IOException ioe) { - log.error("In pipe from {} to {}: {}" + ioe.toString(), in, out); + log.error("In pipe from {} to {}: " + ioe.toString(), in, out); doneEvent.error(ioe); - errCB.onError(ioe); } } }.start(); - return this; - } - - public StreamCopier join(int timeout, TimeUnit unit) - throws IOException { - doneEvent.await(timeout, unit); - return this; + return doneEvent; } public long copy() diff --git a/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java b/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java new file mode 100644 index 00000000..48005cb6 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/connection/channel/SocketStreamCopyMonitor.java @@ -0,0 +1,63 @@ +/* + * Copyright 2010, 2011 sshj contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.connection.channel; + +import net.schmizz.concurrent.Event; +import net.schmizz.sshj.common.IOUtils; + +import java.io.Closeable; +import java.io.IOException; +import java.net.Socket; +import java.util.concurrent.TimeUnit; + +public class SocketStreamCopyMonitor + extends Thread { + + private SocketStreamCopyMonitor(Runnable r) { + super(r); + setName("sockmon"); + setDaemon(true); + } + + private static Closeable wrapSocket(final Socket socket) { + return new Closeable() { + @Override + public void close() + throws IOException { + socket.close(); + } + }; + } + + public static void monitor(final int frequency, final TimeUnit unit, + final Event x, final Event y, + final Channel channel, final Socket socket) { + new SocketStreamCopyMonitor(new Runnable() { + public void run() { + try { + for (Event ev = x; + !ev.tryAwait(frequency, unit); + ev = (ev == x) ? y : x) { + } + } catch (IOException ignored) { + } finally { + IOUtils.closeQuietly(channel, wrapSocket(socket)); + } + } + }).start(); + } + +} diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java index 9821a377..b9e5b553 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java @@ -15,19 +15,20 @@ */ package net.schmizz.sshj.connection.channel.direct; +import net.schmizz.concurrent.Event; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.StreamCopier; -import net.schmizz.sshj.common.StreamCopier.ErrorCallback; import net.schmizz.sshj.connection.Connection; +import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.net.ServerSocketFactory; -import java.io.Closeable; import java.io.IOException; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketAddress; +import java.util.concurrent.TimeUnit; public class LocalPortForwarder { @@ -45,34 +46,22 @@ public class LocalPortForwarder { throws IOException { sock.setSendBufferSize(getLocalMaxPacketSize()); sock.setReceiveBufferSize(getRemoteMaxPacketSize()); - - final ErrorCallback closer = StreamCopier.closeOnErrorCallback(this, - new Closeable() { - @Override - public void close() - throws IOException { - sock.close(); - } - }); - - new StreamCopier(getInputStream(), sock.getOutputStream()) - .bufSize(getLocalMaxPacketSize()) - .errorCallback(closer) - .spawnDaemon("chan2soc"); - - new StreamCopier(sock.getInputStream(), getOutputStream()) + final Event soc2chan = new StreamCopier(sock.getInputStream(), getOutputStream()) .bufSize(getRemoteMaxPacketSize()) - .errorCallback(closer) .spawnDaemon("soc2chan"); - } + final Event chan2soc = new StreamCopier(getInputStream(), sock.getOutputStream()) + .bufSize(getLocalMaxPacketSize()) + .spawnDaemon("chan2soc"); + SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, sock); + } @Override protected SSHPacket buildOpenReq() { return super.buildOpenReq() - .putString(host) - .putInt(port) - .putString(ss.getInetAddress().getHostAddress()) - .putInt(ss.getLocalPort()); + .putString(host) + .putInt(port) + .putString(ss.getInetAddress().getHostAddress()) + .putInt(ss.getLocalPort()); } } @@ -146,6 +135,7 @@ public class LocalPortForwarder { chan.open(); chan.start(); } + log.info("Interrupted!"); } } \ No newline at end of file diff --git a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/AbstractForwardedChannelOpener.java b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/AbstractForwardedChannelOpener.java index 9655ba22..8454c44b 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/AbstractForwardedChannelOpener.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/AbstractForwardedChannelOpener.java @@ -50,7 +50,7 @@ public abstract class AbstractForwardedChannelOpener new Thread() { { - setName("ConnectListener"); + setName("chanopener"); } @Override diff --git a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/SocketForwardingConnectListener.java b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/SocketForwardingConnectListener.java index 3bfe8476..415ecbe4 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/SocketForwardingConnectListener.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/SocketForwardingConnectListener.java @@ -15,16 +15,17 @@ */ package net.schmizz.sshj.connection.channel.forwarded; +import net.schmizz.concurrent.Event; import net.schmizz.sshj.common.StreamCopier; -import net.schmizz.sshj.common.StreamCopier.ErrorCallback; import net.schmizz.sshj.connection.channel.Channel; +import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.Closeable; import java.io.IOException; import java.net.Socket; import java.net.SocketAddress; +import java.util.concurrent.TimeUnit; /** A {@link ConnectListener} that forwards what is received over the channel to a socket and vice-versa. */ public class SocketForwardingConnectListener @@ -54,23 +55,15 @@ public class SocketForwardingConnectListener // ok so far -- could connect, let's confirm the channel chan.confirm(); - final ErrorCallback closer = StreamCopier.closeOnErrorCallback(chan, new Closeable() { - @Override - public void close() - throws IOException { - sock.close(); - } - }); - - new StreamCopier(sock.getInputStream(), chan.getOutputStream()) + final Event soc2chan = new StreamCopier(sock.getInputStream(), chan.getOutputStream()) .bufSize(chan.getRemoteMaxPacketSize()) - .errorCallback(closer) .spawnDaemon("soc2chan"); - new StreamCopier(chan.getInputStream(), sock.getOutputStream()) + final Event chan2soc = new StreamCopier(chan.getInputStream(), sock.getOutputStream()) .bufSize(chan.getLocalMaxPacketSize()) - .errorCallback(closer) .spawnDaemon("chan2soc"); + + SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, chan2soc, soc2chan, chan, sock); } } \ No newline at end of file