diff --git a/src/main/java/com/hierynomus/sshj/common/KeyDecryptionFailedException.java b/src/main/java/com/hierynomus/sshj/common/KeyDecryptionFailedException.java new file mode 100644 index 00000000..a889a73d --- /dev/null +++ b/src/main/java/com/hierynomus/sshj/common/KeyDecryptionFailedException.java @@ -0,0 +1,38 @@ +/* + * Copyright (C)2009 - SSHJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.sshj.common; + +import org.bouncycastle.openssl.EncryptionException; + +import java.io.IOException; + +/** + * Thrown when a key file could not be decrypted correctly, e.g. if its checkInts differed in the case of an OpenSSH + * key file. + */ +public class KeyDecryptionFailedException extends IOException { + + public static final String MESSAGE = "Decryption of the key failed. A supplied passphrase may be incorrect."; + + public KeyDecryptionFailedException() { + super(MESSAGE); + } + + public KeyDecryptionFailedException(EncryptionException cause) { + super(MESSAGE, cause); + } + +} diff --git a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java index b2b8b46e..620fe30a 100644 --- a/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java +++ b/src/main/java/com/hierynomus/sshj/userauth/keyprovider/OpenSSHKeyV1KeyFile.java @@ -15,6 +15,7 @@ */ package com.hierynomus.sshj.userauth.keyprovider; +import com.hierynomus.sshj.common.KeyDecryptionFailedException; import com.hierynomus.sshj.transport.cipher.BlockCiphers; import net.i2p.crypto.eddsa.EdDSAPrivateKey; import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable; @@ -111,8 +112,16 @@ public class OpenSSHKeyV1KeyFile extends BaseFileKeyProvider { return readUnencrypted(privateKeyBuffer, publicKey); } else { logger.info("Keypair is encrypted with: " + cipherName + ", " + kdfName + ", " + Arrays.toString(kdfOptions)); - PlainBuffer decrypted = decryptBuffer(privateKeyBuffer, cipherName, kdfName, kdfOptions); - return readUnencrypted(decrypted, publicKey); + while (true) { + PlainBuffer decryptionBuffer = new PlainBuffer(privateKeyBuffer); + PlainBuffer decrypted = decryptBuffer(decryptionBuffer, cipherName, kdfName, kdfOptions); + try { + return readUnencrypted(decrypted, publicKey); + } catch (KeyDecryptionFailedException e) { + if (pwdf == null || !pwdf.shouldRetry(resource)) + throw e; + } + } // throw new IOException("Cannot read encrypted keypair with " + cipherName + " yet."); } } @@ -184,7 +193,7 @@ public class OpenSSHKeyV1KeyFile extends BaseFileKeyProvider { int checkInt1 = keyBuffer.readUInt32AsInt(); // uint32 checkint1 int checkInt2 = keyBuffer.readUInt32AsInt(); // uint32 checkint2 if (checkInt1 != checkInt2) { - throw new IOException("The checkInts differed, the key was not correctly decoded."); + throw new KeyDecryptionFailedException(); } // The private key section contains both the public key and the private key String keyType = keyBuffer.readString(); // string keytype diff --git a/src/main/java/net/schmizz/sshj/userauth/keyprovider/PKCS8KeyFile.java b/src/main/java/net/schmizz/sshj/userauth/keyprovider/PKCS8KeyFile.java index d7b42af3..0ab4d960 100644 --- a/src/main/java/net/schmizz/sshj/userauth/keyprovider/PKCS8KeyFile.java +++ b/src/main/java/net/schmizz/sshj/userauth/keyprovider/PKCS8KeyFile.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.userauth.keyprovider; +import com.hierynomus.sshj.common.KeyDecryptionFailedException; import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.SecurityUtils; import net.schmizz.sshj.userauth.password.PasswordUtils; @@ -85,7 +86,7 @@ public class PKCS8KeyFile extends BaseFileKeyProvider { if (pwdf != null && pwdf.shouldRetry(resource)) continue; else - throw e; + throw new KeyDecryptionFailedException(e); } finally { IOUtils.closeQuietly(r); } diff --git a/src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java b/src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java index 9e29bc49..8457521d 100644 --- a/src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java +++ b/src/test/java/net/schmizz/sshj/keyprovider/OpenSSHKeyFileTest.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.keyprovider; +import com.hierynomus.sshj.common.KeyDecryptionFailedException; import com.hierynomus.sshj.userauth.certificate.Certificate; import com.hierynomus.sshj.userauth.keyprovider.OpenSSHKeyV1KeyFile; import net.schmizz.sshj.common.KeyType; @@ -200,12 +201,34 @@ public class OpenSSHKeyFileTest { @Test public void shouldLoadProtectedED25519PrivateKeyAes256CTR() throws IOException { - checkOpenSSHKeyV1("src/test/resources/keytypes/ed25519_protected", "sshjtest"); + checkOpenSSHKeyV1("src/test/resources/keytypes/ed25519_protected", "sshjtest", false); + checkOpenSSHKeyV1("src/test/resources/keytypes/ed25519_protected", "sshjtest", true); } @Test public void shouldLoadProtectedED25519PrivateKeyAes256CBC() throws IOException { - checkOpenSSHKeyV1("src/test/resources/keytypes/ed25519_aes256cbc.pem", "foobar"); + checkOpenSSHKeyV1("src/test/resources/keytypes/ed25519_aes256cbc.pem", "foobar", false); + checkOpenSSHKeyV1("src/test/resources/keytypes/ed25519_aes256cbc.pem", "foobar", true); + } + + @Test(expected = KeyDecryptionFailedException.class) + public void shouldFailOnIncorrectPassphraseAfterRetries() throws IOException { + OpenSSHKeyV1KeyFile keyFile = new OpenSSHKeyV1KeyFile(); + keyFile.init(new File("src/test/resources/keytypes/ed25519_aes256cbc.pem"), new PasswordFinder() { + private int reqCounter = 0; + + @Override + public char[] reqPassword(Resource resource) { + reqCounter++; + return "incorrect".toCharArray(); + } + + @Override + public boolean shouldRetry(Resource resource) { + return reqCounter <= 3; + } + }); + keyFile.getPrivate(); } @Test @@ -224,17 +247,25 @@ public class OpenSSHKeyFileTest { assertThat(aPrivate.getAlgorithm(), equalTo("ECDSA")); } - private void checkOpenSSHKeyV1(String key, final String password) throws IOException { + private void checkOpenSSHKeyV1(String key, final String password, boolean withRetry) throws IOException { OpenSSHKeyV1KeyFile keyFile = new OpenSSHKeyV1KeyFile(); keyFile.init(new File(key), new PasswordFinder() { + private int reqCounter = 0; + @Override public char[] reqPassword(Resource resource) { - return password.toCharArray(); + if (withRetry && reqCounter < 3) { + reqCounter++; + // Return an incorrect password three times before returning the correct one. + return (password + "incorrect").toCharArray(); + } else { + return password.toCharArray(); + } } @Override public boolean shouldRetry(Resource resource) { - return false; + return withRetry && reqCounter <= 3; } }); PrivateKey aPrivate = keyFile.getPrivate();