Add DisconnectListener, refactor tests

This commit is contained in:
Shikhar Bhushan
2011-03-13 22:51:35 +00:00
parent 38883bf15d
commit ce5fad9809
9 changed files with 325 additions and 77 deletions

View File

@@ -67,7 +67,7 @@ public abstract class AbstractService
}
@Override
public void notifyDisconnect()
public void notifyDisconnect(DisconnectReason reason)
throws SSHException {
log.debug("Was notified of disconnect");
}

View File

@@ -15,6 +15,7 @@
*/
package net.schmizz.sshj;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.ErrorNotifiable;
import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHPacketHandler;
@@ -48,7 +49,7 @@ public interface Service
void request()
throws TransportException;
void notifyDisconnect()
void notifyDisconnect(DisconnectReason reason)
throws SSHException;
}

View File

@@ -246,10 +246,10 @@ public class ConnectionImpl
}
@Override
public void notifyDisconnect()
public void notifyDisconnect(DisconnectReason reason)
throws SSHException {
super.notifyDisconnect();
final ConnectionException ex = new ConnectionException("Disconnected.");
super.notifyDisconnect(reason);
final ConnectionException ex = new ConnectionException("Disconnected");
FutureUtils.alertAll(ex, globalReqFutures);
ErrorNotifiable.Util.alertAll(ex, new HashSet<Channel>(channels.values()));
}

View File

@@ -0,0 +1,24 @@
/*
* Copyright 2010 Shikhar Bhushan
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
**/
package net.schmizz.sshj.transport;
import net.schmizz.sshj.common.DisconnectReason;
public interface DisconnectListener {
void notifyDisconnect(DisconnectReason reason);
}

View File

@@ -44,6 +44,7 @@ import net.schmizz.sshj.transport.verification.HostKeyVerifier;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.TimeUnit;
/** Transport layer of the SSH protocol. */
public interface Transport
@@ -171,12 +172,18 @@ public interface Transport
boolean isRunning();
/**
* Joins the thread calling this method to the transport's death. The transport dies of exceptional events.
* Joins the thread calling this method to the transport's death.
*
* @throws TransportException when the transport dies
* @throws TransportException if the transport dies of an exception
*/
void join()
throws TransportException;
/**
* Joins the thread calling this method to the transport's death.
*
* @throws TransportException if the transport dies of an exception
*/
void join(int timeout, TimeUnit unit) throws TransportException;
/** Send a disconnection packet with reason as {@link DisconnectReason#BY_APPLICATION}, and closes this transport. */
void disconnect();
@@ -211,4 +218,17 @@ public interface Transport
*/
long write(SSHPacket payload)
throws TransportException;
/**
* Specify a {@code listener} that will be notified upon disconnection.
*
* @param listener
*/
void setDisconnectListener(DisconnectListener listener);
/**
* @return the current disconnect listener.
*/
DisconnectListener getDisconnectListener();
}

View File

