sftpd: handle read and write from the same handle (#158)

Fixes #155
This commit is contained in:
Nicola Murino
2020-08-31 06:45:22 +02:00
committed by GitHub
parent 91a4c64390
commit 4748e6f54d
8 changed files with 209 additions and 58 deletions

View File

@@ -74,13 +74,22 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) {
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, request.Filepath, common.TransferDownload,
0, 0, 0, false, c.Fs)
t := newTransfer(baseTransfer, nil, r)
t := newTransfer(baseTransfer, nil, r, nil)
return t, nil
}
// OpenFile implements OpenFileWriter interface
func (c *Connection) OpenFile(request *sftp.Request) (sftp.WriterAtReaderAt, error) {
return c.handleFilewrite(request)
}
// Filewrite handles the write actions for a file on the system.
func (c *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
return c.handleFilewrite(request)
}
func (c *Connection) handleFilewrite(request *sftp.Request) (sftp.WriterAtReaderAt, error) {
c.UpdateLastActivity()
if !c.User.IsFileAllowed(request.Filepath) {
@@ -98,12 +107,24 @@ func (c *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
filePath = c.Fs.GetAtomicUploadPath(p)
}
var errForRead error
if !vfs.IsLocalOsFs(c.Fs) && request.Pflags().Read {
// read and write mode is only supported for local filesystem
errForRead = sftp.ErrSSHFxOpUnsupported
}
if !c.User.HasPerm(dataprovider.PermDownload, path.Dir(request.Filepath)) {
// we can try to read only for local fs here, see above.
// os.ErrPermission will become sftp.ErrSSHFxPermissionDenied when sent to
// the client
errForRead = os.ErrPermission
}
stat, statErr := c.Fs.Lstat(p)
if (statErr == nil && stat.Mode()&os.ModeSymlink == os.ModeSymlink) || c.Fs.IsNotExist(statErr) {
if !c.User.HasPerm(dataprovider.PermUpload, path.Dir(request.Filepath)) {
return nil, sftp.ErrSSHFxPermissionDenied
}
return c.handleSFTPUploadToNewFile(p, filePath, request.Filepath)
return c.handleSFTPUploadToNewFile(p, filePath, request.Filepath, errForRead)
}
if statErr != nil {
@@ -121,7 +142,7 @@ func (c *Connection) Filewrite(request *sftp.Request) (io.WriterAt, error) {
return nil, sftp.ErrSSHFxPermissionDenied
}
return c.handleSFTPUploadToExistingFile(request.Pflags(), p, filePath, stat.Size(), request.Filepath)
return c.handleSFTPUploadToExistingFile(request.Pflags(), p, filePath, stat.Size(), request.Filepath, errForRead)
}
// Filecmd hander for basic SFTP system calls related to files, but not anything to do with reading
@@ -272,7 +293,7 @@ func (c *Connection) handleSFTPRemove(filePath string, request *sftp.Request) er
return c.RemoveFile(filePath, request.Filepath, fi)
}
func (c *Connection) handleSFTPUploadToNewFile(resolvedPath, filePath, requestPath string) (io.WriterAt, error) {
func (c *Connection) handleSFTPUploadToNewFile(resolvedPath, filePath, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) {
quotaResult := c.HasSpace(true, requestPath)
if !quotaResult.HasSpace {
c.Log(logger.LevelInfo, "denying file write due to quota limits")
@@ -292,13 +313,13 @@ func (c *Connection) handleSFTPUploadToNewFile(resolvedPath, filePath, requestPa
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs)
t := newTransfer(baseTransfer, w, nil)
t := newTransfer(baseTransfer, w, nil, errForRead)
return t, nil
}
func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, resolvedPath, filePath string,
fileSize int64, requestPath string) (io.WriterAt, error) {
fileSize int64, requestPath string, errForRead error) (sftp.WriterAtReaderAt, error) {
var err error
quotaResult := c.HasSpace(false, requestPath)
if !quotaResult.HasSpace {
@@ -363,7 +384,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r
baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, c.Fs)
t := newTransfer(baseTransfer, w, nil)
t := newTransfer(baseTransfer, w, nil, errForRead)
return t, nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/eikenb/pipeat"
"github.com/pkg/sftp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"github.com/drakkan/sftpgo/common"
@@ -159,7 +160,7 @@ func TestUploadResumeInvalidOffset(t *testing.T) {
fs := vfs.NewOsFs("", os.TempDir(), nil)
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
_, err = transfer.WriteAt([]byte("test"), 0)
assert.Error(t, err, "upload with invalid offset must fail")
if assert.Error(t, transfer.ErrTransfer) {
@@ -187,7 +188,7 @@ func TestReadWriteErrors(t *testing.T) {
fs := vfs.NewOsFs("", os.TempDir(), nil)
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
err = file.Close()
assert.NoError(t, err)
_, err = transfer.WriteAt([]byte("test"), 0)
@@ -201,8 +202,8 @@ func TestReadWriteErrors(t *testing.T) {
r, _, err := pipeat.Pipe()
assert.NoError(t, err)
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
transfer = newTransfer(baseTransfer, nil, r)
err = transfer.closeIO()
transfer = newTransfer(baseTransfer, nil, r, nil)
err = transfer.Close()
assert.NoError(t, err)
_, err = transfer.ReadAt(buf, 0)
assert.Error(t, err, "reading from a closed pipe must fail")
@@ -211,7 +212,7 @@ func TestReadWriteErrors(t *testing.T) {
assert.NoError(t, err)
pipeWriter := vfs.NewPipeWriter(w)
baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
transfer = newTransfer(baseTransfer, pipeWriter, nil)
transfer = newTransfer(baseTransfer, pipeWriter, nil, nil)
err = r.Close()
assert.NoError(t, err)
@@ -224,9 +225,12 @@ func TestReadWriteErrors(t *testing.T) {
assert.EqualError(t, err, errFake.Error())
_, err = transfer.WriteAt([]byte("test"), 0)
assert.Error(t, err, "writing to closed pipe must fail")
err = transfer.BaseTransfer.Close()
assert.EqualError(t, err, errFake.Error())
err = os.Remove(testfile)
assert.NoError(t, err)
assert.Len(t, conn.GetTransfers(), 0)
}
func TestUnsupportedListOP(t *testing.T) {
@@ -254,7 +258,7 @@ func TestTransferCancelFn(t *testing.T) {
fs := vfs.NewOsFs("", os.TempDir(), nil)
conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs)
baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
errFake := errors.New("fake error, this will trigger cancelFn")
transfer.TransferError(errFake)
@@ -293,7 +297,7 @@ func TestMockFsErrors(t *testing.T) {
flags.Write = true
flags.Trunc = false
flags.Append = true
_, err = c.handleSFTPUploadToExistingFile(flags, testfile, testfile, 0, "/testfile")
_, err = c.handleSFTPUploadToExistingFile(flags, testfile, testfile, 0, "/testfile", nil)
assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error())
fs = newMockOsFs(errFake, nil, false, "123", os.TempDir())
@@ -321,18 +325,18 @@ func TestUploadFiles(t *testing.T) {
var flags sftp.FileOpenFlags
flags.Write = true
flags.Trunc = true
_, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path")
_, err := c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path", nil)
assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
common.Config.UploadMode = common.UploadModeStandard
_, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path")
_, err = c.handleSFTPUploadToExistingFile(flags, "missing_path", "other_missing_path", 0, "/missing_path", nil)
assert.Error(t, err, "upload to existing file must fail if one or both paths are invalid")
missingFile := "missing/relative/file.txt"
if runtime.GOOS == osWindows {
missingFile = "missing\\relative\\file.txt"
}
_, err = c.handleSFTPUploadToNewFile(".", missingFile, "/missing")
_, err = c.handleSFTPUploadToNewFile(".", missingFile, "/missing", nil)
assert.Error(t, err, "upload new file in missing path must fail")
c.BaseConnection.Fs = newMockOsFs(nil, nil, false, "123", os.TempDir())
@@ -341,7 +345,7 @@ func TestUploadFiles(t *testing.T) {
err = f.Close()
assert.NoError(t, err)
tr, err := c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name())
tr, err := c.handleSFTPUploadToExistingFile(flags, f.Name(), f.Name(), 123, f.Name(), nil)
if assert.NoError(t, err) {
transfer := tr.(*transfer)
transfers := c.GetTransfers()
@@ -990,7 +994,7 @@ func TestSystemCommandErrors(t *testing.T) {
sshCmd.connection.channel = &mockSSHChannel
baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", common.TransferDownload,
0, 0, 0, false, fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
destBuff := make([]byte, 65535)
dst := bytes.NewBuffer(destBuff)
_, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel)
@@ -1542,7 +1546,7 @@ func TestSCPUploadFiledata(t *testing.T) {
baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(),
"/"+testfile, common.TransferDownload, 0, 0, 0, true, fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
err = scpCommand.getUploadFileData(2, transfer)
assert.Error(t, err, "upload must fail, we send a fake write error message")
@@ -1574,7 +1578,7 @@ func TestSCPUploadFiledata(t *testing.T) {
file, err = os.Create(testfile)
assert.NoError(t, err)
baseTransfer.File = file
transfer = newTransfer(baseTransfer, nil, nil)
transfer = newTransfer(baseTransfer, nil, nil, nil)
transfer.Connection.AddTransfer(transfer)
err = scpCommand.getUploadFileData(2, transfer)
assert.Error(t, err, "upload must fail, we have not enough data to read")
@@ -1626,7 +1630,7 @@ func TestUploadError(t *testing.T) {
assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile,
testfile, common.TransferUpload, 0, 0, 0, true, fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
errFake := errors.New("fake error")
transfer.TransferError(errFake)
@@ -1645,6 +1649,49 @@ func TestUploadError(t *testing.T) {
common.Config.UploadMode = oldUploadMode
}
func TestTransferFailingReader(t *testing.T) {
user := dataprovider.User{
Username: "testuser",
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := newMockOsFs(nil, nil, true, "", os.TempDir())
connection := &Connection{
BaseConnection: common.NewBaseConnection("", common.ProtocolSFTP, user, fs),
}
request := sftp.NewRequest("Open", "afile.txt")
request.Flags = 27 // read,write,create,truncate
transfer, err := connection.handleFilewrite(request)
require.NoError(t, err)
buf := make([]byte, 32)
_, err = transfer.ReadAt(buf, 0)
assert.EqualError(t, err, sftp.ErrSSHFxOpUnsupported.Error())
if c, ok := transfer.(io.Closer); ok {
err = c.Close()
assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
}
fsPath := filepath.Join(os.TempDir(), "afile.txt")
r, _, err := pipeat.Pipe()
assert.NoError(t, err)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, fsPath, filepath.Base(fsPath), common.TransferUpload, 0, 0, 0, false, fs)
errRead := errors.New("read is not allowed")
tr := newTransfer(baseTransfer, nil, r, errRead)
_, err = tr.ReadAt(buf, 0)
assert.EqualError(t, err, errRead.Error())
err = tr.Close()
assert.EqualError(t, err, sftp.ErrSSHFxFailure.Error())
err = os.Remove(fsPath)
assert.NoError(t, err)
assert.Len(t, connection.GetTransfers(), 0)
}
func TestConnectionStatusStruct(t *testing.T) {
var transfers []common.ConnectionTransfer
transferUL := common.ConnectionTransfer{

View File

@@ -227,7 +227,7 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, requestPath,
common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, c.connection.Fs)
t := newTransfer(baseTransfer, w, nil)
t := newTransfer(baseTransfer, w, nil, nil)
return c.getUploadFileData(sizeToRead, t)
}
@@ -485,7 +485,7 @@ func (c *scpCommand) handleDownload(filePath string) error {
baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, filePath,
common.TransferDownload, 0, 0, 0, false, c.connection.Fs)
t := newTransfer(baseTransfer, nil, r)
t := newTransfer(baseTransfer, nil, r, nil)
err = c.sendDownloadFileData(p, stat, t)
// we need to call Close anyway and return close error if any and

View File

@@ -359,6 +359,51 @@ func TestBasicSFTPHandling(t *testing.T) {
assert.NoError(t, err)
}
func TestOpenReadWrite(t *testing.T) {
usePubKey := false
u := getTestUser(usePubKey)
u.QuotaSize = 6553600
user, _, err := httpd.AddUser(u, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
client, err := getSftpClient(user, usePubKey)
if assert.NoError(t, err) {
defer client.Close()
sftpFile, err := client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC)
if assert.NoError(t, err) {
testData := []byte("test data")
n, err := sftpFile.Write(testData)
assert.NoError(t, err)
assert.Equal(t, len(testData), n)
buffer := make([]byte, 128)
n, err = sftpFile.ReadAt(buffer, 1)
assert.EqualError(t, err, io.EOF.Error())
assert.Equal(t, len(testData)-1, n)
assert.Equal(t, testData[1:], buffer[:n])
sftpFile.Close()
}
sftpFile, err = client.OpenFile(testFileName, os.O_RDWR|os.O_CREATE|os.O_TRUNC)
if assert.NoError(t, err) {
testData := []byte("new test data")
n, err := sftpFile.Write(testData)
assert.NoError(t, err)
assert.Equal(t, len(testData), n)
buffer := make([]byte, 128)
n, err = sftpFile.ReadAt(buffer, 1)
assert.EqualError(t, err, io.EOF.Error())
assert.Equal(t, len(testData)-1, n)
assert.Equal(t, testData[1:], buffer[:n])
sftpFile.Close()
sftpFile.Close()
}
}
_, err = httpd.RemoveUser(user, http.StatusOK)
assert.NoError(t, err)
err = os.RemoveAll(user.GetHomeDir())
assert.NoError(t, err)
}
func TestConcurrency(t *testing.T) {
usePubKey := true
numLogins := 50

View File

@@ -354,7 +354,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
defer stdin.Close()
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath,
common.TransferUpload, 0, 0, remainingQuotaSize, false, c.connection.Fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel)
c.connection.Log(logger.LevelDebug, "command: %#v, copy from remote command to sdtin ended, written: %v, "+
@@ -367,7 +367,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() {
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath,
common.TransferDownload, 0, 0, 0, false, c.connection.Fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout)
c.connection.Log(logger.LevelDebug, "command: %#v, copy from sdtout to remote command ended, written: %v err: %v",
@@ -381,7 +381,7 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error {
go func() {
baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath,
common.TransferDownload, 0, 0, 0, false, c.connection.Fs)
transfer := newTransfer(baseTransfer, nil, nil)
transfer := newTransfer(baseTransfer, nil, nil, nil)
w, e := transfer.copyFromReaderToWriter(c.connection.channel.Stderr(), stderr)
c.connection.Log(logger.LevelDebug, "command: %#v, copy from sdterr to remote command ended, written: %v err: %v",

View File

@@ -22,6 +22,19 @@ type readerAtCloser interface {
io.Closer
}
type failingReader struct {
innerReader readerAtCloser
errRead error
}
func (r *failingReader) ReadAt(p []byte, off int64) (n int, err error) {
return 0, r.errRead
}
func (r *failingReader) Close() error {
return r.innerReader.Close()
}
// transfer defines the transfer details.
// It implements the io.ReaderAt and io.WriterAt interfaces to handle SFTP downloads and uploads
type transfer struct {
@@ -31,16 +44,31 @@ type transfer struct {
isFinished bool
}
func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt) *transfer {
func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt,
errForRead error) *transfer {
var writer writerAtCloser
var reader readerAtCloser
if baseTransfer.File != nil {
writer = baseTransfer.File
reader = baseTransfer.File
if errForRead == nil {
reader = baseTransfer.File
} else {
reader = &failingReader{
innerReader: baseTransfer.File,
errRead: errForRead,
}
}
} else if pipeWriter != nil {
writer = pipeWriter
} else if pipeReader != nil {
reader = pipeReader
if errForRead == nil {
reader = pipeReader
} else {
reader = &failingReader{
innerReader: pipeReader,
errRead: errForRead,
}
}
}
return &transfer{
BaseTransfer: baseTransfer,