SCP remote path escaping is now configurable (Fixes #212, #184, #152)

This commit is contained in:
Jeroen van Erp
2015-09-21 14:51:57 +02:00
parent 28a11b0b45
commit d520585a09
6 changed files with 191 additions and 150 deletions

View File

@@ -18,7 +18,6 @@ package net.schmizz.sshj.xfer.scp;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.xfer.LocalDestFile;
import net.schmizz.sshj.xfer.TransferListener;
import net.schmizz.sshj.xfer.scp.SCPEngine.Arg;
import java.io.IOException;
import java.io.OutputStream;
@@ -26,9 +25,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArgument;
import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArguments;
/** Support for uploading files over a connected link using SCP. */
public final class SCPDownloadClient extends AbstractSCPClient {
@@ -43,11 +39,15 @@ public final class SCPDownloadClient extends AbstractSCPClient {
}
/** Download a file from {@code sourcePath} on the connected host to {@code targetPath} locally. */
public synchronized int copy(String sourcePath, LocalDestFile targetFile)
public synchronized int copy(String sourcePath, LocalDestFile targetFile) throws IOException {
return copy(sourcePath, targetFile, ScpCommandLine.EscapeMode.NoEscape);
}
public synchronized int copy(String sourcePath, LocalDestFile targetFile, ScpCommandLine.EscapeMode escapeMode)
throws IOException {
engine.cleanSlate();
try {
startCopy(sourcePath, targetFile);
startCopy(sourcePath, targetFile, escapeMode);
} finally {
engine.exit();
}
@@ -62,15 +62,15 @@ public final class SCPDownloadClient extends AbstractSCPClient {
this.recursiveMode = recursive;
}
void startCopy(String sourcePath, LocalDestFile targetFile)
private void startCopy(String sourcePath, LocalDestFile targetFile, ScpCommandLine.EscapeMode escapeMode)
throws IOException {
List<SCPArgument> args = SCPArguments.with(Arg.SOURCE)
.and(Arg.QUIET)
.and(Arg.PRESERVE_TIMES)
.and(Arg.RECURSIVE, recursiveMode)
.and(Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0))
.arguments();
engine.execSCPWith(args, sourcePath);
ScpCommandLine commandLine = ScpCommandLine.with(ScpCommandLine.Arg.SOURCE)
.and(ScpCommandLine.Arg.QUIET)
.and(ScpCommandLine.Arg.PRESERVE_TIMES)
.and(ScpCommandLine.Arg.RECURSIVE, recursiveMode)
.and(ScpCommandLine.Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0));
commandLine.withPath(sourcePath, escapeMode);
engine.execSCPWith(commandLine);
engine.signal("Start status OK");

View File

@@ -28,34 +28,14 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
/** @see <a href="http://blogs.sun.com/janp/entry/how_the_scp_protocol_works">SCP Protocol</a> */
class SCPEngine {
enum Arg {
SOURCE('f'),
SINK('t'),
RECURSIVE('r'),
VERBOSE('v'),
PRESERVE_TIMES('p'),
QUIET('q'),
LIMIT('l');
private final char a;
private Arg(char a) {
this.a = a;
}
@Override
public String toString() {
return "-" + a;
}
}
private static final String SCP_COMMAND = "scp";
private static final char LF = '\n';
private final Logger log = LoggerFactory.getLogger(getClass());
@@ -99,19 +79,9 @@ class SCPEngine {
exitStatus = -1;
}
void execSCPWith(List<SCPArgument> args, String path)
void execSCPWith(ScpCommandLine commandLine)
throws SSHException {
final StringBuilder cmd = new StringBuilder(SCP_COMMAND);
for (SCPArgument arg : args) {
cmd.append(" ").append(arg);
}
cmd.append(" ");
if (path == null || path.isEmpty()) {
cmd.append(".");
} else {
cmd.append("'").append(path.replaceAll("'", "\\'")).append("'");
}
scp = host.startSession().exec(cmd.toString());
scp = host.startSession().exec(commandLine.toCommandLine());
}
void exit() {
@@ -187,85 +157,4 @@ class SCPEngine {
TransferListener getTransferListener() {
return listener;
}
public static class SCPArgument {
private Arg name;
private String value;
private SCPArgument(Arg name, String value) {
this.name = name;
this.value = value;
}
public static SCPArgument addArgument(Arg name, String value) {
return new SCPArgument(name, value);
}
@Override
public String toString() {
String option = name.toString();
if (value != null) {
option = option + value;
}
return option;
}
}
public static class SCPArguments {
private static List<SCPArgument> args = null;
private SCPArguments() {
this.args = new LinkedList<SCPArgument>();
}
private static void addArgument(Arg name, String value, boolean accept) {
if (accept) {
args.add(SCPArgument.addArgument(name, value));
}
}
public static SCPArguments with(Arg name) {
return with(name, null, true);
}
public static SCPArguments with(Arg name, String value) {
return with(name, value, true);
}
public static SCPArguments with(Arg name, boolean accept) {
return with(name, null, accept);
}
public static SCPArguments with(Arg name, String value, boolean accept) {
SCPArguments scpArguments = new SCPArguments();
addArgument(name, value, accept);
return scpArguments;
}
public SCPArguments and(Arg name) {
addArgument(name, null, true);
return this;
}
public SCPArguments and(Arg name, String value) {
addArgument(name, value, true);
return this;
}
public SCPArguments and(Arg name, boolean accept) {
addArgument(name, null, accept);
return this;
}
public SCPArguments and(Arg name, String value, boolean accept) {
addArgument(name, value, accept);
return this;
}
public List<SCPArgument> arguments() {
return args;
}
}
}

View File

@@ -1,12 +1,12 @@
/**
* Copyright 2009 sshj contributors
*
* <p/>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* <p/>
* http://www.apache.org/licenses/LICENSE-2.0
*
* <p/>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -20,14 +20,9 @@ import net.schmizz.sshj.common.StreamCopier;
import net.schmizz.sshj.xfer.LocalFileFilter;
import net.schmizz.sshj.xfer.LocalSourceFile;
import net.schmizz.sshj.xfer.TransferListener;
import net.schmizz.sshj.xfer.scp.SCPEngine.Arg;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArgument;
import static net.schmizz.sshj.xfer.scp.SCPEngine.SCPArguments;
/** Support for uploading files over a connected link using SCP. */
public final class SCPUploadClient extends AbstractSCPClient {
@@ -45,9 +40,14 @@ public final class SCPUploadClient extends AbstractSCPClient {
/** Upload a local file from {@code localFile} to {@code targetPath} on the remote host. */
public synchronized int copy(LocalSourceFile sourceFile, String remotePath)
throws IOException {
return copy(sourceFile, remotePath, ScpCommandLine.EscapeMode.SingleQuote);
}
public synchronized int copy(LocalSourceFile sourceFile, String remotePath, ScpCommandLine.EscapeMode escapeMode)
throws IOException {
engine.cleanSlate();
try {
startCopy(sourceFile, remotePath);
startCopy(sourceFile, remotePath, escapeMode);
} finally {
engine.exit();
}
@@ -58,14 +58,14 @@ public final class SCPUploadClient extends AbstractSCPClient {
this.uploadFilter = uploadFilter;
}
private synchronized void startCopy(LocalSourceFile sourceFile, String targetPath)
private void startCopy(LocalSourceFile sourceFile, String targetPath, ScpCommandLine.EscapeMode escapeMode)
throws IOException {
List<SCPArgument> args = SCPArguments.with(Arg.SINK)
.and(Arg.RECURSIVE)
.and(Arg.PRESERVE_TIMES, sourceFile.providesAtimeMtime())
.and(Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0))
.arguments();
engine.execSCPWith(args, targetPath);
ScpCommandLine commandLine = ScpCommandLine.with(ScpCommandLine.Arg.SINK)
.and(ScpCommandLine.Arg.RECURSIVE)
.and(ScpCommandLine.Arg.PRESERVE_TIMES, sourceFile.providesAtimeMtime())
.and(ScpCommandLine.Arg.LIMIT, String.valueOf(bandwidthLimit), (bandwidthLimit > 0));
commandLine.withPath(targetPath, escapeMode);
engine.execSCPWith(commandLine);
engine.check("Start status OK");
process(engine.getTransferListener(), sourceFile);
}

View File

@@ -0,0 +1,132 @@
package net.schmizz.sshj.xfer.scp;
import java.util.LinkedHashMap;
/**
* Command line to be sent to the remote SSH process to setup an SCP process in the correct mode.
*/
public class ScpCommandLine {
private static final String SCP_COMMAND = "scp";
private EscapeMode mode;
enum Arg {
SOURCE('f'),
SINK('t'),
RECURSIVE('r'),
VERBOSE('v'),
PRESERVE_TIMES('p'),
QUIET('q'),
LIMIT('l');
private final char a;
private Arg(char a) {
this.a = a;
}
@Override
public String toString() {
return "-" + a;
}
}
public enum EscapeMode {
NoEscape,
Space {
@Override
String escapedPath(String path) {
return path.replace(" ", "\\ ");
}
},
DoubleQuote {
@Override
String escapedPath(String path) {
return "\"" + path.replace("\"", "\\\"") + "\"";
}
},
SingleQuote {
@Override
String escapedPath(String path) {
return "\'" + path.replace("'", "\\'") + "'";
}
};
String escapedPath(String path) {
return path;
}
}
private LinkedHashMap<Arg, String> arguments = new LinkedHashMap<Arg, String>();
private String path;
ScpCommandLine() {
}
static ScpCommandLine with(Arg name) {
return with(name, null, true);
}
static ScpCommandLine with(Arg name, String value) {
return with(name, value, true);
}
static ScpCommandLine with(Arg name, boolean accept) {
return with(name, null, accept);
}
static ScpCommandLine with(Arg name, String value, boolean accept) {
ScpCommandLine commandLine = new ScpCommandLine();
commandLine.addArgument(name, value, accept);
return commandLine;
}
private void addArgument(Arg name, String value, boolean accept) {
if (accept) {
arguments.put(name, value);
}
}
ScpCommandLine and(Arg name) {
addArgument(name, null, true);
return this;
}
ScpCommandLine and(Arg name, String value) {
addArgument(name, value, true);
return this;
}
ScpCommandLine and(Arg name, boolean accept) {
addArgument(name, null, accept);
return this;
}
ScpCommandLine and(Arg name, String value, boolean accept) {
addArgument(name, value, accept);
return this;
}
ScpCommandLine withPath(String path, EscapeMode mode) {
this.path = path;
this.mode = mode;
return this;
}
String toCommandLine() {
final StringBuilder cmd = new StringBuilder(SCP_COMMAND);
for (Arg arg : arguments.keySet()) {
cmd.append(" ").append(arg);
String s = arguments.get(arg);
if (s != null && !s.trim().isEmpty()) {
cmd.append(s);
}
}
cmd.append(" ");
if (path == null || path.trim().isEmpty()) {
cmd.append(".");
} else {
cmd.append(mode.escapedPath(path));
}
return cmd.toString();
}
}

View File

@@ -2,10 +2,7 @@ package com.hierynomus.sshj.test.util;
import net.schmizz.sshj.common.IOUtils;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Writer;
import java.io.*;
public class FileUtil {
@@ -17,4 +14,14 @@ public class FileUtil {
IOUtils.closeQuietly(w);
}
}
public static String readFromFile(File f) throws IOException {
FileInputStream fileInputStream = new FileInputStream(f);
try {
ByteArrayOutputStream byteArrayOutputStream = IOUtils.readFully(fileInputStream);
return byteArrayOutputStream.toString("UTF-8");
} finally {
IOUtils.closeQuietly(fileInputStream);
}
}
}

View File

@@ -3,6 +3,7 @@ package net.schmizz.sshj.xfer.scp;
import com.hierynomus.sshj.test.SshFixture;
import com.hierynomus.sshj.test.util.FileUtil;
import net.schmizz.sshj.SSHClient;
import org.hamcrest.CoreMatchers;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@@ -14,6 +15,7 @@ import java.io.IOException;
import static junit.framework.Assert.assertFalse;
import static junit.framework.Assert.assertTrue;
import static org.hamcrest.MatcherAssert.assertThat;
public class SCPFileTransferTest {
@@ -34,7 +36,7 @@ public class SCPFileTransferTest {
sourceFile = tempFolder.newFile(DEFAULT_FILE_NAME);
FileUtil.writeToFile(sourceFile, "This is my file");
targetDir = tempFolder.newFolder();
targetFile = new File(targetDir + File.separator + DEFAULT_FILE_NAME);
targetFile = new File(targetDir, DEFAULT_FILE_NAME);
sshClient = fixture.setupConnectedDefaultClient();
sshClient.authPassword("test", "test");
}
@@ -79,4 +81,15 @@ public class SCPFileTransferTest {
scpFileTransfer.download(sourceFile.getAbsolutePath(), targetDir.getAbsolutePath());
assertTrue(targetFile.exists());
}
@Test
public void shouldSCPDownloadFileWithoutPathEscaping() throws IOException {
SCPFileTransfer scpFileTransfer = sshClient.newSCPFileTransfer();
assertFalse(targetFile.exists());
File file = tempFolder.newFile("new file.txt");
FileUtil.writeToFile(file, "Some content");
scpFileTransfer.download(tempFolder.getRoot().getAbsolutePath() + "/new file.txt", targetFile.getAbsolutePath());
assertTrue(targetFile.exists());
assertThat(FileUtil.readFromFile(targetFile), CoreMatchers.containsString("Some content"));
}
}