From b01eccda4a930827f10b62ac581c92b6b0294643 Mon Sep 17 00:00:00 2001
From: Jeroen van Erp
Date: Mon, 11 Apr 2016 15:05:27 +0200
Subject: [PATCH] Fixed bug in Forward lookup in which we did not deal with the
special cases (Fixes #239)
---
build.gradle | 1 +
.../forwarded/RemotePortForwarder.java | 44 +++++-
.../forwarded/RemotePortForwarderTest.java | 145 ++++++++++++++++--
3 files changed, 174 insertions(+), 16 deletions(-)
diff --git a/build.gradle b/build.gradle
index 7465080c..8107687a 100644
--- a/build.gradle
+++ b/build.gradle
@@ -65,6 +65,7 @@ dependencies {
testCompile "org.apache.sshd:sshd-core:1.1.0"
testRuntime "ch.qos.logback:logback-classic:1.1.2"
testCompile 'org.glassfish.grizzly:grizzly-http-server:2.3.17'
+ testCompile 'org.apache.httpcomponents:httpclient:4.5.2'
}
diff --git a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java
index 00ed1fa5..e52d7838 100644
--- a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java
+++ b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java
@@ -117,6 +117,34 @@ public class RemotePortForwarder
return address + ":" + port;
}
+ private boolean handles(ForwardedTCPIPChannel channel) {
+ Forward channelForward = channel.getParentForward();
+ if (channelForward.getPort() != port) {
+ return false;
+ }
+ if ("".equals(address)) {
+ // This forward handles all protocols
+ return true;
+ }
+ if (channelForward.address.equals(address)) {
+ // Addresses match up
+ return true;
+ }
+ if ("localhost".equals(address) && (channelForward.address.equals("127.0.0.1") || channelForward.address.equals("::1"))) {
+ // Localhost special case.
+ return true;
+ }
+ if ("::".equals(address) && channelForward.address.indexOf("::") > 0) {
+ // Listen on all IPv6
+ return true;
+ }
+ if ("0.0.0.0".equals(address) && channelForward.address.indexOf('.') > 0) {
+ // Listen on all IPv4
+ return true;
+ }
+ return false;
+ }
+
}
/** A {@code forwarded-tcpip} channel. */
@@ -224,11 +252,15 @@ public class RemotePortForwarder
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
- if (listeners.containsKey(chan.getParentForward()))
- callListener(listeners.get(chan.getParentForward()), chan);
- else
- chan.reject(OpenFailException.Reason.ADMINISTRATIVELY_PROHIBITED, "Forwarding was not requested on `"
- + chan.getParentForward() + "`");
+
+ for (Forward forward : listeners.keySet()) {
+ if (forward.handles(chan)) {
+ callListener(listeners.get(forward), chan);
+ return;
+ }
+ }
+ chan.reject(OpenFailException.Reason.ADMINISTRATIVELY_PROHIBITED, "Forwarding was not requested on `"
+ + chan.getParentForward() + "`");
}
-}
\ No newline at end of file
+}
diff --git a/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePortForwarderTest.java b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePortForwarderTest.java
index 579f06a9..4ecf3914 100644
--- a/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePortForwarderTest.java
+++ b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePortForwarderTest.java
@@ -4,41 +4,166 @@ import com.hierynomus.sshj.test.HttpServer;
import com.hierynomus.sshj.test.SshFixture;
import com.hierynomus.sshj.test.util.FileUtil;
import net.schmizz.sshj.SSHClient;
+import net.schmizz.sshj.connection.Connection;
+import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder;
import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener;
+import org.apache.http.HttpResponse;
+import org.apache.http.client.HttpClient;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
-import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
import java.io.File;
import java.io.IOException;
+import java.net.ConnectException;
import java.net.InetSocketAddress;
+import java.util.concurrent.atomic.AtomicInteger;
-import static org.junit.Assert.*;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.junit.Assert.assertThat;
public class RemotePortForwarderTest {
+ // Credentials for an remote SSH Server to test against.
+ private static final String REMOTE_HOST = "x.x.x.x";
+ private static final String USER = "xxxx";
+ private static final String PASSWORD = "yyyy";
+
+ private static final PortRange RANGE = new PortRange(9000, 9999);
+ private static final InetSocketAddress HTTP_SERVER_SOCKET_ADDR = new InetSocketAddress("127.0.0.1", 8080);
+
@Rule
public SshFixture fixture = new SshFixture();
@Rule
public HttpServer httpServer = new HttpServer();
- @Test
- public void shouldDynamicallyForwardPort() throws IOException {
+ @Before
+ public void setup() throws IOException {
fixture.getServer().setTcpipForwardingFilter(new AcceptAllForwardingFilter());
File file = httpServer.getDocRoot().newFile("index.html");
FileUtil.writeToFile(file, "
Hi!
");
+ }
+
+ @Test
+ public void shouldHaveWorkingHttpServer() throws IOException {
+ // Just to check that we have a working http server...
+ httpGet("127.0.0.1", 8080);
+ }
+
+ @Test
+ public void shouldDynamicallyForwardPortForLocalhost() throws IOException {
+ SSHClient sshClient = getFixtureClient();
+ RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", new SinglePort(0));
+ httpGet("127.0.0.1", bind.getPort());
+ }
+
+ @Test
+ public void shouldDynamicallyForwardPortForAllIPv4() throws IOException {
+ SSHClient sshClient = getFixtureClient();
+ RemotePortForwarder.Forward bind = forwardPort(sshClient, "0.0.0.0", new SinglePort(0));
+ httpGet("127.0.0.1", bind.getPort());
+ }
+
+ @Test
+ public void shouldDynamicallyForwardPortForAllProtocols() throws IOException {
+ SSHClient sshClient = getFixtureClient();
+ RemotePortForwarder.Forward bind = forwardPort(sshClient, "", new SinglePort(0));
+ httpGet("127.0.0.1", bind.getPort());
+ }
+
+ @Test
+ public void shouldForwardPortForLocalhost() throws IOException {
+ SSHClient sshClient = getFixtureClient();
+ RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", RANGE);
+ httpGet("127.0.0.1", bind.getPort());
+ }
+
+ @Test
+ public void shouldForwardPortForAllIPv4() throws IOException {
+ SSHClient sshClient = getFixtureClient();
+ RemotePortForwarder.Forward bind = forwardPort(sshClient, "0.0.0.0", RANGE);
+ httpGet("127.0.0.1", bind.getPort());
+ }
+
+ @Test
+ public void shouldForwardPortForAllProtocols() throws IOException {
+ SSHClient sshClient = getFixtureClient();
+ RemotePortForwarder.Forward bind = forwardPort(sshClient, "", RANGE);
+ httpGet("127.0.0.1", bind.getPort());
+ }
+
+ private RemotePortForwarder.Forward forwardPort(SSHClient sshClient, String address, PortRange portRange) throws IOException {
+ while (true) {
+ try {
+ RemotePortForwarder.Forward forward = sshClient.getRemotePortForwarder().bind(
+ // where the server should listen
+ new RemotePortForwarder.Forward(address, portRange.nextPort()),
+ // what we do with incoming connections that are forwarded to us
+ new SocketForwardingConnectListener(HTTP_SERVER_SOCKET_ADDR));
+
+ return forward;
+ } catch (ConnectionException ce) {
+ if (!portRange.hasNext()) {
+ throw ce;
+ }
+ }
+ }
+ }
+
+ private void httpGet(String server, int port) throws IOException {
+ HttpClient client = HttpClientBuilder.create().build();
+ String urlString = "http://" + server + ":" + port;
+ System.out.println("Trying: GET " + urlString);
+ HttpResponse execute = client.execute(new HttpGet(urlString));
+ assertThat(execute.getStatusLine().getStatusCode(), equalTo(200));
+ }
+
+ private SSHClient getFixtureClient() throws IOException {
SSHClient sshClient = fixture.setupConnectedDefaultClient();
sshClient.authPassword("jeroen", "jeroen");
- sshClient.getRemotePortForwarder().bind(
- // where the server should listen
- new RemotePortForwarder.Forward(0),
- // what we do with incoming connections that are forwarded to us
- new SocketForwardingConnectListener(new InetSocketAddress("127.0.0.1", 8080)));
+ return sshClient;
+ }
+
+ private static class PortRange {
+ private int upper;
+ private int current;
+
+ public PortRange(int lower, int upper) {
+ this.upper = upper;
+ this.current = lower;
+ }
+
+ public int nextPort() {
+ if (current < upper) {
+ return current++;
+ }
+ throw new IllegalStateException("Out of ports!");
+ }
+
+ public boolean hasNext() {
+ return current < upper;
+ }
+ }
+
+ private static class SinglePort extends PortRange {
+ private final int port;
+
+ public SinglePort(int port) {
+ super(port, port);
+ this.port = port;
+ }
+
+ @Override
+ public int nextPort() {
+ return port;
+ }
+
}
+
}