Merge branch 'heartbeat'

This commit is contained in:
hierynomus
2015-01-19 10:06:23 +01:00
16 changed files with 318 additions and 92 deletions

View File

@@ -202,6 +202,16 @@ public class Promise<V, T extends Throwable> {
}
}
/** @return whether this promise was fulfilled with either a value or an error. */
public boolean isFulfilled() {
lock.lock();
try {
return pendingEx != null || val != null;
} finally {
lock.unlock();
}
}
/** @return whether this promise has threads waiting on it. */
public boolean hasWaiters() {
lock.lock();

View File

@@ -0,0 +1,34 @@
/**
* Copyright 2009 sshj contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.keepalive;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
final class Heartbeater
extends KeepAlive {
Heartbeater(ConnectionImpl conn) {
super(conn, "heartbeater");
}
@Override
protected void doKeepAlive() throws TransportException {
conn.getTransport().write(new SSHPacket(Message.IGNORE));
}
}

View File

@@ -0,0 +1,68 @@
package net.schmizz.keepalive;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class KeepAlive extends Thread {
protected final Logger log = LoggerFactory.getLogger(getClass());
protected final ConnectionImpl conn;
protected int keepAliveInterval = 0;
protected KeepAlive(ConnectionImpl conn, String name) {
this.conn = conn;
setName(name);
}
public synchronized int getKeepAliveInterval() {
return keepAliveInterval;
}
public synchronized void setKeepAliveInterval(int keepAliveInterval) {
this.keepAliveInterval = keepAliveInterval;
if (keepAliveInterval > 0 && getState() == State.NEW) {
start();
}
notify();
}
synchronized protected int getPositiveInterval()
throws InterruptedException {
while (keepAliveInterval <= 0) {
wait();
}
return keepAliveInterval;
}
@Override
public void run() {
log.debug("Starting {}, sending keep-alive every {} seconds", getClass().getSimpleName(), keepAliveInterval);
try {
while (!isInterrupted()) {
final int hi = getPositiveInterval();
if (conn.getTransport().isRunning()) {
log.debug("Sending keep-alive since {} seconds elapsed", hi);
doKeepAlive();
}
Thread.sleep(hi * 1000);
}
} catch (Exception e) {
// If we weren't interrupted, kill the transport, then this exception was unexpected.
// Else we're in shutdown-mode already, so don't forcibly kill the transport.
if (!isInterrupted()) {
conn.getTransport().die(e);
}
}
log.debug("Stopping {}", getClass().getSimpleName());
}
protected abstract void doKeepAlive() throws TransportException, ConnectionException;
}

View File

@@ -0,0 +1,24 @@
package net.schmizz.keepalive;
import net.schmizz.sshj.connection.ConnectionImpl;
public abstract class KeepAliveProvider {
public static final KeepAliveProvider HEARTBEAT = new KeepAliveProvider() {
@Override
public KeepAlive provide(ConnectionImpl connection) {
return new Heartbeater(connection);
}
};
public static final KeepAliveProvider KEEP_ALIVE = new KeepAliveProvider() {
@Override
public KeepAlive provide(ConnectionImpl connection) {
return new KeepAliveRunner(connection);
}
};
public abstract KeepAlive provide(ConnectionImpl connection);
}

View File

@@ -0,0 +1,60 @@
package net.schmizz.keepalive;
import net.schmizz.concurrent.Promise;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.LinkedList;
import java.util.Queue;
import static java.lang.String.format;
import static net.schmizz.sshj.common.DisconnectReason.CONNECTION_LOST;
public class KeepAliveRunner extends KeepAlive {
/** The max number of keep-alives that should be unanswered before killing the connection. */
private int maxAliveCount = 5;
/** The queue of promises. */
private final Queue<Promise<SSHPacket, ConnectionException>> queue =
new LinkedList<Promise<SSHPacket, ConnectionException>>();
KeepAliveRunner(ConnectionImpl conn) {
super(conn, "keep-alive");
}
synchronized public int getMaxAliveCount() {
return maxAliveCount;
}
synchronized public void setMaxAliveCount(int maxAliveCount) {
this.maxAliveCount = maxAliveCount;
}
@Override
protected void doKeepAlive() throws TransportException, ConnectionException {
emptyQueue(queue);
checkMaxReached(queue);
queue.add(conn.sendGlobalRequest("keepalive@openssh.com", true, new byte[0]));
}
private void checkMaxReached(Queue<Promise<SSHPacket, ConnectionException>> queue) throws ConnectionException {
if (queue.size() >= maxAliveCount) {
throw new ConnectionException(CONNECTION_LOST,
format("Did not receive any keep-alive response for %s seconds", maxAliveCount * keepAliveInterval));
}
}
private void emptyQueue(Queue<Promise<SSHPacket, ConnectionException>> queue) {
Promise<SSHPacket, ConnectionException> peek = queue.peek();
while (peek != null && peek.isFulfilled()) {
log.debug("Received response from server to our keep-alive.");
queue.remove();
peek = queue.peek();
}
}
}