@@ -85,6 +85,13 @@ public final class TransportImpl
private final Service nullService = new NullService(this);
private final DisconnectListener nullDisconnectListener = new DisconnectListener() {
@Override
public void notifyDisconnect(DisconnectReason reason) {
log.debug("Default disconnect listener - {}", reason);
}
};
private final Config config;
private final KeyExchanger kexer;
@@ -97,11 +104,9 @@ public final class TransportImpl
private final Decoder decoder;
private final Event<TransportException> serviceAccept = new Event<TransportException>("service accept",
TransportException.chainer);
private final Event<TransportException> serviceAccept = new Event<TransportException>("service accept", TransportException.chainer);
private final Event<TransportException> close = new Event<TransportException>("transport close",
TransportException.chainer);
private final Event<TransportException> close = new Event<TransportException>("transport close", TransportException.chainer);
/** Client version identification string */
private final String clientID;
@@ -113,6 +118,8 @@ public final class TransportImpl
/** Currently active service e.g. UserAuthService, ConnectionService */
private volatile Service service = nullService;
private DisconnectListener disconnectListener = nullDisconnectListener;
private ConnInfo connInfo;
/** Server version identification string */
@@ -210,7 +217,7 @@ public final class TransportImpl
if (!ident.startsWith("SSH-2.0-") && !ident.startsWith("SSH-1.99-"))
throw new TransportException(DisconnectReason.PROTOCOL_VERSION_NOT_SUPPORTED,
"Server does not support SSHv2, identified as: " + ident);
"Server does not support SSHv2, identified as: " + ident);
return ident;
}
@@ -347,6 +354,12 @@ public final class TransportImpl
close.await();
}
@Override
public void join(int timeout, TimeUnit unit)
throws TransportException {
close.await(timeout, unit);
}
@Override
public boolean isRunning() {
return reader.isAlive() && !close.isSet();
@@ -364,10 +377,11 @@ public final class TransportImpl
@Override
public void disconnect(DisconnectReason reason, String message) {
close.lock(); // CAS type operation on close
close.lock();
try {
disconnectListener.notifyDisconnect(reason);
try {
service.notifyDisconnect();
service.notifyDisconnect(reason);
} catch (SSHException logged) {
log.warn("{} did not handle disconnect cleanly: {}", service, logged);
}
@@ -381,6 +395,16 @@ public final class TransportImpl
}
}
@Override
public void setDisconnectListener(DisconnectListener listener) {
this.disconnectListener = listener == null ? nullDisconnectListener : listener;
}
@Override
public DisconnectListener getDisconnectListener() {
return disconnectListener;
}
@Override
public long write(SSHPacket payload)
throws TransportException {
@@ -501,7 +525,7 @@ public final class TransportImpl
try {
if (!serviceAccept.hasWaiters())
throw new TransportException(DisconnectReason.PROTOCOL_ERROR,
"Got a service accept notification when none was awaited");
"Got a service accept notification when none was awaited");
serviceAccept.set();
} finally {
serviceAccept.unlock();
@@ -540,6 +564,8 @@ public final class TransportImpl
final SSHException causeOfDeath = SSHException.chainer.chain(ex);
disconnectListener.notifyDisconnect(causeOfDeath.getDisconnectReason());
FutureUtils.alertAll(causeOfDeath, close, serviceAccept);
kexer.notifyError(causeOfDeath);
getService().notifyError(causeOfDeath);

View File

@@ -17,93 +17,42 @@ package net.schmizz.sshj;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.util.BogusPasswordAuthenticator;
import org.apache.sshd.SshServer;
import org.apache.sshd.common.keyprovider.FileKeyPairProvider;
import net.schmizz.sshj.util.BasicFixture;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.net.ServerSocket;
import static org.junit.Assert.assertTrue;
/* Kinda basic right now */
public class SmokeTest {
private SSHClient ssh;
private SshServer sshd;
private final String hostname = "localhost";
private int port;
private static final String hostkey = "src/test/resources/hostkey.pem";
private static final String fingerprint = "ce:a7:c1:cf:17:3f:96:49:6a:53:1a:05:0b:ba:90:db";
private final BasicFixture fixture = new BasicFixture();
@Before
public void setUp()
throws IOException {
ServerSocket s = new ServerSocket(0);
port = s.getLocalPort();
s.close();
sshd = SshServer.setUpDefaultServer();
sshd.setPort(port);
sshd.setKeyPairProvider(new FileKeyPairProvider(new String[]{hostkey}));
// sshd.setShellFactory(new EchoShellFactory());
sshd.setPasswordAuthenticator(new BogusPasswordAuthenticator());
sshd.start();
ssh = new SSHClient();
ssh.addHostKeyVerifier(fingerprint);
fixture.init(false);
}
@After
public void tearUp()
public void tearDown()
throws IOException, InterruptedException {
ssh.disconnect();
sshd.stop();
fixture.done();
}
@Test
public void testAuthenticate()
public void connected()
throws IOException {
connect();
authenticate();
assertTrue(ssh.isAuthenticated());
assertTrue(fixture.getClient().isConnected());
}
@Test
public void testConnect()
throws IOException {
connect();
assertTrue(ssh.isConnected());
public void authenticated() throws UserAuthException, TransportException {
fixture.dummyAuth();
assertTrue(fixture.getClient().isAuthenticated());
}
// @Test
// // TODO -- test I/O
// public void testShell() throws IOException
// {
// connect();
// authenticate();
//
// Shell shell = ssh.startSession().startShell();
// assertTrue(shell.isOpen());
//
// shell.close();
// assertFalse(shell.isOpen());
// }
private void authenticate()
throws UserAuthException, TransportException {
ssh.authPassword("same", "same");
}
private void connect()
throws IOException {
ssh.connect(hostname, port);
}
}
}

View File

@@ -0,0 +1,97 @@
/*
* Copyright 2010 Shikhar Bhushan
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.transport;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.util.BasicFixture;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class Disconnection {
private final BasicFixture fixture = new BasicFixture();
private boolean notified;
@Before
public void setUp()
throws IOException {
fixture.init();
notified = false;
fixture.getClient().getTransport().setDisconnectListener(new DisconnectListener() {
@Override
public void notifyDisconnect(DisconnectReason reason) {
notified = true;
}
});
}
@After
public void tearDown()
throws IOException, InterruptedException {
fixture.done();
}
private boolean joinToClientTransport(int seconds) {
try {
fixture.getClient().getTransport().join(seconds, TimeUnit.SECONDS);
return true;
} catch (TransportException ignored) {
return false;
}
}
@Test
public void listenerNotifiedOnClientDisconnect()
throws IOException {
fixture.stopClient();
assertTrue(notified);
}
@Test
public void listenerNotifiedOnServerDisconnect()
throws InterruptedException, IOException {
fixture.stopServer();
joinToClientTransport(2);
assertTrue(notified);
}
@Test
public void joinNotifiedOnClientDisconnect()
throws IOException {
fixture.stopClient();
assertTrue(joinToClientTransport(2));
}
@Test
public void joinNotifiedOnServerDisconnect()
throws TransportException, InterruptedException {
fixture.stopServer();
assertFalse(joinToClientTransport(2));
}
}

View File

@@ -0,0 +1,131 @@
/*
* Copyright 2010 Shikhar Bhushan
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.util;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.userauth.UserAuthException;
import org.apache.sshd.SshServer;
import org.apache.sshd.common.keyprovider.FileKeyPairProvider;
import org.apache.sshd.server.PasswordAuthenticator;
import org.apache.sshd.server.session.ServerSession;
import java.io.IOException;
import java.net.ServerSocket;
public class BasicFixture {
public static final String hostkey = "src/test/resources/hostkey.pem";
public static final String fingerprint = "ce:a7:c1:cf:17:3f:96:49:6a:53:1a:05:0b:ba:90:db";
public static final String hostname = "localhost";
public final int port = gimmeAPort();
private SSHClient client;
private SshServer server;
private boolean clientRunning = false;
private boolean serverRunning = false;
private static int gimmeAPort() {
try {
ServerSocket s = null;
try {
s = new ServerSocket(0);
return s.getLocalPort();
} finally {
if (s != null)
s.close();
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public void init()
throws IOException {
init(false);
}
public void init(boolean authenticate)
throws IOException {
startServer();
startClient(authenticate);
}
public void done()
throws InterruptedException, IOException {
stopClient();
stopServer();
}
public void startServer()
throws IOException {
server = SshServer.setUpDefaultServer();
server.setPort(port);
server.setKeyPairProvider(new FileKeyPairProvider(new String[]{hostkey}));
server.setPasswordAuthenticator(new PasswordAuthenticator() {
@Override
public boolean authenticate(String u, String p, ServerSession s) {
return false;
}
});
server.start();
serverRunning = true;
}
public void stopServer()
throws InterruptedException {
if (serverRunning) {
server.stop();
serverRunning = false;
}
}
public SshServer getServer() {
return server;
}
public void startClient(boolean authenticate)
throws IOException {
client = new SSHClient();
client.addHostKeyVerifier(fingerprint);
client.connect(hostname, port);
if (authenticate)
dummyAuth();
clientRunning = true;
}
public void stopClient()
throws IOException {
if (clientRunning) {
client.disconnect();
clientRunning = false;
}
}
public SSHClient getClient() {
return client;
}
public void dummyAuth()
throws UserAuthException, TransportException {
server.setPasswordAuthenticator(new BogusPasswordAuthenticator());
client.authPassword("same", "same");
}
}