diff --git a/src/main/java/net/schmizz/sshj/DefaultConfig.java b/src/main/java/net/schmizz/sshj/DefaultConfig.java index f847c02d..e8240148 100644 --- a/src/main/java/net/schmizz/sshj/DefaultConfig.java +++ b/src/main/java/net/schmizz/sshj/DefaultConfig.java @@ -33,6 +33,8 @@ import net.schmizz.sshj.transport.cipher.TripleDESCBC; import net.schmizz.sshj.transport.compression.NoneCompression; import net.schmizz.sshj.transport.kex.DHG1; import net.schmizz.sshj.transport.kex.DHG14; +import net.schmizz.sshj.transport.kex.DHGexSHA1; +import net.schmizz.sshj.transport.kex.DHGexSHA256; import net.schmizz.sshj.transport.mac.HMACMD5; import net.schmizz.sshj.transport.mac.HMACMD596; import net.schmizz.sshj.transport.mac.HMACSHA1; @@ -98,9 +100,9 @@ public class DefaultConfig protected void initKeyExchangeFactories(boolean bouncyCastleRegistered) { if (bouncyCastleRegistered) - setKeyExchangeFactories(new DHG14.Factory(), new DHG1.Factory()); + setKeyExchangeFactories(new DHG14.Factory(), new DHG1.Factory(), new DHGexSHA1.Factory(), new DHGexSHA256.Factory()); else - setKeyExchangeFactories(new DHG1.Factory()); + setKeyExchangeFactories(new DHG1.Factory(), new DHGexSHA1.Factory()); } protected void initRandomFactory(boolean bouncyCastleRegistered) { diff --git a/src/main/java/net/schmizz/sshj/transport/digest/SHA256.java b/src/main/java/net/schmizz/sshj/transport/digest/SHA256.java new file mode 100644 index 00000000..94f7c413 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/digest/SHA256.java @@ -0,0 +1,26 @@ +package net.schmizz.sshj.transport.digest; + +/** SHA256 Digest. */ +public class SHA256 extends BaseDigest { + + /** Named factory for SHA256 digest */ + public static class Factory + implements net.schmizz.sshj.common.Factory.Named { + + @Override + public Digest create() { + return new SHA256(); + } + + @Override + public String getName() { + return "sha256"; + } + } + + /** Create a new instance of a SHA256 digest */ + public SHA256() { + super("SHA-256", 32); + } + +} diff --git a/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java b/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java index 4ab304e9..48005ed6 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java @@ -38,21 +38,14 @@ import java.util.Arrays; * Base class for DHG key exchange algorithms. Implementations will only have to configure the required data on the * {@link DH} class in the */ -public abstract class AbstractDHG +public abstract class AbstractDHG extends KeyExchangeBase implements KeyExchange { private final Logger log = LoggerFactory.getLogger(getClass()); - private Transport trans; - private final Digest sha1 = new SHA1(); private final DH dh = new DH(); - private String V_S; - private String V_C; - private byte[] I_S; - private byte[] I_C; - private byte[] H; private PublicKey hostKey; @@ -79,11 +72,7 @@ public abstract class AbstractDHG @Override public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException { - this.trans = trans; - this.V_S = V_S; - this.V_C = V_C; - this.I_S = Arrays.copyOf(I_S, I_S.length); - this.I_C = Arrays.copyOf(I_C, I_C.length); + super.init(trans, V_S, V_C, I_S, I_C); sha1.init(); initDH(dh); @@ -112,11 +101,7 @@ public abstract class AbstractDHG dh.computeK(f); - final Buffer.PlainBuffer buf = new Buffer.PlainBuffer() - .putString(V_C) - .putString(V_S) - .putString(I_C) - .putString(I_S) + final Buffer.PlainBuffer buf = initializedBuffer() .putString(K_S) .putMPInt(dh.getE()) .putMPInt(f) diff --git a/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHGex.java b/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHGex.java new file mode 100644 index 00000000..d4864f65 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHGex.java @@ -0,0 +1,124 @@ +package net.schmizz.sshj.transport.kex; + +import net.schmizz.sshj.common.*; +import net.schmizz.sshj.signature.Signature; +import net.schmizz.sshj.transport.Transport; +import net.schmizz.sshj.transport.TransportException; +import net.schmizz.sshj.transport.digest.Digest; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.security.PublicKey; +import java.util.Arrays; + +public abstract class AbstractDHGex extends KeyExchangeBase { + private final Logger log = LoggerFactory.getLogger(getClass()); + + private Digest digest; + + private int minBits = 1024; + private int maxBits = 8192; + private int preferredBits = 2048; + + private DH dh; + private PublicKey hostKey; + private byte[] H; + + public AbstractDHGex(Digest digest) { + this.digest = digest; + } + + @Override + public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException { + super.init(trans, V_S, V_C, I_S, I_C); + dh = new DH(); + digest.init(); + + log.debug("Sending {}", Message.KEX_DH_GEX_REQUEST); + trans.write(new SSHPacket(Message.KEX_DH_GEX_REQUEST).putUInt32(minBits).putUInt32(preferredBits).putUInt32(maxBits)); + } + + @Override + public byte[] getH() { + return Arrays.copyOf(H, H.length); + } + + @Override + public BigInteger getK() { + return dh.getK(); + } + + @Override + public Digest getHash() { + return digest; + } + + @Override + public PublicKey getHostKey() { + return hostKey; + } + + @Override + public boolean next(Message msg, SSHPacket buffer) throws GeneralSecurityException, TransportException { + log.debug("Got message {}", msg); + try { + switch (msg) { + case KEXDH_31: + return parseGexGroup(buffer); + case KEX_DH_GEX_REPLY: + return parseGexReply(buffer); + } + } catch (Buffer.BufferException be) { + throw new TransportException(be); + } + throw new TransportException("Unexpected message " + msg); + } + + private boolean parseGexReply(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException { + byte[] K_S = buffer.readBytes(); + BigInteger f = buffer.readMPInt(); + byte[] sig = buffer.readBytes(); + hostKey = new Buffer.PlainBuffer(K_S).readPublicKey(); + + dh.computeK(f); + BigInteger k = dh.getK(); + + final Buffer.PlainBuffer buf = initializedBuffer() + .putString(K_S) + .putUInt32(minBits) + .putUInt32(preferredBits) + .putUInt32(maxBits) + .putMPInt(dh.getP()) + .putMPInt(dh.getG()) + .putMPInt(dh.getE()) + .putMPInt(f) + .putMPInt(k); + digest.update(buf.array(), buf.rpos(), buf.available()); + H = digest.digest(); + Signature signature = Factory.Named.Util.create(trans.getConfig().getSignatureFactories(), + KeyType.fromKey(hostKey).toString()); + signature.init(hostKey, null); + signature.update(H, 0, H.length); + if (!signature.verify(sig)) + throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, + "KeyExchange signature verification failed"); + return true; + + } + + private boolean parseGexGroup(SSHPacket buffer) throws Buffer.BufferException, GeneralSecurityException, TransportException { + BigInteger p = buffer.readMPInt(); + BigInteger g = buffer.readMPInt(); + int bitLength = p.bitLength(); + if (bitLength < minBits || bitLength > maxBits) { + throw new GeneralSecurityException("Server generated gex p is out of range (" + bitLength + " bits)"); + } + log.debug("Received server p bitlength {}", bitLength); + dh.init(p, g); + log.debug("Sending {}", Message.KEX_DH_GEX_INIT); + trans.write(new SSHPacket(Message.KEX_DH_GEX_INIT).putMPInt(dh.getE())); + return false; + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/kex/DH.java b/src/main/java/net/schmizz/sshj/transport/kex/DH.java index 3bdea282..2a4acede 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/DH.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/DH.java @@ -73,4 +73,11 @@ public class DH { return K; } + public BigInteger getP() { + return p; + } + + public BigInteger getG() { + return g; + } } diff --git a/src/main/java/net/schmizz/sshj/transport/kex/DHGexSHA1.java b/src/main/java/net/schmizz/sshj/transport/kex/DHGexSHA1.java new file mode 100644 index 00000000..61f66e20 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/kex/DHGexSHA1.java @@ -0,0 +1,25 @@ +package net.schmizz.sshj.transport.kex; + +import net.schmizz.sshj.transport.digest.SHA1; + +public class DHGexSHA1 extends AbstractDHGex { + + /** Named factory for DHGexSHA1 key exchange */ + public static class Factory + implements net.schmizz.sshj.common.Factory.Named { + + @Override + public KeyExchange create() { + return new DHGexSHA1(); + } + + @Override + public String getName() { + return "diffie-hellman-group-exchange-sha1"; + } + } + + public DHGexSHA1() { + super(new SHA1()); + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/kex/DHGexSHA256.java b/src/main/java/net/schmizz/sshj/transport/kex/DHGexSHA256.java new file mode 100644 index 00000000..250443b3 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/kex/DHGexSHA256.java @@ -0,0 +1,25 @@ +package net.schmizz.sshj.transport.kex; + +import net.schmizz.sshj.transport.digest.SHA256; + +public class DHGexSHA256 extends AbstractDHGex { + + /** Named factory for DHGexSHA256 key exchange */ + public static class Factory + implements net.schmizz.sshj.common.Factory.Named { + + @Override + public KeyExchange create() { + return new DHGexSHA256(); + } + + @Override + public String getName() { + return "diffie-hellman-group-exchange-sha256"; + } + } + + public DHGexSHA256() { + super(new SHA256()); + } +} diff --git a/src/main/java/net/schmizz/sshj/transport/kex/KeyExchangeBase.java b/src/main/java/net/schmizz/sshj/transport/kex/KeyExchangeBase.java new file mode 100644 index 00000000..bad31a49 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/transport/kex/KeyExchangeBase.java @@ -0,0 +1,37 @@ +package net.schmizz.sshj.transport.kex; + +import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.transport.Transport; +import net.schmizz.sshj.transport.TransportException; + +import java.security.GeneralSecurityException; +import java.util.Arrays; + +/** + * Created by ajvanerp on 29/10/15. + */ +public abstract class KeyExchangeBase implements KeyExchange { + protected Transport trans; + + private String V_S; + private String V_C; + private byte[] I_S; + private byte[] I_C; + + @Override + public void init(Transport trans, String V_S, String V_C, byte[] I_S, byte[] I_C) throws GeneralSecurityException, TransportException { + this.trans = trans; + this.V_S = V_S; + this.V_C = V_C; + this.I_S = Arrays.copyOf(I_S, I_S.length); + this.I_C = Arrays.copyOf(I_C, I_C.length); + } + + protected Buffer.PlainBuffer initializedBuffer() { + return new Buffer.PlainBuffer() + .putString(V_C) + .putString(V_S) + .putString(I_C) + .putString(I_S); + } +} diff --git a/src/test/java/com/hierynomus/sshj/test/SshFixture.java b/src/test/java/com/hierynomus/sshj/test/SshFixture.java index 4f18b970..3e655c20 100644 --- a/src/test/java/com/hierynomus/sshj/test/SshFixture.java +++ b/src/test/java/com/hierynomus/sshj/test/SshFixture.java @@ -158,4 +158,8 @@ public class SshFixture extends ExternalResource { } } } + + public SshServer getServer() { + return server; + } } diff --git a/src/test/java/com/hierynomus/sshj/transport/kex/DiffieHellmanGroupExchangeTest.java b/src/test/java/com/hierynomus/sshj/transport/kex/DiffieHellmanGroupExchangeTest.java new file mode 100644 index 00000000..9266b1bb --- /dev/null +++ b/src/test/java/com/hierynomus/sshj/transport/kex/DiffieHellmanGroupExchangeTest.java @@ -0,0 +1,45 @@ +package com.hierynomus.sshj.transport.kex; + +import com.hierynomus.sshj.test.SshFixture; +import net.schmizz.sshj.SSHClient; +import org.apache.sshd.common.KeyExchange; +import org.apache.sshd.common.NamedFactory; +import org.apache.sshd.server.kex.DHGEX; +import org.apache.sshd.server.kex.DHGEX256; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.hamcrest.MatcherAssert.assertThat; + +public class DiffieHellmanGroupExchangeTest { + @Rule + public SshFixture fixture = new SshFixture(false); + + @After + public void stopServer() { + fixture.stopServer(); + } + + @Test + public void shouldKexWithGroupExchangeSha1() throws IOException { + setupAndCheckKex(new DHGEX.Factory()); + } + + @Test + public void shouldKexWithGroupExchangeSha256() throws IOException { + setupAndCheckKex(new DHGEX256.Factory()); + } + + private void setupAndCheckKex(NamedFactory factory) throws IOException { + fixture.getServer().setKeyExchangeFactories(Collections.singletonList(factory)); + fixture.start(); + SSHClient sshClient = fixture.setupConnectedDefaultClient(); + assertThat("should be connected", sshClient.isConnected()); + sshClient.disconnect(); + } +}