View File

@@ -15,6 +15,7 @@
*/
package net.schmizz.sshj;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.signature.Signature;
import net.schmizz.sshj.transport.cipher.Cipher;
@@ -144,4 +145,14 @@ public interface Config {
*/
void setVersion(String version);
/**
* @return The provider that creates the keep-alive implementation of choice.
*/
KeepAliveProvider getKeepAliveProvider();
/**
* Set the provider that provides the keep-alive implementation.
* @param keepAliveProvider keep-alive provider
*/
void setKeepAliveProvider(KeepAliveProvider keepAliveProvider);
}

View File

@@ -15,6 +15,7 @@
*/
package net.schmizz.sshj;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.signature.Signature;
import net.schmizz.sshj.transport.cipher.Cipher;
@@ -34,6 +35,7 @@ public class ConfigImpl
private String version;
private Factory<Random> randomFactory;
private KeepAliveProvider keepAliveProvider;
private List<Factory.Named<KeyExchange>> kexFactories;
private List<Factory.Named<Cipher>> cipherFactories;
@@ -146,4 +148,13 @@ public class ConfigImpl
this.version = version;
}
@Override
public KeepAliveProvider getKeepAliveProvider() {
return keepAliveProvider;
}
@Override
public void setKeepAliveProvider(KeepAliveProvider keepAliveProvider) {
this.keepAliveProvider = keepAliveProvider;
}
}

View File

@@ -15,6 +15,7 @@
*/
package net.schmizz.sshj;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.signature.SignatureDSA;
@@ -92,6 +93,7 @@ public class DefaultConfig
initCompressionFactories();
initMACFactories();
initSignatureFactories();
setKeepAliveProvider(KeepAliveProvider.HEARTBEAT);
}
protected void initKeyExchangeFactories(boolean bouncyCastleRegistered) {

View File

@@ -144,9 +144,9 @@ public class SSHClient
*/
public SSHClient(Config config) {
super(DEFAULT_PORT);
this.trans = new TransportImpl(config);
this.trans = new TransportImpl(config, this);
this.auth = new UserAuthImpl(trans);
this.conn = new ConnectionImpl(trans);
this.conn = new ConnectionImpl(trans, config.getKeepAliveProvider());
}
/**

View File

@@ -16,6 +16,7 @@
package net.schmizz.sshj.connection;
import net.schmizz.concurrent.Promise;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.channel.Channel;
import net.schmizz.sshj.connection.channel.OpenFailException;
@@ -150,4 +151,9 @@ public interface Connection {
* @param timeout timeout in milliseconds
*/
void setTimeoutMs(int timeout);
/**
* @return The configured {@link net.schmizz.keepalive.KeepAlive} mechanism.
*/
KeepAlive getKeepAlive();
}

View File

