diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/AbstractSCPClient.java b/src/main/java/net/schmizz/sshj/xfer/scp/AbstractSCPClient.java new file mode 100644 index 00000000..af455da0 --- /dev/null +++ b/src/main/java/net/schmizz/sshj/xfer/scp/AbstractSCPClient.java @@ -0,0 +1,16 @@ +package net.schmizz.sshj.xfer.scp; + +abstract class AbstractSCPClient { + + protected final SCPEngine engine; + protected int bandwidthLimit; + + AbstractSCPClient(SCPEngine engine) { + this.engine = engine; + } + + AbstractSCPClient(SCPEngine engine, int bandwidthLimit) { + this.engine = engine; + this.bandwidthLimit = bandwidthLimit; + } +} diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java index c2ad4258..7d80c5e2 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPDownloadClient.java @@ -24,18 +24,21 @@ import java.io.IOException; import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; -import java.util.LinkedList; import java.util.List; +import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArguments; + /** Support for uploading files over a connected link using SCP. */ -public final class SCPDownloadClient { +public final class SCPDownloadClient extends AbstractSCPClient { private boolean recursiveMode = true; - private final SCPEngine engine; - SCPDownloadClient(SCPEngine engine) { - this.engine = engine; + super(engine); + } + + SCPDownloadClient(SCPEngine engine, int bandwidthLimit) { + super(engine, bandwidthLimit); } /** Download a file from {@code sourcePath} on the connected host to {@code targetPath} locally. */ @@ -60,12 +63,12 @@ public final class SCPDownloadClient { void startCopy(String sourcePath, LocalDestFile targetFile) throws IOException { - List args = new LinkedList(); - args.add(Arg.SOURCE); - args.add(Arg.QUIET); - args.add(Arg.PRESERVE_TIMES); - if (recursiveMode) - args.add(Arg.RECURSIVE); + List args = SCPArguments.with(Arg.SOURCE) + .and(Arg.QUIET) + .and(Arg.PRESERVE_TIMES) + .and(Arg.RECURSIVE, recursiveMode) + .and(Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0)) + .arguments(); engine.execSCPWith(args, sourcePath); engine.signal("Start status OK"); diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java index 66bac99d..78d7eff6 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPEngine.java @@ -28,6 +28,7 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.LinkedList; import java.util.List; /** @see SCP Protocol */ @@ -39,17 +40,31 @@ class SCPEngine { RECURSIVE('r'), VERBOSE('v'), PRESERVE_TIMES('p'), - QUIET('q'); + QUIET('q'), + LIMIT('l', ""); private final char a; + private String v; private Arg(char a) { this.a = a; } + private Arg(char a, String v) { + this.a = a; + this.v = v; + } + + public void setValue(String v) { + this.v = v; + } @Override public String toString() { - return "-" + a; + String arg = "-" + a; + if (v != null && v.length() > 0) { + arg = arg + v; + } + return arg; } } @@ -186,4 +201,63 @@ class SCPEngine { return listener; } + public static class SCPArguments { + + private static List args = null; + + private SCPArguments() { + this.args = new LinkedList(); + } + + private static void addArg(Arg arg, String value, boolean accept) { + if (accept) { + if (null != value && value.length() > 0) { + arg.setValue(value); + } + args.add(arg); + } + } + + public static SCPArguments with(Arg arg) { + return with(arg, null, true); + } + + public static SCPArguments with(Arg arg, String value) { + return with(arg, value, true); + } + + public static SCPArguments with(Arg arg, boolean accept) { + return with(arg, null, accept); + } + + public static SCPArguments with(Arg arg, String value, boolean accept) { + SCPArguments scpArguments = new SCPArguments(); + addArg(arg, value, accept); + return scpArguments; + } + + public SCPArguments and(Arg arg) { + addArg(arg, null, true); + return this; + } + + public SCPArguments and(Arg arg, String value) { + addArg(arg, value, true); + return this; + } + + public SCPArguments and(Arg arg, boolean accept) { + addArg(arg, null, accept); + return this; + } + + public SCPArguments and(Arg arg, String value, boolean accept) { + addArg(arg, value, accept); + return this; + } + + public List arguments() { + return args; + } + } } diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPFileTransfer.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPFileTransfer.java index 7e055e3b..a71371c8 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPFileTransfer.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPFileTransfer.java @@ -28,18 +28,23 @@ public class SCPFileTransfer extends AbstractFileTransfer implements FileTransfer { + /** Default bandwidth limit for SCP transfert in Kbit/s (-1 means unlimited) */ + private static final int DEFAULT_BANDWIDTH_LIMIT = -1; + private final SessionFactory sessionFactory; + private int bandwidthLimit; public SCPFileTransfer(SessionFactory sessionFactory) { this.sessionFactory = sessionFactory; + this.bandwidthLimit = DEFAULT_BANDWIDTH_LIMIT; } public SCPDownloadClient newSCPDownloadClient() { - return new SCPDownloadClient(newSCPEngine()); + return new SCPDownloadClient(newSCPEngine(), bandwidthLimit); } public SCPUploadClient newSCPUploadClient() { - return new SCPUploadClient(newSCPEngine()); + return new SCPUploadClient(newSCPEngine(), bandwidthLimit); } private SCPEngine newSCPEngine() { @@ -70,4 +75,10 @@ public class SCPFileTransfer newSCPUploadClient().copy(localFile, remotePath); } + public SCPFileTransfer bandwidthLimit(int limit) { + if (limit > 0) { + this.bandwidthLimit = limit; + } + return this; + } } diff --git a/src/main/java/net/schmizz/sshj/xfer/scp/SCPUploadClient.java b/src/main/java/net/schmizz/sshj/xfer/scp/SCPUploadClient.java index 474ce1fc..ba5c8073 100644 --- a/src/main/java/net/schmizz/sshj/xfer/scp/SCPUploadClient.java +++ b/src/main/java/net/schmizz/sshj/xfer/scp/SCPUploadClient.java @@ -24,17 +24,21 @@ import net.schmizz.sshj.xfer.scp.SCPEngine.Arg; import java.io.IOException; import java.io.InputStream; -import java.util.LinkedList; import java.util.List; -/** Support for uploading files over a connected link using SCP. */ -public final class SCPUploadClient { +import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArguments; + +/** Support for uploading files over a connected link using SCP. */ +public final class SCPUploadClient extends AbstractSCPClient { - private final SCPEngine engine; private LocalFileFilter uploadFilter; SCPUploadClient(SCPEngine engine) { - this.engine = engine; + super(engine); + } + + SCPUploadClient(SCPEngine engine, int bandwidthLimit) { + super(engine, bandwidthLimit); } /** Upload a local file from {@code localFile} to {@code targetPath} on the remote host. */ @@ -55,11 +59,11 @@ public final class SCPUploadClient { private synchronized void startCopy(LocalSourceFile sourceFile, String targetPath) throws IOException { - List args = new LinkedList(); - args.add(Arg.SINK); - args.add(Arg.RECURSIVE); - if (sourceFile.providesAtimeMtime()) - args.add(Arg.PRESERVE_TIMES); + List args = SCPArguments.with(Arg.SINK) + .and(Arg.RECURSIVE) + .and(Arg.PRESERVE_TIMES, sourceFile.providesAtimeMtime()) + .and(Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0)) + .arguments(); engine.execSCPWith(args, targetPath); engine.check("Start status OK"); process(engine.getTransferListener(), sourceFile); diff --git a/src/test/java/net/schmizz/sshj/xfer/scp/SCPFileTransferTest.java b/src/test/java/net/schmizz/sshj/xfer/scp/SCPFileTransferTest.java new file mode 100644 index 00000000..759266e6 --- /dev/null +++ b/src/test/java/net/schmizz/sshj/xfer/scp/SCPFileTransferTest.java @@ -0,0 +1,82 @@ +package net.schmizz.sshj.xfer.scp; + +import com.hierynomus.sshj.test.SshFixture; +import com.hierynomus.sshj.test.util.FileUtil; +import net.schmizz.sshj.SSHClient; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; + +import static junit.framework.Assert.assertFalse; +import static junit.framework.Assert.assertTrue; + +public class SCPFileTransferTest { + + public static final String DEFAULT_FILE_NAME = "my_file.txt"; + File targetDir; + File sourceFile; + File targetFile; + SSHClient sshClient; + + @Rule + public SshFixture fixture = new SshFixture(); + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Before + public void init() throws IOException { + sourceFile = tempFolder.newFile(DEFAULT_FILE_NAME); + FileUtil.writeToFile(sourceFile, "This is my file"); + targetDir = tempFolder.newFolder(); + targetFile = new File(targetDir + File.separator + DEFAULT_FILE_NAME); + sshClient = fixture.setupConnectedDefaultClient(); + sshClient.authPassword("test", "test"); + } + + @After + public void cleanup() { + if (targetFile.exists()) { + targetFile.delete(); + } + } + + @Test + public void should_SCP_Upload_File() throws IOException { + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer(); + assertFalse(targetFile.exists()); + scpFileTransfer.upload(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath()); + assertTrue(targetFile.exists()); + } + + @Test + public void should_SCP_Upload_File_With_Bandwidth_Limit() throws IOException { + // Limit upload transfert at 2Mo/s + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer().bandwidthLimit(16000); + assertFalse(targetFile.exists()); + scpFileTransfer.upload(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath()); + assertTrue(targetFile.exists()); + } + + @Test + public void should_SCP_Download_File() throws IOException { + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer(); + assertFalse(targetFile.exists()); + scpFileTransfer.download(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath()); + assertTrue(targetFile.exists()); + } + + @Test + public void should_SCP_Download_File_With_Bandwidth_Limit() throws IOException { + // Limit download transfert at 128Ko/s + SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer().bandwidthLimit(1024); + assertFalse(targetFile.exists()); + scpFileTransfer.download(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath()); + assertTrue(targetFile.exists()); + } +}