Fixed bug in Forward lookup in which we did not deal with the special cases (Fixes #239)

This commit is contained in:
Jeroen van Erp
2016-04-11 15:05:27 +02:00
parent 4c9ebc306d
commit b01eccda4a
3 changed files with 174 additions and 16 deletions

View File

@@ -65,6 +65,7 @@ dependencies {
testCompile "org.apache.sshd:sshd-core:1.1.0"
testRuntime "ch.qos.logback:logback-classic:1.1.2"
testCompile 'org.glassfish.grizzly:grizzly-http-server:2.3.17'
testCompile 'org.apache.httpcomponents:httpclient:4.5.2'
}

View File

@@ -117,6 +117,34 @@ public class RemotePortForwarder
return address + ":" + port;
}
private boolean handles(ForwardedTCPIPChannel channel) {
Forward channelForward = channel.getParentForward();
if (channelForward.getPort() != port) {
return false;
}
if ("".equals(address)) {
// This forward handles all protocols
return true;
}
if (channelForward.address.equals(address)) {
// Addresses match up
return true;
}
if ("localhost".equals(address) && (channelForward.address.equals("127.0.0.1") || channelForward.address.equals("::1"))) {
// Localhost special case.
return true;
}
if ("::".equals(address) && channelForward.address.indexOf("::") > 0) {
// Listen on all IPv6
return true;
}
if ("0.0.0.0".equals(address) && channelForward.address.indexOf('.') > 0) {
// Listen on all IPv4
return true;
}
return false;
}
}
/** A {@code forwarded-tcpip} channel. */
@@ -224,9 +252,13 @@ public class RemotePortForwarder
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
if (listeners.containsKey(chan.getParentForward()))
callListener(listeners.get(chan.getParentForward()), chan);
else
for (Forward forward : listeners.keySet()) {
if (forward.handles(chan)) {
callListener(listeners.get(forward), chan);
return;
}
}
chan.reject(OpenFailException.Reason.ADMINISTRATIVELY_PROHIBITED, "Forwarding was not requested on `"
+ chan.getParentForward() + "`");
}

View File

@@ -4,41 +4,166 @@ import com.hierynomus.sshj.test.HttpServer;
import com.hierynomus.sshj.test.SshFixture;
import com.hierynomus.sshj.test.util.FileUtil;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.forwarded.RemotePortForwarder;
import net.schmizz.sshj.connection.channel.forwarded.SocketForwardingConnectListener;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.File;
import java.io.IOException;
import java.net.ConnectException;
import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicInteger;
import static org.junit.Assert.*;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertThat;
public class RemotePortForwarderTest {
// Credentials for an remote SSH Server to test against.
private static final String REMOTE_HOST = "x.x.x.x";
private static final String USER = "xxxx";
private static final String PASSWORD = "yyyy";
private static final PortRange RANGE = new PortRange(9000, 9999);
private static final InetSocketAddress HTTP_SERVER_SOCKET_ADDR = new InetSocketAddress("127.0.0.1", 8080);
@Rule
public SshFixture fixture = new SshFixture();
@Rule
public HttpServer httpServer = new HttpServer();
@Test
public void shouldDynamicallyForwardPort() throws IOException {
@Before
public void setup() throws IOException {
fixture.getServer().setTcpipForwardingFilter(new AcceptAllForwardingFilter());
File file = httpServer.getDocRoot().newFile("index.html");
FileUtil.writeToFile(file, "<html><head/><body><h1>Hi!</h1></body></html>");
}
@Test
public void shouldHaveWorkingHttpServer() throws IOException {
// Just to check that we have a working http server...
httpGet("127.0.0.1", 8080);
}
@Test
public void shouldDynamicallyForwardPortForLocalhost() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", new SinglePort(0));
httpGet("127.0.0.1", bind.getPort());
}
@Test
public void shouldDynamicallyForwardPortForAllIPv4() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "0.0.0.0", new SinglePort(0));
httpGet("127.0.0.1", bind.getPort());
}
@Test
public void shouldDynamicallyForwardPortForAllProtocols() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "", new SinglePort(0));
httpGet("127.0.0.1", bind.getPort());
}
@Test
public void shouldForwardPortForLocalhost() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "127.0.0.1", RANGE);
httpGet("127.0.0.1", bind.getPort());
}
@Test
public void shouldForwardPortForAllIPv4() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "0.0.0.0", RANGE);
httpGet("127.0.0.1", bind.getPort());
}
@Test
public void shouldForwardPortForAllProtocols() throws IOException {
SSHClient sshClient = getFixtureClient();
RemotePortForwarder.Forward bind = forwardPort(sshClient, "", RANGE);
httpGet("127.0.0.1", bind.getPort());
}
private RemotePortForwarder.Forward forwardPort(SSHClient sshClient, String address, PortRange portRange) throws IOException {
while (true) {
try {
RemotePortForwarder.Forward forward = sshClient.getRemotePortForwarder().bind(
// where the server should listen
new RemotePortForwarder.Forward(address, portRange.nextPort()),
// what we do with incoming connections that are forwarded to us
new SocketForwardingConnectListener(HTTP_SERVER_SOCKET_ADDR));
return forward;
} catch (ConnectionException ce) {
if (!portRange.hasNext()) {
throw ce;
}
}
}
}
private void httpGet(String server, int port) throws IOException {
HttpClient client = HttpClientBuilder.create().build();
String urlString = "http://" + server + ":" + port;
System.out.println("Trying: GET " + urlString);
HttpResponse execute = client.execute(new HttpGet(urlString));
assertThat(execute.getStatusLine().getStatusCode(), equalTo(200));
}
private SSHClient getFixtureClient() throws IOException {
SSHClient sshClient = fixture.setupConnectedDefaultClient();
sshClient.authPassword("jeroen", "jeroen");
sshClient.getRemotePortForwarder().bind(
// where the server should listen
new RemotePortForwarder.Forward(0),
// what we do with incoming connections that are forwarded to us
new SocketForwardingConnectListener(new InetSocketAddress("127.0.0.1", 8080)));
return sshClient;
}
private static class PortRange {
private int upper;
private int current;
public PortRange(int lower, int upper) {
this.upper = upper;
this.current = lower;
}
public int nextPort() {
if (current < upper) {
return current++;
}
throw new IllegalStateException("Out of ports!");
}
public boolean hasNext() {
return current < upper;
}
}
private static class SinglePort extends PortRange {
private final int port;
public SinglePort(int port) {
super(port, port);
this.port = port;
}
@Override
public int nextPort() {
return port;
}
}
}