diff --git a/src/main/java/net/schmizz/sshj/transport/verification/ConsoleKnownHostsVerifier.java b/src/main/java/net/schmizz/sshj/transport/verification/ConsoleKnownHostsVerifier.java
index ca96eb59..1ed2396e 100644
--- a/src/main/java/net/schmizz/sshj/transport/verification/ConsoleKnownHostsVerifier.java
+++ b/src/main/java/net/schmizz/sshj/transport/verification/ConsoleKnownHostsVerifier.java
@@ -46,8 +46,8 @@ public class ConsoleKnownHostsVerifier
response = console.readLine("Please explicitly enter yes/no: ");
}
if (response.equalsIgnoreCase(YES)) {
- entries().add(new Entry(hostname, key));
try {
+ entries().add(new SimpleEntry(hostname, key));
write();
} catch (IOException e) {
throw new RuntimeException(e);
diff --git a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java
index 8fec41b8..dd30a782 100644
--- a/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java
+++ b/src/main/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHosts.java
@@ -37,7 +37,6 @@ package net.schmizz.sshj.transport.verification;
import net.schmizz.sshj.common.Base64;
import net.schmizz.sshj.common.Buffer;
-import net.schmizz.sshj.common.ByteArrayUtils;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.SSHException;
@@ -57,118 +56,66 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-// TODO: allow modifications to known_hosts e.g. adding entries
-
/**
* A {@link HostKeyVerifier} implementation for a {@code known_hosts} file i.e. in the format used by OpenSSH.
*
- * Hashed hostnames are correctly handled.
+ * Hashed hostnames are correctly handled (but not yet for writing back).
*
* @see Hashed hostnames spec
*/
public class OpenSSHKnownHosts
implements HostKeyVerifier {
- private static final String LS = System.getProperty("line.separator");
+ public static interface Entry {
- /** Represents a single line */
- public static class Entry {
+ KeyType getType();
- private final MAC sha1 = new HMACSHA1();
+ boolean appliesTo(String host)
+ throws IOException;
- private final List hosts;
- private final KeyType type;
+ PublicKey getKey()
+ throws IOException;
+ String getLine();
+
+ }
+
+ public static abstract class BaseEntry
+ implements Entry {
+
+ private KeyType type;
private PublicKey key;
private String sKey;
- /** Construct an entry from the hostname and public key */
- public Entry(String host, PublicKey key) {
- this.key = key;
- this.hosts = Arrays.asList(host);
- type = KeyType.fromKey(key);
- }
-
- /**
- * Construct an entry from a string containing the line
- *
- * @param line the line from a known_hosts file
- *
- * @throws SSHException if it could not be parsed for any reason
- */
- public Entry(String line)
+ protected void init(PublicKey key)
throws SSHException {
- String[] parts = line.split(" ");
- if (parts.length != 3)
- throw new SSHException("Line parts not 3: " + line);
- hosts = Arrays.asList(parts[0].split(","));
- type = KeyType.fromString(parts[1]);
+ this.key = key;
+ this.type = KeyType.fromKey(key);
if (type == KeyType.UNKNOWN)
- throw new SSHException("Unknown key type: " + parts[1]);
- sKey = parts[2];
+ throw new SSHException("Unknown key type for key: " + key);
}
- /** Checks whether this entry is applicable to some {@code hostname} */
- public boolean appliesTo(String hostname)
- throws IOException {
- if (!hosts.isEmpty() && hosts.get(0).startsWith("|1|")) { // Hashed hostname
- final String[] splitted = hosts.get(0).split("\\|");
- if (splitted.length != 4)
- return false;
-
- final byte[] salt = Base64.decode(splitted[2]);
- if (salt.length != 20)
- return false;
- sha1.init(salt);
-
- final byte[] host = Base64.decode(splitted[3]);
- if (ByteArrayUtils.equals(host, sha1.doFinal(hostname.getBytes())))
- return true;
- } else
- // Un-hashed, possibly comma-delimited
- for (String host : hosts)
- if (host.equals(hostname))
- return true;
- return false;
- }
-
- /**
- * Returns the public host key represented in this entry.
- *
- * The key is cached so repeated calls to this method may be made without concern.
- *
- * @return the host key
- */
- public PublicKey getKey() {
- if (key == null) {
- byte[] decoded;
- try {
- decoded = Base64.decode(sKey);
- } catch (IOException e) {
- return null;
- }
- key = new Buffer.PlainBuffer(decoded).readPublicKey();
- }
- return key;
+ protected void init(String typeString, String keyString)
+ throws SSHException {
+ this.sKey = keyString;
+ this.type = KeyType.fromString(typeString);
+ if (type == KeyType.UNKNOWN)
+ throw new SSHException("Unknown key type: " + typeString);
}
public KeyType getType() {
return type;
}
- public String getLine() {
- StringBuilder line = new StringBuilder();
- for (String host : hosts) {
- if (line.length() > 0)
- line.append(",");
- line.append(host);
+ public PublicKey getKey()
+ throws IOException {
+ if (key == null) {
+ key = new Buffer.PlainBuffer(Base64.decode(sKey)).readPublicKey();
}
- line.append(" ").append(type.toString());
- line.append(" ").append(getKeyString());
- return line.toString();
+ return key;
}
- private String getKeyString() {
+ protected String getKeyString() {
if (sKey == null) {
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer().putPublicKey(key);
sKey = Base64.encodeBytes(buf.array(), buf.rpos(), buf.available());
@@ -176,14 +123,136 @@ public class OpenSSHKnownHosts
return sKey;
}
+ public String getLine() {
+ final StringBuilder line = new StringBuilder();
+ line.append(getHostPart());
+ line.append(" ").append(type.toString());
+ line.append(" ").append(getKeyString());
+ return line.toString();
+ }
+
@Override
public String toString() {
- return "Entry{hostnames=" + hosts + "; type=" + type + "; key=" + getKey() + "}";
+ return "KnownHostsEntry{host=" + getHostPart() + "; type=" + type + "}";
+ }
+
+ protected abstract String getHostPart();
+
+ }
+
+ public static class SimpleEntry
+ extends BaseEntry {
+
+ private final List hosts;
+
+ public SimpleEntry(String host, PublicKey key)
+ throws SSHException {
+ this(Arrays.asList(host), key);
+ }
+
+ public SimpleEntry(List hosts, PublicKey key)
+ throws SSHException {
+ this.hosts = hosts;
+ init(key);
+ }
+
+ public SimpleEntry(String line)
+ throws SSHException {
+ final String[] parts = line.split(" ");
+ if (parts.length != 3)
+ throw new SSHException("Line parts not 3: " + line);
+ hosts = Arrays.asList(parts[0].split(","));
+ init(parts[1], parts[2]);
+ }
+
+ public boolean appliesTo(String host) {
+ for (String h : hosts)
+ if (host.equals(h))
+ return true;
+ return false;
+ }
+
+ protected String getHostPart() {
+ final StringBuilder sb = new StringBuilder();
+ for (String host : hosts) {
+ if (sb.length() > 0) // a host already in there
+ sb.append(",");
+ sb.append(host);
+ }
+ return sb.toString();
}
}
- private final Logger log = LoggerFactory.getLogger(getClass());
+ public static class HashedEntry
+ extends BaseEntry {
+
+ private final MAC sha1 = new HMACSHA1();
+
+ private String salt;
+ private byte[] saltyBytes;
+
+ private final String hashedHost;
+
+ public HashedEntry(String host, PublicKey key)
+ throws IOException {
+ this.hashedHost = hashHost(host);
+ init(key);
+ {
+ saltyBytes = new byte[sha1.getBlockSize()];
+ new java.util.Random().nextBytes(saltyBytes);
+ }
+ }
+
+ public HashedEntry(String line)
+ throws IOException {
+ final String[] parts = line.split(" ");
+ if (parts.length != 3)
+ throw new SSHException("Line parts not 3: " + line);
+ hashedHost = parts[0];
+ init(parts[1], parts[2]);
+ {
+ final String[] split = hashedHost.split("\\|");
+ if (split.length != 4)
+ throw new SSHException("Unrecognized format for hashed hostname");
+ salt = split[2];
+ }
+ }
+
+ public boolean appliesTo(String host)
+ throws IOException {
+ return hashedHost.equals(hashHost(host));
+ }
+
+ private String hashHost(String host)
+ throws IOException {
+ sha1.init(getSaltyBytes());
+ return "|1|" + getSalt() + "|" + Base64.encodeBytes(sha1.doFinal(host.getBytes()));
+ }
+
+ private byte[] getSaltyBytes()
+ throws IOException {
+ if (saltyBytes == null) {
+ saltyBytes = Base64.decode(salt);
+ }
+ return saltyBytes;
+ }
+
+ private String getSalt()
+ throws IOException {
+ if (salt == null) {
+ salt = Base64.encodeBytes(saltyBytes);
+ }
+ return salt;
+ }
+
+ protected String getHostPart() {
+ return hashedHost;
+ }
+
+ }
+
+ protected final Logger log = LoggerFactory.getLogger(getClass());
protected final File khFile;
protected final List entries = new ArrayList();
@@ -199,13 +268,13 @@ public class OpenSSHKnownHosts
throws IOException {
this.khFile = khFile;
if (khFile.exists()) {
- BufferedReader br = new BufferedReader(new FileReader(khFile));
- String line;
+ final BufferedReader br = new BufferedReader(new FileReader(khFile));
try {
// Read in the file, storing each line as an entry
+ String line;
while ((line = br.readLine()) != null)
try {
- entries.add(new Entry(line));
+ entries.add(isHashed(line) ? new HashedEntry(line) : new SimpleEntry(line));
} catch (SSHException ignore) {
log.debug("Bad line ({}): {} ", ignore.toString(), line);
}
@@ -215,13 +284,12 @@ public class OpenSSHKnownHosts
}
}
- /**
- * Checks whether the specified host is known per the contents of the {@code known_hosts} file.
- *
- * @return {@code true} on successful verification or {@code false} on failure
- */
+ public File getFile() {
+ return khFile;
+ }
+
public boolean verify(final String hostname, final int port, final PublicKey key) {
- KeyType type = KeyType.fromKey(key);
+ final KeyType type = KeyType.fromKey(key);
if (type == KeyType.UNKNOWN)
return false;
@@ -230,11 +298,7 @@ public class OpenSSHKnownHosts
for (Entry e : entries)
try {
if (e.getType() == type && e.appliesTo(adjustedHostname))
- if (key.equals(e.getKey()))
- return true;
- else {
- return hostKeyChangedAction(e, adjustedHostname, key);
- }
+ return key.equals(e.getKey()) || hostKeyChangedAction(e, adjustedHostname, key);
} catch (IOException ioe) {
log.error("Error with {}: {}", e, ioe);
return false;
@@ -256,12 +320,17 @@ public class OpenSSHKnownHosts
return entries;
}
+ private static final String LS = System.getProperty("line.separator");
+
public void write()
throws IOException {
- BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(khFile));
- for (Entry entry : entries)
- bos.write((entry.getLine() + LS).getBytes());
- bos.close();
+ final BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(khFile));
+ try {
+ for (Entry entry : entries)
+ bos.write((entry.getLine() + LS).getBytes());
+ } finally {
+ bos.close();
+ }
}
public static File detectSSHDir() {
@@ -269,4 +338,8 @@ public class OpenSSHKnownHosts
return sshDir.exists() ? sshDir : null;
}
-}
+ public static boolean isHashed(String line) {
+ return line.startsWith("|1|");
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHostsTest.java b/src/test/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHostsTest.java
index 5f04c9ae..b4662515 100644
--- a/src/test/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHostsTest.java
+++ b/src/test/java/net/schmizz/sshj/transport/verification/OpenSSHKnownHostsTest.java
@@ -1,20 +1,17 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
+ * Copyright 2010 Shikhar Bhushan
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
*
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
*/
package net.schmizz.sshj.transport.verification;