Future gets tryGet(), Event gets tryAwait(). ErrorListener disappears from StreamCopier. Socket/channel cleanups for local & remote port forwarding done more consistently with a separate monitoring thread.

This commit is contained in:
Shikhar Bhushan
2011-04-30 22:35:55 +01:00
parent a0109dd8fa
commit 430ebe27ea
8 changed files with 145 additions and 96 deletions

View File

@@ -46,11 +46,9 @@ public class RemotePF {
// where the server should listen // where the server should listen
new Forward(8080), new Forward(8080),
// what we do with incoming connections that are forwarded to us // what we do with incoming connections that are forwarded to us
new SocketForwardingConnectListener(new InetSocketAddress("google.com", 80) new SocketForwardingConnectListener(new InetSocketAddress("google.com", 80)));
));
client.getTransport() client.getTransport().setHeartbeatInterval(30);
.setHeartbeatInterval(30);
// Something to hang on to so that the forwarding stays // Something to hang on to so that the forwarding stays
client.getTransport().join(); client.getTransport().join();

View File

@@ -84,4 +84,19 @@ public class Event<T extends Throwable>
super.get(timeout, unit); super.get(timeout, unit);
} }
/**
* Await this event to have a definite {@code true} or {@code false} value, for {@code timeout} duration.
*
* If the definite value is not available by the time timeout expires, returns {@code null}.
*
* @param timeout timeout
* @param unit the time unit for the timeout
*
* @throws T if another thread meanwhile informs this event of an error
*/
public boolean tryAwait(long timeout, TimeUnit unit)
throws T {
return super.tryGet(timeout, unit) != null;
}
} }

View File