@@ -17,6 +17,8 @@ package net.schmizz.sshj.connection;
import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Promise;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.AbstractService;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason;
@@ -51,6 +53,9 @@ public class ConnectionImpl
private final Queue<Promise<SSHPacket, ConnectionException>> globalReqPromises = new LinkedList<Promise<SSHPacket, ConnectionException>>();
/** {@code keep-alive} mechanism */
private final KeepAlive keepAlive;
private long windowSize = 2048 * 1024;
private int maxPacketSize = 32 * 1024;
@@ -59,11 +64,14 @@ public class ConnectionImpl
/**
* Create with an associated {@link Transport}.
*
* @param config the ssh config
* @param trans transport layer
* @param keepAlive
*/
public ConnectionImpl(Transport trans) {
public ConnectionImpl(Transport trans, KeepAliveProvider keepAlive) {
super("ssh-connection", trans);
timeoutMs = trans.getTimeoutMs();
this.keepAlive = keepAlive.provide(this);
}
@Override
@@ -250,6 +258,7 @@ public class ConnectionImpl
ErrorDeliveryUtil.alertPromises(error, globalReqPromises);
globalReqPromises.clear();
}
keepAlive.interrupt();
ErrorNotifiable.Util.alertAll(error, channels.values());
channels.clear();
}
@@ -264,4 +273,9 @@ public class ConnectionImpl
return timeoutMs;
}
@Override
public KeepAlive getKeepAlive() {
return keepAlive;
}
}

View File

@@ -1,77 +0,0 @@
/**
* Copyright 2009 sshj contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.sshj.transport;
import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
final class Heartbeater
extends Thread {
private final Logger log = LoggerFactory.getLogger(getClass());
private final TransportImpl trans;
private int interval;
Heartbeater(TransportImpl trans) {
this.trans = trans;
setName("heartbeater");
}
synchronized void setInterval(int interval) {
this.interval = interval;
if (interval > 0 && getState() == Thread.State.NEW)
start();
notify();
}
synchronized int getInterval() {
return interval;
}
synchronized private int getPositiveInterval()
throws InterruptedException {
while (interval <= 0)
wait();
return interval;
}
@Override
public void run() {
log.debug("Starting");
try {
while (!isInterrupted()) {
final int hi = getPositiveInterval();
if (trans.isRunning()) {
log.debug("Sending heartbeat since {} seconds elapsed", hi);
trans.write(new SSHPacket(Message.IGNORE));
}
Thread.sleep(hi * 1000);
}
} catch (Exception e) {
if (isInterrupted()) {
// We are meant to shut up and draw to a close if interrupted
} else
trans.die(e);
}
log.debug("Stopping");
}
}

View File

@@ -75,10 +75,18 @@ public interface Transport
*/
void setTimeoutMs(int timeout);
/** @return the interval in seconds at which a heartbeat message is sent to the server */
/**
* @return the interval in seconds at which a heartbeat message is sent to the server
* @deprecated Moved to {@link net.schmizz.keepalive.KeepAlive#getKeepAliveInterval()}. This is accessible through the {@link net.schmizz.sshj.connection.Connection}.
*/
@Deprecated
int getHeartbeatInterval();
/** @param interval the interval in seconds, {@code 0} means no hearbeat */
/**
* @param interval the interval in seconds, {@code 0} means no hearbeat
* @deprecated Moved to {@link net.schmizz.keepalive.KeepAlive#getKeepAliveInterval()}. This is accessible through the {@link net.schmizz.sshj.connection.Connection}.
*/
@Deprecated
void setHeartbeatInterval(int interval);
/** @return the hostname to which this transport is connected. */
@@ -211,4 +219,10 @@ public interface Transport
/** @return the current disconnect listener. */
DisconnectListener getDisconnectListener();
/**
* Kill the transport in an exceptional way.
*
* @param e The exception that occurred.
*/
void die(Exception e);
}

View File

