Fix #805: Prevent CHANNEL_CLOSE to be sent between Channel.isOpen and… (#813)

* Fix #805: Prevent CHANNEL_CLOSE to be sent between Channel.isOpen and a Transport.write call

Otherwise, a disconnect with a "packet referred to nonexistent channel" message can occur.

This particularly happens when the transport.Reader thread passes an eof from the server to the ChannelInputStream, the reading library-user thread returns, and closes the channel at the same time as the transport.Reader thread receives the subsequent CHANNEL_CLOSE from the server.

* Add integration test for #805
This commit is contained in:
kegelh
2022-09-17 07:11:11 +02:00
committed by GitHub
parent 2551f8e559
commit d5d6096d5d
3 changed files with 119 additions and 8 deletions

View File

@@ -0,0 +1,74 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hierynomus.sshj
import net.schmizz.sshj.SSHClient
import net.schmizz.sshj.common.IOUtils
import net.schmizz.sshj.connection.channel.direct.Session
import spock.lang.Specification
import java.util.concurrent.*
import static org.codehaus.groovy.runtime.IOGroovyMethods.withCloseable
class ManyChannelsSpec extends Specification {
def "should work with many channels without nonexistent channel error (GH issue #805)"() {
given:
SshdContainer sshd = new SshdContainer.Builder()
.withSshdConfig("""${SshdContainer.Builder.DEFAULT_SSHD_CONFIG}
MaxSessions 200
""".stripMargin())
.build()
sshd.start()
SSHClient client = sshd.getConnectedClient()
client.authPublickey("sshj", "src/test/resources/id_rsa")
when:
List<Future<Exception>> futures = []
ExecutorService executorService = Executors.newCachedThreadPool()
for (int i in 0..20) {
futures.add(executorService.submit((Callable<Exception>) {
return execute(client)
}))
}
executorService.shutdown()
executorService.awaitTermination(1, TimeUnit.DAYS)
then:
futures*.get().findAll { it != null }.empty
cleanup:
client.close()
}
private static Exception execute(SSHClient sshClient) {
try {
for (def i in 0..100) {
withCloseable (sshClient.startSession()) {sshSession ->
Session.Command sshCommand = sshSession.exec("ls -la")
IOUtils.readFully(sshCommand.getInputStream()).toString()
sshCommand.close()
}
}
} catch (Exception e) {
return e
}
return null
}
}

View File

@@ -304,6 +304,25 @@ public abstract class AbstractChannel
} }
} }
// Prevent CHANNEL_CLOSE to be sent between isOpen and a Transport.write call in the runnable, otherwise
// a disconnect with a "packet referred to nonexistent channel" message can occur.
//
// This particularly happens when the transport.Reader thread passes an eof from the server to the
// ChannelInputStream, the reading library-user thread returns, and closes the channel at the same time as the
// transport.Reader thread receives the subsequent CHANNEL_CLOSE from the server.
boolean whileOpen(TransportRunnable runnable) throws TransportException, ConnectionException {
openCloseLock.lock();
try {
if (isOpen()) {
runnable.run();
return true;
}
} finally {
openCloseLock.unlock();
}
return false;
}
private void gotChannelRequest(SSHPacket buf) private void gotChannelRequest(SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
final String reqType; final String reqType;
@@ -427,5 +446,8 @@ public abstract class AbstractChannel
+ rwin + " >"; + rwin + " >";
} }
public interface TransportRunnable {
void run() throws TransportException, ConnectionException;
}
} }

View File

@@ -30,7 +30,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
*/ */
public final class ChannelOutputStream extends OutputStream implements ErrorNotifiable { public final class ChannelOutputStream extends OutputStream implements ErrorNotifiable {
private final Channel chan; private final AbstractChannel chan;
private final Transport trans; private final Transport trans;
private final Window.Remote win; private final Window.Remote win;
@@ -47,6 +47,12 @@ public final class ChannelOutputStream extends OutputStream implements ErrorNoti
private final SSHPacket packet = new SSHPacket(Message.CHANNEL_DATA); private final SSHPacket packet = new SSHPacket(Message.CHANNEL_DATA);
private final Buffer.PlainBuffer leftOvers = new Buffer.PlainBuffer(); private final Buffer.PlainBuffer leftOvers = new Buffer.PlainBuffer();
private final AbstractChannel.TransportRunnable packetWriteRunnable = new AbstractChannel.TransportRunnable() {
@Override
public void run() throws TransportException {
trans.write(packet);
}
};
DataBuffer() { DataBuffer() {
headerOffset = packet.rpos(); headerOffset = packet.rpos();
@@ -99,8 +105,9 @@ public final class ChannelOutputStream extends OutputStream implements ErrorNoti
if (leftOverBytes > 0) { if (leftOverBytes > 0) {
leftOvers.putRawBytes(packet.array(), packet.wpos(), leftOverBytes); leftOvers.putRawBytes(packet.array(), packet.wpos(), leftOverBytes);
} }
if (!chan.whileOpen(packetWriteRunnable)) {
trans.write(packet); throwStreamClosed();
}
win.consume(writeNow); win.consume(writeNow);
packet.rpos(headerOffset); packet.rpos(headerOffset);
@@ -119,7 +126,7 @@ public final class ChannelOutputStream extends OutputStream implements ErrorNoti
} }
public ChannelOutputStream(Channel chan, Transport trans, Window.Remote win) { public ChannelOutputStream(AbstractChannel chan, Transport trans, Window.Remote win) {
this.chan = chan; this.chan = chan;
this.trans = trans; this.trans = trans;
this.win = win; this.win = win;
@@ -157,7 +164,7 @@ public final class ChannelOutputStream extends OutputStream implements ErrorNoti
if (error != null) { if (error != null) {
throw error; throw error;
} else { } else {
throw new ConnectionException("Stream closed"); throwStreamClosed();
} }
} }
} }
@@ -165,9 +172,14 @@ public final class ChannelOutputStream extends OutputStream implements ErrorNoti
@Override @Override
public synchronized void close() throws IOException { public synchronized void close() throws IOException {
// Not closed yet, and underlying channel is open to flush the data to. // Not closed yet, and underlying channel is open to flush the data to.
if (!closed.getAndSet(true) && chan.isOpen()) { if (!closed.getAndSet(true)) {
buffer.flush(false); chan.whileOpen(new AbstractChannel.TransportRunnable() {
trans.write(new SSHPacket(Message.CHANNEL_EOF).putUInt32(chan.getRecipient())); @Override
public void run() throws TransportException, ConnectionException {
buffer.flush(false);
trans.write(new SSHPacket(Message.CHANNEL_EOF).putUInt32(chan.getRecipient()));
}
});
} }
} }
@@ -188,4 +200,7 @@ public final class ChannelOutputStream extends OutputStream implements ErrorNoti
return "< ChannelOutputStream for Channel #" + chan.getID() + " >"; return "< ChannelOutputStream for Channel #" + chan.getID() + " >";
} }
private static void throwStreamClosed() throws ConnectionException {
throw new ConnectionException("Stream closed");
}
} }