Implemented diffie-hellman-group-exchange Kex methods (Fixes #167)

This commit is contained in:
Jeroen van Erp
2015-10-29 12:30:58 +01:00
parent e24ed6ee7b
commit 47df71c836
10 changed files with 300 additions and 20 deletions

View File

@@ -33,6 +33,8 @@ import net.schmizz.sshj.transport.cipher.TripleDESCBC;
import net.schmizz.sshj.transport.compression.NoneCompression;
import net.schmizz.sshj.transport.kex.DHG1;
import net.schmizz.sshj.transport.kex.DHG14;
import net.schmizz.sshj.transport.kex.DHGexSHA1;
import net.schmizz.sshj.transport.kex.DHGexSHA256;
import net.schmizz.sshj.transport.mac.HMACMD5;
import net.schmizz.sshj.transport.mac.HMACMD596;
import net.schmizz.sshj.transport.mac.HMACSHA1;
@@ -98,9 +100,9 @@ public class DefaultConfig
protected void initKeyExchangeFactories(boolean bouncyCastleRegistered) {
if (bouncyCastleRegistered)
setKeyExchangeFactories(new DHG14.Factory(), new DHG1.Factory());
setKeyExchangeFactories(new DHG14.Factory(), new DHG1.Factory(), new DHGexSHA1.Factory(), new DHGexSHA256.Factory());
else
setKeyExchangeFactories(new DHG1.Factory());
setKeyExchangeFactories(new DHG1.Factory(), new DHGexSHA1.Factory());
}
protected void initRandomFactory(boolean bouncyCastleRegistered) {

View File

@@ -0,0 +1,26 @@
package net.schmizz.sshj.transport.digest;
/** SHA256 Digest. */
public class SHA256 extends BaseDigest {
/** Named factory for SHA256 digest */
public static class Factory
implements net.schmizz.sshj.common.Factory.Named<Digest> {
@Override
public Digest create() {
return new SHA256();
}
@Override
public String getName() {
return "sha256";
}
}
/** Create a new instance of a SHA256 digest */
public SHA256() {
super("SHA-256", 32);
}
}

View File

@@ -38,21 +38,14 @@ import java.util.Arrays;
* Base class for DHG key exchange algorithms. Implementations will only have to configure the required data on the
* {@link DH} class in the
*/
public abstract class AbstractDHG
public abstract class AbstractDHG extends KeyExchangeBase
implements KeyExchange {
private final Logger log = LoggerFactory.getLogger(getClass());
private Transport trans;
private final Digest sha1 = new SHA1();
private final DH dh = new DH();
private String V_S;
private String V_C;
private byte[] I_S;
private byte[] I_C;
private byte[] H;
private PublicKey hostKey;
@@ -79,11 +72,7 @@ public abstract class AbstractDHG
@Override
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C)
throws GeneralSecurityException, TransportException {
this.trans = trans;
this.V_S = V_S;
this.V_C = V_C;
this.I_S = Arrays.copyOf(I_S, I_S.length);
this.I_C = Arrays.copyOf(I_C, I_C.length);
super.init(trans, V_S, V_C, I_S, I_C);
sha1.init();
initDH(dh);
@@ -112,11 +101,7 @@ public abstract class AbstractDHG
dh.computeK(f);
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer()
.putString(V_C)
.putString(V_S)
.putString(I_C)
.putString(I_S)
final Buffer.PlainBuffer buf = initializedBuffer()
.putString(K_S)
.putMPInt(dh.getE())
.putMPInt(f)

View File

@@ -0,0 +1,124 @@
package net.schmizz.sshj.transport.kex;
import net.schmizz.sshj.common.*;
import net.schmizz.sshj.signature.Signature;
import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.digest.Digest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.PublicKey;
import java.util.Arrays;
public abstract class AbstractDHGex extends KeyExchangeBase {
private final Logger log = LoggerFactory.getLogger(getClass());
private Digest digest;
private int minBits = 1024;
private int maxBits = 8192;
private int preferredBits = 2048;
private DH dh;
private PublicKey hostKey;
private byte[] H;
public AbstractDHGex(Digest digest) {
this.digest = digest;
}
@Override
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException {
super.init(trans, V_S, V_C, I_S, I_C);
dh = new DH();
digest.init();
log.debug("Sending {}", Message.KEX_DH_GEX_REQUEST);
trans.write(new SSHPacket(Message.KEX_DH_GEX_REQUEST).putUInt32(minBits).putUInt32(preferredBits).putUInt32(maxBits));
}
@Override
public byte[] getH() {
return Arrays.copyOf(H, H.length);
}
@Override
public BigInteger getK() {
return dh.getK();
}
@Override
public Digest getHash() {
return digest;
}
@Override
public PublicKey getHostKey() {
return hostKey;
}
@Override
public boolean next(Message msg, SSHPacket buffer) throws GeneralSecurityException, TransportException {
log.debug("Got message {}", msg);
try {
switch (msg) {
case KEXDH_31:
return parseGexGroup(buffer);
case KEX_DH_GEX_REPLY:
return parseGexReply(buffer);
}
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
throw new TransportException("Unexpected message " + msg);
}
private boolean parseGexReply(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException {
byte[] K_S = buffer.readBytes();
BigInteger f = buffer.readMPInt();
byte[] sig = buffer.readBytes();
hostKey = new Buffer.PlainBuffer(K_S).readPublicKey();
dh.computeK(f);
BigInteger k = dh.getK();
final Buffer.PlainBuffer buf = initializedBuffer()
.putString(K_S)
.putUInt32(minBits)
.putUInt32(preferredBits)
.putUInt32(maxBits)
.putMPInt(dh.getP())
.putMPInt(dh.getG())
.putMPInt(dh.getE())
.putMPInt(f)
.putMPInt(k);
digest.update(buf.array(), buf.rpos(), buf.available());
H = digest.digest();
Signature signature = Factory.Named.Util.create(trans.getConfig().getSignatureFactories(),
KeyType.fromKey(hostKey).toString());
signature.init(hostKey, null);
signature.update(H, 0, H.length);
if (!signature.verify(sig))
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED,
"KeyExchange signature verification failed");
return true;
}
private boolean parseGexGroup(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException {
BigInteger p = buffer.readMPInt();
BigInteger g = buffer.readMPInt();
int bitLength = p.bitLength();
if (bitLength < minBits || bitLength > maxBits) {
throw new GeneralSecurityException("Server generated gex p is out of range (" + bitLength + " bits)");
}
log.debug("Received server p bitlength {}", bitLength);
dh.init(p, g);
log.debug("Sending {}", Message.KEX_DH_GEX_INIT);
trans.write(new SSHPacket(Message.KEX_DH_GEX_INIT).putMPInt(dh.getE()));
return false;
}
}

View File

@@ -73,4 +73,11 @@ public class DH {
return K;
}
public BigInteger getP() {
return p;
}
public BigInteger getG() {
return g;
}
}

View File

@@ -0,0 +1,25 @@
package net.schmizz.sshj.transport.kex;
import net.schmizz.sshj.transport.digest.SHA1;
public class DHGexSHA1 extends AbstractDHGex {
/** Named factory for DHGexSHA1 key exchange */
public static class Factory
implements net.schmizz.sshj.common.Factory.Named<KeyExchange> {
@Override
public KeyExchange create() {
return new DHGexSHA1();
}
@Override
public String getName() {
return "diffie-hellman-group-exchange-sha1";
}
}
public DHGexSHA1() {
super(new SHA1());
}
}

View File

@@ -0,0 +1,25 @@
package net.schmizz.sshj.transport.kex;
import net.schmizz.sshj.transport.digest.SHA256;
public class DHGexSHA256 extends AbstractDHGex {
/** Named factory for DHGexSHA256 key exchange */
public static class Factory
implements net.schmizz.sshj.common.Factory.Named<KeyExchange> {
@Override
public KeyExchange create() {
return new DHGexSHA256();
}
@Override
public String getName() {
return "diffie-hellman-group-exchange-sha256";
}
}
public DHGexSHA256() {
super(new SHA256());
}
}

View File

@@ -0,0 +1,37 @@
package net.schmizz.sshj.transport.kex;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.TransportException;
import java.security.GeneralSecurityException;
import java.util.Arrays;
/**
* Created by ajvanerp on 29/10/15.
*/
public abstract class KeyExchangeBase implements KeyExchange {
protected Transport trans;
private String V_S;
private String V_C;
private byte[] I_S;
private byte[] I_C;
@Override
public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException {
this.trans = trans;
this.V_S = V_S;
this.V_C = V_C;
this.I_S = Arrays.copyOf(I_S, I_S.length);
this.I_C = Arrays.copyOf(I_C, I_C.length);
}
protected Buffer.PlainBuffer initializedBuffer() {
return new Buffer.PlainBuffer()
.putString(V_C)
.putString(V_S)
.putString(I_C)
.putString(I_S);
}
}

View File

@@ -158,4 +158,8 @@ public class SshFixture extends ExternalResource {
}
}
}
public SshServer getServer() {
return server;
}
}

View File

@@ -0,0 +1,45 @@
package com.hierynomus.sshj.transport.kex;
import com.hierynomus.sshj.test.SshFixture;
import net.schmizz.sshj.SSHClient;
import org.apache.sshd.common.KeyExchange;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.server.kex.DHGEX;
import org.apache.sshd.server.kex.DHGEX256;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.MatcherAssert.assertThat;
public class DiffieHellmanGroupExchangeTest {
@Rule
public SshFixture fixture = new SshFixture(false);
@After
public void stopServer() {
fixture.stopServer();
}
@Test
public void shouldKexWithGroupExchangeSha1() throws IOException {
setupAndCheckKex(new DHGEX.Factory());
}
@Test
public void shouldKexWithGroupExchangeSha256() throws IOException {
setupAndCheckKex(new DHGEX256.Factory());
}
private void setupAndCheckKex(NamedFactory<KeyExchange> factory) throws IOException {
fixture.getServer().setKeyExchangeFactories(Collections.singletonList(factory));
fixture.start();
SSHClient sshClient = fixture.setupConnectedDefaultClient();
assertThat("should be connected", sshClient.isConnected());
sshClient.disconnect();
}
}