@@ -19,6 +19,7 @@ import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.AbstractService;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.Service;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason;
@@ -46,23 +47,23 @@ public final class TransportImpl
NullService(Transport trans) {
super("null-service", trans);
}
}
}
static final class ConnInfo {
final String host;
final int port;
final InputStream in;
final OutputStream out;
public ConnInfo(String host, int port, InputStream in, OutputStream out) {
this.host = host;
this.port = port;
this.in = in;
this.out = out;
}
}
}
private final Logger log = LoggerFactory.getLogger(getClass());
private final Service nullService = new NullService(this);
@@ -80,7 +81,11 @@ public final class TransportImpl
private final Reader reader;
private final Heartbeater heartbeater;
/**
* @deprecated Moved to {@link net.schmizz.sshj.SSHClient}
*/
@Deprecated
private final SSHClient sshClient;
private final Encoder encoder;
@@ -115,13 +120,31 @@ public final class TransportImpl
public TransportImpl(Config config) {
this.config = config;
this.reader = new Reader(this);
this.heartbeater = new Heartbeater(this);
this.encoder = new Encoder(config.getRandomFactory().create(), writeLock);
this.decoder = new Decoder(this);
this.kexer = new KeyExchanger(this);
this.clientID = String.format("SSH-2.0-%s", config.getVersion());
this.sshClient = null;
}
/**
* Temporary constructor until we remove support for the set/get Heartbeat interval from transport.
* @param config
* @param sshClient
*/
@Deprecated
public TransportImpl(Config config, SSHClient sshClient) {
this.config = config;
this.reader = new Reader(this);
this.encoder = new Encoder(config.getRandomFactory().create(), writeLock);
this.decoder = new Decoder(this);
this.kexer = new KeyExchanger(this);
this.clientID = String.format("SSH-2.0-%s", config.getVersion());
this.sshClient = sshClient;
}
@Override
public void init(String remoteHost, int remotePort, InputStream in, OutputStream out)
throws TransportException {
@@ -231,13 +254,17 @@ public final class TransportImpl
}
@Override
@Deprecated
public int getHeartbeatInterval() {
return heartbeater.getInterval();
log.warn("**Deprecated**: Please use: sshClient.getConnection().getKeepAlive().getKeepAliveInterval()");
return sshClient.getConnection().getKeepAlive().getKeepAliveInterval();
}
@Override
@Deprecated
public void setHeartbeatInterval(int interval) {
heartbeater.setInterval(interval);
log.warn("**Deprecated**: Please use: sshClient.getConnection().getKeepAlive().setKeepAliveInterval()");
sshClient.getConnection().getKeepAlive().setKeepAliveInterval(interval);
}
@Override
@@ -542,12 +569,11 @@ public final class TransportImpl
private void finishOff() {
reader.interrupt();
heartbeater.interrupt();
IOUtils.closeQuietly(connInfo.in);
IOUtils.closeQuietly(connInfo.out);
}
void die(Exception ex) {
public void die(Exception ex) {
close.lock();
try {
if (!close.isSet()) {

View File

@@ -139,7 +139,9 @@ public class PKCS8KeyFile
JcePEMDecryptorProviderBuilder decryptorBuilder = new JcePEMDecryptorProviderBuilder();
decryptorBuilder.setProvider("BC");
try {
passphrase = pwdf == null ? null : pwdf.reqPassword(resource);
// Do not return null, as JcePEMDecryptorProviderBuilder$1$1.decrypt would throw an exception
// in that case because it requires a 'password' (i.e. passphrase).
passphrase = pwdf == null ? "".toCharArray() : pwdf.reqPassword(resource);
kp = pemConverter.getKeyPair(encryptedKeyPair.decryptKeyPair(decryptorBuilder.build(passphrase)));
} finally {
PasswordUtils.blankOut(passphrase);

View File

@@ -0,0 +1,21 @@
package net.schmizz.sshj;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
public class SshKeepAlive {
public static void main(String[] args) throws IOException, InterruptedException {
Config config = new DefaultConfig();
config.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE);
SSHClient client = new SSHClient(config);
client.conn.getKeepAlive().setKeepAliveInterval(5);
client.addHostKeyVerifier(new PromiscuousVerifier());
client.connect("172.16.37.129", 22);
client.authPassword("jeroen", "jeroen");
new CountDownLatch(1).await();
}
}