Refactored the local port forwarding API; give caller control over initializing and cleaning up the server socket used.

Also removed 'server socket factory' stuff from SocketClient.
This commit is contained in:
Shikhar Bhushan
2012-01-05 22:25:42 +00:00
parent 8eedeb25fa
commit 3026be282a
4 changed files with 89 additions and 99 deletions

View File

@@ -16,9 +16,11 @@
package examples;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
/**
* This example demonstrates local port forwarding, i.e. when we listen on a particular address and port; and forward
@@ -41,8 +43,16 @@ public class LocalPF {
* _We_ listen on localhost:8080 and forward all connections on to server, which then forwards it to
* google.com:80
*/
ssh.newLocalPortForwarder(new InetSocketAddress("localhost", 8080), "google.com", 80)
.listen();
final LocalPortForwarder.Parameters params
= new LocalPortForwarder.Parameters("0.0.0.0", 8080, "google.com", 80);
final ServerSocket ss = new ServerSocket();
ss.setReuseAddress(true);
ss.bind(new InetSocketAddress(params.getLocalHost(), params.getLocalPort()));
try {
ssh.newLocalPortForwarder(params, ss).listen();
} finally {
ss.close();
}
} finally {
ssh.disconnect();

View File

@@ -64,7 +64,7 @@ import org.slf4j.LoggerFactory;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.net.SocketAddress;
import java.net.ServerSocket;
import java.security.KeyPair;
import java.security.PublicKey;
import java.util.Arrays;
@@ -568,23 +568,21 @@ public class SSHClient
}
/**
* Create a {@link LocalPortForwarder} that will listen on {@code address} and forward incoming connections to the
* server; which will further forward them to {@code host:port}.
* Create a {@link LocalPortForwarder} that will listen based on {@code parameters} using the bound
* {@code serverSocket} and forward incoming connections to the server; which will further forward them to
* {@code host:port}.
* <p/>
* The returned forwarder's {@link LocalPortForwarder#listen() listen()} method should be called to actually start
* listening, this method just creates an instance.
*
* @param address defines where the {@link LocalPortForwarder} listens
* @param host hostname to which the server will forward
* @param port the port at {@code hostname} to which the server wil forward
* @param parameters parameters for the forwarding setup
* @param serverSocket bound server socket
*
* @return a {@link LocalPortForwarder}
*
* @throws IOException if there is an error opening a local server socket
*/
public LocalPortForwarder newLocalPortForwarder(SocketAddress address, String host, int port)
throws IOException {
return new LocalPortForwarder(getServerSocketFactory(), conn, address, host, port);
public LocalPortForwarder newLocalPortForwarder(LocalPortForwarder.Parameters parameters,
ServerSocket serverSocket) {
return new LocalPortForwarder(conn, parameters, serverSocket);
}
/**

View File

@@ -35,6 +35,7 @@
*/
package net.schmizz.sshj;
import javax.net.SocketFactory;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@@ -42,10 +43,6 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import javax.net.ServerSocketFactory;
import javax.net.SocketFactory;
public abstract class SocketClient {
private final int defaultPort;
@@ -55,7 +52,6 @@ public abstract class SocketClient {
private OutputStream output;
private SocketFactory socketFactory = SocketFactory.getDefault();
private ServerSocketFactory serverSocketFactory = ServerSocketFactory.getDefault();
private static final int DEFAULT_CONNECT_TIMEOUT = 0;
private int connectTimeout = DEFAULT_CONNECT_TIMEOUT;
@@ -159,17 +155,6 @@ public abstract class SocketClient {
return socketFactory;
}
public void setServerSocketFactory(ServerSocketFactory factory) {
if (factory == null)
serverSocketFactory = ServerSocketFactory.getDefault();
else
serverSocketFactory = factory;
}
public ServerSocketFactory getServerSocketFactory() {
return serverSocketFactory;
}
public int getConnectTimeout() {
return connectTimeout;
}

View File

@@ -19,104 +19,104 @@ import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.SocketStreamCopyMonitor;
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ServerSocketFactory;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.concurrent.TimeUnit;
public class LocalPortForwarder {
private class DirectTCPIPChannel
extends AbstractDirectChannel {
public static class Parameters {
private final Socket sock;
private final String localHost;
private final int localPort;
private final String remoteHost;
private final int remotePort;
private DirectTCPIPChannel(Connection conn, Socket sock) {
super(conn, "direct-tcpip");
this.sock = sock;
public Parameters(String localHost, int localPort, String remoteHost, int remotePort) {
this.localHost = localHost;
this.localPort = localPort;
this.remoteHost = remoteHost;
this.remotePort = remotePort;
}
private void start()
public String getRemoteHost() {
return remoteHost;
}
public int getRemotePort() {
return remotePort;
}
public String getLocalHost() {
return localHost;
}
public int getLocalPort() {
return localPort;
}
}
public static class DirectTCPIPChannel
extends AbstractDirectChannel {
protected final Socket socket;
protected final Parameters parameters;
public DirectTCPIPChannel(Connection conn, Socket socket, Parameters parameters) {
super(conn, "direct-tcpip");
this.socket = socket;
this.parameters = parameters;
}
protected void start()
throws IOException {
sock.setSendBufferSize(getLocalMaxPacketSize());
sock.setReceiveBufferSize(getRemoteMaxPacketSize());
final Event<IOException> soc2chan = new StreamCopier(sock.getInputStream(), getOutputStream())
socket.setSendBufferSize(getLocalMaxPacketSize());
socket.setReceiveBufferSize(getRemoteMaxPacketSize());
final Event<IOException> soc2chan = new StreamCopier(socket.getInputStream(), getOutputStream())
.bufSize(getRemoteMaxPacketSize())
.spawnDaemon("soc2chan");
final Event<IOException> chan2soc = new StreamCopier(getInputStream(), sock.getOutputStream())
final Event<IOException> chan2soc = new StreamCopier(getInputStream(), socket.getOutputStream())
.bufSize(getLocalMaxPacketSize())
.spawnDaemon("chan2soc");
SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, sock);
SocketStreamCopyMonitor.monitor(5, TimeUnit.SECONDS, soc2chan, chan2soc, this, socket);
}
@Override
protected SSHPacket buildOpenReq() {
return super.buildOpenReq()
.putString(host)
.putUInt32(port)
.putString(ss.getInetAddress().getHostAddress())
.putUInt32(ss.getLocalPort());
.putString(parameters.getRemoteHost())
.putUInt32(parameters.getRemotePort())
.putString(parameters.getLocalHost())
.putUInt32(parameters.getLocalPort());
}
}
private final Logger log = LoggerFactory.getLogger(getClass());
private final Logger log = LoggerFactory.getLogger(LocalPortForwarder.class);
private final Connection conn;
private final ServerSocket ss;
private final String host;
private final int port;
private final Parameters parameters;
private final ServerSocket serverSocket;
/**
* Create a local port forwarder with specified binding ({@code listeningAddr}. It does not, however, start
* listening unless {@link #listen() explicitly told to}. The {@link javax.net.ServerSocketFactory#getDefault()
* default} server socket factory is used.
*
* @param conn {@link Connection} implementation
* @param listeningAddr {@link SocketAddress} this forwarder will listen on, if {@code null} then an ephemeral port
* and valid local address will be picked to bind the server socket
* @param host what host the SSH server will further forward to
* @param port port on {@code toHost}
*
* @throws IOException if there is an error binding on specified {@code listeningAddr}
*/
public LocalPortForwarder(Connection conn, SocketAddress listeningAddr, String host, int port)
throws IOException {
this(ServerSocketFactory.getDefault(), conn, listeningAddr, host, port);
}
/**
* Create a local port forwarder with specified binding ({@code listeningAddr}. It does not, however, start
* listening unless {@link #listen() explicitly told to}.
*
* @param ssf factory to use for creating the server socket
* @param conn {@link Connection} implementation
* @param listeningAddr {@link SocketAddress} this forwarder will listen on, if {@code null} then an ephemeral port
* and valid local address will be picked to bind the server socket
* @param host what host the SSH server will further forward to
* @param port port on {@code toHost}
*
* @throws IOException if there is an error binding on specified {@code listeningAddr}
*/
public LocalPortForwarder(ServerSocketFactory ssf, Connection conn, SocketAddress listeningAddr, String host, int port)
throws IOException {
public LocalPortForwarder(Connection conn, Parameters parameters, ServerSocket serverSocket) {
this.conn = conn;
this.host = host;
this.port = port;
this.ss = ssf.createServerSocket();
ss.setReceiveBufferSize(conn.getMaxPacketSize());
ss.bind(listeningAddr);
this.parameters = parameters;
this.serverSocket = serverSocket;
}
/** @return the address to which this forwarder is bound for listening */
public SocketAddress getListeningAddress() {
return ss.getLocalSocketAddress();
protected DirectTCPIPChannel openChannel(Socket socket)
throws TransportException, ConnectionException {
final DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, socket, parameters);
chan.open();
return chan;
}
/**
@@ -126,14 +126,11 @@ public class LocalPortForwarder {
*/
public void listen()
throws IOException {
log.info("Listening on {}", ss.getLocalSocketAddress());
Socket sock;
log.info("Listening on {}", serverSocket.getLocalSocketAddress());
while (!Thread.currentThread().isInterrupted()) {
sock = ss.accept();
log.info("Got connection from {}", sock.getRemoteSocketAddress());
DirectTCPIPChannel chan = new DirectTCPIPChannel(conn, sock);
chan.open();
chan.start();
final Socket socket = serverSocket.accept();
log.info("Got connection from {}", socket.getRemoteSocketAddress());
openChannel(socket).start();
}
log.info("Interrupted!");
}