From ce5fad9809f8472f81fc0c5b28c58ca472b8c12e Mon Sep 17 00:00:00 2001 From: Shikhar Bhushan Date: Sun, 13 Mar 2011 22:51:35 +0000 Subject: [PATCH] Add DisconnectListener, refactor tests --- .../net/schmizz/sshj/AbstractService.java | 2 +- src/main/java/net/schmizz/sshj/Service.java | 3 +- .../sshj/connection/ConnectionImpl.java | 6 +- .../sshj/transport/DisconnectListener.java | 24 ++++ .../net/schmizz/sshj/transport/Transport.java | 24 +++- .../schmizz/sshj/transport/TransportImpl.java | 42 ++++-- src/test/java/net/schmizz/sshj/SmokeTest.java | 73 ++-------- .../schmizz/sshj/transport/Disconnection.java | 97 +++++++++++++ .../net/schmizz/sshj/util/BasicFixture.java | 131 ++++++++++++++++++ 9 files changed, 325 insertions(+), 77 deletions(-) create mode 100644 src/main/java/net/schmizz/sshj/transport/DisconnectListener.java create mode 100644 src/test/java/net/schmizz/sshj/transport/Disconnection.java create mode 100644 src/test/java/net/schmizz/sshj/util/BasicFixture.java diff --git a/src/main/java/net/schmizz/sshj/AbstractService.java b/src/main/java/net/schmizz/sshj/AbstractService.java index c69bcb65..9bdad59d 100644 --- a/src/main/java/net/schmizz/sshj/AbstractService.java +++ b/src/main/java/net/schmizz/sshj/AbstractService.java @@ -67,7 +67,7 @@ public abstract class AbstractService } @Override - public void notifyDisconnect() + public void notifyDisconnect(DisconnectReason reason) throws SSHException { log.debug("Was notified of disconnect"); } diff --git a/src/main/java/net/schmizz/sshj/Service.java b/src/main/java/net/schmizz/sshj/Service.java index dbf1ce66..f9ee6b47 100644 --- a/src/main/java/net/schmizz/sshj/Service.java +++ b/src/main/java/net/schmizz/sshj/Service.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj; +import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.ErrorNotifiable; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHPacketHandler; @@ -48,7 +49,7 @@ public interface Service void request() throws TransportException; - void notifyDisconnect() + void notifyDisconnect(DisconnectReason reason) throws SSHException; } \ No newline at end of file diff --git a/src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java b/src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java index ac326118..23bfb4f5 100644 --- a/src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java +++ b/src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java @@ -246,10 +246,10 @@ public class ConnectionImpl } @Override - public void notifyDisconnect() + public void notifyDisconnect(DisconnectReason reason) throws SSHException { - super.notifyDisconnect(); - final ConnectionException ex = new ConnectionException("Disconnected."); + super.notifyDisconnect(reason); + final ConnectionException ex = new ConnectionException("Disconnected"); FutureUtils.alertAll(ex, globalReqFutures); ErrorNotifiable.Util.alertAll(ex, new HashSet(channels.values())); } diff --git a/src/main/java/net/schmizz/sshj/transport/DisconnectListener.java b/src/main/java/net/schmizz/sshj/transport/DisconnectListener.java new file mode 100644 index 00000000..7b4199b3 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/DisconnectListener.java @@ -0,0 +1,24 @@ +/* + * Copyright 2010 Shikhar Bhushan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ +package net.schmizz.sshj.transport; + +import net.schmizz.sshj.common.DisconnectReason; + +public interface DisconnectListener { + + void notifyDisconnect(DisconnectReason reason); + +} diff --git a/src/main/java/net/schmizz/sshj/transport/Transport.java b/src/main/java/net/schmizz/sshj/transport/Transport.java index 3d39bfe3..c1f5f825 100644 --- a/src/main/java/net/schmizz/sshj/transport/Transport.java +++ b/src/main/java/net/schmizz/sshj/transport/Transport.java @@ -44,6 +44,7 @@ import net.schmizz.sshj.transport.verification.HostKeyVerifier; import java.io.InputStream; import java.io.OutputStream; +import java.util.concurrent.TimeUnit; /** Transport layer of the SSH protocol. */ public interface Transport @@ -171,12 +172,18 @@ public interface Transport boolean isRunning(); /** - * Joins the thread calling this method to the transport's death. The transport dies of exceptional events. + * Joins the thread calling this method to the transport's death. * - * @throws TransportException when the transport dies + * @throws TransportException if the transport dies of an exception */ void join() throws TransportException; + /** + * Joins the thread calling this method to the transport's death. + * + * @throws TransportException if the transport dies of an exception + */ + void join(int timeout, TimeUnit unit) throws TransportException; /** Send a disconnection packet with reason as {@link DisconnectReason#BY_APPLICATION}, and closes this transport. */ void disconnect(); @@ -211,4 +218,17 @@ public interface Transport */ long write(SSHPacket payload) throws TransportException; + + /** + * Specify a {@code listener} that will be notified upon disconnection. + * + * @param listener + */ + void setDisconnectListener(DisconnectListener listener); + + /** + * @return the current disconnect listener. + */ + DisconnectListener getDisconnectListener(); + } \ No newline at end of file diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index 2a35e769..fbcca625 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -85,6 +85,13 @@ public final class TransportImpl private final Service nullService = new NullService(this); + private final DisconnectListener nullDisconnectListener = new DisconnectListener() { + @Override + public void notifyDisconnect(DisconnectReason reason) { + log.debug("Default disconnect listener - {}", reason); + } + }; + private final Config config; private final KeyExchanger kexer; @@ -97,11 +104,9 @@ public final class TransportImpl private final Decoder decoder; - private final Event serviceAccept = new Event("service accept", - TransportException.chainer); + private final Event serviceAccept = new Event("service accept", TransportException.chainer); - private final Event close = new Event("transport close", - TransportException.chainer); + private final Event close = new Event("transport close", TransportException.chainer); /** Client version identification string */ private final String clientID; @@ -113,6 +118,8 @@ public final class TransportImpl /** Currently active service e.g. UserAuthService, ConnectionService */ private volatile Service service = nullService; + private DisconnectListener disconnectListener = nullDisconnectListener; + private ConnInfo connInfo; /** Server version identification string */ @@ -210,7 +217,7 @@ public final class TransportImpl if (!ident.startsWith("SSH-2.0-") && !ident.startsWith("SSH-1.99-")) throw new TransportException(DisconnectReason.PROTOCOL_VERSION_NOT_SUPPORTED, - "Server does not support SSHv2, identified as: " + ident); + "Server does not support SSHv2, identified as: " + ident); return ident; } @@ -347,6 +354,12 @@ public final class TransportImpl close.await(); } + @Override + public void join(int timeout, TimeUnit unit) + throws TransportException { + close.await(timeout, unit); + } + @Override public boolean isRunning() { return reader.isAlive() && !close.isSet(); @@ -364,10 +377,11 @@ public final class TransportImpl @Override public void disconnect(DisconnectReason reason, String message) { - close.lock(); // CAS type operation on close + close.lock(); try { + disconnectListener.notifyDisconnect(reason); try { - service.notifyDisconnect(); + service.notifyDisconnect(reason); } catch (SSHException logged) { log.warn("{} did not handle disconnect cleanly: {}", service, logged); } @@ -381,6 +395,16 @@ public final class TransportImpl } } + @Override + public void setDisconnectListener(DisconnectListener listener) { + this.disconnectListener = listener == null ? nullDisconnectListener : listener; + } + + @Override + public DisconnectListener getDisconnectListener() { + return disconnectListener; + } + @Override public long write(SSHPacket payload) throws TransportException { @@ -501,7 +525,7 @@ public final class TransportImpl try { if (!serviceAccept.hasWaiters()) throw new TransportException(DisconnectReason.PROTOCOL_ERROR, - "Got a service accept notification when none was awaited"); + "Got a service accept notification when none was awaited"); serviceAccept.set(); } finally { serviceAccept.unlock(); @@ -540,6 +564,8 @@ public final class TransportImpl final SSHException causeOfDeath = SSHException.chainer.chain(ex); + disconnectListener.notifyDisconnect(causeOfDeath.getDisconnectReason()); + FutureUtils.alertAll(causeOfDeath, close, serviceAccept); kexer.notifyError(causeOfDeath); getService().notifyError(causeOfDeath); diff --git a/src/test/java/net/schmizz/sshj/SmokeTest.java b/src/test/java/net/schmizz/sshj/SmokeTest.java index e1f5a356..41a7b5bf 100644 --- a/src/test/java/net/schmizz/sshj/SmokeTest.java +++ b/src/test/java/net/schmizz/sshj/SmokeTest.java @@ -17,93 +17,42 @@ package net.schmizz.sshj; import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.userauth.UserAuthException; -import net.schmizz.sshj.util.BogusPasswordAuthenticator; -import org.apache.sshd.SshServer; -import org.apache.sshd.common.keyprovider.FileKeyPairProvider; +import net.schmizz.sshj.util.BasicFixture; import org.junit.After; import org.junit.Before; import org.junit.Test; import java.io.IOException; -import java.net.ServerSocket; import static org.junit.Assert.assertTrue; /* Kinda basic right now */ - public class SmokeTest { - private SSHClient ssh; - private SshServer sshd; - - private final String hostname = "localhost"; - private int port; - - private static final String hostkey = "src/test/resources/hostkey.pem"; - private static final String fingerprint = "ce:a7:c1:cf:17:3f:96:49:6a:53:1a:05:0b:ba:90:db"; + private final BasicFixture fixture = new BasicFixture(); @Before public void setUp() throws IOException { - ServerSocket s = new ServerSocket(0); - port = s.getLocalPort(); - s.close(); - - sshd = SshServer.setUpDefaultServer(); - sshd.setPort(port); - sshd.setKeyPairProvider(new FileKeyPairProvider(new String[]{hostkey})); - // sshd.setShellFactory(new EchoShellFactory()); - sshd.setPasswordAuthenticator(new BogusPasswordAuthenticator()); - sshd.start(); - - ssh = new SSHClient(); - ssh.addHostKeyVerifier(fingerprint); + fixture.init(false); } @After - public void tearUp() + public void tearDown() throws IOException, InterruptedException { - ssh.disconnect(); - sshd.stop(); + fixture.done(); } @Test - public void testAuthenticate() + public void connected() throws IOException { - connect(); - authenticate(); - assertTrue(ssh.isAuthenticated()); + assertTrue(fixture.getClient().isConnected()); } @Test - public void testConnect() - throws IOException { - connect(); - assertTrue(ssh.isConnected()); + public void authenticated() throws UserAuthException, TransportException { + fixture.dummyAuth(); + assertTrue(fixture.getClient().isAuthenticated()); } - // @Test - // // TODO -- test I/O - // public void testShell() throws IOException - // { - // connect(); - // authenticate(); - // - // Shell shell = ssh.startSession().startShell(); - // assertTrue(shell.isOpen()); - // - // shell.close(); - // assertFalse(shell.isOpen()); - // } - - private void authenticate() - throws UserAuthException, TransportException { - ssh.authPassword("same", "same"); - } - - private void connect() - throws IOException { - ssh.connect(hostname, port); - } - -} +} \ No newline at end of file diff --git a/src/test/java/net/schmizz/sshj/transport/Disconnection.java b/src/test/java/net/schmizz/sshj/transport/Disconnection.java new file mode 100644 index 00000000..525513d0 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/transport/Disconnection.java @@ -0,0 +1,97 @@ +/* + * Copyright 2010 Shikhar Bhushan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package net.schmizz.sshj.transport; + +import net.schmizz.sshj.common.DisconnectReason; +import net.schmizz.sshj.util.BasicFixture; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class Disconnection { + + private final BasicFixture fixture = new BasicFixture(); + + private boolean notified; + + @Before + public void setUp() + throws IOException { + fixture.init(); + + notified = false; + + fixture.getClient().getTransport().setDisconnectListener(new DisconnectListener() { + @Override + public void notifyDisconnect(DisconnectReason reason) { + notified = true; + } + }); + + } + + @After + public void tearDown() + throws IOException, InterruptedException { + fixture.done(); + } + + private boolean joinToClientTransport(int seconds) { + try { + fixture.getClient().getTransport().join(seconds, TimeUnit.SECONDS); + return true; + } catch (TransportException ignored) { + return false; + } + } + + @Test + public void listenerNotifiedOnClientDisconnect() + throws IOException { + fixture.stopClient(); + assertTrue(notified); + } + + @Test + public void listenerNotifiedOnServerDisconnect() + throws InterruptedException, IOException { + fixture.stopServer(); + joinToClientTransport(2); + assertTrue(notified); + } + + @Test + public void joinNotifiedOnClientDisconnect() + throws IOException { + fixture.stopClient(); + assertTrue(joinToClientTransport(2)); + } + + @Test + public void joinNotifiedOnServerDisconnect() + throws TransportException, InterruptedException { + fixture.stopServer(); + assertFalse(joinToClientTransport(2)); + } + +} \ No newline at end of file diff --git a/src/test/java/net/schmizz/sshj/util/BasicFixture.java b/src/test/java/net/schmizz/sshj/util/BasicFixture.java new file mode 100644 index 00000000..3da14237 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/util/BasicFixture.java @@ -0,0 +1,131 @@ +/* +* Copyright 2010 Shikhar Bhushan +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package net.schmizz.sshj.util; + +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.transport.TransportException; +import net.schmizz.sshj.userauth.UserAuthException; +import org.apache.sshd.SshServer; +import org.apache.sshd.common.keyprovider.FileKeyPairProvider; +import org.apache.sshd.server.PasswordAuthenticator; +import org.apache.sshd.server.session.ServerSession; + +import java.io.IOException; +import java.net.ServerSocket; + + +public class BasicFixture { + + public static final String hostkey = "src/test/resources/hostkey.pem"; + public static final String fingerprint = "ce:a7:c1:cf:17:3f:96:49:6a:53:1a:05:0b:ba:90:db"; + + public static final String hostname = "localhost"; + public final int port = gimmeAPort(); + + private SSHClient client; + private SshServer server; + + private boolean clientRunning = false; + private boolean serverRunning = false; + + private static int gimmeAPort() { + try { + ServerSocket s = null; + try { + s = new ServerSocket(0); + return s.getLocalPort(); + } finally { + if (s != null) + s.close(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void init() + throws IOException { + init(false); + } + + public void init(boolean authenticate) + throws IOException { + startServer(); + startClient(authenticate); + } + + public void done() + throws InterruptedException, IOException { + stopClient(); + stopServer(); + } + + public void startServer() + throws IOException { + server = SshServer.setUpDefaultServer(); + server.setPort(port); + server.setKeyPairProvider(new FileKeyPairProvider(new String[]{hostkey})); + server.setPasswordAuthenticator(new PasswordAuthenticator() { + @Override + public boolean authenticate(String u, String p, ServerSession s) { + return false; + } + }); + server.start(); + serverRunning = true; + } + + public void stopServer() + throws InterruptedException { + if (serverRunning) { + server.stop(); + serverRunning = false; + } + } + + public SshServer getServer() { + return server; + } + + public void startClient(boolean authenticate) + throws IOException { + client = new SSHClient(); + client.addHostKeyVerifier(fingerprint); + client.connect(hostname, port); + if (authenticate) + dummyAuth(); + clientRunning = true; + } + + public void stopClient() + throws IOException { + if (clientRunning) { + client.disconnect(); + clientRunning = false; + } + } + + public SSHClient getClient() { + return client; + } + + public void dummyAuth() + throws UserAuthException, TransportException { + server.setPasswordAuthenticator(new BogusPasswordAuthenticator()); + client.authPassword("same", "same"); + } + +}