From 3026be282a91e1d067721ff4a67c5ab31fb76737 Mon Sep 17 00:00:00 2001 From: Shikhar Bhushan Date: Thu, 5 Jan 2012 22:25:42 +0000 Subject: [PATCH] Refactored the local port forwarding API; give caller control over initializing and cleaning up the server socket used. Also removed 'server socket factory' stuff from SocketClient. --- src/main/java/examples/LocalPF.java | 14 +- src/main/java/net/schmizz/sshj/SSHClient.java | 20 ++- .../java/net/schmizz/sshj/SocketClient.java | 17 +-- .../channel/direct/LocalPortForwarder.java | 137 +++++++++--------- 4 files changed, 89 insertions(+), 99 deletions(-) diff --git a/src/main/java/examples/LocalPF.java b/src/main/java/examples/LocalPF.java index 233ea281..5a77a07d 100644 --- a/src/main/java/examples/LocalPF.java +++ b/src/main/java/examples/LocalPF.java @@ -16,9 +16,11 @@ package examples; import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.ServerSocket; /** * This example demonstrates local port forwarding, i.e. when we listen on a particular address and port; and forward @@ -41,8 +43,16 @@ public class LocalPF { * _We_ listen on localhost:8080 and forward all connections on to server, which then forwards it to * google.com:80 */ - ssh.newLocalPortForwarder(new InetSocketAddress("localhost", 8080), "google.com", 80) - .listen(); + final LocalPortForwarder.Parameters params + = new LocalPortForwarder.Parameters("0.0.0.0", 8080, "google.com", 80); + final ServerSocket ss = new ServerSocket(); + ss.setReuseAddress(true); + ss.bind(new InetSocketAddress(params.getLocalHost(), params.getLocalPort())); + try { + ssh.newLocalPortForwarder(params, ss).listen(); + } finally { + ss.close(); + } } finally { ssh.disconnect(); diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index 6dbd98b6..22343258 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -64,7 +64,7 @@ import org.slf4j.LoggerFactory; import java.io.Closeable; import java.io.File; import java.io.IOException; -import java.net.SocketAddress; +import java.net.ServerSocket; import java.security.KeyPair; import java.security.PublicKey; import java.util.Arrays; @@ -568,23 +568,21 @@ public class SSHClient } /** - * Create a {@link LocalPortForwarder} that will listen on {@code address} and forward incoming connections to the - * server; which will further forward them to {@code host:port}. + * Create a {@link LocalPortForwarder} that will listen based on {@code parameters} using the bound + * {@code serverSocket} and forward incoming connections to the server; which will further forward them to + * {@code host:port}. *

* The returned forwarder's {@link LocalPortForwarder#listen() listen()} method should be called to actually start * listening, this method just creates an instance. * - * @param address defines where the {@link LocalPortForwarder} listens - * @param host hostname to which the server will forward - * @param port the port at {@code hostname} to which the server wil forward + * @param parameters parameters for the forwarding setup + * @param serverSocket bound server socket * * @return a {@link LocalPortForwarder} - * - * @throws IOException if there is an error opening a local server socket */ - public LocalPortForwarder newLocalPortForwarder(SocketAddress address, String host, int port) - throws IOException { - return new LocalPortForwarder(getServerSocketFactory(), conn, address, host, port); + public LocalPortForwarder newLocalPortForwarder(LocalPortForwarder.Parameters parameters, + ServerSocket serverSocket) { + return new LocalPortForwarder(conn, parameters, serverSocket); } /** diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index 145ae8ef..dd614c32 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -35,6 +35,7 @@ */ package net.schmizz.sshj; +import javax.net.SocketFactory; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -42,10 +43,6 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; -import javax.net.ServerSocketFactory; -import javax.net.SocketFactory; - - public abstract class SocketClient { private final int defaultPort; @@ -55,7 +52,6 @@ public abstract class SocketClient { private OutputStream output; private SocketFactory socketFactory = SocketFactory.getDefault(); - private ServerSocketFactory serverSocketFactory = ServerSocketFactory.getDefault(); private static final int DEFAULT_CONNECT_TIMEOUT = 0; private int connectTimeout = DEFAULT_CONNECT_TIMEOUT; @@ -159,17 +155,6 @@ public abstract class SocketClient { return socketFactory; } - public void setServerSocketFactory(ServerSocketFactory factory) { - if (factory == null) - serverSocketFactory = ServerSocketFactory.getDefault(); - else - serverSocketFactory = factory; - } - - public ServerSocketFactory getServerSocketFactory() { - return serverSocketFactory; - } - public int getConnectTimeout() { return connectTimeout; } diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java index 3466fef5..da2b9927 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java @@ -19,104 +19,104 @@ import net.schmizz.concurrent.Event; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.StreamCopier; import net.schmizz.sshj.connection.Connection; +import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor; +import net.schmizz.sshj.transport.TransportException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.net.ServerSocketFactory; import java.io.IOException; import java.net.ServerSocket; import java.net.Socket; -import java.net.SocketAddress; import java.util.concurrent.TimeUnit; public class LocalPortForwarder { - private class DirectTCPIPChannel - extends AbstractDirectChannel { + public static class Parameters { - private final Socket sock; + private final String localHost; + private final int localPort; + private final String remoteHost; + private final int remotePort; - private DirectTCPIPChannel(Connection conn, Socket sock) { - super(conn, "direct-tcpip"); - this.sock = sock; + public Parameters(String localHost, int localPort, String remoteHost, int remotePort) { + this.localHost = localHost; + this.localPort = localPort; + this.remoteHost = remoteHost; + this.remotePort = remotePort; } - private void start() + public String getRemoteHost() { + return remoteHost; + } + + public int getRemotePort() { + return remotePort; + } + + public String getLocalHost() { + return localHost; + } + + public int getLocalPort() { + return localPort; + } + + } + + public static class DirectTCPIPChannel + extends AbstractDirectChannel { + + protected final Socket socket; + protected final Parameters parameters; + + public DirectTCPIPChannel(Connection conn, Socket socket, Parameters parameters) { + super(conn, "direct-tcpip"); + this.socket = socket; + this.parameters = parameters; + } + + protected void start() throws IOException { - sock.setSendBufferSize(getLocalMaxPacketSize()); - sock.setReceiveBufferSize(getRemoteMaxPacketSize()); - final Event soc2chan = new StreamCopier(sock.getInputStream(), getOutputStream()) + socket.setSendBufferSize(getLocalMaxPacketSize()); + socket.setReceiveBufferSize(getRemoteMaxPacketSize()); + final Event soc2chan = new StreamCopier(socket.getInputStream(), getOutputStream()) .bufSize(getRemoteMaxPacketSize()) .spawnDaemon("soc2chan"); - final Event chan2soc = new StreamCopier(getInputStream(), sock.getOutputStream()) + final Event chan2soc = new StreamCopier(getInputStream(), socket.getOutputStream()) .bufSize(getLocalMaxPacketSize()) .spawnDaemon("chan2soc"); - SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, sock); + SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, socket); } @Override protected SSHPacket buildOpenReq() { return super.buildOpenReq() - .putString(host) - .putUInt32(port) - .putString(ss.getInetAddress().getHostAddress()) - .putUInt32(ss.getLocalPort()); + .putString(parameters.getRemoteHost()) + .putUInt32(parameters.getRemotePort()) + .putString(parameters.getLocalHost()) + .putUInt32(parameters.getLocalPort()); } } - private final Logger log = LoggerFactory.getLogger(getClass()); + private final Logger log = LoggerFactory.getLogger(LocalPortForwarder.class); private final Connection conn; - private final ServerSocket ss; - private final String host; - private final int port; + private final Parameters parameters; + private final ServerSocket serverSocket; - /** - * Create a local port forwarder with specified binding ({@code listeningAddr}. It does not, however, start - * listening unless {@link #listen() explicitly told to}. The {@link javax.net.ServerSocketFactory#getDefault() - * default} server socket factory is used. - * - * @param conn {@link Connection} implementation - * @param listeningAddr {@link SocketAddress} this forwarder will listen on, if {@code null} then an ephemeral port - * and valid local address will be picked to bind the server socket - * @param host what host the SSH server will further forward to - * @param port port on {@code toHost} - * - * @throws IOException if there is an error binding on specified {@code listeningAddr} - */ - public LocalPortForwarder(Connection conn, SocketAddress listeningAddr, String host, int port) - throws IOException { - this(ServerSocketFactory.getDefault(), conn, listeningAddr, host, port); - } - - /** - * Create a local port forwarder with specified binding ({@code listeningAddr}. It does not, however, start - * listening unless {@link #listen() explicitly told to}. - * - * @param ssf factory to use for creating the server socket - * @param conn {@link Connection} implementation - * @param listeningAddr {@link SocketAddress} this forwarder will listen on, if {@code null} then an ephemeral port - * and valid local address will be picked to bind the server socket - * @param host what host the SSH server will further forward to - * @param port port on {@code toHost} - * - * @throws IOException if there is an error binding on specified {@code listeningAddr} - */ - public LocalPortForwarder(ServerSocketFactory ssf, Connection conn, SocketAddress listeningAddr, String host, int port) - throws IOException { + public LocalPortForwarder(Connection conn, Parameters parameters, ServerSocket serverSocket) { this.conn = conn; - this.host = host; - this.port = port; - this.ss = ssf.createServerSocket(); - ss.setReceiveBufferSize(conn.getMaxPacketSize()); - ss.bind(listeningAddr); + this.parameters = parameters; + this.serverSocket = serverSocket; } - /** @return the address to which this forwarder is bound for listening */ - public SocketAddress getListeningAddress() { - return ss.getLocalSocketAddress(); + protected DirectTCPIPChannel openChannel(Socket socket) + throws TransportException, ConnectionException { + final DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters); + chan.open(); + return chan; } /** @@ -126,14 +126,11 @@ public class LocalPortForwarder { */ public void listen() throws IOException { - log.info("Listening on {}", ss.getLocalSocketAddress()); - Socket sock; + log.info("Listening on {}", serverSocket.getLocalSocketAddress()); while (!Thread.currentThread().isInterrupted()) { - sock = ss.accept(); - log.info("Got connection from {}", sock.getRemoteSocketAddress()); - DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, sock); - chan.open(); - chan.start(); + final Socket socket = serverSocket.accept(); + log.info("Got connection from {}", socket.getRemoteSocketAddress()); + openChannel(socket).start(); } log.info("Interrupted!"); }