@@ -119,7 +119,7 @@ public class Future<V, T extends Throwable> {
*/ */
public V get() public V get()
throws T { throws T {
return get(0, TimeUnit.SECONDS); return tryGet(0, TimeUnit.SECONDS);
} }
/** /**
@@ -134,6 +134,27 @@ public class Future<V, T extends Throwable> {
*/ */
public V get(long timeout, TimeUnit unit) public V get(long timeout, TimeUnit unit)
throws T { throws T {
final V value = tryGet(timeout, unit);
if (value == null)
throw chainer.chain(new TimeoutException("Timeout expired"));
else
return value;
}
/**
* Wait for {@code timeout} duration for this future's value to be set.
*
* If the value is not set by the time the timeout expires, returns {@code null}.
*
* @param timeout the timeout
* @param unit time unit for the timeout
*
* @return the value or {@code null}
*
* @throws T in case another thread informs the future of an error meanwhile
*/
public V tryGet(long timeout, TimeUnit unit)
throws T {
lock(); lock();
try { try {
if (pendingEx != null) if (pendingEx != null)
@@ -145,7 +166,7 @@ public class Future<V, T extends Throwable> {
if (timeout == 0) if (timeout == 0)
cond.await(); cond.await();
else if (!cond.await(timeout, unit)) else if (!cond.await(timeout, unit))
throw chainer.chain(new TimeoutException("Timeout expired")); return null;
if (pendingEx != null) { if (pendingEx != null) {
log.error("<<{}>> woke to: {}", name, pendingEx.toString()); log.error("<<{}>> woke to: {}", name, pendingEx.toString());
throw pendingEx; throw pendingEx;
@@ -204,4 +225,9 @@ public class Future<V, T extends Throwable> {
lock.unlock(); lock.unlock();
} }
@Override
public String toString() {
return name;
}
} }

View File

@@ -20,41 +20,19 @@ import net.schmizz.concurrent.ExceptionChainer;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.concurrent.TimeUnit;
public class StreamCopier { public class StreamCopier {
public interface ErrorCallback {
void onError(IOException ioe);
}
public interface Listener { public interface Listener {
void reportProgress(long transferred) throws IOException; void reportProgress(long transferred)
throws IOException;
} }
public static ErrorCallback closeOnErrorCallback(final Closeable... toClose) {
return new ErrorCallback() {
@Override
public void onError(IOException ioe) {
IOUtils.closeQuietly(toClose);
}
};
}
private static final ErrorCallback NULL_CALLBACK = new ErrorCallback() {
@Override
public void onError(IOException ioe) {
}
};
private static final Listener NULL_LISTENER = new Listener() { private static final Listener NULL_LISTENER = new Listener() {
@Override @Override
public void reportProgress(long transferred) { public void reportProgress(long transferred) {
@@ -67,20 +45,11 @@ public class StreamCopier {
private final OutputStream out; private final OutputStream out;
private Listener listener = NULL_LISTENER; private Listener listener = NULL_LISTENER;
private ErrorCallback errCB = NULL_CALLBACK;
private int bufSize = 1; private int bufSize = 1;
private boolean keepFlushing = true; private boolean keepFlushing = true;
private long length = -1; private long length = -1;
private final Event<IOException> doneEvent =
new Event<IOException>("copyDone", new ExceptionChainer<IOException>() {
@Override
public IOException chain(Throwable t) {
return (t instanceof IOException) ? (IOException) t : new IOException(t);
}
});
public StreamCopier(InputStream in, OutputStream out) { public StreamCopier(InputStream in, OutputStream out) {
this.in = in; this.in = in;
this.out = out; this.out = out;
@@ -102,26 +71,28 @@ public class StreamCopier {
return this; return this;
} }
public StreamCopier errorCallback(ErrorCallback errCB) {
if (errCB == null) errCB = NULL_CALLBACK;
this.errCB = errCB;
return this;
}
public StreamCopier length(long length) { public StreamCopier length(long length) {
this.length = length; this.length = length;
return this; return this;
} }
public StreamCopier spawn(String name) { public Event<IOException> spawn(String name) {
return spawn(name, false); return spawn(name, false);
} }
public StreamCopier spawnDaemon(String name) { public Event<IOException> spawnDaemon(String name) {
return spawn(name, true); return spawn(name, true);
} }
private StreamCopier spawn(final String name, final boolean daemon) { private Event<IOException> spawn(final String name, final boolean daemon) {
final Event<IOException> doneEvent =
new Event<IOException>("copyDone", new ExceptionChainer<IOException>() {
@Override
public IOException chain(Throwable t) {
return (t instanceof IOException) ? (IOException) t : new IOException(t);
}
});
new Thread() { new Thread() {
{ {
setName(name); setName(name);
@@ -136,19 +107,12 @@ public class StreamCopier {
log.debug("Done copying from {}", in); log.debug("Done copying from {}", in);
doneEvent.set(); doneEvent.set();
} catch (IOException ioe) { } catch (IOException ioe) {
log.error("In pipe from {} to {}: {}" + ioe.toString(), in, out); log.error("In pipe from {} to {}: " + ioe.toString(), in, out);
doneEvent.error(ioe); doneEvent.error(ioe);
errCB.onError(ioe);
} }
} }
}.start(); }.start();
return this; return doneEvent;
}
public StreamCopier join(int timeout, TimeUnit unit)
throws IOException {
doneEvent.await(timeout, unit);
return this;
} }
public long copy() public long copy()

View File

@@ -0,0 +1,63 @@
/*
* Copyright 2010, 2011 sshj contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.connection.channel;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.IOUtils;
import java.io.Closeable;
import java.io.IOException;
import java.net.Socket;
import java.util.concurrent.TimeUnit;
public class SocketStreamCopyMonitor
extends Thread {
private SocketStreamCopyMonitor(Runnable r) {
super(r);
setName("sockmon");
setDaemon(true);
}
private static Closeable wrapSocket(final Socket socket) {
return new Closeable() {
@Override
public void close()
throws IOException {
socket.close();
}
};
}
public static void monitor(final int frequency, final TimeUnit unit,
final Event<IOException> x, final Event<IOException> y,
final Channel channel, final Socket socket) {
new SocketStreamCopyMonitor(new Runnable() {
public void run() {
try {
for (Event<IOException> ev = x;
!ev.tryAwait(frequency, unit);
ev = (ev == x) ? y : x) {
}
} catch (IOException ignored) {
} finally {
IOUtils.closeQuietly(channel, wrapSocket(socket));
}
}
}).start();
}
}

View File

@@ -15,19 +15,20 @@
*/ */
package net.schmizz.sshj.connection.channel.direct; package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.common.StreamCopier; import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.common.StreamCopier.ErrorCallback;
import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import javax.net.ServerSocketFactory; import javax.net.ServerSocketFactory;
import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.concurrent.TimeUnit;
public class LocalPortForwarder { public class LocalPortForwarder {
@@ -45,25 +46,13 @@ public class LocalPortForwarder {
throws IOException { throws IOException {
sock.setSendBufferSize(getLocalMaxPacketSize()); sock.setSendBufferSize(getLocalMaxPacketSize());
sock.setReceiveBufferSize(getRemoteMaxPacketSize()); sock.setReceiveBufferSize(getRemoteMaxPacketSize());
final Event<IOException> soc2chan = new StreamCopier(sock.getInputStream(), getOutputStream())
final ErrorCallback closer = StreamCopier.closeOnErrorCallback(this,
new Closeable() {
@Override
public void close()
throws IOException {
sock.close();
}
});
new StreamCopier(getInputStream(), sock.getOutputStream())
.bufSize(getLocalMaxPacketSize())
.errorCallback(closer)
.spawnDaemon("chan2soc");
new StreamCopier(sock.getInputStream(), getOutputStream())
.bufSize(getRemoteMaxPacketSize()) .bufSize(getRemoteMaxPacketSize())
.errorCallback(closer)
.spawnDaemon("soc2chan"); .spawnDaemon("soc2chan");
final Event<IOException> chan2soc = new StreamCopier(getInputStream(), sock.getOutputStream())
.bufSize(getLocalMaxPacketSize())
.spawnDaemon("chan2soc");
SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, sock);
} }
@Override @Override
@@ -146,6 +135,7 @@ public class LocalPortForwarder {
chan.open(); chan.open();
chan.start(); chan.start();
} }
log.info("Interrupted!");
} }
} }

View File

@@ -50,7 +50,7 @@ public abstract class AbstractForwardedChannelOpener
new Thread() { new Thread() {
{ {
setName("ConnectListener"); setName("chanopener");
} }
@Override @Override

View File

@@ -15,16 +15,17 @@
*/ */
package net.schmizz.sshj.connection.channel.forwarded; package net.schmizz.sshj.connection.channel.forwarded;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.StreamCopier; import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.common.StreamCopier.ErrorCallback;
import net.schmizz.sshj.connection.channel.Channel; import net.schmizz.sshj.connection.channel.Channel;
import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.net.Socket; import java.net.Socket;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.concurrent.TimeUnit;
/** A {@link ConnectListener} that forwards what is received over the channel to a socket and vice-versa. */ /** A {@link ConnectListener} that forwards what is received over the channel to a socket and vice-versa. */
public class SocketForwardingConnectListener public class SocketForwardingConnectListener
@@ -54,23 +55,15 @@ public class SocketForwardingConnectListener
// ok so far -- could connect, let's confirm the channel // ok so far -- could connect, let's confirm the channel
chan.confirm(); chan.confirm();
final ErrorCallback closer = StreamCopier.closeOnErrorCallback(chan, new Closeable() { final Event<IOException> soc2chan = new StreamCopier(sock.getInputStream(), chan.getOutputStream())
@Override
public void close()
throws IOException {
sock.close();
}
});
new StreamCopier(sock.getInputStream(), chan.getOutputStream())
.bufSize(chan.getRemoteMaxPacketSize()) .bufSize(chan.getRemoteMaxPacketSize())
.errorCallback(closer)
.spawnDaemon("soc2chan"); .spawnDaemon("soc2chan");
new StreamCopier(chan.getInputStream(), sock.getOutputStream()) final Event<IOException> chan2soc = new StreamCopier(chan.getInputStream(), sock.getOutputStream())
.bufSize(chan.getLocalMaxPacketSize()) .bufSize(chan.getLocalMaxPacketSize())
.errorCallback(closer)
.spawnDaemon("chan2soc"); .spawnDaemon("chan2soc");
SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, chan2soc, soc2chan, chan, sock);
} }
} }