diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index b391a866..daecc5a6 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -139,6 +139,8 @@ public class SSHClient /** {@code ssh-connection} service */ protected final Connection conn; + private final List forwarders = new ArrayList(); + /** Default constructor. Initializes this object using {@link DefaultConfig}. */ public SSHClient() { this(new DefaultConfig()); @@ -431,6 +433,14 @@ public class SSHClient @Override public void disconnect() throws IOException { + for (LocalPortForwarder forwarder : forwarders) { + try { + forwarder.close(); + } catch (IOException e) { + log.warn("Error closing forwarder", e); + } + } + forwarders.clear(); trans.disconnect(); super.disconnect(); } @@ -648,7 +658,9 @@ public class SSHClient */ public LocalPortForwarder newLocalPortForwarder(LocalPortForwarder.Parameters parameters, ServerSocket serverSocket) { - return new LocalPortForwarder(conn, parameters, serverSocket); + LocalPortForwarder forwarder = new LocalPortForwarder(conn, parameters, serverSocket); + forwarders.add(forwarder); + return forwarder; } /** diff --git a/src/main/java/net/schmizz/sshj/SocketClient.java b/src/main/java/net/schmizz/sshj/SocketClient.java index 7d7dd9d1..e831bffa 100644 --- a/src/main/java/net/schmizz/sshj/SocketClient.java +++ b/src/main/java/net/schmizz/sshj/SocketClient.java @@ -48,12 +48,50 @@ public abstract class SocketClient { this.defaultPort = defaultPort; } - public void connect(InetAddress host, int port) throws IOException { - socket = socketFactory.createSocket(); - socket.connect(new InetSocketAddress(host, port), connectTimeout); + /** + * Connect to a host via a proxy. + * @param hostname The host name to connect to. + * @param proxy The proxy to connect via. + * @deprecated This method will be removed after v0.12.0. If you want to connect via a proxy, you can do this by injecting a {@link javax.net.SocketFactory} + * into the SocketClient. The SocketFactory should create sockets using the {@link java.net.Socket#Socket(java.net.Proxy)} constructor. + */ + @Deprecated + public void connect(String hostname, Proxy proxy) throws IOException { + connect(hostname, defaultPort, proxy); + } + + /** + * Connect to a host via a proxy. + * @param hostname The host name to connect to. + * @param port The port to connect to. + * @param proxy The proxy to connect via. + * @deprecated This method will be removed after v0.12.0. If you want to connect via a proxy, you can do this by injecting a {@link javax.net.SocketFactory} + * into the SocketClient. The SocketFactory should create sockets using the {@link java.net.Socket#Socket(java.net.Proxy)} constructor. + */ + @Deprecated + public void connect(String hostname, int port, Proxy proxy) throws IOException { + this.hostname = hostname; + if (JavaVersion.isJava7OrEarlier() && proxy.type() == Proxy.Type.HTTP) { + // Java7 and earlier have no support for HTTP Connect proxies, return our custom socket. + socket = new Jdk7HttpProxySocket(proxy); + } else { + socket = new Socket(proxy); + } + socket.connect(new InetSocketAddress(hostname, port), connectTimeout); onConnect(); } + /** + * Connect to a host via a proxy. + * @param host The host address to connect to. + * @param proxy The proxy to connect via. + * @deprecated This method will be removed after v0.12.0. If you want to connect via a proxy, you can do this by injecting a {@link javax.net.SocketFactory} + * into the SocketClient. The SocketFactory should create sockets using the {@link java.net.Socket#Socket(java.net.Proxy)} constructor. + */ + @Deprecated + public void connect(InetAddress host, Proxy proxy) throws IOException { + connect(host, defaultPort, proxy); + } /** * Connect to a host via a proxy. @@ -75,23 +113,41 @@ public abstract class SocketClient { onConnect(); } - public void connect(String hostname, int port) throws IOException { - this.hostname = hostname; - connect(InetAddress.getByName(hostname), port); + public void connect(String hostname) throws IOException { + connect(hostname, defaultPort); } - /** - * Connect to a host via a proxy. - * @param hostname The host name to connect to. - * @param port The port to connect to. - * @param proxy The proxy to connect via. - * @deprecated This method will be removed after v0.12.0. If you want to connect via a proxy, you can do this by injecting a {@link javax.net.SocketFactory} - * into the SocketClient. The SocketFactory should create sockets using the {@link java.net.Socket#Socket(java.net.Proxy)} constructor. - */ - @Deprecated - public void connect(String hostname, int port, Proxy proxy) throws IOException { - this.hostname = hostname; - connect(InetAddress.getByName(hostname), port, proxy); + public void connect(String hostname, int port) throws IOException { + if (hostname == null) { + connect(InetAddress.getByName(null), port); + } else { + this.hostname = hostname; + socket = socketFactory.createSocket(); + socket.connect(new InetSocketAddress(hostname, port), connectTimeout); + onConnect(); + } + } + + public void connect(String hostname, int port, InetAddress localAddr, int localPort) throws IOException { + if (hostname == null) { + connect(InetAddress.getByName(null), port, localAddr, localPort); + } else { + this.hostname = hostname; + socket = socketFactory.createSocket(); + socket.bind(new InetSocketAddress(localAddr, localPort)); + socket.connect(new InetSocketAddress(hostname, port), connectTimeout); + onConnect(); + } + } + + public void connect(InetAddress host) throws IOException { + connect(host, defaultPort); + } + + public void connect(InetAddress host, int port) throws IOException { + socket = socketFactory.createSocket(); + socket.connect(new InetSocketAddress(host, port), connectTimeout); + onConnect(); } public void connect(InetAddress host, int port, InetAddress localAddr, int localPort) @@ -102,43 +158,6 @@ public abstract class SocketClient { onConnect(); } - public void connect(String hostname, int port, InetAddress localAddr, int localPort) throws IOException { - this.hostname = hostname; - connect(InetAddress.getByName(hostname), port, localAddr, localPort); - } - - public void connect(InetAddress host) throws IOException { - connect(host, defaultPort); - } - - public void connect(String hostname) throws IOException { - connect(hostname, defaultPort); - } - - /** - * Connect to a host via a proxy. - * @param host The host address to connect to. - * @param proxy The proxy to connect via. - * @deprecated This method will be removed after v0.12.0. If you want to connect via a proxy, you can do this by injecting a {@link javax.net.SocketFactory} - * into the SocketClient. The SocketFactory should create sockets using the {@link java.net.Socket#Socket(java.net.Proxy)} constructor. - */ - @Deprecated - public void connect(InetAddress host, Proxy proxy) throws IOException { - connect(host, defaultPort, proxy); - } - - /** - * Connect to a host via a proxy. - * @param hostname The host name to connect to. - * @param proxy The proxy to connect via. - * @deprecated This method will be removed after v0.12.0. If you want to connect via a proxy, you can do this by injecting a {@link javax.net.SocketFactory} - * into the SocketClient. The SocketFactory should create sockets using the {@link java.net.Socket#Socket(java.net.Proxy)} constructor. - */ - @Deprecated - public void connect(String hostname, Proxy proxy) throws IOException { - connect(hostname, defaultPort, proxy); - } - public void disconnect() throws IOException { if (socket != null) { socket.close(); diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java index a3511cac..1721e5a8 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java @@ -28,6 +28,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.ServerSocket; import java.net.Socket; +import java.net.SocketException; import java.util.concurrent.TimeUnit; import static com.hierynomus.sshj.backport.Sockets.asCloseable; @@ -134,11 +135,33 @@ public class LocalPortForwarder { throws IOException { log.info("Listening on {}", serverSocket.getLocalSocketAddress()); while (!Thread.currentThread().isInterrupted()) { - final Socket socket = serverSocket.accept(); - log.debug("Got connection from {}", socket.getRemoteSocketAddress()); - startChannel(socket); + try { + final Socket socket = serverSocket.accept(); + log.debug("Got connection from {}", socket.getRemoteSocketAddress()); + startChannel(socket); + } catch (SocketException e) { + if (!serverSocket.isClosed()) { + throw e; + } + } + } + if (serverSocket.isClosed()) { + log.debug("LocalPortForwarder closed"); + } else { + log.debug("LocalPortForwarder interrupted!"); } - log.debug("Interrupted!"); } -} \ No newline at end of file + /** + * Close the ServerSocket that's listening for connections to forward. + * + * @throws IOException + */ + public void close() throws IOException { + if (!serverSocket.isClosed()) { + log.info("Closing listener on {}", serverSocket.getLocalSocketAddress()); + serverSocket.close(); + } + } + +}