refactored known_hosts handling especially for writing back

This commit is contained in:
Shikhar Bhushan
2010-03-07 14:52:12 +01:00
parent 0df02a3ff9
commit bc4edb7629
3 changed files with 190 additions and 120 deletions

View File

@@ -46,8 +46,8 @@ public class ConsoleKnownHostsVerifier
response = console.readLine("Please explicitly enter yes/no: "); response = console.readLine("Please explicitly enter yes/no: ");
} }
if (response.equalsIgnoreCase(YES)) { if (response.equalsIgnoreCase(YES)) {
entries().add(new Entry(hostname, key));
try { try {
entries().add(new SimpleEntry(hostname, key));
write(); write();
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);

View File

@@ -37,7 +37,6 @@ package net.schmizz.sshj.transport.verification;
import net.schmizz.sshj.common.Base64; import net.schmizz.sshj.common.Base64;
import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.ByteArrayUtils;
import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.KeyType; import net.schmizz.sshj.common.KeyType;
import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHException;
@@ -57,118 +56,66 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
// TODO: allow modifications to known_hosts e.g. adding entries
/** /**
* A {@link HostKeyVerifier} implementation for a {@code known_hosts} file i.e. in the format used by OpenSSH. * A {@link HostKeyVerifier} implementation for a {@code known_hosts} file i.e. in the format used by OpenSSH.
* <p/> * <p/>
* Hashed hostnames are correctly handled. * Hashed hostnames are correctly handled (but not yet for writing back).
* *
* @see <a href="http://nms.lcs.mit.edu/projects/ssh/README.hashed-hosts">Hashed hostnames spec</a> * @see <a href="http://nms.lcs.mit.edu/projects/ssh/README.hashed-hosts">Hashed hostnames spec</a>
*/ */
public class OpenSSHKnownHosts public class OpenSSHKnownHosts
implements HostKeyVerifier { implements HostKeyVerifier {
private static final String LS = System.getProperty("line.separator"); public static interface Entry {
/** Represents a single line */ KeyType getType();
public static class Entry {
private final MAC sha1 = new HMACSHA1(); boolean appliesTo(String host)
throws IOException;
private final List<String> hosts; PublicKey getKey()
private final KeyType type; throws IOException;
String getLine();
}
public static abstract class BaseEntry
implements Entry {
private KeyType type;
private PublicKey key; private PublicKey key;
private String sKey; private String sKey;
/** Construct an entry from the hostname and public key */ protected void init(PublicKey key)
public Entry(String host, PublicKey key) {
this.key = key;
this.hosts = Arrays.asList(host);
type = KeyType.fromKey(key);
}
/**
* Construct an entry from a string containing the line
*
* @param line the line from a known_hosts file
*
* @throws SSHException if it could not be parsed for any reason
*/
public Entry(String line)
throws SSHException { throws SSHException {
String[] parts = line.split(" "); this.key = key;
if (parts.length != 3) this.type = KeyType.fromKey(key);
throw new SSHException("Line parts not 3: " + line);
hosts = Arrays.asList(parts[0].split(","));
type = KeyType.fromString(parts[1]);
if (type == KeyType.UNKNOWN) if (type == KeyType.UNKNOWN)
throw new SSHException("Unknown key type: " + parts[1]); throw new SSHException("Unknown key type for key: " + key);
sKey = parts[2];
} }
/** Checks whether this entry is applicable to some {@code hostname} */ protected void init(String typeString, String keyString)
public boolean appliesTo(String hostname) throws SSHException {
throws IOException { this.sKey = keyString;
if (!hosts.isEmpty() && hosts.get(0).startsWith("|1|")) { // Hashed hostname this.type = KeyType.fromString(typeString);
final String[] splitted = hosts.get(0).split("\\|"); if (type == KeyType.UNKNOWN)
if (splitted.length != 4) throw new SSHException("Unknown key type: " + typeString);
return false;
final byte[] salt = Base64.decode(splitted[2]);
if (salt.length != 20)
return false;
sha1.init(salt);
final byte[] host = Base64.decode(splitted[3]);
if (ByteArrayUtils.equals(host, sha1.doFinal(hostname.getBytes())))
return true;
} else
// Un-hashed, possibly comma-delimited
for (String host : hosts)
if (host.equals(hostname))
return true;
return false;
}
/**
* Returns the public host key represented in this entry.
* <p/>
* The key is cached so repeated calls to this method may be made without concern.
*
* @return the host key
*/
public PublicKey getKey() {
if (key == null) {
byte[] decoded;
try {
decoded = Base64.decode(sKey);
} catch (IOException e) {
return null;
}
key = new Buffer.PlainBuffer(decoded).readPublicKey();
}
return key;
} }
public KeyType getType() { public KeyType getType() {
return type; return type;
} }
public String getLine() { public PublicKey getKey()
StringBuilder line = new StringBuilder(); throws IOException {
for (String host : hosts) { if (key == null) {
if (line.length() > 0) key = new Buffer.PlainBuffer(Base64.decode(sKey)).readPublicKey();
line.append(",");
line.append(host);
} }
line.append(" ").append(type.toString()); return key;
line.append(" ").append(getKeyString());
return line.toString();
} }
private String getKeyString() { protected String getKeyString() {
if (sKey == null) { if (sKey == null) {
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer().putPublicKey(key); final Buffer.PlainBuffer buf = new Buffer.PlainBuffer().putPublicKey(key);
sKey = Base64.encodeBytes(buf.array(), buf.rpos(), buf.available()); sKey = Base64.encodeBytes(buf.array(), buf.rpos(), buf.available());
@@ -176,14 +123,136 @@ public class OpenSSHKnownHosts
return sKey; return sKey;
} }
public String getLine() {
final StringBuilder line = new StringBuilder();
line.append(getHostPart());
line.append(" ").append(type.toString());
line.append(" ").append(getKeyString());
return line.toString();
}
@Override @Override
public String toString() { public String toString() {
return "Entry{hostnames=" + hosts + "; type=" + type + "; key=" + getKey() + "}"; return "KnownHostsEntry{host=" + getHostPart() + "; type=" + type + "}";
}
protected abstract String getHostPart();
}
public static class SimpleEntry
extends BaseEntry {
private final List<String> hosts;
public SimpleEntry(String host, PublicKey key)
throws SSHException {
this(Arrays.asList(host), key);
}
public SimpleEntry(List<String> hosts, PublicKey key)
throws SSHException {
this.hosts = hosts;
init(key);
}
public SimpleEntry(String line)
throws SSHException {
final String[] parts = line.split(" ");
if (parts.length != 3)
throw new SSHException("Line parts not 3: " + line);
hosts = Arrays.asList(parts[0].split(","));
init(parts[1], parts[2]);
}
public boolean appliesTo(String host) {
for (String h : hosts)
if (host.equals(h))
return true;
return false;
}
protected String getHostPart() {
final StringBuilder sb = new StringBuilder();
for (String host : hosts) {
if (sb.length() > 0) // a host already in there
sb.append(",");
sb.append(host);
}
return sb.toString();
} }
} }
private final Logger log = LoggerFactory.getLogger(getClass()); public static class HashedEntry
extends BaseEntry {
private final MAC sha1 = new HMACSHA1();
private String salt;
private byte[] saltyBytes;
private final String hashedHost;
public HashedEntry(String host, PublicKey key)
throws IOException {
this.hashedHost = hashHost(host);
init(key);
{
saltyBytes = new byte[sha1.getBlockSize()];
new java.util.Random().nextBytes(saltyBytes);
}
}
public HashedEntry(String line)
throws IOException {
final String[] parts = line.split(" ");
if (parts.length != 3)
throw new SSHException("Line parts not 3: " + line);
hashedHost = parts[0];
init(parts[1], parts[2]);
{
final String[] split = hashedHost.split("\\|");
if (split.length != 4)
throw new SSHException("Unrecognized format for hashed hostname");
salt = split[2];
}
}
public boolean appliesTo(String host)
throws IOException {
return hashedHost.equals(hashHost(host));
}
private String hashHost(String host)
throws IOException {
sha1.init(getSaltyBytes());
return "|1|" + getSalt() + "|" + Base64.encodeBytes(sha1.doFinal(host.getBytes()));
}
private byte[] getSaltyBytes()
throws IOException {
if (saltyBytes == null) {
saltyBytes = Base64.decode(salt);
}
return saltyBytes;
}
private String getSalt()
throws IOException {
if (salt == null) {
salt = Base64.encodeBytes(saltyBytes);
}
return salt;
}
protected String getHostPart() {
return hashedHost;
}
}
protected final Logger log = LoggerFactory.getLogger(getClass());
protected final File khFile; protected final File khFile;
protected final List<Entry> entries = new ArrayList<Entry>(); protected final List<Entry> entries = new ArrayList<Entry>();
@@ -199,13 +268,13 @@ public class OpenSSHKnownHosts
throws IOException { throws IOException {
this.khFile = khFile; this.khFile = khFile;
if (khFile.exists()) { if (khFile.exists()) {
BufferedReader br = new BufferedReader(new FileReader(khFile)); final BufferedReader br = new BufferedReader(new FileReader(khFile));
String line;
try { try {
// Read in the file, storing each line as an entry // Read in the file, storing each line as an entry
String line;
while ((line = br.readLine()) != null) while ((line = br.readLine()) != null)
try { try {
entries.add(new Entry(line)); entries.add(isHashed(line) ? new HashedEntry(line) : new SimpleEntry(line));
} catch (SSHException ignore) { } catch (SSHException ignore) {
log.debug("Bad line ({}): {} ", ignore.toString(), line); log.debug("Bad line ({}): {} ", ignore.toString(), line);
} }
@@ -215,13 +284,12 @@ public class OpenSSHKnownHosts
} }
} }
/** public File getFile() {
* Checks whether the specified host is known per the contents of the {@code known_hosts} file. return khFile;
* }
* @return {@code true} on successful verification or {@code false} on failure
*/
public boolean verify(final String hostname, final int port, final PublicKey key) { public boolean verify(final String hostname, final int port, final PublicKey key) {
KeyType type = KeyType.fromKey(key); final KeyType type = KeyType.fromKey(key);
if (type == KeyType.UNKNOWN) if (type == KeyType.UNKNOWN)
return false; return false;
@@ -230,11 +298,7 @@ public class OpenSSHKnownHosts
for (Entry e : entries) for (Entry e : entries)
try { try {
if (e.getType() == type && e.appliesTo(adjustedHostname)) if (e.getType() == type && e.appliesTo(adjustedHostname))
if (key.equals(e.getKey())) return key.equals(e.getKey()) || hostKeyChangedAction(e, adjustedHostname, key);
return true;
else {
return hostKeyChangedAction(e, adjustedHostname, key);
}
} catch (IOException ioe) { } catch (IOException ioe) {
log.error("Error with {}: {}", e, ioe); log.error("Error with {}: {}", e, ioe);
return false; return false;
@@ -256,12 +320,17 @@ public class OpenSSHKnownHosts
return entries; return entries;
} }
private static final String LS = System.getProperty("line.separator");
public void write() public void write()
throws IOException { throws IOException {
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(khFile)); final BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(khFile));
for (Entry entry : entries) try {
bos.write((entry.getLine() + LS).getBytes()); for (Entry entry : entries)
bos.close(); bos.write((entry.getLine() + LS).getBytes());
} finally {
bos.close();
}
} }
public static File detectSSHDir() { public static File detectSSHDir() {
@@ -269,4 +338,8 @@ public class OpenSSHKnownHosts
return sshDir.exists() ? sshDir : null; return sshDir.exists() ? sshDir : null;
} }
} public static boolean isHashed(String line) {
return line.startsWith("|1|");
}
}

View File

@@ -1,20 +1,17 @@
/* /*
* Licensed to the Apache Software Foundation (ASF) under one * Copyright 2010 Shikhar Bhushan
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* *
* Unless required by applicable law or agreed to in writing, * http://www.apache.org/licenses/LICENSE-2.0
* software distributed under the License is distributed on an *
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * Unless required by applicable law or agreed to in writing, software
* KIND, either express or implied. See the License for the * distributed under the License is distributed on an "AS IS" BASIS,
* specific language governing permissions and limitations * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* under the License. * See the License for the specific language governing permissions and
* limitations under the License.
*/ */
package net.schmizz.sshj.transport.verification; package net.schmizz.sshj.transport.verification;