mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-06 22:30:56 +03:00
add SCP support
SCP is an experimental feature, we have our own SCP implementation since we can't rely on scp system command to proper handle permissions, quota and user's home dir restrictions. The SCP protocol is quite simple but there is no official docs about it, so we need more testing and feedbacks before enabling it by default. We may not handle some borderline cases or have sneaky bugs. This commit contains some breaking changes to the REST API. SFTPGo API should be stable now and I hope no more breaking changes before the first stable release.
This commit is contained in:
@@ -8,6 +8,8 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
@@ -77,8 +79,11 @@ iixITGvaNZh/tjAAAACW5pY29sYUBwMQE=
|
||||
)
|
||||
|
||||
var (
|
||||
allPerms = []string{dataprovider.PermAny}
|
||||
homeBasePath string
|
||||
allPerms = []string{dataprovider.PermAny}
|
||||
homeBasePath string
|
||||
scpPath string
|
||||
pubKeyPath string
|
||||
privateKeyPath string
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -97,6 +102,8 @@ func TestMain(m *testing.M) {
|
||||
httpdConf := config.GetHTTPDConfig()
|
||||
router := api.GetHTTPRouter()
|
||||
sftpdConf.BindPort = 2022
|
||||
// we need to test SCP support
|
||||
sftpdConf.IsSCPEnabled = true
|
||||
// we run the test cases with UploadMode atomic. The non atomic code path
|
||||
// simply does not execute some code so if it works in atomic mode will
|
||||
// work in non atomic mode too
|
||||
@@ -109,10 +116,27 @@ func TestMain(m *testing.M) {
|
||||
sftpdConf.Actions.Command = "/usr/bin/true"
|
||||
sftpdConf.Actions.HTTPNotificationURL = "http://127.0.0.1:8080/"
|
||||
}
|
||||
pubKeyPath = filepath.Join(homeBasePath, "ssh_key.pub")
|
||||
privateKeyPath = filepath.Join(homeBasePath, "ssh_key")
|
||||
err = ioutil.WriteFile(pubKeyPath, []byte(testPubKey+"\n"), 0600)
|
||||
if err != nil {
|
||||
logger.WarnToConsole("unable to save public key to file: %v", err)
|
||||
}
|
||||
err = ioutil.WriteFile(privateKeyPath, []byte(testPrivateKey+"\n"), 0600)
|
||||
if err != nil {
|
||||
logger.WarnToConsole("unable to save private key to file: %v", err)
|
||||
}
|
||||
|
||||
sftpd.SetDataProvider(dataProvider)
|
||||
api.SetDataProvider(dataProvider)
|
||||
|
||||
scpPath, err = exec.LookPath("scp")
|
||||
if err != nil {
|
||||
logger.Warn(logSender, "unable to get scp command. SCP tests will be skipped, err: %v", err)
|
||||
logger.WarnToConsole("unable to get scp command. SCP tests will be skipped, err: %v", err)
|
||||
scpPath = ""
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger.Debug(logSender, "initializing SFTP server with config %+v", sftpdConf)
|
||||
if err := sftpdConf.Initialize(configDir); err != nil {
|
||||
@@ -1399,6 +1423,503 @@ func TestSSHConnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Start SCP tests
|
||||
func TestSCPBasicHandling(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
u.QuotaSize = 6553600
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(131074)
|
||||
expectedQuotaSize := user.UsedQuotaSize + testFileSize
|
||||
expectedQuotaFiles := user.UsedQuotaFiles + 1
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
// test to download a missing file
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err == nil {
|
||||
t.Errorf("downloading a missing file via scp must fail")
|
||||
}
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading file via scp: %v", err)
|
||||
}
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err != nil {
|
||||
t.Errorf("error downloading file via scp: %v", err)
|
||||
}
|
||||
fi, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
t.Errorf("stat for the downloaded file must succeed")
|
||||
} else {
|
||||
if fi.Size() != testFileSize {
|
||||
t.Errorf("size of the file downloaded via SCP does not match the expected one")
|
||||
}
|
||||
}
|
||||
os.Remove(localPath)
|
||||
user, _, err = api.GetUserByID(user.ID, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("error getting user: %v", err)
|
||||
}
|
||||
if expectedQuotaFiles != user.UsedQuotaFiles {
|
||||
t.Errorf("quota files does not match, expected: %v, actual: %v", expectedQuotaFiles, user.UsedQuotaFiles)
|
||||
}
|
||||
if expectedQuotaSize != user.UsedQuotaSize {
|
||||
t.Errorf("quota size does not match, expected: %v, actual: %v", expectedQuotaSize, user.UsedQuotaSize)
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPUploadFileOverwrite(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(32760)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, filepath.Join("/", testFileName))
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading file via scp: %v", err)
|
||||
}
|
||||
// test a new upload that must overwrite the existing file
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading existing file via scp: %v", err)
|
||||
}
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err != nil {
|
||||
t.Errorf("error downloading file via scp: %v", err)
|
||||
}
|
||||
fi, err := os.Stat(localPath)
|
||||
if err != nil {
|
||||
t.Errorf("stat for the downloaded file must succeed")
|
||||
} else {
|
||||
if fi.Size() != testFileSize {
|
||||
t.Errorf("size of the file downloaded via SCP does not match the expected one")
|
||||
}
|
||||
}
|
||||
os.Remove(localPath)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPRecursive(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testBaseDirName := "test_dir"
|
||||
testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName)
|
||||
testBaseDirDownName := "test_dir_down"
|
||||
testBaseDirDownPath := filepath.Join(homeBasePath, testBaseDirDownName)
|
||||
testFilePath := filepath.Join(homeBasePath, testBaseDirName, testFileName)
|
||||
testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testBaseDirName, testFileName)
|
||||
testFileSize := int64(131074)
|
||||
createTestFile(testFilePath, testFileSize)
|
||||
createTestFile(testFilePath1, testFileSize)
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testBaseDirName))
|
||||
// test to download a missing dir
|
||||
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true)
|
||||
if err == nil {
|
||||
t.Errorf("downloading a missing dir via scp must fail")
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
||||
err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading dir via scp: %v", err)
|
||||
}
|
||||
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, true)
|
||||
if err != nil {
|
||||
t.Errorf("error downloading dir via scp: %v", err)
|
||||
}
|
||||
// test download without passing -r
|
||||
err = scpDownload(testBaseDirDownPath, remoteDownPath, true, false)
|
||||
if err == nil {
|
||||
t.Errorf("recursive download without -r must fail")
|
||||
}
|
||||
fi, err := os.Stat(filepath.Join(testBaseDirDownPath, testFileName))
|
||||
if err != nil {
|
||||
t.Errorf("error downloading file using scp recursive: %v", err)
|
||||
} else {
|
||||
if fi.Size() != testFileSize {
|
||||
t.Errorf("size for file downloaded using recursive scp does not match, actual: %v, expected: %v", fi.Size(), testFileSize)
|
||||
}
|
||||
}
|
||||
fi, err = os.Stat(filepath.Join(testBaseDirDownPath, testBaseDirName, testFileName))
|
||||
if err != nil {
|
||||
t.Errorf("error downloading file using scp recursive: %v", err)
|
||||
} else {
|
||||
if fi.Size() != testFileSize {
|
||||
t.Errorf("size for file downloaded using recursive scp does not match, actual: %v, expected: %v", fi.Size(), testFileSize)
|
||||
}
|
||||
}
|
||||
// upload to a non existent dir
|
||||
remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/non_existent_dir")
|
||||
err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
||||
if err == nil {
|
||||
t.Errorf("uploading via scp to a non existent dir must fail")
|
||||
}
|
||||
os.RemoveAll(testBaseDirPath)
|
||||
os.RemoveAll(testBaseDirDownPath)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPPermCreateDirs(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermUpload}
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(32760)
|
||||
testBaseDirName := "test_dir"
|
||||
testBaseDirPath := filepath.Join(homeBasePath, testBaseDirName)
|
||||
testFilePath1 := filepath.Join(homeBasePath, testBaseDirName, testFileName)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
err = createTestFile(testFilePath1, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp/")
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err == nil {
|
||||
t.Errorf("scp upload must fail, the user cannot create new dirs")
|
||||
}
|
||||
err = scpUpload(testBaseDirPath, remoteUpPath, true)
|
||||
if err == nil {
|
||||
t.Errorf("scp upload must fail, the user cannot create new dirs")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
os.RemoveAll(testBaseDirPath)
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPPermUpload(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
u.Permissions = []string{dataprovider.PermDownload, dataprovider.PermCreateDirs}
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65536)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/tmp")
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err == nil {
|
||||
t.Errorf("scp upload must fail, the user cannot upload")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPPermDownload(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
u := getTestUser(usePubKey)
|
||||
u.Permissions = []string{dataprovider.PermUpload, dataprovider.PermCreateDirs}
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65537)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "tmp")
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading existing file via scp: %v", err)
|
||||
}
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/tmp", testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err == nil {
|
||||
t.Errorf("scp download must fail, the user cannot download")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPQuotaSize(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
testFileSize := int64(65535)
|
||||
u := getTestUser(usePubKey)
|
||||
u.QuotaFiles = 1
|
||||
u.QuotaSize = testFileSize - 1
|
||||
user, _, err := api.AddUser(u, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testFileName))
|
||||
err = scpUpload(testFilePath, remoteUpPath, true)
|
||||
if err != nil {
|
||||
t.Errorf("error uploading existing file via scp: %v", err)
|
||||
}
|
||||
err = scpUpload(testFilePath, remoteUpPath+".quota", true)
|
||||
if err == nil {
|
||||
t.Errorf("user is over quota scp upload must fail")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPEscapeHomeDir(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
os.MkdirAll(user.GetHomeDir(), 0777)
|
||||
testDir := "testDir"
|
||||
linkPath := filepath.Join(homeBasePath, defaultUsername, testDir)
|
||||
err = os.Symlink(homeBasePath, linkPath)
|
||||
if err != nil {
|
||||
t.Errorf("error making local symlink: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDir, testDir))
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err == nil {
|
||||
t.Errorf("uploading to a dir with a symlink outside home dir must fail")
|
||||
}
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir, testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err == nil {
|
||||
t.Errorf("scp download must fail, the requested file has a symlink outside user home")
|
||||
}
|
||||
remoteDownPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join("/", testDir))
|
||||
err = scpDownload(homeBasePath, remoteDownPath, false, true)
|
||||
if err == nil {
|
||||
t.Errorf("scp download must fail, the requested dir is a symlink outside user home")
|
||||
}
|
||||
err = os.Remove(testFilePath)
|
||||
if err != nil {
|
||||
t.Errorf("error removing test file")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPUploadPaths(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
testDirName := "testDir"
|
||||
testDirPath := filepath.Join(user.GetHomeDir(), testDirName)
|
||||
os.MkdirAll(testDirPath, 0777)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, testDirName)
|
||||
remoteDownPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testFileName))
|
||||
localPath := filepath.Join(homeBasePath, "scp_download.dat")
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err != nil {
|
||||
t.Errorf("scp upload error: %v", err)
|
||||
}
|
||||
err = scpDownload(localPath, remoteDownPath, false, false)
|
||||
if err != nil {
|
||||
t.Errorf("scp download error: %v", err)
|
||||
}
|
||||
// upload a file to a missing dir
|
||||
remoteUpPath = fmt.Sprintf("%v@127.0.0.1:%v", user.Username, path.Join(testDirName, testDirName, testFileName))
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err == nil {
|
||||
t.Errorf("scp upload to a missing dir must fail")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSCPOverwriteDirWithFile(t *testing.T) {
|
||||
if len(scpPath) == 0 {
|
||||
t.Skip("scp command not found, unable to execute this test")
|
||||
}
|
||||
usePubKey := true
|
||||
user, _, err := api.AddUser(getTestUser(usePubKey), http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to add user: %v", err)
|
||||
}
|
||||
testFileName := "test_file.dat"
|
||||
testFilePath := filepath.Join(homeBasePath, testFileName)
|
||||
testFileSize := int64(65535)
|
||||
testDirPath := filepath.Join(user.GetHomeDir(), testFileName)
|
||||
os.MkdirAll(testDirPath, 0777)
|
||||
err = createTestFile(testFilePath, testFileSize)
|
||||
if err != nil {
|
||||
t.Errorf("unable to create test file: %v", err)
|
||||
}
|
||||
remoteUpPath := fmt.Sprintf("%v@127.0.0.1:%v", user.Username, "/")
|
||||
err = scpUpload(testFilePath, remoteUpPath, false)
|
||||
if err == nil {
|
||||
t.Errorf("copying a file over an existing dir must fail")
|
||||
}
|
||||
err = os.RemoveAll(user.GetHomeDir())
|
||||
if err != nil {
|
||||
t.Errorf("error removing uploaded files")
|
||||
}
|
||||
_, err = api.RemoveUser(user, http.StatusOK)
|
||||
if err != nil {
|
||||
t.Errorf("unable to remove user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// End SCP tests
|
||||
|
||||
func waitTCPListening(address string) {
|
||||
for {
|
||||
conn, err := net.Dial("tcp", address)
|
||||
@@ -1487,6 +2008,10 @@ func getSftpClient(user dataprovider.User, usePubKey bool) (*sftp.Client, error)
|
||||
}
|
||||
|
||||
func createTestFile(path string, size int64) error {
|
||||
baseDir := filepath.Dir(path)
|
||||
if _, err := os.Stat(baseDir); os.IsNotExist(err) {
|
||||
os.MkdirAll(baseDir, 0777)
|
||||
}
|
||||
content := make([]byte, size)
|
||||
_, err := rand.Read(content)
|
||||
if err != nil {
|
||||
@@ -1572,6 +2097,49 @@ func sftpDownloadNonBlocking(remoteSourcePath string, localDestPath string, expe
|
||||
return c
|
||||
}
|
||||
|
||||
func scpUpload(localPath, remotePath string, preserveTime bool) error {
|
||||
var args []string
|
||||
if preserveTime {
|
||||
args = append(args, "-p")
|
||||
}
|
||||
fi, err := os.Stat(localPath)
|
||||
if err == nil {
|
||||
if fi.IsDir() {
|
||||
args = append(args, "-r")
|
||||
}
|
||||
}
|
||||
args = append(args, "-P")
|
||||
args = append(args, "2022")
|
||||
args = append(args, "-o")
|
||||
args = append(args, "StrictHostKeyChecking=no")
|
||||
args = append(args, "-i")
|
||||
args = append(args, privateKeyPath)
|
||||
args = append(args, localPath)
|
||||
args = append(args, remotePath)
|
||||
cmd := exec.Command(scpPath, args...)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func scpDownload(localPath, remotePath string, preserveTime, recursive bool) error {
|
||||
var args []string
|
||||
if preserveTime {
|
||||
args = append(args, "-p")
|
||||
}
|
||||
if recursive {
|
||||
args = append(args, "-r")
|
||||
}
|
||||
args = append(args, "-P")
|
||||
args = append(args, "2022")
|
||||
args = append(args, "-o")
|
||||
args = append(args, "StrictHostKeyChecking=no")
|
||||
args = append(args, "-i")
|
||||
args = append(args, privateKeyPath)
|
||||
args = append(args, remotePath)
|
||||
args = append(args, localPath)
|
||||
cmd := exec.Command(scpPath, args...)
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
func waitForActiveTransfer() {
|
||||
stats := sftpd.GetConnectionsStats()
|
||||
for len(stats) < 1 {
|
||||
|
||||
Reference in New Issue
Block a user