diff --git a/src/main/java/net/schmizz/sshj/common/Buffer.java b/src/main/java/net/schmizz/sshj/common/Buffer.java index d5012d54..7e3ccb0a 100644 --- a/src/main/java/net/schmizz/sshj/common/Buffer.java +++ b/src/main/java/net/schmizz/sshj/common/Buffer.java @@ -44,7 +44,8 @@ import java.util.Arrays; public class Buffer> { public static class BufferException - extends SSHRuntimeException { + extends SSHException { + public BufferException(String message) { super(message); } @@ -139,7 +140,8 @@ public class Buffer> { this.wpos = wpos; } - protected void ensureAvailable(int a) { + protected void ensureAvailable(int a) + throws BufferException { if (available() < a) throw new BufferException("Underflow"); } @@ -177,7 +179,8 @@ public class Buffer> { * * @return the {@code true} or {@code false} value read */ - public boolean readBoolean() { + public boolean readBoolean() + throws BufferException { return readByte() != 0; } @@ -197,7 +200,8 @@ public class Buffer> { * * @return the byte read */ - public byte readByte() { + public byte readByte() + throws BufferException { ensureAvailable(1); return data[rpos++]; } @@ -221,7 +225,8 @@ public class Buffer> { * * @return the byte-array read */ - public byte[] readBytes() { + public byte[] readBytes() + throws BufferException { int len = readUInt32AsInt(); if (len < 0 || len > 32768) throw new BufferException("Bad item length: " + len); @@ -254,11 +259,13 @@ public class Buffer> { return putUInt32(len - off).putRawBytes(b, off, len); } - public void readRawBytes(byte[] buf) { + public void readRawBytes(byte[] buf) + throws BufferException { readRawBytes(buf, 0, buf.length); } - public void readRawBytes(byte[] buf, int off, int len) { + public void readRawBytes(byte[] buf, int off, int len) + throws BufferException { ensureAvailable(len); System.arraycopy(data, rpos, buf, off, len); rpos += len; @@ -294,16 +301,18 @@ public class Buffer> { return (T) this; } - public int readUInt32AsInt() { + public int readUInt32AsInt() + throws BufferException { return (int) readUInt32(); } - public long readUInt32() { + public long readUInt32() + throws BufferException { ensureAvailable(4); return data[rpos++] << 24 & 0xff000000L | - data[rpos++] << 16 & 0x00ff0000L | - data[rpos++] << 8 & 0x0000ff00L | - data[rpos++] & 0x000000ffL; + data[rpos++] << 16 & 0x00ff0000L | + data[rpos++] << 8 & 0x0000ff00L | + data[rpos++] & 0x000000ffL; } /** @@ -317,7 +326,7 @@ public class Buffer> { public T putUInt32(long uint32) { ensureCapacity(4); if (uint32 < 0 || uint32 > 0xffffffffL) - throw new BufferException("Invalid value: " + uint32); + throw new RuntimeException("Invalid value: " + uint32); data[wpos++] = (byte) (uint32 >> 24); data[wpos++] = (byte) (uint32 >> 16); data[wpos++] = (byte) (uint32 >> 8); @@ -330,7 +339,8 @@ public class Buffer> { * * @return the MP integer as a {@code BigInteger} */ - public BigInteger readMPInt() { + public BigInteger readMPInt() + throws BufferException { return new BigInteger(readMPIntAsBytes()); } @@ -363,11 +373,13 @@ public class Buffer> { return putRawBytes(foo); } - public byte[] readMPIntAsBytes() { + public byte[] readMPIntAsBytes() + throws BufferException { return readBytes(); } - public long readUInt64() { + public long readUInt64() + throws BufferException { long uint64 = (readUInt32() << 32) + (readUInt32() & 0xffffffffL); if (uint64 < 0) throw new BufferException("Cannot handle values > Long.MAX_VALUE"); @@ -377,7 +389,7 @@ public class Buffer> { @SuppressWarnings("unchecked") public T putUInt64(long uint64) { if (uint64 < 0) - throw new BufferException("Invalid value: " + uint64); + throw new RuntimeException("Invalid value: " + uint64); data[wpos++] = (byte) (uint64 >> 56); data[wpos++] = (byte) (uint64 >> 48); data[wpos++] = (byte) (uint64 >> 40); @@ -394,7 +406,8 @@ public class Buffer> { * * @return the string as a Java {@code String} */ - public String readString() { + public String readString() + throws BufferException { int len = readUInt32AsInt(); if (len < 0 || len > 32768) throw new BufferException("Bad item length: " + len); @@ -414,7 +427,8 @@ public class Buffer> { * * @return the string as a byte-array */ - public byte[] readStringAsBytes() { + public byte[] readStringAsBytes() + throws BufferException { return readBytes(); } @@ -452,7 +466,8 @@ public class Buffer> { return (T) this; } - public PublicKey readPublicKey() { + public PublicKey readPublicKey() + throws BufferException { try { final String type = readString(); return KeyType.fromString(type).readPubKeyFromBuffer(type, this); diff --git a/src/main/java/net/schmizz/sshj/common/KeyType.java b/src/main/java/net/schmizz/sshj/common/KeyType.java index bd550345..4d955f11 100644 --- a/src/main/java/net/schmizz/sshj/common/KeyType.java +++ b/src/main/java/net/schmizz/sshj/common/KeyType.java @@ -36,8 +36,13 @@ public enum KeyType { @Override public PublicKey readPubKeyFromBuffer(String type, Buffer buf) throws GeneralSecurityException { - final BigInteger e = buf.readMPInt(); - final BigInteger n = buf.readMPInt(); + final BigInteger e, n; + try { + e = buf.readMPInt(); + n = buf.readMPInt(); + } catch (Buffer.BufferException be) { + throw new GeneralSecurityException(be); + } final KeyFactory keyFactory = SecurityUtils.getKeyFactory("RSA"); return keyFactory.generatePublic(new RSAPublicKeySpec(n, e)); } @@ -63,10 +68,15 @@ public enum KeyType { @Override public PublicKey readPubKeyFromBuffer(String type, Buffer buf) throws GeneralSecurityException { - final BigInteger p = buf.readMPInt(); - final BigInteger q = buf.readMPInt(); - final BigInteger g = buf.readMPInt(); - final BigInteger y = buf.readMPInt(); + BigInteger p, q, g, y; + try { + p = buf.readMPInt(); + q = buf.readMPInt(); + g = buf.readMPInt(); + y = buf.readMPInt(); + } catch (Buffer.BufferException be) { + throw new GeneralSecurityException(be); + } final KeyFactory keyFactory = SecurityUtils.getKeyFactory("DSA"); return keyFactory.generatePublic(new DSAPublicKeySpec(y, p, q, g)); } diff --git a/src/main/java/net/schmizz/sshj/common/SSHPacket.java b/src/main/java/net/schmizz/sshj/common/SSHPacket.java index 936a230a..f0374e42 100644 --- a/src/main/java/net/schmizz/sshj/common/SSHPacket.java +++ b/src/main/java/net/schmizz/sshj/common/SSHPacket.java @@ -75,7 +75,8 @@ public class SSHPacket * * @return the message identifier */ - public Message readMessageID() { + public Message readMessageID() + throws BufferException { return Message.fromByte(readByte()); } diff --git a/src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java b/src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java index 4971356f..9adf3e0b 100644 --- a/src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java +++ b/src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java @@ -15,16 +15,16 @@ */ package net.schmizz.sshj.connection; -import net.schmizz.concurrent.Promise; import net.schmizz.concurrent.ErrorDeliveryUtil; +import net.schmizz.concurrent.Promise; import net.schmizz.sshj.AbstractService; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.ErrorNotifiable; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.connection.channel.Channel; -import net.schmizz.sshj.connection.channel.OpenFailException; import net.schmizz.sshj.connection.channel.OpenFailException.Reason; import net.schmizz.sshj.connection.channel.forwarded.ForwardedChannelOpener; import net.schmizz.sshj.transport.Transport; @@ -103,14 +103,18 @@ public class ConnectionImpl private Channel getChannel(SSHPacket buffer) throws ConnectionException { - int recipient = buffer.readUInt32AsInt(); - Channel channel = get(recipient); - if (channel != null) - return channel; - else { - buffer.rpos(buffer.rpos() - 5); - throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Received " + buffer.readMessageID() - + " on unknown channel #" + recipient); + try { + final int recipient = buffer.readUInt32AsInt(); + final Channel channel = get(recipient); + if (channel != null) + return channel; + else { + buffer.rpos(buffer.rpos() - 5); + throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, + "Received " + buffer.readMessageID() + " on unknown channel #" + recipient); + } + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); } } @@ -180,12 +184,13 @@ public class ConnectionImpl @Override public Promise sendGlobalRequest(String name, boolean wantReply, - byte[] specifics) + byte[] specifics) throws TransportException { synchronized (globalReqPromises) { log.info("Making global request for `{}`", name); trans.write(new SSHPacket(Message.GLOBAL_REQUEST).putString(name) - .putBoolean(wantReply).putRawBytes(specifics)); + .putBoolean(wantReply) + .putRawBytes(specifics)); Promise promise = null; if (wantReply) { @@ -212,13 +217,17 @@ public class ConnectionImpl private void gotChannelOpen(SSHPacket buf) throws ConnectionException, TransportException { - final String type = buf.readString(); - log.debug("Received CHANNEL_OPEN for `{}` channel", type); - if (openers.containsKey(type)) - openers.get(type).handleOpen(buf); - else { - log.warn("No opener found for `{}` CHANNEL_OPEN request -- rejecting", type); - sendOpenFailure(buf.readUInt32AsInt(), OpenFailException.Reason.UNKNOWN_CHANNEL_TYPE, ""); + try { + final String type = buf.readString(); + log.debug("Received CHANNEL_OPEN for `{}` channel", type); + if (openers.containsKey(type)) + openers.get(type).handleOpen(buf); + else { + log.warn("No opener found for `{}` CHANNEL_OPEN request -- rejecting", type); + sendOpenFailure(buf.readUInt32AsInt(), Reason.UNKNOWN_CHANNEL_TYPE, ""); + } + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); } } @@ -226,9 +235,9 @@ public class ConnectionImpl public void sendOpenFailure(int recipient, Reason reason, String message) throws TransportException { trans.write(new SSHPacket(Message.CHANNEL_OPEN_FAILURE) - .putUInt32(recipient) - .putUInt32(reason.getCode()) - .putString(message)); + .putUInt32(recipient) + .putUInt32(reason.getCode()) + .putString(message)); } @Override diff --git a/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java index bcb153ae..ba7c54a6 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/AbstractChannel.java @@ -183,11 +183,11 @@ public abstract class AbstractChannel break; case CHANNEL_EXTENDED_DATA: - gotExtendedData(buf.readUInt32AsInt(), buf); + gotExtendedData(buf); break; case CHANNEL_WINDOW_ADJUST: - gotWindowAdjustment(buf.readUInt32AsInt()); + gotWindowAdjustment(buf); break; case CHANNEL_REQUEST: @@ -301,13 +301,24 @@ public abstract class AbstractChannel private void gotChannelRequest(SSHPacket buf) throws ConnectionException, TransportException { - final String reqType = buf.readString(); - buf.readBoolean(); // We don't care about the 'want-reply' value + final String reqType; + try { + reqType = buf.readString(); + buf.readBoolean(); // We don't care about the 'want-reply' value + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } log.info("Got chan request for `{}`", reqType); handleRequest(reqType, buf); } - private void gotWindowAdjustment(int howMuch) { + private void gotWindowAdjustment(SSHPacket buf) throws ConnectionException { + final int howMuch; + try { + howMuch = buf.readUInt32AsInt(); + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } log.info("Received window adjustment for {} bytes", howMuch); rwin.expand(howMuch); } @@ -317,10 +328,10 @@ public abstract class AbstractChannel close.set(); } - protected void gotExtendedData(int dataTypeCode, SSHPacket buf) + protected void gotExtendedData(SSHPacket buf) throws ConnectionException, TransportException { - throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Extended data not supported on " + type - + " channel"); + throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, + "Extended data not supported on " + type + " channel"); } protected void gotUnknown(Message msg, SSHPacket buf) @@ -338,7 +349,12 @@ public abstract class AbstractChannel protected void receiveInto(ChannelInputStream stream, SSHPacket buf) throws ConnectionException, TransportException { - final int len = buf.readUInt32AsInt(); + final int len; + try { + len = buf.readUInt32AsInt(); + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } if (len < 0 || len > getLocalMaxPacketSize() || len > buf.available()) throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, "Bad item length: " + len); if (log.isTraceEnabled()) diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/AbstractDirectChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/AbstractDirectChannel.java index c3ada0ea..d70f4a99 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/AbstractDirectChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/AbstractDirectChannel.java @@ -35,6 +35,7 @@ */ package net.schmizz.sshj.connection.channel.direct; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.connection.Connection; @@ -67,13 +68,23 @@ public abstract class AbstractDirectChannel open.await(conn.getTimeout(), TimeUnit.SECONDS); } - private void gotOpenConfirmation(SSHPacket buf) { - init(buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt()); + private void gotOpenConfirmation(SSHPacket buf) + throws ConnectionException { + try { + init(buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt()); + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } open.set(); } - private void gotOpenFailure(SSHPacket buf) { - open.deliverError(new OpenFailException(getType(), buf.readUInt32AsInt(), buf.readString())); + private void gotOpenFailure(SSHPacket buf) + throws ConnectionException { + try { + open.deliverError(new OpenFailException(getType(), buf.readUInt32AsInt(), buf.readString())); + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } finishOff(); } diff --git a/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java b/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java index 88bcae12..b7b7a44d 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/direct/SessionChannel.java @@ -36,6 +36,7 @@ package net.schmizz.sshj.connection.channel.direct; import net.schmizz.sshj.common.Buffer; +import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.IOUtils; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHPacket; @@ -147,17 +148,21 @@ public class SessionChannel @Override public void handleRequest(String req, SSHPacket buf) throws ConnectionException, TransportException { - if ("xon-xoff".equals(req)) - canDoFlowControl = buf.readBoolean(); - else if ("exit-status".equals(req)) - exitStatus = buf.readUInt32AsInt(); - else if ("exit-signal".equals(req)) { - exitSignal = Signal.fromString(buf.readString()); - wasCoreDumped = buf.readBoolean(); // core dumped - exitErrMsg = buf.readString(); - sendClose(); - } else - super.handleRequest(req, buf); + try { + if ("xon-xoff".equals(req)) + canDoFlowControl = buf.readBoolean(); + else if ("exit-status".equals(req)) + exitStatus = buf.readUInt32AsInt(); + else if ("exit-signal".equals(req)) { + exitSignal = Signal.fromString(buf.readString()); + wasCoreDumped = buf.readBoolean(); // core dumped + exitErrMsg = buf.readString(); + sendClose(); + } else + super.handleRequest(req, buf); + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } } @Override @@ -225,12 +230,18 @@ public class SessionChannel } @Override - protected void gotExtendedData(int dataTypeCode, SSHPacket buf) + protected void gotExtendedData(SSHPacket buf) throws ConnectionException, TransportException { - if (dataTypeCode == 1) - receiveInto(err, buf); - else - super.gotExtendedData(dataTypeCode, buf); + try { + final int dataTypeCode = buf.readUInt32AsInt(); + if (dataTypeCode == 1) + receiveInto(err, buf); + else + throw new ConnectionException(DisconnectReason.PROTOCOL_ERROR, + "Bad extended data type = " + dataTypeCode); + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } } @Override @@ -246,13 +257,15 @@ public class SessionChannel @Override @Deprecated - public String getOutputAsString() throws IOException { + public String getOutputAsString() + throws IOException { return IOUtils.readFully(getInputStream()).toString(); } @Override @Deprecated - public String getErrorAsString() throws IOException { + public String getErrorAsString() + throws IOException { return IOUtils.readFully(getErrorStream()).toString(); } diff --git a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java index 8160f3c5..a9dac9de 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/RemotePortForwarder.java @@ -168,7 +168,11 @@ public class RemotePortForwarder throws ConnectionException, TransportException { SSHPacket reply = req(PF_REQ, forward); if (forward.port == 0) - forward.port = reply.readUInt32AsInt(); + try { + forward.port = reply.readUInt32AsInt(); + } catch (Buffer.BufferException e) { + throw new ConnectionException(e); + } log.info("Remote end listening on {}", forward); listeners.put(forward, listener); return forward; @@ -211,9 +215,14 @@ public class RemotePortForwarder @Override public void handleOpen(SSHPacket buf) throws ConnectionException, TransportException { - final ForwardedTCPIPChannel chan = new ForwardedTCPIPChannel(conn, buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt(), - new Forward(buf.readString(), buf.readUInt32AsInt()), - buf.readString(), buf.readUInt32AsInt()); + final ForwardedTCPIPChannel chan; + try { + chan = new ForwardedTCPIPChannel(conn, buf.readUInt32AsInt(), buf.readUInt32AsInt(), buf.readUInt32AsInt(), + new Forward(buf.readString(), buf.readUInt32AsInt()), + buf.readString(), buf.readUInt32AsInt()); + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } if (listeners.containsKey(chan.getParentForward())) callListener(listeners.get(chan.getParentForward()), chan); else diff --git a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/X11Forwarder.java b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/X11Forwarder.java index 935ee88f..f6eef032 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/forwarded/X11Forwarder.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/forwarded/X11Forwarder.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.connection.channel.forwarded; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.connection.Connection; import net.schmizz.sshj.connection.ConnectionException; @@ -55,10 +56,14 @@ public class X11Forwarder @Override public void handleOpen(SSHPacket buf) throws ConnectionException, TransportException { - callListener(listener, new X11Channel(conn, - buf.readUInt32AsInt(), - buf.readUInt32AsInt(), buf.readUInt32AsInt(), - buf.readString(), buf.readUInt32AsInt())); + try { + callListener(listener, new X11Channel(conn, + buf.readUInt32AsInt(), + buf.readUInt32AsInt(), buf.readUInt32AsInt(), + buf.readString(), buf.readUInt32AsInt())); + } catch (Buffer.BufferException be) { + throw new ConnectionException(be); + } } /** Stop handling {@code x11} channel open requests. De-registers itself with connection layer. */ diff --git a/src/main/java/net/schmizz/sshj/sftp/Response.java b/src/main/java/net/schmizz/sshj/sftp/Response.java index d666ad9c..bdd626aa 100644 --- a/src/main/java/net/schmizz/sshj/sftp/Response.java +++ b/src/main/java/net/schmizz/sshj/sftp/Response.java @@ -51,11 +51,15 @@ public class Response private final PacketType type; private final long reqID; - public Response(Buffer pk, int protocolVersion) { + public Response(Buffer pk, int protocolVersion) throws SFTPException { super(pk); this.protocolVersion = protocolVersion; this.type = readType(); - this.reqID = readUInt32(); + try { + this.reqID = readUInt32(); + } catch (BufferException be) { + throw new SFTPException(be); + } } public int getProtocolVersion() { @@ -70,8 +74,12 @@ public class Response return type; } - public StatusCode readStatusCode() { - return StatusCode.fromInt(readUInt32AsInt()); + public StatusCode readStatusCode() throws SFTPException { + try { + return StatusCode.fromInt(readUInt32AsInt()); + } catch (BufferException be) { + throw new SFTPException(be); + } } public Response ensurePacketTypeIs(PacketType pt) @@ -99,7 +107,11 @@ public class Response protected String error(StatusCode sc) throws SFTPException { - throw new SFTPException(sc, protocolVersion < 3 ? sc.toString() : readString()); + try { + throw new SFTPException(sc, protocolVersion < 3 ? sc.toString() : readString()); + } catch (BufferException be) { + throw new SFTPException(be); + } } } diff --git a/src/main/java/net/schmizz/sshj/sftp/SFTPPacket.java b/src/main/java/net/schmizz/sshj/sftp/SFTPPacket.java index 7e90f99d..a299dc8b 100644 --- a/src/main/java/net/schmizz/sshj/sftp/SFTPPacket.java +++ b/src/main/java/net/schmizz/sshj/sftp/SFTPPacket.java @@ -33,27 +33,37 @@ public class SFTPPacket> putByte(pt.toByte()); } - public FileAttributes readFileAttributes() { + public FileAttributes readFileAttributes() + throws SFTPException { final FileAttributes.Builder builder = new FileAttributes.Builder(); - final int mask = readUInt32AsInt(); - if (FileAttributes.Flag.SIZE.isSet(mask)) - builder.withSize(readUInt64()); - if (FileAttributes.Flag.UIDGID.isSet(mask)) - builder.withUIDGID(readUInt32AsInt(), readUInt32AsInt()); - if (FileAttributes.Flag.MODE.isSet(mask)) - builder.withPermissions(readUInt32AsInt()); - if (FileAttributes.Flag.ACMODTIME.isSet(mask)) - builder.withAtimeMtime(readUInt32AsInt(), readUInt32AsInt()); - if (FileAttributes.Flag.EXTENDED.isSet(mask)) { - final int extCount = readUInt32AsInt(); - for (int i = 0; i < extCount; i++) - builder.withExtended(readString(), readString()); + try { + final int mask = readUInt32AsInt(); + if (FileAttributes.Flag.SIZE.isSet(mask)) + builder.withSize(readUInt64()); + if (FileAttributes.Flag.UIDGID.isSet(mask)) + builder.withUIDGID(readUInt32AsInt(), readUInt32AsInt()); + if (FileAttributes.Flag.MODE.isSet(mask)) + builder.withPermissions(readUInt32AsInt()); + if (FileAttributes.Flag.ACMODTIME.isSet(mask)) + builder.withAtimeMtime(readUInt32AsInt(), readUInt32AsInt()); + if (FileAttributes.Flag.EXTENDED.isSet(mask)) { + final int extCount = readUInt32AsInt(); + for (int i = 0; i < extCount; i++) + builder.withExtended(readString(), readString()); + } + } catch (BufferException be) { + throw new SFTPException(be); } return builder.build(); } - public PacketType readType() { - return PacketType.fromByte(readByte()); + public PacketType readType() + throws SFTPException { + try { + return PacketType.fromByte(readByte()); + } catch (BufferException be) { + throw new SFTPException(be); + } } public T putFileAttributes(FileAttributes fa) { diff --git a/src/main/java/net/schmizz/sshj/transport/Decoder.java b/src/main/java/net/schmizz/sshj/transport/Decoder.java index 0cd3fdc3..074583a1 100644 --- a/src/main/java/net/schmizz/sshj/transport/Decoder.java +++ b/src/main/java/net/schmizz/sshj/transport/Decoder.java @@ -35,6 +35,7 @@ */ package net.schmizz.sshj.transport; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.ByteArrayUtils; import net.schmizz.sshj.common.DisconnectReason; import net.schmizz.sshj.common.SSHException; @@ -157,7 +158,12 @@ final class Decoder throws TransportException { cipher.update(inputBuffer.array(), 0, cipherSize); - final int len = inputBuffer.readUInt32AsInt(); // Read packet length + final int len; // Read packet length + try { + len = inputBuffer.readUInt32AsInt(); + } catch (Buffer.BufferException be) { + throw new TransportException(be); + } if (isInvalidPacketLength(len)) { // Check packet length validity log.info("Error decoding packet (invalid length) {}", inputBuffer.printHex()); diff --git a/src/main/java/net/schmizz/sshj/transport/Proposal.java b/src/main/java/net/schmizz/sshj/transport/Proposal.java index 7d96791a..dd33c2b3 100644 --- a/src/main/java/net/schmizz/sshj/transport/Proposal.java +++ b/src/main/java/net/schmizz/sshj/transport/Proposal.java @@ -36,6 +36,7 @@ package net.schmizz.sshj.transport; import net.schmizz.sshj.Config; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.Factory; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHPacket; @@ -85,18 +86,22 @@ class Proposal { packet.putUInt32(0); // "Reserved" for future by spec } - public Proposal(SSHPacket packet) { + public Proposal(SSHPacket packet) throws TransportException { this.packet = packet; final int savedPos = packet.rpos(); packet.rpos(packet.rpos() + 17); // Skip message ID & cookie - kex = fromCommaString(packet.readString()); - sig = fromCommaString(packet.readString()); - c2sCipher = fromCommaString(packet.readString()); - s2cCipher = fromCommaString(packet.readString()); - c2sMAC = fromCommaString(packet.readString()); - s2cMAC = fromCommaString(packet.readString()); - c2sComp = fromCommaString(packet.readString()); - s2cComp = fromCommaString(packet.readString()); + try { + kex = fromCommaString(packet.readString()); + sig = fromCommaString(packet.readString()); + c2sCipher = fromCommaString(packet.readString()); + s2cCipher = fromCommaString(packet.readString()); + c2sMAC = fromCommaString(packet.readString()); + s2cMAC = fromCommaString(packet.readString()); + c2sComp = fromCommaString(packet.readString()); + s2cComp = fromCommaString(packet.readString()); + } catch (Buffer.BufferException be) { + throw new TransportException(be); + } packet.rpos(savedPos); } diff --git a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java index 8524c16b..67d2f0be 100644 --- a/src/main/java/net/schmizz/sshj/transport/TransportImpl.java +++ b/src/main/java/net/schmizz/sshj/transport/TransportImpl.java @@ -501,18 +501,26 @@ public final class TransportImpl } } - private void gotDebug(SSHPacket buf) { - boolean display = buf.readBoolean(); - String message = buf.readString(); - log.info("Received SSH_MSG_DEBUG (display={}) '{}'", display, message); + private void gotDebug(SSHPacket buf) throws TransportException { + try { + final boolean display = buf.readBoolean(); + final String message = buf.readString(); + log.info("Received SSH_MSG_DEBUG (display={}) '{}'", display, message); + } catch (Buffer.BufferException be) { + throw new TransportException(be); + } } private void gotDisconnect(SSHPacket buf) throws TransportException { - DisconnectReason code = DisconnectReason.fromInt(buf.readUInt32AsInt()); - String message = buf.readString(); - log.info("Received SSH_MSG_DISCONNECT (reason={}, msg={})", code, message); - throw new TransportException(code, "Disconnected; server said: " + message); + try { + final DisconnectReason code = DisconnectReason.fromInt(buf.readUInt32AsInt()); + final String message = buf.readString(); + log.info("Received SSH_MSG_DISCONNECT (reason={}, msg={})", code, message); + throw new TransportException(code, "Disconnected; server said: " + message); + } catch (Buffer.BufferException be) { + throw new TransportException(be); + } } private void gotServiceAccept() diff --git a/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java b/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java index b0224f72..a87e2498 100644 --- a/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java +++ b/src/main/java/net/schmizz/sshj/transport/kex/AbstractDHG.java @@ -121,14 +121,21 @@ public abstract class AbstractDHG throw new TransportException(DisconnectReason.KEY_EXCHANGE_FAILED, "Unexpected packet: " + msg); log.info("Received SSH_MSG_KEXDH_REPLY"); - final byte[] K_S = packet.readBytes(); - final byte[] f = packet.readMPIntAsBytes(); - final byte[] sig = packet.readBytes(); // signature sent by server + final byte[] K_S; + final byte[] f; + final byte[] sig; // signature sent by server + try { + K_S = packet.readBytes(); + f = packet.readMPIntAsBytes(); + sig = packet.readBytes(); + hostKey = new Buffer.PlainBuffer(K_S).readPublicKey(); + } catch (Buffer.BufferException be) { + throw new TransportException(be); + } + dh.setF(new BigInteger(f)); K = dh.getK(); - hostKey = new Buffer.PlainBuffer(K_S).readPublicKey(); - final Buffer.PlainBuffer buf = new Buffer.PlainBuffer() .putString(V_C) .putString(V_S) diff --git a/src/main/java/net/schmizz/sshj/userauth/method/AuthKeyboardInteractive.java b/src/main/java/net/schmizz/sshj/userauth/method/AuthKeyboardInteractive.java index 512067cd..4a0e5b74 100644 --- a/src/main/java/net/schmizz/sshj/userauth/method/AuthKeyboardInteractive.java +++ b/src/main/java/net/schmizz/sshj/userauth/method/AuthKeyboardInteractive.java @@ -15,6 +15,7 @@ */ package net.schmizz.sshj.userauth.method; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.transport.TransportException; @@ -63,15 +64,20 @@ public class AuthKeyboardInteractive if (cmd != Message.USERAUTH_60) { super.handle(cmd, buf); } else { - provider.init(makeAccountResource(), buf.readString(), buf.readString()); - buf.readString(); // lang-tag - final int numPrompts = buf.readUInt32AsInt(); - final CharArrWrap[] userReplies = new CharArrWrap[numPrompts]; - for (int i = 0; i < numPrompts; i++) { - final String prompt = buf.readString(); - final boolean echo = buf.readBoolean(); - log.info("Requesting response for challenge `{}`; echo={}", prompt, echo); - userReplies[i] = new CharArrWrap(provider.getResponse(prompt, echo)); + final CharArrWrap[] userReplies; + try { + provider.init(makeAccountResource(), buf.readString(), buf.readString()); + buf.readString(); // lang-tag + final int numPrompts = buf.readUInt32AsInt(); + userReplies = new CharArrWrap[numPrompts]; + for (int i = 0; i < numPrompts; i++) { + final String prompt = buf.readString(); + final boolean echo = buf.readBoolean(); + log.info("Requesting response for challenge `{}`; echo={}", prompt, echo); + userReplies[i] = new CharArrWrap(provider.getResponse(prompt, echo)); + } + } catch (Buffer.BufferException be) { + throw new UserAuthException(be); } respond(userReplies); } diff --git a/src/test/java/net/schmizz/sshj/transport/Disconnection.java b/src/test/java/net/schmizz/sshj/transport/Disconnection.java index cdb44ea5..1b5586e5 100644 --- a/src/test/java/net/schmizz/sshj/transport/Disconnection.java +++ b/src/test/java/net/schmizz/sshj/transport/Disconnection.java @@ -23,7 +23,6 @@ import org.junit.Test; import java.io.IOException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; diff --git a/src/test/java/net/schmizz/sshj/util/BufferTest.java b/src/test/java/net/schmizz/sshj/util/BufferTest.java index ba30539e..bb12b83a 100644 --- a/src/test/java/net/schmizz/sshj/util/BufferTest.java +++ b/src/test/java/net/schmizz/sshj/util/BufferTest.java @@ -41,7 +41,8 @@ public class BufferTest { } @Test - public void testDataTypes() { + public void testDataTypes() + throws Buffer.BufferException { // bool assertEquals(handyBuf.putBoolean(true).readBoolean(), true); @@ -63,7 +64,8 @@ public class BufferTest { } @Test - public void testPassword() { + public void testPassword() + throws Buffer.BufferException { char[] pass = "lolcatz".toCharArray(); // test if put correctly as a string assertEquals(new Buffer.PlainBuffer().putSensitiveString(pass).readString(), "lolcatz"); @@ -73,7 +75,7 @@ public class BufferTest { @Test public void testPosition() - throws UnsupportedEncodingException { + throws UnsupportedEncodingException, Buffer.BufferException { assertEquals(5, posBuf.wpos()); assertEquals(0, posBuf.rpos()); assertEquals(5, posBuf.available()); @@ -95,7 +97,8 @@ public class BufferTest { } @Test(expected = Buffer.BufferException.class) - public void testUnderflow() { + public void testUnderflow() + throws Buffer.BufferException { // exhaust the buffer for (int i = 0; i < 5; ++i) posBuf.readByte();