Add an option to limit the used bandwidth with SCP upload and download features

This commit is contained in:
lguerin
2015-08-12 16:55:08 +02:00
parent 84d15f4cf5
commit 782ff9b83e
6 changed files with 215 additions and 25 deletions

View File

@@ -0,0 +1,16 @@
package net.schmizz.sshj.xfer.scp;
abstract class AbstractSCPClient {
protected final SCPEngine engine;
protected int bandwidthLimit;
AbstractSCPClient(SCPEngine engine) {
this.engine = engine;
}
AbstractSCPClient(SCPEngine engine, int bandwidthLimit) {
this.engine = engine;
this.bandwidthLimit = bandwidthLimit;
}
}

View File

@@ -24,18 +24,21 @@ import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArguments;
/** Support for uploading files over a connected link using SCP. */
public final class SCPDownloadClient {
public final class SCPDownloadClient extends AbstractSCPClient {
private boolean recursiveMode = true;
private final SCPEngine engine;
SCPDownloadClient(SCPEngine engine) {
this.engine = engine;
super(engine);
}
SCPDownloadClient(SCPEngine engine, int bandwidthLimit) {
super(engine, bandwidthLimit);
}
/** Download a file from {@code sourcePath} on the connected host to {@code targetPath} locally. */
@@ -60,12 +63,12 @@ public final class SCPDownloadClient {
void startCopy(String sourcePath, LocalDestFile targetFile)
throws IOException {
List<Arg> args = new LinkedList<Arg>();
args.add(Arg.SOURCE);
args.add(Arg.QUIET);
args.add(Arg.PRESERVE_TIMES);
if (recursiveMode)
args.add(Arg.RECURSIVE);
List<Arg> args = SCPArguments.with(Arg.SOURCE)
.and(Arg.QUIET)
.and(Arg.PRESERVE_TIMES)
.and(Arg.RECURSIVE, recursiveMode)
.and(Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0))
.arguments();
engine.execSCPWith(args, sourcePath);
engine.signal("Start status OK");

View File

@@ -28,6 +28,7 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.LinkedList;
import java.util.List;
/** @see <a href="http://blogs.sun.com/janp/entry/how_the_scp_protocol_works">SCP Protocol</a> */
@@ -39,17 +40,31 @@ class SCPEngine {
RECURSIVE('r'),
VERBOSE('v'),
PRESERVE_TIMES('p'),
QUIET('q');
QUIET('q'),
LIMIT('l', "");
private final char a;
private String v;
private Arg(char a) {
this.a = a;
}
private Arg(char a, String v) {
this.a = a;
this.v = v;
}
public void setValue(String v) {
this.v = v;
}
@Override
public String toString() {
return "-" + a;
String arg = "-" + a;
if (v != null && v.length() > 0) {
arg = arg + v;
}
return arg;
}
}
@@ -186,4 +201,63 @@ class SCPEngine {
return listener;
}
public static class SCPArguments {
private static List<Arg> args = null;
private SCPArguments() {
this.args = new LinkedList<Arg>();
}
private static void addArg(Arg arg, String value, boolean accept) {
if (accept) {
if (null != value && value.length() > 0) {
arg.setValue(value);
}
args.add(arg);
}
}
public static SCPArguments with(Arg arg) {
return with(arg, null, true);
}
public static SCPArguments with(Arg arg, String value) {
return with(arg, value, true);
}
public static SCPArguments with(Arg arg, boolean accept) {
return with(arg, null, accept);
}
public static SCPArguments with(Arg arg, String value, boolean accept) {
SCPArguments scpArguments = new SCPArguments();
addArg(arg, value, accept);
return scpArguments;
}
public SCPArguments and(Arg arg) {
addArg(arg, null, true);
return this;
}
public SCPArguments and(Arg arg, String value) {
addArg(arg, value, true);
return this;
}
public SCPArguments and(Arg arg, boolean accept) {
addArg(arg, null, accept);
return this;
}
public SCPArguments and(Arg arg, String value, boolean accept) {
addArg(arg, value, accept);
return this;
}
public List<Arg> arguments() {
return args;
}
}
}

View File

@@ -28,18 +28,23 @@ public class SCPFileTransfer
extends AbstractFileTransfer
implements FileTransfer {
/** Default bandwidth limit for SCP transfert in Kbit/s (-1 means unlimited) */
private static final int DEFAULT_BANDWIDTH_LIMIT = -1;
private final SessionFactory sessionFactory;
private int bandwidthLimit;
public SCPFileTransfer(SessionFactory sessionFactory) {
this.sessionFactory = sessionFactory;
this.bandwidthLimit = DEFAULT_BANDWIDTH_LIMIT;
}
public SCPDownloadClient newSCPDownloadClient() {
return new SCPDownloadClient(newSCPEngine());
return new SCPDownloadClient(newSCPEngine(), bandwidthLimit);
}
public SCPUploadClient newSCPUploadClient() {
return new SCPUploadClient(newSCPEngine());
return new SCPUploadClient(newSCPEngine(), bandwidthLimit);
}
private SCPEngine newSCPEngine() {
@@ -70,4 +75,10 @@ public class SCPFileTransfer
newSCPUploadClient().copy(localFile, remotePath);
}
public SCPFileTransfer bandwidthLimit(int limit) {
if (limit > 0) {
this.bandwidthLimit = limit;
}
return this;
}
}

View File

@@ -24,17 +24,21 @@ import net.schmizz.sshj.xfer.scp.SCPEngine.Arg;
import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedList;
import java.util.List;
/** Support for uploading files over a connected link using SCP. */
public final class SCPUploadClient {
import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArguments;
/** Support for uploading files over a connected link using SCP. */
public final class SCPUploadClient extends AbstractSCPClient {
private final SCPEngine engine;
private LocalFileFilter uploadFilter;
SCPUploadClient(SCPEngine engine) {
this.engine = engine;
super(engine);
}
SCPUploadClient(SCPEngine engine, int bandwidthLimit) {
super(engine, bandwidthLimit);
}
/** Upload a local file from {@code localFile} to {@code targetPath} on the remote host. */
@@ -55,11 +59,11 @@ public final class SCPUploadClient {
private synchronized void startCopy(LocalSourceFile sourceFile, String targetPath)
throws IOException {
List<Arg> args = new LinkedList<Arg>();
args.add(Arg.SINK);
args.add(Arg.RECURSIVE);
if (sourceFile.providesAtimeMtime())
args.add(Arg.PRESERVE_TIMES);
List<Arg> args = SCPArguments.with(Arg.SINK)
.and(Arg.RECURSIVE)
.and(Arg.PRESERVE_TIMES, sourceFile.providesAtimeMtime())
.and(Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0))
.arguments();
engine.execSCPWith(args, targetPath);
engine.check("Start status OK");
process(engine.getTransferListener(), sourceFile);

View File

@@ -0,0 +1,82 @@
package net.schmizz.sshj.xfer.scp;
import com.hierynomus.sshj.test.SshFixture;
import com.hierynomus.sshj.test.util.FileUtil;
import net.schmizz.sshj.SSHClient;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import java.io.File;
import java.io.IOException;
import static junit.framework.Assert.assertFalse;
import static junit.framework.Assert.assertTrue;
public class SCPFileTransferTest {
public static final String DEFAULT_FILE_NAME = "my_file.txt";
File targetDir;
File sourceFile;
File targetFile;
SSHClient sshClient;
@Rule
public SshFixture fixture = new SshFixture();
@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();
@Before
public void init() throws IOException {
sourceFile = tempFolder.newFile(DEFAULT_FILE_NAME);
FileUtil.writeToFile(sourceFile, "This is my file");
targetDir = tempFolder.newFolder();
targetFile = new File(targetDir + File.separator + DEFAULT_FILE_NAME);
sshClient = fixture.setupConnectedDefaultClient();
sshClient.authPassword("test", "test");
}
@After
public void cleanup() {
if (targetFile.exists()) {
targetFile.delete();
}
}
@Test
public void should_SCP_Upload_File() throws IOException {
SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer();
assertFalse(targetFile.exists());
scpFileTransfer.upload(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath());
assertTrue(targetFile.exists());
}
@Test
public void should_SCP_Upload_File_With_Bandwidth_Limit() throws IOException {
// Limit upload transfert at 2Mo/s
SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer().bandwidthLimit(16000);
assertFalse(targetFile.exists());
scpFileTransfer.upload(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath());
assertTrue(targetFile.exists());
}
@Test
public void should_SCP_Download_File() throws IOException {
SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer();
assertFalse(targetFile.exists());
scpFileTransfer.download(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath());
assertTrue(targetFile.exists());
}
@Test
public void should_SCP_Download_File_With_Bandwidth_Limit() throws IOException {
// Limit download transfert at 128Ko/s
SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer().bandwidthLimit(1024);
assertFalse(targetFile.exists());
scpFileTransfer.download(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath());
assertTrue(targetFile.exists());
}
}