From 20879a4aa559e21da3a753de9a8aaebb6667c377 Mon Sep 17 00:00:00 2001 From: Jeroen van Erp Date: Thu, 29 Dec 2016 16:06:57 +0100 Subject: [PATCH] LocalPortForwarder interrupts its thread on close() --- .../channel/direct/LocalPortForwarder.java | 18 +++++- .../direct/LocalPortForwarderSpec.groovy | 55 +++++++++++++++++++ 2 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 src/test/groovy/com/hierynomus/sshj/connection/channel/direct/LocalPortForwarderSpec.groovy diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java index 2369a658..3407e368 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/LocalPortForwarder.java @@ -107,6 +107,7 @@ public class LocalPortForwarder { private final Connection conn; private final Parameters parameters; private final ServerSocket serverSocket; + private Thread runningThread; public LocalPortForwarder(Connection conn, Parameters parameters, ServerSocket serverSocket, LoggerFactory loggerFactory) { this.conn = conn; @@ -132,10 +133,20 @@ public class LocalPortForwarder { * * @throws IOException */ - public void listen() - throws IOException { + public void listen() throws IOException { + listen(Thread.currentThread()); + } + + /** + * Start listening for incoming connections and forward to remote host as a channel and ensure that the thread is registered. + * This is useful if for instance {@link #close() is called from another thread} + * + * @throws IOException + */ + public void listen(Thread runningThread) throws IOException { + this.runningThread = runningThread; log.info("Listening on {}", serverSocket.getLocalSocketAddress()); - while (!Thread.currentThread().isInterrupted()) { + while (!runningThread.isInterrupted()) { try { final Socket socket = serverSocket.accept(); log.debug("Got connection from {}", socket.getRemoteSocketAddress()); @@ -162,6 +173,7 @@ public class LocalPortForwarder { if (!serverSocket.isClosed()) { log.info("Closing listener on {}", serverSocket.getLocalSocketAddress()); serverSocket.close(); + runningThread.interrupt(); } } diff --git a/src/test/groovy/com/hierynomus/sshj/connection/channel/direct/LocalPortForwarderSpec.groovy b/src/test/groovy/com/hierynomus/sshj/connection/channel/direct/LocalPortForwarderSpec.groovy new file mode 100644 index 00000000..4c6ac9d6 --- /dev/null +++ b/src/test/groovy/com/hierynomus/sshj/connection/channel/direct/LocalPortForwarderSpec.groovy @@ -0,0 +1,55 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.connection.channel.direct + +import com.hierynomus.sshj.test.SshFixture +import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder +import org.junit.Rule +import spock.lang.Specification + +class LocalPortForwarderSpec extends Specification { + @Rule + SshFixture tunnelFixture = new SshFixture() + + @Rule + SshFixture realServer = new SshFixture() + + def "should not hang when disconnect tunnel"() { + given: + def client = tunnelFixture.setupConnectedDefaultClient() + client.authPassword("test", "test") + def socket = new ServerSocket(0) + def lpf = client.newLocalPortForwarder(new LocalPortForwarder.Parameters("localhost", socket.getLocalPort(), "localhost", realServer.server.port), socket) + def thread = new Thread(new Runnable() { + @Override + void run() { + lpf.listen() + } + }) + + when: + thread.start() + + then: + thread.isAlive() + + when: + lpf.close() + + then: + socket.isClosed() + } +}