fix for #47 - should send data down rather than sitting around waiting for an adjustment if there is window space available

This commit is contained in:
Shikhar Bhushan
2011-12-20 10:41:49 +00:00
parent 7a77f85ced
commit 22a5ffe735
2 changed files with 104 additions and 51 deletions

View File

@@ -35,12 +35,14 @@
*/ */
package net.schmizz.sshj.connection.channel; package net.schmizz.sshj.connection.channel;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.ErrorNotifiable; import net.schmizz.sshj.common.ErrorNotifiable;
import net.schmizz.sshj.common.Message; import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHException; import net.schmizz.sshj.common.SSHException;
import net.schmizz.sshj.common.SSHPacket; import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.transport.Transport; import net.schmizz.sshj.transport.Transport;
import net.schmizz.sshj.transport.TransportException;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
@@ -56,26 +58,92 @@ public final class ChannelOutputStream
private final Channel chan; private final Channel chan;
private final Transport trans; private final Transport trans;
private final Window.Remote win; private final Window.Remote win;
private final SSHPacket buffer = new SSHPacket();
private final DataBuffer buffer = new DataBuffer();
private final byte[] b = new byte[1]; private final byte[] b = new byte[1];
private int bufferLength;
private boolean closed; private boolean closed;
private SSHException error; private SSHException error;
private final class DataBuffer {
private final int headerOffset;
private final int dataOffset;
private final SSHPacket packet = new SSHPacket(Message.CHANNEL_DATA);
private final Buffer.PlainBuffer leftOvers = new Buffer.PlainBuffer();
DataBuffer() {
headerOffset = packet.rpos();
packet.putUInt32(0); // recipient
packet.putUInt32(0); // data length
dataOffset = packet.wpos();
}
int write(byte[] data, int off, int len)
throws TransportException, ConnectionException {
final int bufferSize = packet.wpos() - dataOffset;
if (bufferSize >= win.getMaxPacketSize()) {
flush(bufferSize);
return 0;
} else {
final int n = Math.min(len - off, win.getMaxPacketSize() - bufferSize);
packet.putRawBytes(data, off, n);
return n;
}
}
void flush()
throws TransportException, ConnectionException {
flush(packet.wpos() - dataOffset);
}
void flush(int bufferSize)
throws TransportException, ConnectionException {
while (bufferSize > 0) {
int remoteWindowSize = win.getSize();
if (remoteWindowSize == 0)
remoteWindowSize = win.awaitExpansion(remoteWindowSize);
// We can only write the min. of
// a) how much data we have
// b) the max packet size
// c) what the current window size will allow
final int writeNow = Math.min(bufferSize, Math.min(win.getMaxPacketSize(), remoteWindowSize));
packet.wpos(headerOffset);
packet.putMessageID(Message.CHANNEL_DATA);
packet.putUInt32(chan.getRecipient());
packet.putUInt32(writeNow);
packet.wpos(dataOffset + writeNow);
final int leftOverBytes = bufferSize - writeNow;
if (leftOverBytes > 0) {
leftOvers.putRawBytes(packet.array(), packet.wpos(), leftOverBytes);
}
trans.write(packet);
win.consume(writeNow);
packet.rpos(headerOffset);
packet.wpos(dataOffset);
if (leftOverBytes > 0) {
packet.putBuffer(leftOvers);
leftOvers.clear();
}
bufferSize = leftOverBytes;
}
}
}
public ChannelOutputStream(Channel chan, Transport trans, Window.Remote win) { public ChannelOutputStream(Channel chan, Transport trans, Window.Remote win) {
this.chan = chan; this.chan = chan;
this.trans = trans; this.trans = trans;
this.win = win; this.win = win;
prepBuffer();
}
private void prepBuffer() {
bufferLength = 0;
buffer.rpos(5);
buffer.wpos(5);
buffer.putMessageID(Message.CHANNEL_DATA);
buffer.putUInt32(0); // meant to be recipient
buffer.putUInt32(0); // meant to be data length
} }
@Override @Override
@@ -86,19 +154,13 @@ public final class ChannelOutputStream
} }
@Override @Override
public synchronized void write(byte[] data, int off, int len) public synchronized void write(final byte[] data, int off, int len)
throws IOException { throws IOException {
checkClose(); checkClose();
while (len > 0) { while (len > 0) {
final int x = Math.min(len, win.getMaxPacketSize() - bufferLength); final int n = buffer.write(data, off, len);
if (x <= 0) { off += n;
flush(); len -= n;
continue;
}
buffer.putRawBytes(data, off, x);
bufferLength += x;
off += x;
len -= x;
} }
} }
@@ -107,55 +169,44 @@ public final class ChannelOutputStream
this.error = error; this.error = error;
} }
private synchronized void checkClose() private void checkClose()
throws SSHException { throws SSHException {
if (closed) if (closed) {
if (error != null) if (error != null)
throw error; throw error;
else else
throw new ConnectionException("Stream closed"); throw new ConnectionException("Stream closed");
}
} }
@Override @Override
public synchronized void close() public synchronized void close()
throws IOException { throws IOException {
if (!closed) if (!closed) {
try { try {
flush(); buffer.flush();
chan.sendEOF(); chan.sendEOF();
} finally { } finally {
setClosed(); setClosed();
} }
}
} }
public synchronized void setClosed() { public synchronized void setClosed() {
closed = true; closed = true;
} }
/**
* Send all data currently buffered. If window space is exhausted in the process, this will block
* until it is expanded by the server.
*
* @throws IOException
*/
@Override @Override
public synchronized void flush() public synchronized void flush()
throws IOException { throws IOException {
checkClose(); checkClose();
buffer.flush();
if (bufferLength <= 0) // No data to send
return;
putRecipientAndLength();
try {
win.waitAndConsume(bufferLength);
trans.write(buffer);
} finally {
prepBuffer();
}
}
private void putRecipientAndLength() {
final int origPos = buffer.wpos();
buffer.wpos(6);
buffer.putUInt32(chan.getRecipient());
buffer.putUInt32(bufferLength);
buffer.wpos(origPos);
} }
@Override @Override

View File

@@ -48,7 +48,9 @@ public abstract class Window {
} }
public int getSize() { public int getSize() {
return size; synchronized (lock) {
return size;
}
} }
public void consume(int dec) public void consume(int dec)
@@ -74,18 +76,18 @@ public abstract class Window {
super(initialWinSize, maxPacketSize); super(initialWinSize, maxPacketSize);
} }
public void waitAndConsume(int howMuch) public int awaitExpansion(int was)
throws ConnectionException { throws ConnectionException {
synchronized (lock) { synchronized (lock) {
while (size < howMuch) { while (size <= was) {
log.debug("Waiting, need window space for {} bytes", howMuch); log.debug("Waiting, need size to grow from {} bytes", was);
try { try {
lock.wait(); lock.wait();
} catch (InterruptedException ie) { } catch (InterruptedException ie) {
throw new ConnectionException(ie); throw new ConnectionException(ie);
} }
} }
consume(howMuch); return size;
} }
} }