From 8a66dc5336fc641b6fe7a607241f3dd7d36af5ad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Torbj=C3=B8rn=20S=C3=B8iland?=
Date: Tue, 19 Oct 2021 16:34:59 +0200
Subject: [PATCH] Close client connection when remote closes connection +
testing (#686) (#687)
---
.../net/schmizz/sshj/common/StreamCopier.java | 10 +-
.../forwarded/LocalPortForwarderTest.java | 139 ++++++++++++++++++
2 files changed, 147 insertions(+), 2 deletions(-)
create mode 100644 src/test/java/com/hierynomus/sshj/connection/channel/forwarded/LocalPortForwarderTest.java
diff --git a/src/main/java/net/schmizz/sshj/common/StreamCopier.java b/src/main/java/net/schmizz/sshj/common/StreamCopier.java
index 344df039..424d8bf4 100644
--- a/src/main/java/net/schmizz/sshj/common/StreamCopier.java
+++ b/src/main/java/net/schmizz/sshj/common/StreamCopier.java
@@ -145,8 +145,14 @@ public class StreamCopier {
final double sizeKiB = count / 1024.0;
log.debug(String.format("%1$,.1f KiB transferred in %2$,.1f seconds (%3$,.2f KiB/s)", sizeKiB, timeSeconds, (sizeKiB / timeSeconds)));
- if (length != -1 && read == -1)
- throw new IOException("Encountered EOF, could not transfer " + length + " bytes");
+ // Did we encounter EOF?
+ if (read == -1) {
+ // If InputStream was closed we should also close OutputStream
+ out.close();
+
+ if (length != -1)
+ throw new IOException("Encountered EOF, could not transfer " + length + " bytes");
+ }
return count;
}
diff --git a/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/LocalPortForwarderTest.java b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/LocalPortForwarderTest.java
new file mode 100644
index 00000000..fd48b886
--- /dev/null
+++ b/src/test/java/com/hierynomus/sshj/connection/channel/forwarded/LocalPortForwarderTest.java
@@ -0,0 +1,139 @@
+/*
+ * Copyright (C)2009 - SSHJ Contributors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.hierynomus.sshj.connection.channel.forwarded;
+
+import com.hierynomus.sshj.test.HttpServer;
+import com.hierynomus.sshj.test.SshFixture;
+import com.hierynomus.sshj.test.util.FileUtil;
+import net.schmizz.sshj.SSHClient;
+import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
+import net.schmizz.sshj.connection.channel.direct.Parameters;
+import org.apache.http.HttpResponse;
+import org.apache.http.client.HttpClient;
+import org.apache.http.client.methods.HttpGet;
+import org.apache.http.impl.client.HttpClientBuilder;
+import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.*;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+public class LocalPortForwarderTest {
+ private static final Logger log = LoggerFactory.getLogger(LocalPortForwarderTest.class);
+
+ @Rule
+ public SshFixture fixture = new SshFixture();
+
+ @Rule
+ public HttpServer httpServer = new HttpServer();
+
+ @Before
+ public void setUp() throws IOException {
+ fixture.getServer().setForwardingFilter(new AcceptAllForwardingFilter());
+ File file = httpServer.getDocRoot().newFile("index.html");
+ FileUtil.writeToFile(file, "
Hi!
");
+ }
+
+ @Test
+ public void shouldHaveWorkingHttpServer() throws IOException {
+ // Just to check that we have a working http server...
+ assertThat(httpGet("127.0.0.1", 8080), equalTo(200));
+ }
+
+ @Test
+ public void shouldHaveHttpServerThatClosesConnectionAfterResponse() throws IOException {
+ // Just to check that the test server does close connections before we try through the forwarder...
+ httpGetAndAssertConnectionClosedByServer(8080);
+ }
+
+ @Test(timeout = 10_000)
+ public void shouldCloseConnectionWhenRemoteServerClosesConnection() throws IOException {
+ SSHClient sshClient = getFixtureClient();
+
+ ServerSocket serverSocket = new ServerSocket();
+ serverSocket.setReuseAddress(true);
+ serverSocket.bind(new InetSocketAddress("0.0.0.0", 12345));
+ LocalPortForwarder localPortForwarder = sshClient.newLocalPortForwarder(new Parameters("0.0.0.0", 12345, "localhost", 8080), serverSocket);
+ new Thread(() -> {
+ try {
+ localPortForwarder.listen();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }, "local port listener").start();
+
+ // Test once to prove that the local HTTP connection is closed when the remote HTTP connection is closed.
+ httpGetAndAssertConnectionClosedByServer(12345);
+
+ // Test again to prove that the tunnel is still open, even after HTTP connection was closed.
+ httpGetAndAssertConnectionClosedByServer(12345);
+ }
+
+ public static void httpGetAndAssertConnectionClosedByServer(int port) throws IOException {
+ System.out.println("HTTP GET to port: " + port);
+ try (Socket socket = new Socket("localhost", port)) {
+ // Send a basic HTTP GET
+ // It returns 400 Bad Request because it's missing a bunch of info, but the HTTP response doesn't matter, we just want to test the connection closing.
+ OutputStream outputStream = socket.getOutputStream();
+ PrintWriter writer = new PrintWriter(outputStream);
+ writer.println("GET / HTTP/1.1");
+ writer.println("");
+ writer.flush();
+
+ // Read the HTTP response
+ InputStream inputStream = socket.getInputStream();
+ InputStreamReader reader = new InputStreamReader(inputStream);
+ int buf = -2;
+ while (true) {
+ buf = reader.read();
+ System.out.print((char)buf);
+ if (buf == -1) {
+ break;
+ }
+ }
+
+ // Attempt to read more. If the server has closed the connection this will return -1
+ int read = inputStream.read();
+
+ // Assert input stream was closed by server.
+ Assert.assertEquals(-1, read);
+ }
+ }
+
+ private int httpGet(String server, int port) throws IOException {
+ HttpClient client = HttpClientBuilder.create().build();
+ String urlString = "http://" + server + ":" + port;
+ log.info("Trying: GET " + urlString);
+ HttpResponse execute = client.execute(new HttpGet(urlString));
+ return execute.getStatusLine().getStatusCode();
+ }
+
+ private SSHClient getFixtureClient() throws IOException {
+ SSHClient sshClient = fixture.setupConnectedDefaultClient();
+ sshClient.authPassword("jeroen", "jeroen");
+ return sshClient;
+ }
+}