diff --git a/src/main/java/net/schmizz/sshj/common/Buffer.java b/src/main/java/net/schmizz/sshj/common/Buffer.java index 9e4e3bd2..22ce5a76 100644 --- a/src/main/java/net/schmizz/sshj/common/Buffer.java +++ b/src/main/java/net/schmizz/sshj/common/Buffer.java @@ -74,10 +74,15 @@ public class Buffer> { /** The default size for a {@code Buffer} (256 bytes) */ public static final int DEFAULT_SIZE = 256; + /** The maximum valid size of buffer (i.e. biggest power of two that can be represented as an int - 2^30) */ + public static final int MAX_SIZE = (1 << 30); + protected static int getNextPowerOf2(int i) { int j = 1; - while (j < i) + while (j < i) { j <<= 1; + if (j <= 0) throw new IllegalArgumentException("Cannot get next power of 2; "+i+" is too large"); + } return j; } diff --git a/src/main/java/net/schmizz/sshj/sftp/PacketReader.java b/src/main/java/net/schmizz/sshj/sftp/PacketReader.java index 282000fc..e48e6868 100644 --- a/src/main/java/net/schmizz/sshj/sftp/PacketReader.java +++ b/src/main/java/net/schmizz/sshj/sftp/PacketReader.java @@ -15,15 +15,16 @@ */ package net.schmizz.sshj.sftp; -import net.schmizz.concurrent.Promise; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.IOException; import java.io.InputStream; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import net.schmizz.concurrent.Promise; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + public class PacketReader extends Thread { @@ -65,7 +66,10 @@ public class PacketReader public SFTPPacket readPacket() throws IOException { int len = getPacketLength(); - + if (len > SFTPPacket.MAX_SIZE) { + throw new IllegalStateException("Invalid packet: indicated length "+len+" too large"); + } + packet.rpos(0); packet.wpos(0); diff --git a/src/test/java/net/schmizz/sshj/common/BufferTest.java b/src/test/java/net/schmizz/sshj/common/BufferTest.java new file mode 100644 index 00000000..6fdc918a --- /dev/null +++ b/src/test/java/net/schmizz/sshj/common/BufferTest.java @@ -0,0 +1,33 @@ +package net.schmizz.sshj.common; + +import static org.junit.Assert.fail; + +import net.schmizz.sshj.common.Buffer.PlainBuffer; + +import org.junit.Test; + +public class BufferTest { + + // Issue 72: previously, it entered an infinite loop trying to establish the buffer size + @Test + public void shouldThrowOnTooLargeCapacity() { + PlainBuffer buffer = new PlainBuffer(); + try { + buffer.ensureCapacity(Integer.MAX_VALUE); + fail("Allegedly ensured buffer capacity of size " + Integer.MAX_VALUE); + } catch (IllegalArgumentException e) { + // success + } + } + + // Issue 72: previously, it entered an infinite loop trying to establish the buffer size + @Test + public void shouldThrowOnTooLargeInitialCapacity() { + try { + new PlainBuffer(Integer.MAX_VALUE); + fail("Allegedly created buffer with size " + Integer.MAX_VALUE); + } catch (IllegalArgumentException e) { + // success + } + } +} diff --git a/src/test/java/net/schmizz/sshj/sftp/PacketReaderTest.java b/src/test/java/net/schmizz/sshj/sftp/PacketReaderTest.java new file mode 100644 index 00000000..b638c795 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/sftp/PacketReaderTest.java @@ -0,0 +1,71 @@ +package net.schmizz.sshj.sftp; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.io.DataOutputStream; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.util.Arrays; + +import net.schmizz.sshj.connection.channel.direct.Session.Subsystem; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +public class PacketReaderTest { + + private DataOutputStream dataout; + private PacketReader reader; + private SFTPEngine engine; + private Subsystem subsystem; + + @Before + public void setUp() throws Exception { + PipedOutputStream pipedout = new PipedOutputStream(); + PipedInputStream pipedin = new PipedInputStream(pipedout); + dataout = new DataOutputStream(pipedout); + + engine = Mockito.mock(SFTPEngine.class); + subsystem = Mockito.mock(Subsystem.class); + Mockito.when(engine.getSubsystem()).thenReturn(subsystem); + Mockito.when(subsystem.getInputStream()).thenReturn(pipedin); + + reader = new PacketReader(engine); + } + + // FIXME What is the byte format for the size? Big endian? Little endian? + @Test + public void shouldReadPacket() throws Exception { + byte[] bytes = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; + dataout.writeInt(10); + dataout.write(bytes); + dataout.flush(); + + SFTPPacket packet = reader.readPacket(); + assertEquals(packet.available(), 10); + assertTrue("actual=" + Arrays.toString(packet.array()), Arrays.equals(bytes, subArray(packet.array(), 0, 10))); + } + + @Test + public void shouldFailWhenPacketLengthTooLarge() throws Exception { + dataout.writeInt(Integer.MAX_VALUE); + dataout.flush(); + + try { + reader.readPacket(); + fail("Should have failed to read packet of size " + Integer.MAX_VALUE); + } catch (IllegalStateException e) { + e.printStackTrace(); + // success; indicated packet size was too large + } + } + + private byte[] subArray(byte[] source, int startIndex, int length) { + byte[] result = new byte[length]; + System.arraycopy(source, startIndex, result, 0, length); + return result; + } +}