Refactored the local port forwarding API; give caller control over initializing and cleaning up the server socket used.

Also removed 'server socket factory' stuff from SocketClient.
This commit is contained in:
Shikhar Bhushan
2012-01-05 22:25:42 +00:00
parent 8eedeb25fa
commit 3026be282a
4 changed files with 89 additions and 99 deletions

View File

@@ -16,9 +16,11 @@
package examples; package examples;
import net.schmizz.sshj.SSHClient; import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.ServerSocket;
/** /**
* This example demonstrates local port forwarding, i.e. when we listen on a particular address and port; and forward * This example demonstrates local port forwarding, i.e. when we listen on a particular address and port; and forward
@@ -41,8 +43,16 @@ public class LocalPF {
* _We_ listen on localhost:8080 and forward all connections on to server, which then forwards it to * _We_ listen on localhost:8080 and forward all connections on to server, which then forwards it to
* google.com:80 * google.com:80
*/ */
ssh.newLocalPortForwarder(new InetSocketAddress("localhost", 8080), "google.com", 80) final LocalPortForwarder.Parameters params
.listen(); = new LocalPortForwarder.Parameters("0.0.0.0", 8080, "google.com", 80);
final ServerSocket ss = new ServerSocket();
ss.setReuseAddress(true);
ss.bind(new InetSocketAddress(params.getLocalHost(), params.getLocalPort()));
try {
ssh.newLocalPortForwarder(params, ss).listen();
} finally {
ss.close();
}
} finally { } finally {
ssh.disconnect(); ssh.disconnect();

View File

@@ -64,7 +64,7 @@ import org.slf4j.LoggerFactory;
import java.io.Closeable; import java.io.Closeable;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.net.SocketAddress; import java.net.ServerSocket;
import java.security.KeyPair; import java.security.KeyPair;
import java.security.PublicKey; import java.security.PublicKey;
import java.util.Arrays; import java.util.Arrays;
@@ -568,23 +568,21 @@ public class SSHClient
} }
/** /**
* Create a {@link LocalPortForwarder} that will listen on {@code address} and forward incoming connections to the * Create a {@link LocalPortForwarder} that will listen based on {@code parameters} using the bound
* server; which will further forward them to {@code host:port}. * {@code serverSocket} and forward incoming connections to the server; which will further forward them to
* {@code host:port}.
* <p/> * <p/>
* The returned forwarder's {@link LocalPortForwarder#listen() listen()} method should be called to actually start * The returned forwarder's {@link LocalPortForwarder#listen() listen()} method should be called to actually start
* listening, this method just creates an instance. * listening, this method just creates an instance.
* *
* @param address defines where the {@link LocalPortForwarder} listens * @param parameters parameters for the forwarding setup
* @param host hostname to which the server will forward * @param serverSocket bound server socket
* @param port the port at {@code hostname} to which the server wil forward
* *
* @return a {@link LocalPortForwarder} * @return a {@link LocalPortForwarder}
*
* @throws IOException if there is an error opening a local server socket
*/ */
public LocalPortForwarder newLocalPortForwarder(SocketAddress address, String host, int port) public LocalPortForwarder newLocalPortForwarder(LocalPortForwarder.Parameters parameters,
throws IOException { ServerSocket serverSocket) {
return new LocalPortForwarder(getServerSocketFactory(), conn, address, host, port); return new LocalPortForwarder(conn, parameters, serverSocket);
} }
/** /**

View File

@@ -35,6 +35,7 @@
*/ */
package net.schmizz.sshj; package net.schmizz.sshj;
import javax.net.SocketFactory;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@@ -42,10 +43,6 @@ import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.Socket; import java.net.Socket;
import javax.net.ServerSocketFactory;
import javax.net.SocketFactory;
public abstract class SocketClient { public abstract class SocketClient {
private final int defaultPort; private final int defaultPort;
@@ -55,7 +52,6 @@ public abstract class SocketClient {
private OutputStream output; private OutputStream output;
private SocketFactory socketFactory = SocketFactory.getDefault(); private SocketFactory socketFactory = SocketFactory.getDefault();
private ServerSocketFactory serverSocketFactory = ServerSocketFactory.getDefault();
private static final int DEFAULT_CONNECT_TIMEOUT = 0; private static final int DEFAULT_CONNECT_TIMEOUT = 0;
private int connectTimeout = DEFAULT_CONNECT_TIMEOUT; private int connectTimeout = DEFAULT_CONNECT_TIMEOUT;
@@ -159,17 +155,6 @@ public abstract class SocketClient {
return socketFactory; return socketFactory;
} }
public void setServerSocketFactory(ServerSocketFactory factory) {
if (factory == null)
serverSocketFactory = ServerSocketFactory.getDefault();
else
serverSocketFactory = factory;
}
public ServerSocketFactory getServerSocketFactory() {
return serverSocketFactory;
}
public int getConnectTimeout() { public int getConnectTimeout() {
return connectTimeout; return connectTimeout;
} }

View File

@@ -19,104 +19,104 @@ import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.common.StreamCopier; import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor; import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor;
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import javax.net.ServerSocketFactory;
import java.io.IOException; import java.io.IOException;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket; import java.net.Socket;
import java.net.SocketAddress;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
public class LocalPortForwarder { public class LocalPortForwarder {
private class DirectTCPIPChannel public static class Parameters {
extends AbstractDirectChannel {
private final Socket sock; private final String localHost;
private final int localPort;
private final String remoteHost;
private final int remotePort;
private DirectTCPIPChannel(Connection conn, Socket sock) { public Parameters(String localHost, int localPort, String remoteHost, int remotePort) {
super(conn, "direct-tcpip"); this.localHost = localHost;
this.sock = sock; this.localPort = localPort;
this.remoteHost = remoteHost;
this.remotePort = remotePort;
} }
private void start() public String getRemoteHost() {
return remoteHost;
}
public int getRemotePort() {
return remotePort;
}
public String getLocalHost() {
return localHost;
}
public int getLocalPort() {
return localPort;
}
}
public static class DirectTCPIPChannel
extends AbstractDirectChannel {
protected final Socket socket;
protected final Parameters parameters;
public DirectTCPIPChannel(Connection conn, Socket socket, Parameters parameters) {
super(conn, "direct-tcpip");
this.socket = socket;
this.parameters = parameters;
}
protected void start()
throws IOException { throws IOException {
sock.setSendBufferSize(getLocalMaxPacketSize()); socket.setSendBufferSize(getLocalMaxPacketSize());
sock.setReceiveBufferSize(getRemoteMaxPacketSize()); socket.setReceiveBufferSize(getRemoteMaxPacketSize());
final Event<IOException> soc2chan = new StreamCopier(sock.getInputStream(), getOutputStream()) final Event<IOException> soc2chan = new StreamCopier(socket.getInputStream(), getOutputStream())
.bufSize(getRemoteMaxPacketSize()) .bufSize(getRemoteMaxPacketSize())
.spawnDaemon("soc2chan"); .spawnDaemon("soc2chan");
final Event<IOException> chan2soc = new StreamCopier(getInputStream(), sock.getOutputStream()) final Event<IOException> chan2soc = new StreamCopier(getInputStream(), socket.getOutputStream())
.bufSize(getLocalMaxPacketSize()) .bufSize(getLocalMaxPacketSize())
.spawnDaemon("chan2soc"); .spawnDaemon("chan2soc");
SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, sock); SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, socket);
} }
@Override @Override
protected SSHPacket buildOpenReq() { protected SSHPacket buildOpenReq() {
return super.buildOpenReq() return super.buildOpenReq()
.putString(host) .putString(parameters.getRemoteHost())
.putUInt32(port) .putUInt32(parameters.getRemotePort())
.putString(ss.getInetAddress().getHostAddress()) .putString(parameters.getLocalHost())
.putUInt32(ss.getLocalPort()); .putUInt32(parameters.getLocalPort());
} }
} }
private final Logger log = LoggerFactory.getLogger(getClass()); private final Logger log = LoggerFactory.getLogger(LocalPortForwarder.class);
private final Connection conn; private final Connection conn;
private final ServerSocket ss; private final Parameters parameters;
private final String host; private final ServerSocket serverSocket;
private final int port;
/** public LocalPortForwarder(Connection conn, Parameters parameters, ServerSocket serverSocket) {
* Create a local port forwarder with specified binding ({@code listeningAddr}. It does not, however, start
* listening unless {@link #listen() explicitly told to}. The {@link javax.net.ServerSocketFactory#getDefault()
* default} server socket factory is used.
*
* @param conn {@link Connection} implementation
* @param listeningAddr {@link SocketAddress} this forwarder will listen on, if {@code null} then an ephemeral port
* and valid local address will be picked to bind the server socket
* @param host what host the SSH server will further forward to
* @param port port on {@code toHost}
*
* @throws IOException if there is an error binding on specified {@code listeningAddr}
*/
public LocalPortForwarder(Connection conn, SocketAddress listeningAddr, String host, int port)
throws IOException {
this(ServerSocketFactory.getDefault(), conn, listeningAddr, host, port);
}
/**
* Create a local port forwarder with specified binding ({@code listeningAddr}. It does not, however, start
* listening unless {@link #listen() explicitly told to}.
*
* @param ssf factory to use for creating the server socket
* @param conn {@link Connection} implementation
* @param listeningAddr {@link SocketAddress} this forwarder will listen on, if {@code null} then an ephemeral port
* and valid local address will be picked to bind the server socket
* @param host what host the SSH server will further forward to
* @param port port on {@code toHost}
*
* @throws IOException if there is an error binding on specified {@code listeningAddr}
*/
public LocalPortForwarder(ServerSocketFactory ssf, Connection conn, SocketAddress listeningAddr, String host, int port)
throws IOException {
this.conn = conn; this.conn = conn;
this.host = host; this.parameters = parameters;
this.port = port; this.serverSocket = serverSocket;
this.ss = ssf.createServerSocket();
ss.setReceiveBufferSize(conn.getMaxPacketSize());
ss.bind(listeningAddr);
} }
/** @return the address to which this forwarder is bound for listening */ protected DirectTCPIPChannel openChannel(Socket socket)
public SocketAddress getListeningAddress() { throws TransportException, ConnectionException {
return ss.getLocalSocketAddress(); final DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters);
chan.open();
return chan;
} }
/** /**
@@ -126,14 +126,11 @@ public class LocalPortForwarder {
*/ */
public void listen() public void listen()
throws IOException { throws IOException {
log.info("Listening on {}", ss.getLocalSocketAddress()); log.info("Listening on {}", serverSocket.getLocalSocketAddress());
Socket sock;
while (!Thread.currentThread().isInterrupted()) { while (!Thread.currentThread().isInterrupted()) {
sock = ss.accept(); final Socket socket = serverSocket.accept();
log.info("Got connection from {}", sock.getRemoteSocketAddress()); log.info("Got connection from {}", socket.getRemoteSocketAddress());
DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, sock); openChannel(socket).start();
chan.open();
chan.start();
} }
log.info("Interrupted!"); log.info("Interrupted!");
} }