Implement AES-GCM cipher support (#630)

* Implement AES-GCM cipher support

Fixes #217.

A port of AES-GCM cipher support from Apache MINA-SSHD, based on https://github.com/apache/mina-sshd/pull/132.

Included tests for decoding SSH packets sent from Apache MINA-SSHD and OpenSSH (Version 7.9p1 as used by Debian 10).

Manual tests also done on OpenSSH server 7.9p1 running Debian 10 with its available ciphers, including 3des-cbc, aes128-cbc, aes192-cbc, aes256-cbc, aes128-ctr, aes192-ctr, aes256-ctr, aes128-gcm@openssh.com and aes256-gcm@openssh.com.

* Changes per PR feedback

- Fixed variable/statement whitespaces and add back missing braces per coding standard requirement
- Moved Buffer.putLong() and Buffer.getLong() into GcmCipher.CounterGCMParameterSpec since it's the only user
- Moved BaseCipher.authSize into GcmCipher since it is the only cipher that would return a non-zero. BaseCipher will keep return 0 instead
- Made BaseCipher.cipher protected instead of making it publicly accessible
- Combined the three decoding modes in Decoder.decode() into one single method, to reduce code duplication
- Added integration test for the ciphers, along with the newly implemented AES-GCM ciphers
This commit is contained in:
Raymond Lai
2020-09-09 15:51:17 +08:00
committed by GitHub
parent 4458332cbf
commit 143069e3e0
33 changed files with 722 additions and 82 deletions

View File

@@ -18,6 +18,7 @@ package net.schmizz.sshj;
import com.hierynomus.sshj.key.KeyAlgorithm;
import com.hierynomus.sshj.key.KeyAlgorithms;
import com.hierynomus.sshj.transport.cipher.BlockCiphers;
import com.hierynomus.sshj.transport.cipher.GcmCiphers;
import com.hierynomus.sshj.transport.cipher.StreamCiphers;
import com.hierynomus.sshj.transport.kex.DHGroups;
import com.hierynomus.sshj.transport.kex.ExtInfoClientFactory;
@@ -171,6 +172,8 @@ public class DefaultConfig
BlockCiphers.AES192CTR(),
BlockCiphers.AES256CBC(),
BlockCiphers.AES256CTR(),
GcmCiphers.AES128GCM(),
GcmCiphers.AES256GCM(),
BlockCiphers.BlowfishCBC(),
BlockCiphers.BlowfishCTR(),
BlockCiphers.Cast128CBC(),

View File

@@ -45,6 +45,7 @@ abstract class Converter {
protected long seq = -1;
protected boolean authed;
protected boolean etm;
protected boolean authMode;
long getSequenceNumber() {
return seq;
@@ -57,7 +58,11 @@ abstract class Converter {
if (compression != null)
compression.init(getCompressionType());
this.cipherSize = cipher.getIVSize();
this.etm = mac.isEtm();
this.etm = this.mac != null && mac.isEtm();
if(cipher.getAuthenticationTagSize() > 0) {
this.cipherSize = cipher.getAuthenticationTagSize();
this.authMode = true;
}
}
void setAuthenticated() {

View File

@@ -70,87 +70,41 @@ final class Decoder
*
* @return number of bytes needed before further decoding possible
*/
private int decode()
throws SSHException {
if (etm) {
return decodeEtm();
} else {
return decodeMte();
}
}
/**
* Decode an Encrypt-Then-Mac packet.
*/
private int decodeEtm() throws SSHException {
int bytesNeeded;
while (true) {
if (packetLength == -1) {
assert inputBuffer.rpos() == 0 : "buffer cleared";
bytesNeeded = 4 - inputBuffer.available();
if (bytesNeeded <= 0) {
// In Encrypt-Then-Mac, the packetlength is sent unencrypted.
packetLength = inputBuffer.readUInt32AsInt();
checkPacketLength(packetLength);
} else {
// Needs more data
break;
}
} else {
assert inputBuffer.rpos() == 4 : "packet length read";
bytesNeeded = packetLength + mac.getBlockSize() - inputBuffer.available();
if (bytesNeeded <= 0) {
seq = seq + 1 & 0xffffffffL;
checkMAC(inputBuffer.array());
decryptBuffer(4, packetLength);
inputBuffer.wpos(packetLength + 4 - inputBuffer.readByte());
final SSHPacket plain = usingCompression() ? decompressed() : inputBuffer;
if (log.isTraceEnabled()) {
log.trace("Received packet #{}: {}", seq, plain.printHex());
}
packetHandler.handle(plain.readMessageID(), plain); // Process the decoded packet
inputBuffer.clear();
packetLength = -1;
} else {
// Needs more data
break;
}
}
}
return bytesNeeded;
}
/**
* Decode a Mac-Then-Encrypt packet
* @return
* @throws SSHException
*/
private int decodeMte() throws SSHException {
private int decode() throws SSHException {
int need;
/* Decoding loop */
for (; ; )
for(;;) {
if (packetLength == -1) { // Waiting for beginning of packet
assert inputBuffer.rpos() == 0 : "buffer cleared";
need = cipherSize - inputBuffer.available();
if (need <= 0) {
packetLength = decryptLength();
if (authMode) {
packetLength = decryptLengthAAD();
} else if (etm) {
packetLength = inputBuffer.readUInt32AsInt();
checkPacketLength(packetLength);
} else {
packetLength = decryptLength();
}
} else {
// Need more data
break;
}
} else {
assert inputBuffer.rpos() == 4 : "packet length read";
need = packetLength + (mac != null ? mac.getBlockSize() : 0) - inputBuffer.available();
need = (authMode) ? packetLength + cipherSize - inputBuffer.available() : packetLength + (mac != null ? mac.getBlockSize() : 0) - inputBuffer.available();
if (need <= 0) {
decryptBuffer(cipherSize, packetLength + 4 - cipherSize); // Decrypt the rest of the payload
seq = seq + 1 & 0xffffffffL;
if (mac != null) {
if (authMode) {
cipher.update(inputBuffer.array(), 4, packetLength);
} else if (etm) {
checkMAC(inputBuffer.array());
decryptBuffer(4, packetLength);
} else {
decryptBuffer(cipherSize, packetLength + 4 - cipherSize); // Decrypt the rest of the payload
if (mac != null) {
checkMAC(inputBuffer.array());
}
}
// Exclude the padding & MAC
inputBuffer.wpos(packetLength + 4 - inputBuffer.readByte());
final SSHPacket plain = usingCompression() ? decompressed() : inputBuffer;
if (log.isTraceEnabled()) {
@@ -160,16 +114,20 @@ final class Decoder
inputBuffer.clear();
packetLength = -1;
} else {
// Need more data
// Needs more data
break;
}
}
}
return need;
}
private void checkMAC(final byte[] data)
throws TransportException {
if (mac == null) {
return;
}
mac.update(seq); // seq num
mac.update(data, 0, packetLength + 4); // packetLength+4 = entire packet w/o mac
mac.doFinal(macResult, 0); // compute
@@ -186,6 +144,20 @@ final class Decoder
return uncompressBuffer;
}
private int decryptLengthAAD() throws TransportException {
cipher.updateAAD(inputBuffer.array(), 0, 4);
final int len;
try {
len = inputBuffer.readUInt32AsInt();
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
checkPacketLength(len);
return len;
}
private int decryptLength()
throws TransportException {
decryptBuffer(0, cipherSize);
@@ -237,7 +209,9 @@ final class Decoder
@Override
void setAlgorithms(Cipher cipher, MAC mac, Compression compression) {
super.setAlgorithms(cipher, mac, compression);
macResult = new byte[mac.getBlockSize()];
if (mac != null) {
macResult = new byte[mac.getBlockSize()];
}
}
@Override

View File

@@ -15,6 +15,7 @@
*/
package net.schmizz.sshj.transport;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.LoggerFactory;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.transport.cipher.Cipher;
@@ -83,7 +84,7 @@ final class Encoder
// Compute padding length
int padLen = cipherSize - (lengthWithoutPadding % cipherSize);
if (padLen < 4) {
if (padLen < 4 || (authMode && padLen < cipherSize)) {
padLen += cipherSize;
}
@@ -94,6 +95,14 @@ final class Encoder
padLen += cipherSize;
packetLen = 1 + payloadSize + padLen;
}
/*
* In AES-GCM ciphers, they require packets must be a multiple of 16 bytes (which is also block size of AES)
* as mentioned in RFC5647 Section 7.2. So we are calculating the extra padding as necessary here
*/
if (authMode && packetLen % cipherSize != 0) {
padLen += cipherSize - (packetLen % cipherSize);
packetLen = 1 + payloadSize + padLen;
}
final int endOfPadding = startOfPacket + 4 + packetLen;
@@ -101,6 +110,7 @@ final class Encoder
buffer.wpos(startOfPacket);
buffer.putUInt32(packetLen);
buffer.putByte((byte) padLen);
// Now wpos will mark end of padding
buffer.wpos(endOfPadding);
@@ -109,14 +119,17 @@ final class Encoder
seq = seq + 1 & 0xffffffffL;
if (etm) {
if (authMode) {
int wpos = buffer.wpos();
buffer.wpos(wpos + cipherSize);
aeadOutgoingBuffer(buffer, startOfPacket, packetLen);
} else if (etm) {
cipher.update(buffer.array(), startOfPacket + 4, packetLen);
putMAC(buffer, startOfPacket, endOfPadding);
} else {
if (mac != null) {
putMAC(buffer, startOfPacket, endOfPadding);
}
cipher.update(buffer.array(), startOfPacket, 4 + packetLen);
}
buffer.rpos(startOfPacket); // Make ready-to-read
@@ -127,6 +140,14 @@ final class Encoder
}
}
protected void aeadOutgoingBuffer(Buffer buf, int offset, int len) {
if (cipher == null || cipher.getAuthenticationTagSize() == 0) {
throw new IllegalArgumentException("AEAD mode requires an AEAD cipher");
}
byte[] data = buf.array();
cipher.updateWithAAD(data, offset, 4, len);
}
@Override
void setAlgorithms(Cipher cipher, MAC mac, Compression compression) {
encodeLock.lock();

View File

@@ -15,7 +15,6 @@
*/
package net.schmizz.sshj.transport;
import com.hierynomus.sshj.key.KeyAlgorithm;
import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Event;
import net.schmizz.sshj.common.*;
@@ -323,13 +322,25 @@ final class KeyExchanger
resizedKey(encryptionKey_S2C, cipher_S2C.getBlockSize(), hash, kex.getK(), kex.getH()),
initialIV_S2C);
final MAC mac_C2S = Factory.Named.Util.create(transport.getConfig().getMACFactories(), negotiatedAlgs
.getClient2ServerMACAlgorithm());
mac_C2S.init(resizedKey(integrityKey_C2S, mac_C2S.getBlockSize(), hash, kex.getK(), kex.getH()));
/*
* For AES-GCM ciphers, MAC will also be AES-GCM, so it is handled by the cipher itself.
* In that case, both s2c and c2s MACs are ignored.
*
* Refer to RFC5647 Section 5.1
*/
MAC mac_C2S = null;
if(cipher_C2S.getAuthenticationTagSize() == 0) {
mac_C2S = Factory.Named.Util.create(transport.getConfig().getMACFactories(), negotiatedAlgs
.getClient2ServerMACAlgorithm());
mac_C2S.init(resizedKey(integrityKey_C2S, mac_C2S.getBlockSize(), hash, kex.getK(), kex.getH()));
}
final MAC mac_S2C = Factory.Named.Util.create(transport.getConfig().getMACFactories(),
negotiatedAlgs.getServer2ClientMACAlgorithm());
mac_S2C.init(resizedKey(integrityKey_S2C, mac_S2C.getBlockSize(), hash, kex.getK(), kex.getH()));
MAC mac_S2C = null;
if(cipher_S2C.getAuthenticationTagSize() == 0) {
mac_S2C = Factory.Named.Util.create(transport.getConfig().getMACFactories(),
negotiatedAlgs.getServer2ClientMACAlgorithm());
mac_S2C.init(resizedKey(integrityKey_S2C, mac_S2C.getBlockSize(), hash, kex.getK(), kex.getH()));
}
final Compression compression_S2C =
Factory.Named.Util.create(transport.getConfig().getCompressionFactories(),

View File

@@ -42,7 +42,7 @@ public abstract class BaseCipher
private final String algorithm;
private final String transformation;
private javax.crypto.Cipher cipher;
protected javax.crypto.Cipher cipher;
public BaseCipher(int ivsize, int bsize, String algorithm, String transformation) {
this.ivsize = ivsize;
@@ -61,6 +61,11 @@ public abstract class BaseCipher
return ivsize;
}
@Override
public int getAuthenticationTagSize() {
return 0;
}
@Override
public void init(Mode mode, byte[] key, byte[] iv) {
key = BaseCipher.resize(key, bsize);
@@ -75,6 +80,7 @@ public abstract class BaseCipher
}
protected abstract void initCipher(javax.crypto.Cipher cipher, Mode mode, byte[] key, byte[] iv) throws InvalidKeyException, InvalidAlgorithmParameterException;
protected SecretKeySpec getKeySpec(byte[] key) {
return new SecretKeySpec(key, algorithm);
}
@@ -92,4 +98,19 @@ public abstract class BaseCipher
}
}
@Override
public void updateAAD(byte[] data, int offset, int length) {
throw new UnsupportedOperationException(getClass() + " does not support AAD operations");
}
@Override
public void updateAAD(byte[] data) {
updateAAD(data, 0, data.length);
}
@Override
public void updateWithAAD(byte[] input, int offset, int aadLen, int inputLen) {
updateAAD(input, offset, aadLen);
update(input, offset + aadLen, inputLen);
}
}

View File

@@ -29,6 +29,9 @@ public interface Cipher {
/** @return the size of the initialization vector */
int getIVSize();
/** @return Size of the authentication tag (AT) in bytes or 0 if this cipher does not support authentication */
int getAuthenticationTagSize();
/**
* Initialize the cipher for encryption or decryption with the given private key and initialization vector
*
@@ -47,4 +50,32 @@ public interface Cipher {
*/
void update(byte[] input, int inputOffset, int inputLen);
/**
* Adds the provided input data as additional authenticated data during encryption or decryption.
*
* @param data The additional data to authenticate
* @param offset The offset of the additional data in the buffer
* @param length The number of bytes in the buffer to use for authentication
*/
void updateAAD(byte[] data, int offset, int length);
/**
* Adds the provided input data as additional authenticated data during encryption or decryption.
*
* @param data The data to authenticate
*/
void updateAAD(byte[] data);
/**
* Performs in-place authenticated encryption or decryption with additional data (AEAD). Authentication tags are
* implicitly appended after the output ciphertext or implicitly verified after the input ciphertext. Header data
* indicated by the {@code aadLen} parameter are authenticated but not encrypted/decrypted, while payload data
* indicated by the {@code inputLen} parameter are authenticated and encrypted/decrypted.
*
* @param input The input/output bytes
* @param offset The offset of the data in the input buffer
* @param aadLen The number of bytes to use as additional authenticated data - starting at offset
* @param inputLen The number of bytes to update - starting at offset + aadLen
*/
void updateWithAAD(byte[] input, int offset, int aadLen, int inputLen);
}

View File

@@ -44,6 +44,11 @@ public class NoneCipher
return 8;
}
@Override
public int getAuthenticationTagSize() {
return 0;
}
@Override
public void init(Mode mode, byte[] bytes, byte[] bytes1) {
// Nothing to do
@@ -54,4 +59,18 @@ public class NoneCipher
// Nothing to do
}
@Override
public void updateAAD(byte[] data, int offset, int length) {
}
@Override
public void updateAAD(byte[] data) {
}
@Override
public void updateWithAAD(byte[] input, int offset, int aadLen, int inputLen) {
}
}