Buffer underflows as checked exceptions. Should not be a RuntimeException in case we get an invalid SSH packet.

This commit is contained in:
Shikhar Bhushan
2011-05-30 20:34:13 +01:00
parent 17d8e91f05
commit 3695e2a184
18 changed files with 291 additions and 146 deletions

View File

@@ -44,7 +44,8 @@ import java.util.Arrays;
public class Buffer<T extends Buffer<T>> { public class Buffer<T extends Buffer<T>> {
public static class BufferException public static class BufferException
extends SSHRuntimeException { extends SSHException {
public BufferException(String message) { public BufferException(String message) {
super(message); super(message);
} }
@@ -139,7 +140,8 @@ public class Buffer<T extends Buffer<T>> {
this.wpos = wpos; this.wpos = wpos;
} }
protected void ensureAvailable(int a) { protected void ensureAvailable(int a)
throws BufferException {
if (available() < a) if (available() < a)
throw new BufferException("Underflow"); throw new BufferException("Underflow");
} }
@@ -177,7 +179,8 @@ public class Buffer<T extends Buffer<T>> {
* *
* @return the {@code true} or {@code false} value read * @return the {@code true} or {@code false} value read
*/ */
public boolean readBoolean() { public boolean readBoolean()
throws BufferException {
return readByte() != 0; return readByte() != 0;
} }
@@ -197,7 +200,8 @@ public class Buffer<T extends Buffer<T>> {
* *
* @return the byte read * @return the byte read
*/ */
public byte readByte() { public byte readByte()
throws BufferException {
ensureAvailable(1); ensureAvailable(1);
return data[rpos++]; return data[rpos++];
} }
@@ -221,7 +225,8 @@ public class Buffer<T extends Buffer<T>> {
* *
* @return the byte-array read * @return the byte-array read
*/ */
public byte[] readBytes() { public byte[] readBytes()
throws BufferException {
int len = readUInt32AsInt(); int len = readUInt32AsInt();
if (len < 0 || len > 32768) if (len < 0 || len > 32768)
throw new BufferException("Bad item length: " + len); throw new BufferException("Bad item length: " + len);
@@ -254,11 +259,13 @@ public class Buffer<T extends Buffer<T>> {
return putUInt32(len - off).putRawBytes(b, off, len); return putUInt32(len - off).putRawBytes(b, off, len);
} }
public void readRawBytes(byte[] buf) { public void readRawBytes(byte[] buf)
throws BufferException {
readRawBytes(buf, 0, buf.length); readRawBytes(buf, 0, buf.length);
} }
public void readRawBytes(byte[] buf, int off, int len) { public void readRawBytes(byte[] buf, int off, int len)
throws BufferException {
ensureAvailable(len); ensureAvailable(len);
System.arraycopy(data, rpos, buf, off, len); System.arraycopy(data, rpos, buf, off, len);
rpos += len; rpos += len;
@@ -294,16 +301,18 @@ public class Buffer<T extends Buffer<T>> {
return (T) this; return (T) this;
} }
public int readUInt32AsInt() { public int readUInt32AsInt()
throws BufferException {
return (int) readUInt32(); return (int) readUInt32();
} }
public long readUInt32() { public long readUInt32()
throws BufferException {
ensureAvailable(4); ensureAvailable(4);
return data[rpos++] << 24 & 0xff000000L | return data[rpos++] << 24 & 0xff000000L |
data[rpos++] << 16 & 0x00ff0000L | data[rpos++] << 16 & 0x00ff0000L |
data[rpos++] << 8 & 0x0000ff00L | data[rpos++] << 8 & 0x0000ff00L |
data[rpos++] & 0x000000ffL; data[rpos++] & 0x000000ffL;
} }
/** /**
@@ -317,7 +326,7 @@ public class Buffer<T extends Buffer<T>> {
public T putUInt32(long uint32) { public T putUInt32(long uint32) {
ensureCapacity(4); ensureCapacity(4);
if (uint32 < 0 || uint32 > 0xffffffffL) if (uint32 < 0 || uint32 > 0xffffffffL)
throw new BufferException("Invalid value: " + uint32); throw new RuntimeException("Invalid value: " + uint32);
data[wpos++] = (byte) (uint32 >> 24); data[wpos++] = (byte) (uint32 >> 24);
data[wpos++] = (byte) (uint32 >> 16); data[wpos++] = (byte) (uint32 >> 16);
data[wpos++] = (byte) (uint32 >> 8); data[wpos++] = (byte) (uint32 >> 8);
@@ -330,7 +339,8 @@ public class Buffer<T extends Buffer<T>> {
* *
* @return the MP integer as a {@code BigInteger} * @return the MP integer as a {@code BigInteger}
*/ */
public BigInteger readMPInt() { public BigInteger readMPInt()
throws BufferException {
return new BigInteger(readMPIntAsBytes()); return new BigInteger(readMPIntAsBytes());
} }
@@ -363,11 +373,13 @@ public class Buffer<T extends Buffer<T>> {
return putRawBytes(foo); return putRawBytes(foo);
} }
public byte[] readMPIntAsBytes() { public byte[] readMPIntAsBytes()
throws BufferException {
return readBytes(); return readBytes();
} }
public long readUInt64() { public long readUInt64()
throws BufferException {
long uint64 = (readUInt32() << 32) + (readUInt32() & 0xffffffffL); long uint64 = (readUInt32() << 32) + (readUInt32() & 0xffffffffL);
if (uint64 < 0) if (uint64 < 0)
throw new BufferException("Cannot handle values > Long.MAX_VALUE"); throw new BufferException("Cannot handle values > Long.MAX_VALUE");
@@ -377,7 +389,7 @@ public class Buffer<T extends Buffer<T>> {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public T putUInt64(long uint64) { public T putUInt64(long uint64) {
if (uint64 < 0) if (uint64 < 0)
throw new BufferException("Invalid value: " + uint64); throw new RuntimeException("Invalid value: " + uint64);
data[wpos++] = (byte) (uint64 >> 56); data[wpos++] = (byte) (uint64 >> 56);
data[wpos++] = (byte) (uint64 >> 48); data[wpos++] = (byte) (uint64 >> 48);
data[wpos++] = (byte) (uint64 >> 40); data[wpos++] = (byte) (uint64 >> 40);
@@ -394,7 +406,8 @@ public class Buffer<T extends Buffer<T>> {
* *
* @return the string as a Java {@code String} * @return the string as a Java {@code String}
*/ */
public String readString() { public String readString()
throws BufferException {
int len = readUInt32AsInt(); int len = readUInt32AsInt();
if (len < 0 || len > 32768) if (len < 0 || len > 32768)
throw new BufferException("Bad item length: " + len); throw new BufferException("Bad item length: " + len);
@@ -414,7 +427,8 @@ public class Buffer<T extends Buffer<T>> {
* *
* @return the string as a byte-array * @return the string as a byte-array
*/ */
public byte[] readStringAsBytes() { public byte[] readStringAsBytes()
throws BufferException {
return readBytes(); return readBytes();
} }
@@ -452,7 +466,8 @@ public class Buffer<T extends Buffer<T>> {
return (T) this; return (T) this;
} }
public PublicKey readPublicKey() { public PublicKey readPublicKey()
throws BufferException {
try { try {
final String type = readString(); final String type = readString();
return KeyType.fromString(type).readPubKeyFromBuffer(type, this); return KeyType.fromString(type).readPubKeyFromBuffer(type, this);

View File

@@ -36,8 +36,13 @@ public enum KeyType {
@Override @Override
public PublicKey readPubKeyFromBuffer(String type, Buffer<?> buf) public PublicKey readPubKeyFromBuffer(String type, Buffer<?> buf)
throws GeneralSecurityException { throws GeneralSecurityException {
final BigInteger e = buf.readMPInt(); final BigInteger e, n;
final BigInteger n = buf.readMPInt(); try {
e = buf.readMPInt();
n = buf.readMPInt();
} catch (Buffer.BufferException be) {
throw new GeneralSecurityException(be);
}
final KeyFactory keyFactory = SecurityUtils.getKeyFactory("RSA"); final KeyFactory keyFactory = SecurityUtils.getKeyFactory("RSA");
return keyFactory.generatePublic(new RSAPublicKeySpec(n, e)); return keyFactory.generatePublic(new RSAPublicKeySpec(n, e));
} }
@@ -63,10 +68,15 @@ public enum KeyType {
@Override @Override
public PublicKey readPubKeyFromBuffer(String type, Buffer<?> buf) public PublicKey readPubKeyFromBuffer(String type, Buffer<?> buf)
throws GeneralSecurityException { throws GeneralSecurityException {
final BigInteger p = buf.readMPInt(); BigInteger p, q, g, y;
final BigInteger q = buf.readMPInt(); try {
final BigInteger g = buf.readMPInt(); p = buf.readMPInt();
final BigInteger y = buf.readMPInt(); q = buf.readMPInt();
g = buf.readMPInt();
y = buf.readMPInt();
} catch (Buffer.BufferException be) {
throw new GeneralSecurityException(be);
}
final KeyFactory keyFactory = SecurityUtils.getKeyFactory("DSA"); final KeyFactory keyFactory = SecurityUtils.getKeyFactory("DSA");
return keyFactory.generatePublic(new DSAPublicKeySpec(y, p, q, g)); return keyFactory.generatePublic(new DSAPublicKeySpec(y, p, q, g));
} }

View File

@@ -75,7 +75,8 @@ public class SSHPacket
* *
* @return the message identifier * @return the message identifier
*/ */
public Message readMessageID() { public Message readMessageID()
throws BufferException {
return Message.fromByte(readByte()); return Message.fromByte(readByte());
} }

View File

@@ -15,16 +15,16 @@
*/ */
package net.schmizz.sshj.connection; package net.schmizz.sshj.connection;
import net.schmizz.concurrent.Promise;
import net.schmizz.concurrent.ErrorDeliveryUtil; import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Promise;
import net.schmizz.sshj.AbstractService; import net.schmizz.sshj.AbstractService;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.ErrorNotifiable; import net.schmizz.sshj.common.ErrorNotifiable;
import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.channel.Channel; import net.schmizz.sshj.connection.channel.Channel;
import net.schmizz.sshj.connection.channel.OpenFailException;
import net.schmizz.sshj.connection.channel.OpenFailException.Reason; import net.schmizz.sshj.connection.channel.OpenFailException.Reason;
import net.schmizz.sshj.connection.channel.forwarded.ForwardedChannelOpener; import net.schmizz.sshj.connection.channel.forwarded.ForwardedChannelOpener;
import net.schmizz.sshj.transport.Transport; import net.schmizz.sshj.transport.Transport;
@@ -103,14 +103,18 @@ public class ConnectionImpl
private Channel getChannel(SSHPacket buffer) private Channel getChannel(SSHPacket buffer)
throws ConnectionException { throws ConnectionException {
int recipient = buffer.readUInt32AsInt(); try {
Channel channel = get(recipient); final int recipient = buffer.readUInt32AsInt();
if (channel != null) final Channel channel = get(recipient);
return channel; if (channel != null)
else { return channel;
buffer.rpos(buffer.rpos() - 5); else {
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Received " + buffer.readMessageID() buffer.rpos(buffer.rpos() - 5);
+ " on unknown channel #" + recipient); throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
"Received " + buffer.readMessageID() + " on unknown channel #" + recipient);
}
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
} }
} }
@@ -180,12 +184,13 @@ public class ConnectionImpl
@Override @Override
public Promise<SSHPacket, ConnectionException> sendGlobalRequest(String name, boolean wantReply, public Promise<SSHPacket, ConnectionException> sendGlobalRequest(String name, boolean wantReply,
byte[] specifics) byte[] specifics)
throws TransportException { throws TransportException {
synchronized (globalReqPromises) { synchronized (globalReqPromises) {
log.info("Making global request for `{}`", name); log.info("Making global request for `{}`", name);
trans.write(new SSHPacket(Message.GLOBAL_REQUEST).putString(name) trans.write(new SSHPacket(Message.GLOBAL_REQUEST).putString(name)
.putBoolean(wantReply).putRawBytes(specifics)); .putBoolean(wantReply)
.putRawBytes(specifics));
Promise<SSHPacket, ConnectionException> promise = null; Promise<SSHPacket, ConnectionException> promise = null;
if (wantReply) { if (wantReply) {
@@ -212,13 +217,17 @@ public class ConnectionImpl
private void gotChannelOpen(SSHPacket buf) private void gotChannelOpen(SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
final String type = buf.readString(); try {
log.debug("Received CHANNEL_OPEN for `{}` channel", type); final String type = buf.readString();
if (openers.containsKey(type)) log.debug("Received CHANNEL_OPEN for `{}` channel", type);
openers.get(type).handleOpen(buf); if (openers.containsKey(type))
else { openers.get(type).handleOpen(buf);
log.warn("No opener found for `{}` CHANNEL_OPEN request -- rejecting", type); else {
sendOpenFailure(buf.readUInt32AsInt(), OpenFailException.Reason.UNKNOWN_CHANNEL_TYPE, ""); log.warn("No opener found for `{}` CHANNEL_OPEN request -- rejecting", type);
sendOpenFailure(buf.readUInt32AsInt(), Reason.UNKNOWN_CHANNEL_TYPE, "");
}
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
} }
} }
@@ -226,9 +235,9 @@ public class ConnectionImpl
public void sendOpenFailure(int recipient, Reason reason, String message) public void sendOpenFailure(int recipient, Reason reason, String message)
throws TransportException { throws TransportException {
trans.write(new SSHPacket(Message.CHANNEL_OPEN_FAILURE) trans.write(new SSHPacket(Message.CHANNEL_OPEN_FAILURE)
.putUInt32(recipient) .putUInt32(recipient)
.putUInt32(reason.getCode()) .putUInt32(reason.getCode())
.putString(message)); .putString(message));
} }
@Override @Override

View File

@@ -183,11 +183,11 @@ public abstract class AbstractChannel
break; break;
case CHANNEL_EXTENDED_DATA: case CHANNEL_EXTENDED_DATA:
gotExtendedData(buf.readUInt32AsInt(), buf); gotExtendedData(buf);
break; break;
case CHANNEL_WINDOW_ADJUST: case CHANNEL_WINDOW_ADJUST:
gotWindowAdjustment(buf.readUInt32AsInt()); gotWindowAdjustment(buf);
break; break;
case CHANNEL_REQUEST: case CHANNEL_REQUEST:
@@ -301,13 +301,24 @@ public abstract class AbstractChannel
private void gotChannelRequest(SSHPacket buf) private void gotChannelRequest(SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
final String reqType = buf.readString(); final String reqType;
buf.readBoolean(); // We don't care about the 'want-reply' value try {
reqType = buf.readString();
buf.readBoolean(); // We don't care about the 'want-reply' value
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
log.info("Got chan request for `{}`", reqType); log.info("Got chan request for `{}`", reqType);
handleRequest(reqType, buf); handleRequest(reqType, buf);
} }
private void gotWindowAdjustment(int howMuch) { private void gotWindowAdjustment(SSHPacket buf) throws ConnectionException {
final int howMuch;
try {
howMuch = buf.readUInt32AsInt();
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
log.info("Received window adjustment for {} bytes", howMuch); log.info("Received window adjustment for {} bytes", howMuch);
rwin.expand(howMuch); rwin.expand(howMuch);
} }
@@ -317,10 +328,10 @@ public abstract class AbstractChannel
close.set(); close.set();
} }
protected void gotExtendedData(int dataTypeCode, SSHPacket buf) protected void gotExtendedData(SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Extended data not supported on " + type throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
+ " channel"); "Extended data not supported on " + type + " channel");
} }
protected void gotUnknown(Message msg, SSHPacket buf) protected void gotUnknown(Message msg, SSHPacket buf)
@@ -338,7 +349,12 @@ public abstract class AbstractChannel
protected void receiveInto(ChannelInputStream stream, SSHPacket buf) protected void receiveInto(ChannelInputStream stream, SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
final int len = buf.readUInt32AsInt(); final int len;
try {
len = buf.readUInt32AsInt();
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
if (len < 0 || len > getLocalMaxPacketSize() || len > buf.available()) if (len < 0 || len > getLocalMaxPacketSize() || len > buf.available())
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Bad item length: " + len); throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Bad item length: " + len);
if (log.isTraceEnabled()) if (log.isTraceEnabled())

View File

@@ -35,6 +35,7 @@
*/ */
package net.schmizz.sshj.connection.channel.direct; package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.Connection;
@@ -67,13 +68,23 @@ public abstract class AbstractDirectChannel
open.await(conn.getTimeout(), TimeUnit.SECONDS); open.await(conn.getTimeout(), TimeUnit.SECONDS);
} }
private void gotOpenConfirmation(SSHPacket buf) { private void gotOpenConfirmation(SSHPacket buf)
init(buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt()); throws ConnectionException {
try {
init(buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt());
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
open.set(); open.set();
} }
private void gotOpenFailure(SSHPacket buf) { private void gotOpenFailure(SSHPacket buf)
open.deliverError(new OpenFailException(getType(), buf.readUInt32AsInt(), buf.readString())); throws ConnectionException {
try {
open.deliverError(new OpenFailException(getType(), buf.readUInt32AsInt(), buf.readString()));
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
finishOff(); finishOff();
} }

View File

@@ -36,6 +36,7 @@
package net.schmizz.sshj.connection.channel.direct; package net.schmizz.sshj.connection.channel.direct;
import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
@@ -147,17 +148,21 @@ public class SessionChannel
@Override @Override
public void handleRequest(String req, SSHPacket buf) public void handleRequest(String req, SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
if ("xon-xoff".equals(req)) try {
canDoFlowControl = buf.readBoolean(); if ("xon-xoff".equals(req))
else if ("exit-status".equals(req)) canDoFlowControl = buf.readBoolean();
exitStatus = buf.readUInt32AsInt(); else if ("exit-status".equals(req))
else if ("exit-signal".equals(req)) { exitStatus = buf.readUInt32AsInt();
exitSignal = Signal.fromString(buf.readString()); else if ("exit-signal".equals(req)) {
wasCoreDumped = buf.readBoolean(); // core dumped exitSignal = Signal.fromString(buf.readString());
exitErrMsg = buf.readString(); wasCoreDumped = buf.readBoolean(); // core dumped
sendClose(); exitErrMsg = buf.readString();
} else sendClose();
super.handleRequest(req, buf); } else
super.handleRequest(req, buf);
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
} }
@Override @Override
@@ -225,12 +230,18 @@ public class SessionChannel
} }
@Override @Override
protected void gotExtendedData(int dataTypeCode, SSHPacket buf) protected void gotExtendedData(SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
if (dataTypeCode == 1) try {
receiveInto(err, buf); final int dataTypeCode = buf.readUInt32AsInt();
else if (dataTypeCode == 1)
super.gotExtendedData(dataTypeCode, buf); receiveInto(err, buf);
else
throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR,
"Bad extended data type = " + dataTypeCode);
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
} }
@Override @Override
@@ -246,13 +257,15 @@ public class SessionChannel
@Override @Override
@Deprecated @Deprecated
public String getOutputAsString() throws IOException { public String getOutputAsString()
throws IOException {
return IOUtils.readFully(getInputStream()).toString(); return IOUtils.readFully(getInputStream()).toString();
} }
@Override @Override
@Deprecated @Deprecated
public String getErrorAsString() throws IOException { public String getErrorAsString()
throws IOException {
return IOUtils.readFully(getErrorStream()).toString(); return IOUtils.readFully(getErrorStream()).toString();
} }

View File

@@ -168,7 +168,11 @@ public class RemotePortForwarder
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
SSHPacket reply = req(PF_REQ, forward); SSHPacket reply = req(PF_REQ, forward);
if (forward.port == 0) if (forward.port == 0)
forward.port = reply.readUInt32AsInt(); try {
forward.port = reply.readUInt32AsInt();
} catch (Buffer.BufferException e) {
throw new ConnectionException(e);
}
log.info("Remote end listening on {}", forward); log.info("Remote end listening on {}", forward);
listeners.put(forward, listener); listeners.put(forward, listener);
return forward; return forward;
@@ -211,9 +215,14 @@ public class RemotePortForwarder
@Override @Override
public void handleOpen(SSHPacket buf) public void handleOpen(SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
final ForwardedTCPIPChannel chan = new ForwardedTCPIPChannel(conn, buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt(), final ForwardedTCPIPChannel chan;
new Forward(buf.readString(), buf.readUInt32AsInt()), try {
buf.readString(), buf.readUInt32AsInt()); chan = new ForwardedTCPIPChannel(conn, buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt(),
new Forward(buf.readString(), buf.readUInt32AsInt()),
buf.readString(), buf.readUInt32AsInt());
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
if (listeners.containsKey(chan.getParentForward())) if (listeners.containsKey(chan.getParentForward()))
callListener(listeners.get(chan.getParentForward()), chan); callListener(listeners.get(chan.getParentForward()), chan);
else else

View File

@@ -15,6 +15,7 @@
*/ */
package net.schmizz.sshj.connection.channel.forwarded; package net.schmizz.sshj.connection.channel.forwarded;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.Connection;
import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.ConnectionException;
@@ -55,10 +56,14 @@ public class X11Forwarder
@Override @Override
public void handleOpen(SSHPacket buf) public void handleOpen(SSHPacket buf)
throws ConnectionException, TransportException { throws ConnectionException, TransportException {
callListener(listener, new X11Channel(conn, try {
buf.readUInt32AsInt(), callListener(listener, new X11Channel(conn,
buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt(),
buf.readString(), buf.readUInt32AsInt())); buf.readUInt32AsInt(), buf.readUInt32AsInt(),
buf.readString(), buf.readUInt32AsInt()));
} catch (Buffer.BufferException be) {
throw new ConnectionException(be);
}
} }
/** Stop handling {@code x11} channel open requests. De-registers itself with connection layer. */ /** Stop handling {@code x11} channel open requests. De-registers itself with connection layer. */

View File

@@ -51,11 +51,15 @@ public class Response
private final PacketType type; private final PacketType type;
private final long reqID; private final long reqID;
public Response(Buffer<Response> pk, int protocolVersion) { public Response(Buffer<Response> pk, int protocolVersion) throws SFTPException {
super(pk); super(pk);
this.protocolVersion = protocolVersion; this.protocolVersion = protocolVersion;
this.type = readType(); this.type = readType();
this.reqID = readUInt32(); try {
this.reqID = readUInt32();
} catch (BufferException be) {
throw new SFTPException(be);
}
} }
public int getProtocolVersion() { public int getProtocolVersion() {
@@ -70,8 +74,12 @@ public class Response
return type; return type;
} }
public StatusCode readStatusCode() { public StatusCode readStatusCode() throws SFTPException {
return StatusCode.fromInt(readUInt32AsInt()); try {
return StatusCode.fromInt(readUInt32AsInt());
} catch (BufferException be) {
throw new SFTPException(be);
}
} }
public Response ensurePacketTypeIs(PacketType pt) public Response ensurePacketTypeIs(PacketType pt)
@@ -99,7 +107,11 @@ public class Response
protected String error(StatusCode sc) protected String error(StatusCode sc)
throws SFTPException { throws SFTPException {
throw new SFTPException(sc, protocolVersion < 3 ? sc.toString() : readString()); try {
throw new SFTPException(sc, protocolVersion < 3 ? sc.toString() : readString());
} catch (BufferException be) {
throw new SFTPException(be);
}
} }
} }

View File

@@ -33,27 +33,37 @@ public class SFTPPacket<T extends SFTPPacket<T>>
putByte(pt.toByte()); putByte(pt.toByte());
} }
public FileAttributes readFileAttributes() { public FileAttributes readFileAttributes()
throws SFTPException {
final FileAttributes.Builder builder = new FileAttributes.Builder(); final FileAttributes.Builder builder = new FileAttributes.Builder();
final int mask = readUInt32AsInt(); try {
if (FileAttributes.Flag.SIZE.isSet(mask)) final int mask = readUInt32AsInt();
builder.withSize(readUInt64()); if (FileAttributes.Flag.SIZE.isSet(mask))
if (FileAttributes.Flag.UIDGID.isSet(mask)) builder.withSize(readUInt64());
builder.withUIDGID(readUInt32AsInt(), readUInt32AsInt()); if (FileAttributes.Flag.UIDGID.isSet(mask))
if (FileAttributes.Flag.MODE.isSet(mask)) builder.withUIDGID(readUInt32AsInt(), readUInt32AsInt());
builder.withPermissions(readUInt32AsInt()); if (FileAttributes.Flag.MODE.isSet(mask))
if (FileAttributes.Flag.ACMODTIME.isSet(mask)) builder.withPermissions(readUInt32AsInt());
builder.withAtimeMtime(readUInt32AsInt(), readUInt32AsInt()); if (FileAttributes.Flag.ACMODTIME.isSet(mask))
if (FileAttributes.Flag.EXTENDED.isSet(mask)) { builder.withAtimeMtime(readUInt32AsInt(), readUInt32AsInt());
final int extCount = readUInt32AsInt(); if (FileAttributes.Flag.EXTENDED.isSet(mask)) {
for (int i = 0; i < extCount; i++) final int extCount = readUInt32AsInt();
builder.withExtended(readString(), readString()); for (int i = 0; i < extCount; i++)
builder.withExtended(readString(), readString());
}
} catch (BufferException be) {
throw new SFTPException(be);
} }
return builder.build(); return builder.build();
} }
public PacketType readType() { public PacketType readType()
return PacketType.fromByte(readByte()); throws SFTPException {
try {
return PacketType.fromByte(readByte());
} catch (BufferException be) {
throw new SFTPException(be);
}
} }
public T putFileAttributes(FileAttributes fa) { public T putFileAttributes(FileAttributes fa) {

View File

@@ -35,6 +35,7 @@
*/ */
package net.schmizz.sshj.transport; package net.schmizz.sshj.transport;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.ByteArrayUtils; import net.schmizz.sshj.common.ByteArrayUtils;
import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.DisconnectReason;
import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHException;
@@ -157,7 +158,12 @@ final class Decoder
throws TransportException { throws TransportException {
cipher.update(inputBuffer.array(), 0, cipherSize); cipher.update(inputBuffer.array(), 0, cipherSize);
final int len = inputBuffer.readUInt32AsInt(); // Read packet length final int len; // Read packet length
try {
len = inputBuffer.readUInt32AsInt();
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
if (isInvalidPacketLength(len)) { // Check packet length validity if (isInvalidPacketLength(len)) { // Check packet length validity
log.info("Error decoding packet (invalid length) {}", inputBuffer.printHex()); log.info("Error decoding packet (invalid length) {}", inputBuffer.printHex());

View File

@@ -36,6 +36,7 @@
package net.schmizz.sshj.transport; package net.schmizz.sshj.transport;
import net.schmizz.sshj.Config; import net.schmizz.sshj.Config;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.Factory; import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
@@ -85,18 +86,22 @@ class Proposal {
packet.putUInt32(0); // "Reserved" for future by spec packet.putUInt32(0); // "Reserved" for future by spec
} }
public Proposal(SSHPacket packet) { public Proposal(SSHPacket packet) throws TransportException {
this.packet = packet; this.packet = packet;
final int savedPos = packet.rpos(); final int savedPos = packet.rpos();
packet.rpos(packet.rpos() + 17); // Skip message ID & cookie packet.rpos(packet.rpos() + 17); // Skip message ID & cookie
kex = fromCommaString(packet.readString()); try {
sig = fromCommaString(packet.readString()); kex = fromCommaString(packet.readString());
c2sCipher = fromCommaString(packet.readString()); sig = fromCommaString(packet.readString());
s2cCipher = fromCommaString(packet.readString()); c2sCipher = fromCommaString(packet.readString());
c2sMAC = fromCommaString(packet.readString()); s2cCipher = fromCommaString(packet.readString());
s2cMAC = fromCommaString(packet.readString()); c2sMAC = fromCommaString(packet.readString());
c2sComp = fromCommaString(packet.readString()); s2cMAC = fromCommaString(packet.readString());
s2cComp = fromCommaString(packet.readString()); c2sComp = fromCommaString(packet.readString());
s2cComp = fromCommaString(packet.readString());
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
packet.rpos(savedPos); packet.rpos(savedPos);
} }

View File

@@ -501,18 +501,26 @@ public final class TransportImpl
} }
} }
private void gotDebug(SSHPacket buf) { private void gotDebug(SSHPacket buf) throws TransportException {
boolean display = buf.readBoolean(); try {
String message = buf.readString(); final boolean display = buf.readBoolean();
log.info("Received SSH_MSG_DEBUG (display={}) '{}'", display, message); final String message = buf.readString();
log.info("Received SSH_MSG_DEBUG (display={}) '{}'", display, message);
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
} }
private void gotDisconnect(SSHPacket buf) private void gotDisconnect(SSHPacket buf)
throws TransportException { throws TransportException {
DisconnectReason code = DisconnectReason.fromInt(buf.readUInt32AsInt()); try {
String message = buf.readString(); final DisconnectReason code = DisconnectReason.fromInt(buf.readUInt32AsInt());
log.info("Received SSH_MSG_DISCONNECT (reason={}, msg={})", code, message); final String message = buf.readString();
throw new TransportException(code, "Disconnected; server said: " + message); log.info("Received SSH_MSG_DISCONNECT (reason={}, msg={})", code, message);
throw new TransportException(code, "Disconnected; server said: " + message);
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
} }
private void gotServiceAccept() private void gotServiceAccept()

View File

@@ -121,14 +121,21 @@ public abstract class AbstractDHG
throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, "Unexpected packet: " + msg); throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, "Unexpected packet: " + msg);
log.info("Received SSH_MSG_KEXDH_REPLY"); log.info("Received SSH_MSG_KEXDH_REPLY");
final byte[] K_S = packet.readBytes(); final byte[] K_S;
final byte[] f = packet.readMPIntAsBytes(); final byte[] f;
final byte[] sig = packet.readBytes(); // signature sent by server final byte[] sig; // signature sent by server
try {
K_S = packet.readBytes();
f = packet.readMPIntAsBytes();
sig = packet.readBytes();
hostKey = new Buffer.PlainBuffer(K_S).readPublicKey();
} catch (Buffer.BufferException be) {
throw new TransportException(be);
}
dh.setF(new BigInteger(f)); dh.setF(new BigInteger(f));
K = dh.getK(); K = dh.getK();
hostKey = new Buffer.PlainBuffer(K_S).readPublicKey();
final Buffer.PlainBuffer buf = new Buffer.PlainBuffer() final Buffer.PlainBuffer buf = new Buffer.PlainBuffer()
.putString(V_C) .putString(V_C)
.putString(V_S) .putString(V_S)

View File

@@ -15,6 +15,7 @@
*/ */
package net.schmizz.sshj.userauth.method; package net.schmizz.sshj.userauth.method;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.transport.TransportException; import net.schmizz.sshj.transport.TransportException;
@@ -63,15 +64,20 @@ public class AuthKeyboardInteractive
if (cmd != Message.USERAUTH_60) { if (cmd != Message.USERAUTH_60) {
super.handle(cmd, buf); super.handle(cmd, buf);
} else { } else {
provider.init(makeAccountResource(), buf.readString(), buf.readString()); final CharArrWrap[] userReplies;
buf.readString(); // lang-tag try {
final int numPrompts = buf.readUInt32AsInt(); provider.init(makeAccountResource(), buf.readString(), buf.readString());
final CharArrWrap[] userReplies = new CharArrWrap[numPrompts]; buf.readString(); // lang-tag
for (int i = 0; i < numPrompts; i++) { final int numPrompts = buf.readUInt32AsInt();
final String prompt = buf.readString(); userReplies = new CharArrWrap[numPrompts];
final boolean echo = buf.readBoolean(); for (int i = 0; i < numPrompts; i++) {
log.info("Requesting response for challenge `{}`; echo={}", prompt, echo); final String prompt = buf.readString();
userReplies[i] = new CharArrWrap(provider.getResponse(prompt, echo)); final boolean echo = buf.readBoolean();
log.info("Requesting response for challenge `{}`; echo={}", prompt, echo);
userReplies[i] = new CharArrWrap(provider.getResponse(prompt, echo));
}
} catch (Buffer.BufferException be) {
throw new UserAuthException(be);
} }
respond(userReplies); respond(userReplies);
} }

View File

@@ -23,7 +23,6 @@ import org.junit.Test;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;

View File

@@ -41,7 +41,8 @@ public class BufferTest {
} }
@Test @Test
public void testDataTypes() { public void testDataTypes()
throws Buffer.BufferException {
// bool // bool
assertEquals(handyBuf.putBoolean(true).readBoolean(), true); assertEquals(handyBuf.putBoolean(true).readBoolean(), true);
@@ -63,7 +64,8 @@ public class BufferTest {
} }
@Test @Test
public void testPassword() { public void testPassword()
throws Buffer.BufferException {
char[] pass = "lolcatz".toCharArray(); char[] pass = "lolcatz".toCharArray();
// test if put correctly as a string // test if put correctly as a string
assertEquals(new Buffer.PlainBuffer().putSensitiveString(pass).readString(), "lolcatz"); assertEquals(new Buffer.PlainBuffer().putSensitiveString(pass).readString(), "lolcatz");
@@ -73,7 +75,7 @@ public class BufferTest {
@Test @Test
public void testPosition() public void testPosition()
throws UnsupportedEncodingException { throws UnsupportedEncodingException, Buffer.BufferException {
assertEquals(5, posBuf.wpos()); assertEquals(5, posBuf.wpos());
assertEquals(0, posBuf.rpos()); assertEquals(0, posBuf.rpos());
assertEquals(5, posBuf.available()); assertEquals(5, posBuf.available());
@@ -95,7 +97,8 @@ public class BufferTest {
} }
@Test(expected = Buffer.BufferException.class) @Test(expected = Buffer.BufferException.class)
public void testUnderflow() { public void testUnderflow()
throws Buffer.BufferException {
// exhaust the buffer // exhaust the buffer
for (int i = 0; i < 5; ++i) for (int i = 0; i < 5; ++i)
posBuf.readByte(); posBuf.readByte();