From b01eccda4a930827f10b62ac581c92b6b0294643 Mon Sep 17 00:00:00 2001 From: Jeroen van Erp Date: Mon, 11 Apr 2016 15:05:27 +0200 Subject: [PATCH] Fixed bug in Forward lookup in which we did not deal with the special cases (Fixes #239) --- build.gradle | 1 + .../forwarded/RemotePortForwarder.java | 44 +++++- .../forwarded/RemotePortForwarderTest.java | 145 ++++++++++++++++-- 3 files changed, 174 insertions(+), 16 deletions(-) diff --git a/build.gradle b/build.gradle index 7465080c..8107687a 100644 --- a/build.gradle +++ b/build.gradle @@ -65,6 +65,7 @@ dependencies { testCompile "org.apache.sshd:sshd-core:1.1.0" testRuntime "ch.qos.logback:logback-classic:1.1.2" testCompile 'org.glassfish.grizzly:grizzly-http-server:2.3.17' + testCompile 'org.apache.httpcomponents:httpclient:4.5.2' } diff --git a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java index 00ed1fa5..e52d7838 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java @@ -117,6 +117,34 @@ public class RemotePortForwarder return address + ":" + port; } + private boolean handles(ForwardedTCPIPChannel channel) { + Forward channelForward = channel.getParentForward(); + if (channelForward.getPort() != port) { + return false; + } + if ("".equals(address)) { + // This forward handles all protocols + return true; + } + if (channelForward.address.equals(address)) { + // Addresses match up + return true; + } + if ("localhost".equals(address) && (channelForward.address.equals("127.0.0.1") || channelForward.address.equals("::1"))) { + // Localhost special case. + return true; + } + if ("::".equals(address) && channelForward.address.indexOf("::") > 0) { + // Listen on all IPv6 + return true; + } + if ("0.0.0.0".equals(address) && channelForward.address.indexOf('.') > 0) { + // Listen on all IPv4 + return true; + } + return false; + } + } /** A {@code forwarded-tcpip} channel. */ @@ -224,11 +252,15 @@ public class RemotePortForwarder } catch (Buffer.BufferException be) { throw new ConnectionException(be); } - if (listeners.containsKey(chan.getParentForward())) - callListener(listeners.get(chan.getParentForward()), chan); - else - chan.reject(OpenFailException.Reason.ADMINISTRATIVELY_PROHIBITED, "Forwarding was not requested on `" - + chan.getParentForward() + "`"); + + for (Forward forward : listeners.keySet()) { + if (forward.handles(chan)) { + callListener(listeners.get(forward), chan); + return; + } + } + chan.reject(OpenFailException.Reason.ADMINISTRATIVELY_PROHIBITED, "Forwarding was not requested on `" + + chan.getParentForward() + "`"); } -} \ No newline at end of file +} diff --git a/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePortForwarderTest.java b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePortForwarderTest.java index 579f06a9..4ecf3914 100644 --- a/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePortForwarderTest.java +++ b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/RemotePortForwarderTest.java @@ -4,41 +4,166 @@ import com.hierynomus.sshj.test.HttpServer; import com.hierynomus.sshj.test.SshFixture; import com.hierynomus.sshj.test.util.FileUtil; import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.connection.Connection; +import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder; import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener; +import org.apache.http.HttpResponse; +import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.impl.client.HttpClientBuilder; import org.apache.sshd.server.forward.AcceptAllForwardingFilter; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.TemporaryFolder; import java.io.File; import java.io.IOException; +import java.net.ConnectException; import java.net.InetSocketAddress; +import java.util.concurrent.atomic.AtomicInteger; -import static org.junit.Assert.*; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.Assert.assertThat; public class RemotePortForwarderTest { + // Credentials for an remote SSH Server to test against. + private static final String REMOTE_HOST = "x.x.x.x"; + private static final String USER = "xxxx"; + private static final String PASSWORD = "yyyy"; + + private static final PortRange RANGE = new PortRange(9000, 9999); + private static final InetSocketAddress HTTP_SERVER_SOCKET_ADDR = new InetSocketAddress("127.0.0.1", 8080); + @Rule public SshFixture fixture = new SshFixture(); @Rule public HttpServer httpServer = new HttpServer(); - @Test - public void shouldDynamicallyForwardPort() throws IOException { + @Before + public void setup() throws IOException { fixture.getServer().setTcpipForwardingFilter(new AcceptAllForwardingFilter()); File file = httpServer.getDocRoot().newFile("index.html"); FileUtil.writeToFile(file, "

Hi!

"); + } + + @Test + public void shouldHaveWorkingHttpServer() throws IOException { + // Just to check that we have a working http server... + httpGet("127.0.0.1", 8080); + } + + @Test + public void shouldDynamicallyForwardPortForLocalhost() throws IOException { + SSHClient sshClient = getFixtureClient(); + RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", new SinglePort(0)); + httpGet("127.0.0.1", bind.getPort()); + } + + @Test + public void shouldDynamicallyForwardPortForAllIPv4() throws IOException { + SSHClient sshClient = getFixtureClient(); + RemotePortForwarder.Forward bind = forwardPort(sshClient, "0.0.0.0", new SinglePort(0)); + httpGet("127.0.0.1", bind.getPort()); + } + + @Test + public void shouldDynamicallyForwardPortForAllProtocols() throws IOException { + SSHClient sshClient = getFixtureClient(); + RemotePortForwarder.Forward bind = forwardPort(sshClient, "", new SinglePort(0)); + httpGet("127.0.0.1", bind.getPort()); + } + + @Test + public void shouldForwardPortForLocalhost() throws IOException { + SSHClient sshClient = getFixtureClient(); + RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", RANGE); + httpGet("127.0.0.1", bind.getPort()); + } + + @Test + public void shouldForwardPortForAllIPv4() throws IOException { + SSHClient sshClient = getFixtureClient(); + RemotePortForwarder.Forward bind = forwardPort(sshClient, "0.0.0.0", RANGE); + httpGet("127.0.0.1", bind.getPort()); + } + + @Test + public void shouldForwardPortForAllProtocols() throws IOException { + SSHClient sshClient = getFixtureClient(); + RemotePortForwarder.Forward bind = forwardPort(sshClient, "", RANGE); + httpGet("127.0.0.1", bind.getPort()); + } + + private RemotePortForwarder.Forward forwardPort(SSHClient sshClient, String address, PortRange portRange) throws IOException { + while (true) { + try { + RemotePortForwarder.Forward forward = sshClient.getRemotePortForwarder().bind( + // where the server should listen + new RemotePortForwarder.Forward(address, portRange.nextPort()), + // what we do with incoming connections that are forwarded to us + new SocketForwardingConnectListener(HTTP_SERVER_SOCKET_ADDR)); + + return forward; + } catch (ConnectionException ce) { + if (!portRange.hasNext()) { + throw ce; + } + } + } + } + + private void httpGet(String server, int port) throws IOException { + HttpClient client = HttpClientBuilder.create().build(); + String urlString = "http://" + server + ":" + port; + System.out.println("Trying: GET " + urlString); + HttpResponse execute = client.execute(new HttpGet(urlString)); + assertThat(execute.getStatusLine().getStatusCode(), equalTo(200)); + } + + private SSHClient getFixtureClient() throws IOException { SSHClient sshClient = fixture.setupConnectedDefaultClient(); sshClient.authPassword("jeroen", "jeroen"); - sshClient.getRemotePortForwarder().bind( - // where the server should listen - new RemotePortForwarder.Forward(0), - // what we do with incoming connections that are forwarded to us - new SocketForwardingConnectListener(new InetSocketAddress("127.0.0.1", 8080))); + return sshClient; + } + + private static class PortRange { + private int upper; + private int current; + + public PortRange(int lower, int upper) { + this.upper = upper; + this.current = lower; + } + + public int nextPort() { + if (current < upper) { + return current++; + } + throw new IllegalStateException("Out of ports!"); + } + + public boolean hasNext() { + return current < upper; + } + } + + private static class SinglePort extends PortRange { + private final int port; + + public SinglePort(int port) { + super(port, port); + this.port = port; + } + + @Override + public int nextPort() { + return port; + } + } + }