diff --git a/src/main/java/net/schmizz/sshj/connection/channel/ChannelOutputStream.java b/src/main/java/net/schmizz/sshj/connection/channel/ChannelOutputStream.java index d79d2c12..be6f7e44 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/ChannelOutputStream.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/ChannelOutputStream.java @@ -35,12 +35,14 @@ */ package net.schmizz.sshj.connection.channel; +import net.schmizz.sshj.common.Buffer; import net.schmizz.sshj.common.ErrorNotifiable; import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.transport.Transport; +import net.schmizz.sshj.transport.TransportException; import java.io.IOException; import java.io.OutputStream; @@ -56,26 +58,92 @@ public final class ChannelOutputStream private final Channel chan; private final Transport trans; private final Window.Remote win; - private final SSHPacket buffer = new SSHPacket(); + + private final DataBuffer buffer = new DataBuffer(); private final byte[] b = new byte[1]; - private int bufferLength; + private boolean closed; private SSHException error; + private final class DataBuffer { + + private final int headerOffset; + private final int dataOffset; + + private final SSHPacket packet = new SSHPacket(Message.CHANNEL_DATA); + private final Buffer.PlainBuffer leftOvers = new Buffer.PlainBuffer(); + + DataBuffer() { + headerOffset = packet.rpos(); + packet.putUInt32(0); // recipient + packet.putUInt32(0); // data length + dataOffset = packet.wpos(); + } + + int write(byte[] data, int off, int len) + throws TransportException, ConnectionException { + final int bufferSize = packet.wpos() - dataOffset; + if (bufferSize >= win.getMaxPacketSize()) { + flush(bufferSize); + return 0; + } else { + final int n = Math.min(len - off, win.getMaxPacketSize() - bufferSize); + packet.putRawBytes(data, off, n); + return n; + } + } + + void flush() + throws TransportException, ConnectionException { + flush(packet.wpos() - dataOffset); + } + + void flush(int bufferSize) + throws TransportException, ConnectionException { + while (bufferSize > 0) { + + int remoteWindowSize = win.getSize(); + if (remoteWindowSize == 0) + remoteWindowSize = win.awaitExpansion(remoteWindowSize); + + // We can only write the min. of + // a) how much data we have + // b) the max packet size + // c) what the current window size will allow + final int writeNow = Math.min(bufferSize, Math.min(win.getMaxPacketSize(), remoteWindowSize)); + + packet.wpos(headerOffset); + packet.putMessageID(Message.CHANNEL_DATA); + packet.putUInt32(chan.getRecipient()); + packet.putUInt32(writeNow); + packet.wpos(dataOffset + writeNow); + + final int leftOverBytes = bufferSize - writeNow; + if (leftOverBytes > 0) { + leftOvers.putRawBytes(packet.array(), packet.wpos(), leftOverBytes); + } + + trans.write(packet); + win.consume(writeNow); + + packet.rpos(headerOffset); + packet.wpos(dataOffset); + + if (leftOverBytes > 0) { + packet.putBuffer(leftOvers); + leftOvers.clear(); + } + + bufferSize = leftOverBytes; + } + } + + } + public ChannelOutputStream(Channel chan, Transport trans, Window.Remote win) { this.chan = chan; this.trans = trans; this.win = win; - prepBuffer(); - } - - private void prepBuffer() { - bufferLength = 0; - buffer.rpos(5); - buffer.wpos(5); - buffer.putMessageID(Message.CHANNEL_DATA); - buffer.putUInt32(0); // meant to be recipient - buffer.putUInt32(0); // meant to be data length } @Override @@ -86,19 +154,13 @@ public final class ChannelOutputStream } @Override - public synchronized void write(byte[] data, int off, int len) + public synchronized void write(final byte[] data, int off, int len) throws IOException { checkClose(); while (len > 0) { - final int x = Math.min(len, win.getMaxPacketSize() - bufferLength); - if (x <= 0) { - flush(); - continue; - } - buffer.putRawBytes(data, off, x); - bufferLength += x; - off += x; - len -= x; + final int n = buffer.write(data, off, len); + off += n; + len -= n; } } @@ -107,55 +169,44 @@ public final class ChannelOutputStream this.error = error; } - private synchronized void checkClose() + private void checkClose() throws SSHException { - if (closed) + if (closed) { if (error != null) throw error; else throw new ConnectionException("Stream closed"); + } } @Override public synchronized void close() throws IOException { - if (!closed) + if (!closed) { try { - flush(); + buffer.flush(); chan.sendEOF(); } finally { setClosed(); } + } } public synchronized void setClosed() { closed = true; } + /** + * Send all data currently buffered. If window space is exhausted in the process, this will block + * until it is expanded by the server. + * + * @throws IOException + */ @Override public synchronized void flush() throws IOException { checkClose(); - - if (bufferLength <= 0) // No data to send - return; - - putRecipientAndLength(); - - try { - win.waitAndConsume(bufferLength); - trans.write(buffer); - } finally { - prepBuffer(); - } - } - - private void putRecipientAndLength() { - final int origPos = buffer.wpos(); - buffer.wpos(6); - buffer.putUInt32(chan.getRecipient()); - buffer.putUInt32(bufferLength); - buffer.wpos(origPos); + buffer.flush(); } @Override diff --git a/src/main/java/net/schmizz/sshj/connection/channel/Window.java b/src/main/java/net/schmizz/sshj/connection/channel/Window.java index a967746f..376ee993 100644 --- a/src/main/java/net/schmizz/sshj/connection/channel/Window.java +++ b/src/main/java/net/schmizz/sshj/connection/channel/Window.java @@ -48,7 +48,9 @@ public abstract class Window { } public int getSize() { - return size; + synchronized (lock) { + return size; + } } public void consume(int dec) @@ -74,18 +76,18 @@ public abstract class Window { super(initialWinSize, maxPacketSize); } - public void waitAndConsume(int howMuch) + public int awaitExpansion(int was) throws ConnectionException { synchronized (lock) { - while (size < howMuch) { - log.debug("Waiting, need window space for {} bytes", howMuch); + while (size <= was) { + log.debug("Waiting, need size to grow from {} bytes", was); try { lock.wait(); } catch (InterruptedException ie) { throw new ConnectionException(ie); } } - consume(howMuch); + return size; } }