diff --git a/.github/workflows/development.yml b/.github/workflows/development.yml index 005550d0..5c9ad07e 100644 --- a/.github/workflows/development.yml +++ b/.github/workflows/development.yml @@ -49,7 +49,7 @@ jobs: shell: bash - name: Run test cases using SQLite provider - run: go test -v -p 1 -timeout 5m ./... -coverprofile=coverage.txt -covermode=atomic + run: go test -v -p 1 -timeout 10m ./... -coverprofile=coverage.txt -covermode=atomic - name: Upload coverage to Codecov if: ${{ matrix.upload-coverage }} @@ -60,10 +60,10 @@ jobs: - name: Run test cases using bolt provider run: | - go test -v -p 1 -timeout 1m ./config -covermode=atomic - go test -v -p 1 -timeout 1m ./common -covermode=atomic - go test -v -p 1 -timeout 2m ./httpd -covermode=atomic - go test -v -p 1 -timeout 5m ./sftpd -covermode=atomic + go test -v -p 1 -timeout 2m ./config -covermode=atomic + go test -v -p 1 -timeout 2m ./common -covermode=atomic + go test -v -p 1 -timeout 3m ./httpd -covermode=atomic + go test -v -p 1 -timeout 8m ./sftpd -covermode=atomic go test -v -p 1 -timeout 2m ./ftpd -covermode=atomic go test -v -p 1 -timeout 2m ./webdavd -covermode=atomic env: @@ -71,7 +71,7 @@ jobs: SFTPGO_DATA_PROVIDER__NAME: 'sftpgo_bolt.db' - name: Run test cases using memory provider - run: go test -v -p 1 -timeout 5m ./... -covermode=atomic + run: go test -v -p 1 -timeout 10m ./... -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: memory SFTPGO_DATA_PROVIDER__NAME: '' @@ -153,7 +153,7 @@ jobs: - name: Run tests using PostgreSQL provider run: | ./sftpgo initprovider - go test -v -p 1 -timeout 5m ./... -covermode=atomic + go test -v -p 1 -timeout 10m ./... -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: postgresql SFTPGO_DATA_PROVIDER__NAME: sftpgo @@ -165,7 +165,7 @@ jobs: - name: Run tests using MySQL provider run: | ./sftpgo initprovider - go test -v -p 1 -timeout 5m ./... -covermode=atomic + go test -v -p 1 -timeout 10m ./... -covermode=atomic env: SFTPGO_DATA_PROVIDER__DRIVER: mysql SFTPGO_DATA_PROVIDER__NAME: sftpgo diff --git a/common/common.go b/common/common.go index 436001b4..3db62dbc 100644 --- a/common/common.go +++ b/common/common.go @@ -145,7 +145,7 @@ type ActiveTransfer interface { GetVirtualPath() string GetStartTime() time.Time SignalClose() - Truncate(fsPath string, size int64) error + Truncate(fsPath string, size int64) (int64, error) } // ActiveConnection defines the interface for the current active connections diff --git a/common/common_test.go b/common/common_test.go index b8caa2d0..4e221610 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -19,6 +19,7 @@ import ( "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/httpclient" "github.com/drakkan/sftpgo/logger" + "github.com/drakkan/sftpgo/vfs" ) const ( @@ -267,13 +268,14 @@ func TestConnectionStatus(t *testing.T) { user := dataprovider.User{ Username: username, } - c1 := NewBaseConnection("id1", ProtocolSFTP, user, nil) + fs := vfs.NewOsFs("", os.TempDir(), nil) + c1 := NewBaseConnection("id1", ProtocolSFTP, user, fs) fakeConn1 := &fakeConnection{ BaseConnection: c1, } - t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/r1", TransferUpload, 0, 0, true) + t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/r1", TransferUpload, 0, 0, 0, true, fs) t1.BytesReceived = 123 - t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, true) + t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs) t2.BytesSent = 456 c2 := NewBaseConnection("id2", ProtocolSSH, user, nil) fakeConn2 := &fakeConnection{ @@ -285,7 +287,7 @@ func TestConnectionStatus(t *testing.T) { BaseConnection: c3, command: "PROPFIND", } - t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/r2", TransferDownload, 0, 0, true) + t3 := NewBaseTransfer(nil, c3, nil, "/p2", "/r2", TransferDownload, 0, 0, 0, true, fs) Connections.Add(fakeConn1) Connections.Add(fakeConn2) Connections.Add(fakeConn3) diff --git a/common/connection.go b/common/connection.go index 0dca07fe..434d2a75 100644 --- a/common/connection.go +++ b/common/connection.go @@ -172,18 +172,18 @@ func (c *BaseConnection) SignalTransfersAbort() error { return nil } -func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) error { +func (c *BaseConnection) truncateOpenHandle(fsPath string, size int64) (int64, error) { c.RLock() defer c.RUnlock() for _, t := range c.activeTransfers { - err := t.Truncate(fsPath, size) + initialSize, err := t.Truncate(fsPath, size) if err != errTransferMismatch { - return err + return initialSize, err } } - return errNoTransfer + return 0, errNoTransfer } // ListDir reads the directory named by fsPath and returns a list of directory entries @@ -481,7 +481,7 @@ func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAtt return c.GetPermissionDeniedError() } - if err := c.truncateFile(fsPath, attributes.Size); err != nil { + if err := c.truncateFile(fsPath, virtualPath, attributes.Size); err != nil { c.Log(logger.LevelWarn, "failed to truncate path %#v, size: %v, err: %+v", fsPath, attributes.Size, err) return c.GetFsError(err) } @@ -491,17 +491,34 @@ func (c *BaseConnection) SetStat(fsPath, virtualPath string, attributes *StatAtt return nil } -func (c *BaseConnection) truncateFile(fsPath string, size int64) error { +func (c *BaseConnection) truncateFile(fsPath, virtualPath string, size int64) error { // check first if we have an open transfer for the given path and try to truncate the file already opened // if we found no transfer we truncate by path. - // pkg/sftp should expose an optional interface and call truncate directly on the opened handle ... - // If we try to truncate by path an already opened file we get an error on Windows + var initialSize int64 var err error - err = c.truncateOpenHandle(fsPath, size) + initialSize, err = c.truncateOpenHandle(fsPath, size) if err == errNoTransfer { c.Log(logger.LevelDebug, "file path %#v not found in active transfers, execute trucate by path", fsPath) + var info os.FileInfo + info, err = c.Fs.Stat(fsPath) + if err != nil { + return err + } + initialSize = info.Size() err = c.Fs.Truncate(fsPath, size) } + if err == nil && vfs.IsLocalOsFs(c.Fs) { + sizeDiff := initialSize - size + vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) + if err == nil { + dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -sizeDiff, false) //nolint:errcheck + if vfolder.IsIncludedInUserQuota() { + dataprovider.UpdateUserQuota(c.User, 0, -sizeDiff, false) //nolint:errcheck + } + } else { + dataprovider.UpdateUserQuota(c.User, 0, -sizeDiff, false) //nolint:errcheck + } + } return err } diff --git a/common/connection_test.go b/common/connection_test.go index a92df599..201ef525 100644 --- a/common/connection_test.go +++ b/common/connection_test.go @@ -554,8 +554,36 @@ func TestSetStat(t *testing.T) { assert.Equal(t, int64(1), fi.Size()) } + vDir := filepath.Join(os.TempDir(), "vdir") + err = os.MkdirAll(vDir, os.ModePerm) + assert.NoError(t, err) + c.User.VirtualFolders = nil + c.User.VirtualFolders = append(c.User.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: vDir, + }, + VirtualPath: "/vpath", + QuotaSize: -1, + QuotaFiles: -1, + }) + + filePath = filepath.Join(vDir, "afile.txt") + err = ioutil.WriteFile(filePath, []byte("hello"), os.ModePerm) + assert.NoError(t, err) + err = c.SetStat(filePath, "/vpath/afile.txt", &StatAttributes{ + Flags: StatAttrSize, + Size: 1, + }) + assert.NoError(t, err) + fi, err = os.Stat(filePath) + if assert.NoError(t, err) { + assert.Equal(t, int64(1), fi.Size()) + } + err = os.RemoveAll(user.GetHomeDir()) assert.NoError(t, err) + err = os.RemoveAll(vDir) + assert.NoError(t, err) } func TestSpaceForCrossRename(t *testing.T) { diff --git a/common/transfer.go b/common/transfer.go index 4a978fc3..252a9d40 100644 --- a/common/transfer.go +++ b/common/transfer.go @@ -11,6 +11,7 @@ import ( "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/metrics" + "github.com/drakkan/sftpgo/vfs" ) var ( @@ -21,6 +22,7 @@ var ( // BaseTransfer contains protocols common transfer details for an upload or a download. type BaseTransfer struct { //nolint:maligned ID uint64 + Fs vfs.Fs File *os.File Connection *BaseConnection cancelFn func() @@ -33,6 +35,7 @@ type BaseTransfer struct { //nolint:maligned requestPath string BytesSent int64 BytesReceived int64 + MaxWriteSize int64 AbortTransfer int32 sync.Mutex ErrTransfer error @@ -40,7 +43,7 @@ type BaseTransfer struct { //nolint:maligned // NewBaseTransfer returns a new BaseTransfer and adds it to the given connection func NewBaseTransfer(file *os.File, conn *BaseConnection, cancelFn func(), fsPath, requestPath string, transferType int, - minWriteOffset, initialSize int64, isNewFile bool) *BaseTransfer { + minWriteOffset, initialSize, maxWriteSize int64, isNewFile bool, fs vfs.Fs) *BaseTransfer { t := &BaseTransfer{ ID: conn.GetTransferID(), File: file, @@ -55,7 +58,9 @@ func NewBaseTransfer(file *os.File, conn *BaseConnection, cancelFn func(), fsPat requestPath: requestPath, BytesSent: 0, BytesReceived: 0, + MaxWriteSize: maxWriteSize, AbortTransfer: 0, + Fs: fs, } conn.AddTransfer(t) @@ -110,18 +115,33 @@ func (t *BaseTransfer) SetCancelFn(cancelFn func()) { // Truncate changes the size of the opened file. // Supported for local fs only -func (t *BaseTransfer) Truncate(fsPath string, size int64) error { +func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) { if fsPath == t.GetFsPath() { if t.File != nil { - return t.File.Truncate(size) + initialSize := t.InitialSize + err := t.File.Truncate(size) + if err == nil { + t.Lock() + t.InitialSize = size + if t.MaxWriteSize > 0 { + sizeDiff := initialSize - size + t.MaxWriteSize += sizeDiff + metrics.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer) + atomic.StoreInt64(&t.BytesReceived, 0) + } + t.Unlock() + } + t.Connection.Log(logger.LevelDebug, "file %#v truncated to size %v max write size %v new initial size %v err: %v", + fsPath, size, t.MaxWriteSize, t.InitialSize, err) + return initialSize, err } - if size == 0 { + if size == 0 && atomic.LoadInt64(&t.BytesSent) == 0 { // for cloud providers the file is always truncated to zero, we don't support append/resume for uploads - return nil + return 0, nil } - return ErrOpUnsupported + return 0, ErrOpUnsupported } - return errTransferMismatch + return 0, errTransferMismatch } // TransferError is called if there is an unexpected error. @@ -190,10 +210,17 @@ func (t *BaseTransfer) Close() error { atomic.LoadInt64(&t.BytesSent), t.ErrTransfer) go action.execute() //nolint:errcheck } else { + fileSize := atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset + info, err := t.Fs.Stat(t.fsPath) + if err == nil { + fileSize = info.Size() + } + t.Connection.Log(logger.LevelDebug, "upload file size %v stat error %v", fileSize, err) + t.updateQuota(numFiles, fileSize) logger.TransferLog(uploadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesReceived), t.Connection.User.Username, t.Connection.ID, t.Connection.protocol) action := newActionNotification(&t.Connection.User, operationUpload, t.fsPath, "", "", t.Connection.protocol, - atomic.LoadInt64(&t.BytesReceived)+t.MinWriteOffset, t.ErrTransfer) + fileSize, t.ErrTransfer) go action.execute() //nolint:errcheck } if t.ErrTransfer != nil { @@ -202,26 +229,25 @@ func (t *BaseTransfer) Close() error { err = t.ErrTransfer } } - t.updateQuota(numFiles) return err } -func (t *BaseTransfer) updateQuota(numFiles int) bool { +func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool { // S3 uploads are atomic, if there is an error nothing is uploaded if t.File == nil && t.ErrTransfer != nil { return false } - bytesReceived := atomic.LoadInt64(&t.BytesReceived) - if t.transferType == TransferUpload && (numFiles != 0 || bytesReceived > 0) { + sizeDiff := fileSize - t.InitialSize + if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff > 0) { vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath)) if err == nil { dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck - bytesReceived-t.InitialSize, false) + sizeDiff, false) if vfolder.IsIncludedInUserQuota() { - dataprovider.UpdateUserQuota(t.Connection.User, numFiles, bytesReceived-t.InitialSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck } } else { - dataprovider.UpdateUserQuota(t.Connection.User, numFiles, bytesReceived-t.InitialSize, false) //nolint:errcheck + dataprovider.UpdateUserQuota(t.Connection.User, numFiles, sizeDiff, false) //nolint:errcheck } return true } diff --git a/common/transfer_test.go b/common/transfer_test.go index 23b72386..bec86fc1 100644 --- a/common/transfer_test.go +++ b/common/transfer_test.go @@ -20,10 +20,11 @@ func TestTransferUpdateQuota(t *testing.T) { Connection: conn, transferType: TransferUpload, BytesReceived: 123, + Fs: vfs.NewOsFs("", os.TempDir(), nil), } errFake := errors.New("fake error") transfer.TransferError(errFake) - assert.False(t, transfer.updateQuota(1)) + assert.False(t, transfer.updateQuota(1, 0)) err := transfer.Close() if assert.Error(t, err) { assert.EqualError(t, err, errFake.Error()) @@ -41,7 +42,7 @@ func TestTransferUpdateQuota(t *testing.T) { transfer.ErrTransfer = nil transfer.BytesReceived = 1 transfer.requestPath = "/vdir/file" - assert.True(t, transfer.updateQuota(1)) + assert.True(t, transfer.updateQuota(1, 0)) err = transfer.Close() assert.NoError(t, err) } @@ -52,6 +53,7 @@ func TestTransferThrottling(t *testing.T) { UploadBandwidth: 50, DownloadBandwidth: 40, } + fs := vfs.NewOsFs("", os.TempDir(), nil) testFileSize := int64(131072) wantedUploadElapsed := 1000 * (testFileSize / 1000) / u.UploadBandwidth wantedDownloadElapsed := 1000 * (testFileSize / 1000) / u.DownloadBandwidth @@ -59,7 +61,7 @@ func TestTransferThrottling(t *testing.T) { wantedUploadElapsed -= wantedDownloadElapsed / 10 wantedDownloadElapsed -= wantedDownloadElapsed / 10 conn := NewBaseConnection("id", ProtocolSCP, u, nil) - transfer := NewBaseTransfer(nil, conn, nil, "", "", TransferUpload, 0, 0, true) + transfer := NewBaseTransfer(nil, conn, nil, "", "", TransferUpload, 0, 0, 0, true, fs) transfer.BytesReceived = testFileSize transfer.Connection.UpdateLastActivity() startTime := transfer.Connection.GetLastActivity() @@ -69,7 +71,7 @@ func TestTransferThrottling(t *testing.T) { err := transfer.Close() assert.NoError(t, err) - transfer = NewBaseTransfer(nil, conn, nil, "", "", TransferDownload, 0, 0, true) + transfer = NewBaseTransfer(nil, conn, nil, "", "", TransferDownload, 0, 0, 0, true, fs) transfer.BytesSent = testFileSize transfer.Connection.UpdateLastActivity() startTime = transfer.Connection.GetLastActivity() @@ -97,13 +99,14 @@ func TestTruncate(t *testing.T) { _, err = file.Write([]byte("hello")) assert.NoError(t, err) conn := NewBaseConnection(fs.ConnectionID(), ProtocolSFTP, u, fs) - transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, true) + transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs) err = conn.SetStat(testFile, "/transfer_test_file", &StatAttributes{ Size: 2, Flags: StatAttrSize, }) assert.NoError(t, err) + assert.Equal(t, int64(98), transfer.MaxWriteSize) err = transfer.Close() assert.NoError(t, err) err = file.Close() @@ -113,12 +116,22 @@ func TestTruncate(t *testing.T) { assert.Equal(t, int64(2), fi.Size()) } - transfer = NewBaseTransfer(nil, conn, nil, testFile, "", TransferUpload, 0, 0, true) - err = transfer.Truncate("mismatch", 0) - assert.EqualError(t, err, errTransferMismatch.Error()) - err = transfer.Truncate(testFile, 0) + transfer = NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 100, true, fs) + // file.Stat will fail on a closed file + err = conn.SetStat(testFile, "/transfer_test_file", &StatAttributes{ + Size: 2, + Flags: StatAttrSize, + }) + assert.Error(t, err) + err = transfer.Close() assert.NoError(t, err) - err = transfer.Truncate(testFile, 1) + + transfer = NewBaseTransfer(nil, conn, nil, testFile, "", TransferUpload, 0, 0, 0, true, fs) + _, err = transfer.Truncate("mismatch", 0) + assert.EqualError(t, err, errTransferMismatch.Error()) + _, err = transfer.Truncate(testFile, 0) + assert.NoError(t, err) + _, err = transfer.Truncate(testFile, 1) assert.EqualError(t, err, ErrOpUnsupported.Error()) err = transfer.Close() @@ -148,7 +161,7 @@ func TestTransferErrors(t *testing.T) { assert.FailNow(t, "unable to open test file") } conn := NewBaseConnection("id", ProtocolSFTP, u, fs) - transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, true) + transfer := NewBaseTransfer(file, conn, nil, testFile, "/transfer_test_file", TransferUpload, 0, 0, 0, true, fs) assert.Nil(t, transfer.cancelFn) assert.Equal(t, testFile, transfer.GetFsPath()) transfer.SetCancelFn(cancelFn) @@ -174,7 +187,7 @@ func TestTransferErrors(t *testing.T) { assert.FailNow(t, "unable to open test file") } fsPath := filepath.Join(os.TempDir(), "test_file") - transfer = NewBaseTransfer(file, conn, nil, fsPath, "/test_file", TransferUpload, 0, 0, true) + transfer = NewBaseTransfer(file, conn, nil, fsPath, "/test_file", TransferUpload, 0, 0, 0, true, fs) transfer.BytesReceived = 9 transfer.TransferError(errFake) assert.Error(t, transfer.ErrTransfer, errFake.Error()) @@ -193,7 +206,7 @@ func TestTransferErrors(t *testing.T) { if !assert.NoError(t, err) { assert.FailNow(t, "unable to open test file") } - transfer = NewBaseTransfer(file, conn, nil, fsPath, "/test_file", TransferUpload, 0, 0, true) + transfer = NewBaseTransfer(file, conn, nil, fsPath, "/test_file", TransferUpload, 0, 0, 0, true, fs) transfer.BytesReceived = 9 // the file is closed from the embedding struct before to call close err = file.Close() diff --git a/docs/s3.md b/docs/s3.md index 7b99bbb7..3e7f9060 100644 --- a/docs/s3.md +++ b/docs/s3.md @@ -24,6 +24,8 @@ Some SFTP commands don't work over S3: - `symlink` and `chtimes` will fail - `chown` and `chmod` are silently ignored +- `truncate` is not supported +- opening a file for both reading and writing at the same time is not supported - upload resume is not supported - upload mode `atomic` is ignored since S3 uploads are already atomic @@ -33,3 +35,4 @@ Other notes: - We don't support renaming non empty directories since we should rename all the contents too and this could take a long time: think about directories with thousands of files; for each file we should do an AWS API call. - For server side encryption, you have to configure the mapped bucket to automatically encrypt objects. - A local home directory is still required to store temporary files. +- Clients that require advanced filesystem-like features such as `sshfs` are not supported. diff --git a/ftpd/handler.go b/ftpd/handler.go index 9a9011c9..c181388f 100644 --- a/ftpd/handler.go +++ b/ftpd/handler.go @@ -292,8 +292,8 @@ func (c *Connection) downloadFile(fsPath, ftpPath string, offset int64) (ftpserv } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, ftpPath, common.TransferDownload, - 0, 0, false) - t := newTransfer(baseTransfer, nil, r, 0, offset) + 0, 0, 0, false, c.Fs) + t := newTransfer(baseTransfer, nil, r, offset) return t, nil } @@ -353,8 +353,8 @@ func (c *Connection) handleFTPUploadToNewFile(resolvedPath, filePath, requestPat maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, 0, true) - t := newTransfer(baseTransfer, w, nil, maxWriteSize, 0) + common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs) + t := newTransfer(baseTransfer, w, nil, 0) return t, nil } @@ -396,6 +396,7 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file if isResume { c.Log(logger.LevelDebug, "upload resume requested, file path: %#v initial size: %v", filePath, fileSize) minWriteOffset = fileSize + initialSize = fileSize } else { if vfs.IsLocalOsFs(c.Fs) { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) @@ -415,8 +416,8 @@ func (c *Connection) handleFTPUploadToExistingFile(flags int, resolvedPath, file vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, minWriteOffset, initialSize, false) - t := newTransfer(baseTransfer, w, nil, maxWriteSize, 0) + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, c.Fs) + t := newTransfer(baseTransfer, w, nil, 0) return t, nil } diff --git a/ftpd/internal_test.go b/ftpd/internal_test.go index a19fc316..97d00cf4 100644 --- a/ftpd/internal_test.go +++ b/ftpd/internal_test.go @@ -403,8 +403,8 @@ func TestTransferErrors(t *testing.T) { clientContext: mockCC, } baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, file.Name(), testfile, common.TransferDownload, - 0, 0, false) - tr := newTransfer(baseTransfer, nil, nil, 0, 0) + 0, 0, 0, false, fs) + tr := newTransfer(baseTransfer, nil, nil, 0) err = tr.Close() assert.NoError(t, err) _, err = tr.Seek(10, 0) @@ -421,8 +421,8 @@ func TestTransferErrors(t *testing.T) { r, _, err := pipeat.Pipe() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, - common.TransferUpload, 0, 0, false) - tr = newTransfer(baseTransfer, nil, r, 0, 10) + common.TransferUpload, 0, 0, 0, false, fs) + tr = newTransfer(baseTransfer, nil, r, 10) pos, err := tr.Seek(10, 0) assert.NoError(t, err) assert.Equal(t, pos, tr.expectedOffset) @@ -433,8 +433,8 @@ func TestTransferErrors(t *testing.T) { assert.NoError(t, err) pipeWriter := vfs.NewPipeWriter(w) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testfile, testfile, - common.TransferUpload, 0, 0, false) - tr = newTransfer(baseTransfer, pipeWriter, nil, 0, 0) + common.TransferUpload, 0, 0, 0, false, fs) + tr = newTransfer(baseTransfer, pipeWriter, nil, 0) err = r.Close() assert.NoError(t, err) diff --git a/ftpd/transfer.go b/ftpd/transfer.go index 231d1d47..54b7acfd 100644 --- a/ftpd/transfer.go +++ b/ftpd/transfer.go @@ -18,12 +18,11 @@ type transfer struct { writer io.WriteCloser reader io.ReadCloser isFinished bool - maxWriteSize int64 expectedOffset int64 } func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt, - maxWriteSize, expectedOffset int64) *transfer { + expectedOffset int64) *transfer { var writer io.WriteCloser var reader io.ReadCloser if baseTransfer.File != nil { @@ -39,7 +38,6 @@ func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, writer: writer, reader: reader, isFinished: false, - maxWriteSize: maxWriteSize, expectedOffset: expectedOffset, } } @@ -70,7 +68,7 @@ func (t *transfer) Write(p []byte) (n int, err error) { written, e = t.writer.Write(p) atomic.AddInt64(&t.BytesReceived, int64(written)) - if t.maxWriteSize > 0 && e == nil && atomic.LoadInt64(&t.BytesReceived) > t.maxWriteSize { + if t.MaxWriteSize > 0 && e == nil && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize { e = common.ErrQuotaExceeded } if e != nil { diff --git a/go.mod b/go.mod index e1a8ed85..f2e4c966 100644 --- a/go.mod +++ b/go.mod @@ -51,7 +51,7 @@ require ( replace ( github.com/fclairamb/ftpserverlib => github.com/drakkan/ftpserverlib v0.0.0-20200814103339-511fcfd63dfe github.com/jlaffaye/ftp => github.com/drakkan/ftp v0.0.0-20200730125632-b21eac28818c - github.com/pkg/sftp => github.com/drakkan/sftp v0.0.0-20200820090459-de8eb908f763 + github.com/pkg/sftp => github.com/drakkan/sftp v0.0.0-20200822075112-b48593166377 golang.org/x/crypto => github.com/drakkan/crypto v0.0.0-20200731130417-7674a892f9b1 golang.org/x/net => github.com/drakkan/net v0.0.0-20200807161257-daa5cda5ae27 ) diff --git a/go.sum b/go.sum index 78d9f2fe..f10bafa5 100644 --- a/go.sum +++ b/go.sum @@ -111,8 +111,8 @@ github.com/drakkan/ftpserverlib v0.0.0-20200814103339-511fcfd63dfe h1:yiliFcCxu5 github.com/drakkan/ftpserverlib v0.0.0-20200814103339-511fcfd63dfe/go.mod h1:ShLpSOXbtoMDYxTb5eRs9wDBfkQ7VINYghclB4P2z4E= github.com/drakkan/net v0.0.0-20200807161257-daa5cda5ae27 h1:hh14GxmE3PMKL+4nvMmX7O8CUtbD/52IKDjbMTYX7IY= github.com/drakkan/net v0.0.0-20200807161257-daa5cda5ae27/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -github.com/drakkan/sftp v0.0.0-20200820090459-de8eb908f763 h1:2s1IvldI3h8S4FAF7K9uAB05/L0k8TlYF60GIEc2LUY= -github.com/drakkan/sftp v0.0.0-20200820090459-de8eb908f763/go.mod h1:i24A96cQ6ZvWut9G/Uv3LvC4u3VebGsBR5JFvPyChLc= +github.com/drakkan/sftp v0.0.0-20200822075112-b48593166377 h1:Zq2sezrJ0QwCVUpjOjMPZMHx3foI5pciGsNfCj8pHm4= +github.com/drakkan/sftp v0.0.0-20200822075112-b48593166377/go.mod h1:i24A96cQ6ZvWut9G/Uv3LvC4u3VebGsBR5JFvPyChLc= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/sftpd/handler.go b/sftpd/handler.go index 515b5349..f9e2c785 100644 --- a/sftpd/handler.go +++ b/sftpd/handler.go @@ -73,8 +73,8 @@ func (c *Connection) Fileread(request *sftp.Request) (io.ReaderAt, error) { } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, p, request.Filepath, common.TransferDownload, - 0, 0, false) - t := newTransfer(baseTransfer, nil, r, 0) + 0, 0, 0, false, c.Fs) + t := newTransfer(baseTransfer, nil, r) return t, nil } @@ -274,8 +274,8 @@ func (c *Connection) handleSFTPUploadToNewFile(resolvedPath, filePath, requestPa maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, 0, true) - t := newTransfer(baseTransfer, w, nil, maxWriteSize) + common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs) + t := newTransfer(baseTransfer, w, nil) return t, nil } @@ -291,10 +291,12 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r minWriteOffset := int64(0) osFlags := getOSOpenFlags(pflags) - isResume := pflags.Append && osFlags&os.O_TRUNC == 0 + isTruncate := osFlags&os.O_TRUNC != 0 + isResume := pflags.Append && !isTruncate - // if there is a size limit remaining size cannot be 0 here, since quotaResult.HasSpace - // will return false in this case and we deny the upload before + // if there is a size limit the remaining size cannot be 0 here, since quotaResult.HasSpace + // will return false in this case and we deny the upload before. + // For Cloud FS GetMaxWriteSize will return unsupported operation maxWriteSize, err := c.GetMaxWriteSize(quotaResult, isResume, fileSize) if err != nil { c.Log(logger.LevelDebug, "unable to get max write size: %v", err) @@ -320,8 +322,9 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r if isResume { c.Log(logger.LevelDebug, "upload resume requested, file path %#v initial size: %v", filePath, fileSize) minWriteOffset = fileSize + initialSize = fileSize } else { - if vfs.IsLocalOsFs(c.Fs) { + if vfs.IsLocalOsFs(c.Fs) && isTruncate { vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(requestPath)) if err == nil { dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, 0, -fileSize, false) //nolint:errcheck @@ -334,7 +337,7 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r } else { initialSize = fileSize } - if maxWriteSize > 0 { + if maxWriteSize > 0 && isTruncate { maxWriteSize += fileSize } } @@ -342,8 +345,8 @@ func (c *Connection) handleSFTPUploadToExistingFile(pflags sftp.FileOpenFlags, r vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, minWriteOffset, initialSize, false) - t := newTransfer(baseTransfer, w, nil, maxWriteSize) + common.TransferUpload, minWriteOffset, initialSize, maxWriteSize, false, c.Fs) + t := newTransfer(baseTransfer, w, nil) return t, nil } diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index d02b8c03..ff47606e 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -158,8 +158,8 @@ func TestUploadResumeInvalidOffset(t *testing.T) { } fs := vfs.NewOsFs("", os.TempDir(), nil) conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs) - baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferUpload, 10, 0, false) - transfer := newTransfer(baseTransfer, nil, nil, 0) + baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferUpload, 10, 0, 0, false, fs) + transfer := newTransfer(baseTransfer, nil, nil) _, err = transfer.WriteAt([]byte("test"), 0) assert.Error(t, err, "upload with invalid offset must fail") if assert.Error(t, transfer.ErrTransfer) { @@ -186,8 +186,8 @@ func TestReadWriteErrors(t *testing.T) { } fs := vfs.NewOsFs("", os.TempDir(), nil) conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs) - baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, false) - transfer := newTransfer(baseTransfer, nil, nil, 0) + baseTransfer := common.NewBaseTransfer(file, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) + transfer := newTransfer(baseTransfer, nil, nil) err = file.Close() assert.NoError(t, err) _, err = transfer.WriteAt([]byte("test"), 0) @@ -200,8 +200,8 @@ func TestReadWriteErrors(t *testing.T) { r, _, err := pipeat.Pipe() assert.NoError(t, err) - baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, false) - transfer = newTransfer(baseTransfer, nil, r, 0) + baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) + transfer = newTransfer(baseTransfer, nil, r) err = transfer.closeIO() assert.NoError(t, err) _, err = transfer.ReadAt(buf, 0) @@ -210,8 +210,8 @@ func TestReadWriteErrors(t *testing.T) { r, w, err := pipeat.Pipe() assert.NoError(t, err) pipeWriter := vfs.NewPipeWriter(w) - baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, false) - transfer = newTransfer(baseTransfer, pipeWriter, nil, 0) + baseTransfer = common.NewBaseTransfer(nil, conn, nil, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) + transfer = newTransfer(baseTransfer, pipeWriter, nil) err = r.Close() assert.NoError(t, err) @@ -242,8 +242,8 @@ func TestTransferCancelFn(t *testing.T) { } fs := vfs.NewOsFs("", os.TempDir(), nil) conn := common.NewBaseConnection("", common.ProtocolSFTP, user, fs) - baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), testfile, common.TransferDownload, 0, 0, false) - transfer := newTransfer(baseTransfer, nil, nil, 0) + baseTransfer := common.NewBaseTransfer(file, conn, cancelFn, file.Name(), testfile, common.TransferDownload, 0, 0, 0, false, fs) + transfer := newTransfer(baseTransfer, nil, nil) errFake := errors.New("fake error, this will trigger cancelFn") transfer.TransferError(errFake) @@ -978,8 +978,8 @@ func TestSystemCommandErrors(t *testing.T) { } sshCmd.connection.channel = &mockSSHChannel baseTransfer := common.NewBaseTransfer(nil, sshCmd.connection.BaseConnection, nil, "", "", common.TransferDownload, - 0, 0, false) - transfer := newTransfer(baseTransfer, nil, nil, 0) + 0, 0, 0, false, fs) + transfer := newTransfer(baseTransfer, nil, nil) destBuff := make([]byte, 65535) dst := bytes.NewBuffer(destBuff) _, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel) @@ -992,7 +992,7 @@ func TestSystemCommandErrors(t *testing.T) { WriteError: nil, } sshCmd.connection.channel = &mockSSHChannel - transfer.maxWriteSize = 1 + transfer.MaxWriteSize = 1 _, err = transfer.copyFromReaderToWriter(dst, sshCmd.connection.channel) assert.EqualError(t, err, common.ErrQuotaExceeded.Error()) @@ -1006,7 +1006,7 @@ func TestSystemCommandErrors(t *testing.T) { sshCmd.connection.channel = &mockSSHChannel _, err = transfer.copyFromReaderToWriter(sshCmd.connection.channel, dst) assert.EqualError(t, err, io.ErrShortWrite.Error()) - transfer.maxWriteSize = -1 + transfer.MaxWriteSize = -1 _, err = transfer.copyFromReaderToWriter(sshCmd.connection.channel, dst) assert.EqualError(t, err, common.ErrQuotaExceeded.Error()) err = os.RemoveAll(homeDir) @@ -1530,8 +1530,8 @@ func TestSCPUploadFiledata(t *testing.T) { assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, scpCommand.connection.BaseConnection, nil, file.Name(), - "/"+testfile, common.TransferDownload, 0, 0, true) - transfer := newTransfer(baseTransfer, nil, nil, 0) + "/"+testfile, common.TransferDownload, 0, 0, 0, true, fs) + transfer := newTransfer(baseTransfer, nil, nil) err = scpCommand.getUploadFileData(2, transfer) assert.Error(t, err, "upload must fail, we send a fake write error message") @@ -1563,7 +1563,7 @@ func TestSCPUploadFiledata(t *testing.T) { file, err = os.Create(testfile) assert.NoError(t, err) baseTransfer.File = file - transfer = newTransfer(baseTransfer, nil, nil, 0) + transfer = newTransfer(baseTransfer, nil, nil) transfer.Connection.AddTransfer(transfer) err = scpCommand.getUploadFileData(2, transfer) assert.Error(t, err, "upload must fail, we have not enough data to read") @@ -1614,8 +1614,8 @@ func TestUploadError(t *testing.T) { file, err := os.Create(fileTempName) assert.NoError(t, err) baseTransfer := common.NewBaseTransfer(file, connection.BaseConnection, nil, testfile, - testfile, common.TransferUpload, 0, 0, true) - transfer := newTransfer(baseTransfer, nil, nil, 0) + testfile, common.TransferUpload, 0, 0, 0, true, fs) + transfer := newTransfer(baseTransfer, nil, nil) errFake := errors.New("fake error") transfer.TransferError(errFake) diff --git a/sftpd/scp.go b/sftpd/scp.go index 4620316f..304cd9f7 100644 --- a/sftpd/scp.go +++ b/sftpd/scp.go @@ -226,8 +226,8 @@ func (c *scpCommand) handleUploadFile(resolvedPath, filePath string, sizeToRead vfs.SetPathPermissions(c.connection.Fs, filePath, c.connection.User.GetUID(), c.connection.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, initialSize, isNewFile) - t := newTransfer(baseTransfer, w, nil, maxWriteSize) + common.TransferUpload, 0, initialSize, maxWriteSize, isNewFile, c.connection.Fs) + t := newTransfer(baseTransfer, w, nil) return c.getUploadFileData(sizeToRead, t) } @@ -484,8 +484,8 @@ func (c *scpCommand) handleDownload(filePath string) error { } baseTransfer := common.NewBaseTransfer(file, c.connection.BaseConnection, cancelFn, p, filePath, - common.TransferDownload, 0, 0, false) - t := newTransfer(baseTransfer, nil, r, 0) + common.TransferDownload, 0, 0, 0, false, c.connection.Fs) + t := newTransfer(baseTransfer, nil, r) err = c.sendDownloadFileData(p, stat, t) // we need to call Close anyway and return close error if any and diff --git a/sftpd/sftpd_test.go b/sftpd/sftpd_test.go index 96e86f54..716c9d8e 100644 --- a/sftpd/sftpd_test.go +++ b/sftpd/sftpd_test.go @@ -653,8 +653,9 @@ func TestStat(t *testing.T) { err = client.Chmod(testFileName, newPerm) assert.NoError(t, err) newFi, err = client.Lstat(testFileName) - assert.NoError(t, err) - assert.Equal(t, newPerm, newFi.Mode().Perm()) + if assert.NoError(t, err) { + assert.Equal(t, newPerm, newFi.Mode().Perm()) + } _, err = client.ReadLink(testFileName) assert.Error(t, err, "readlink is not supported and must fail") newPerm = os.FileMode(0666) @@ -2606,6 +2607,182 @@ func TestVirtualFoldersQuotaLimit(t *testing.T) { assert.NoError(t, err) } +func TestTruncateQuotaLimits(t *testing.T) { + usePubKey := true + u := getTestUser(usePubKey) + u.QuotaSize = 20 + mappedPath := filepath.Join(os.TempDir(), "mapped") + err := os.MkdirAll(mappedPath, os.ModePerm) + assert.NoError(t, err) + vdirPath := "/vmapped" + u.VirtualFolders = append(u.VirtualFolders, vfs.VirtualFolder{ + BaseVirtualFolder: vfs.BaseVirtualFolder{ + MappedPath: mappedPath, + }, + VirtualPath: vdirPath, + QuotaFiles: 10, + }) + user, _, err := httpd.AddUser(u, http.StatusOK) + assert.NoError(t, err) + client, err := getSftpClient(user, usePubKey) + if assert.NoError(t, err) { + defer client.Close() + data := []byte("test data") + f, err := client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(2) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(5) + assert.NoError(t, err) + expectedQuotaSize = int64(5) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + expectedQuotaSize = int64(5) + int64(len(data)) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, expectedQuotaFiles, user.UsedQuotaFiles) + assert.Equal(t, expectedQuotaSize, user.UsedQuotaSize) + } + // now truncate by path + err = client.Truncate(testFileName, 5) + assert.NoError(t, err) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + // now open an existing file without truncate it, quota should not change + f, err = client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + } + // open the file truncating it + f, err = client.OpenFile(testFileName, os.O_WRONLY|os.O_TRUNC) + if assert.NoError(t, err) { + err = f.Close() + assert.NoError(t, err) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + } + // now test max write size + f, err = client.OpenFile(testFileName, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(11) + assert.NoError(t, err) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(11), user.UsedQuotaSize) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(5) + assert.NoError(t, err) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(5), user.UsedQuotaSize) + n, err = f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(12) + assert.NoError(t, err) + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 1, user.UsedQuotaFiles) + assert.Equal(t, int64(12), user.UsedQuotaSize) + _, err = f.Write(data) + if assert.Error(t, err) { + assert.Contains(t, err.Error(), common.ErrQuotaExceeded.Error()) + } + err = f.Close() + assert.Error(t, err) + // the file is deleted + user, _, err = httpd.GetUserByID(user.ID, http.StatusOK) + assert.NoError(t, err) + assert.Equal(t, 0, user.UsedQuotaFiles) + assert.Equal(t, int64(0), user.UsedQuotaSize) + } + + // basic test inside a virtual folder + vfileName := path.Join(vdirPath, testFileName) + f, err = client.OpenFile(vfileName, os.O_WRONLY) + if assert.NoError(t, err) { + n, err := f.Write(data) + assert.NoError(t, err) + assert.Equal(t, len(data), n) + err = f.Truncate(2) + assert.NoError(t, err) + expectedQuotaFiles := 0 + expectedQuotaSize := int64(2) + folder, _, err := httpd.GetFolders(0, 0, mappedPath, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, folder, 1) { + fold := folder[0] + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + } + err = f.Close() + assert.NoError(t, err) + expectedQuotaFiles = 1 + folder, _, err = httpd.GetFolders(0, 0, mappedPath, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, folder, 1) { + fold := folder[0] + assert.Equal(t, expectedQuotaSize, fold.UsedQuotaSize) + assert.Equal(t, expectedQuotaFiles, fold.UsedQuotaFiles) + } + } + err = client.Truncate(vfileName, 1) + assert.NoError(t, err) + folder, _, err := httpd.GetFolders(0, 0, mappedPath, http.StatusOK) + assert.NoError(t, err) + if assert.Len(t, folder, 1) { + fold := folder[0] + assert.Equal(t, int64(1), fold.UsedQuotaSize) + assert.Equal(t, 1, fold.UsedQuotaFiles) + } + } + + _, err = httpd.RemoveUser(user, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(user.GetHomeDir()) + assert.NoError(t, err) + _, err = httpd.RemoveFolder(vfs.BaseVirtualFolder{MappedPath: mappedPath}, http.StatusOK) + assert.NoError(t, err) + err = os.RemoveAll(mappedPath) + assert.NoError(t, err) +} + func TestVirtualFoldersQuotaRenameOverwrite(t *testing.T) { usePubKey := true testFileSize := int64(131072) diff --git a/sftpd/ssh_cmd.go b/sftpd/ssh_cmd.go index adf30135..68145df2 100644 --- a/sftpd/ssh_cmd.go +++ b/sftpd/ssh_cmd.go @@ -353,8 +353,8 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { defer stdin.Close() baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath, - common.TransferUpload, 0, 0, false) - transfer := newTransfer(baseTransfer, nil, nil, remainingQuotaSize) + common.TransferUpload, 0, 0, remainingQuotaSize, false, c.connection.Fs) + transfer := newTransfer(baseTransfer, nil, nil) w, e := transfer.copyFromReaderToWriter(stdin, c.connection.channel) c.connection.Log(logger.LevelDebug, "command: %#v, copy from remote command to sdtin ended, written: %v, "+ @@ -366,8 +366,8 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath, - common.TransferDownload, 0, 0, false) - transfer := newTransfer(baseTransfer, nil, nil, 0) + common.TransferDownload, 0, 0, 0, false, c.connection.Fs) + transfer := newTransfer(baseTransfer, nil, nil) w, e := transfer.copyFromReaderToWriter(c.connection.channel, stdout) c.connection.Log(logger.LevelDebug, "command: %#v, copy from sdtout to remote command ended, written: %v err: %v", @@ -380,8 +380,8 @@ func (c *sshCommand) executeSystemCommand(command systemCommand) error { go func() { baseTransfer := common.NewBaseTransfer(nil, c.connection.BaseConnection, nil, command.fsPath, sshDestPath, - common.TransferDownload, 0, 0, false) - transfer := newTransfer(baseTransfer, nil, nil, 0) + common.TransferDownload, 0, 0, 0, false, c.connection.Fs) + transfer := newTransfer(baseTransfer, nil, nil) w, e := transfer.copyFromReaderToWriter(c.connection.channel.Stderr(), stderr) c.connection.Log(logger.LevelDebug, "command: %#v, copy from sdterr to remote command ended, written: %v err: %v", diff --git a/sftpd/transfer.go b/sftpd/transfer.go index fadde755..43b80771 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -26,14 +26,12 @@ type readerAtCloser interface { // It implements the io.ReaderAt and io.WriterAt interfaces to handle SFTP downloads and uploads type transfer struct { *common.BaseTransfer - writerAt writerAtCloser - readerAt readerAtCloser - isFinished bool - maxWriteSize int64 + writerAt writerAtCloser + readerAt readerAtCloser + isFinished bool } -func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt, - maxWriteSize int64) *transfer { +func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt) *transfer { var writer writerAtCloser var reader readerAtCloser if baseTransfer.File != nil { @@ -49,7 +47,6 @@ func newTransfer(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, writerAt: writer, readerAt: reader, isFinished: false, - maxWriteSize: maxWriteSize, } } @@ -86,7 +83,7 @@ func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) { written, e = t.writerAt.WriteAt(p, off) atomic.AddInt64(&t.BytesReceived, int64(written)) - if t.maxWriteSize > 0 && e == nil && atomic.LoadInt64(&t.BytesReceived) > t.maxWriteSize { + if t.MaxWriteSize > 0 && e == nil && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize { e = common.ErrQuotaExceeded } if e != nil { @@ -151,7 +148,7 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, var written int64 var err error - if t.maxWriteSize < 0 { + if t.MaxWriteSize < 0 { return 0, common.ErrQuotaExceeded } isDownload := t.GetType() == common.TransferDownload @@ -168,7 +165,7 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, } else { atomic.StoreInt64(&t.BytesReceived, written) } - if t.maxWriteSize > 0 && written > t.maxWriteSize { + if t.MaxWriteSize > 0 && written > t.MaxWriteSize { err = common.ErrQuotaExceeded break } diff --git a/webdavd/file.go b/webdavd/file.go index cfd2dbf8..69154687 100644 --- a/webdavd/file.go +++ b/webdavd/file.go @@ -21,17 +21,15 @@ var errTransferAborted = errors.New("transfer aborted") type webDavFile struct { *common.BaseTransfer - writer io.WriteCloser - reader io.ReadCloser - isFinished bool - maxWriteSize int64 - startOffset int64 - info os.FileInfo - fs vfs.Fs + writer io.WriteCloser + reader io.ReadCloser + isFinished bool + startOffset int64 + info os.FileInfo } func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt, - maxWriteSize int64, info os.FileInfo, fs vfs.Fs) *webDavFile { + info os.FileInfo) *webDavFile { var writer io.WriteCloser var reader io.ReadCloser if baseTransfer.File != nil { @@ -47,10 +45,8 @@ func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter writer: writer, reader: reader, isFinished: false, - maxWriteSize: maxWriteSize, startOffset: 0, info: info, - fs: fs, } } @@ -72,7 +68,7 @@ func (fi webDavFileInfo) ContentType(ctx context.Context) (string, error) { if len(contentType) > 0 { return contentType, nil } - if c, ok := fi.file.fs.(vfs.MimeTyper); ok { + if c, ok := fi.file.Fs.(vfs.MimeTyper); ok { contentType, err := c.GetMimeType(fi.file.GetFsPath()) return contentType, err } @@ -107,7 +103,7 @@ func (f *webDavFile) Stat() (os.FileInfo, error) { } return info, nil } - info, err := f.fs.Stat(f.GetFsPath()) + info, err := f.Fs.Stat(f.GetFsPath()) if err != nil { return info, err } @@ -133,7 +129,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) { f.TransferError(common.ErrOpUnsupported) return 0, common.ErrOpUnsupported } - _, r, cancelFn, err := f.fs.Open(f.GetFsPath(), 0) + _, r, cancelFn, err := f.Fs.Open(f.GetFsPath(), 0) f.Lock() f.reader = r f.ErrTransfer = err @@ -171,7 +167,7 @@ func (f *webDavFile) Write(p []byte) (n int, err error) { written, e = f.writer.Write(p) atomic.AddInt64(&f.BytesReceived, int64(written)) - if f.maxWriteSize > 0 && e == nil && atomic.LoadInt64(&f.BytesReceived) > f.maxWriteSize { + if f.MaxWriteSize > 0 && e == nil && atomic.LoadInt64(&f.BytesReceived) > f.MaxWriteSize { e = common.ErrQuotaExceeded } if e != nil { @@ -228,7 +224,7 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) { } } - _, r, cancelFn, err := f.fs.Open(f.GetFsPath(), startByte) + _, r, cancelFn, err := f.Fs.Open(f.GetFsPath(), startByte) f.Lock() if err == nil { diff --git a/webdavd/handler.go b/webdavd/handler.go index 63015a4a..9076dfa0 100644 --- a/webdavd/handler.go +++ b/webdavd/handler.go @@ -172,9 +172,9 @@ func (c *Connection) getFile(fsPath, virtualPath string, info os.FileInfo) (webd } } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, nil, fsPath, virtualPath, common.TransferDownload, - 0, 0, false) + 0, 0, 0, false, c.Fs) - return newWebDavFile(baseTransfer, nil, nil, 0, info, c.Fs), nil + return newWebDavFile(baseTransfer, nil, nil, info), nil } // we don't know if the file will be downloaded or opened for get properties so we check both permissions @@ -202,9 +202,9 @@ func (c *Connection) getFile(fsPath, virtualPath string, info os.FileInfo) (webd } baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, fsPath, virtualPath, common.TransferDownload, - 0, 0, false) + 0, 0, 0, false, c.Fs) - return newWebDavFile(baseTransfer, nil, r, 0, info, c.Fs), nil + return newWebDavFile(baseTransfer, nil, r, info), nil } func (c *Connection) putFile(fsPath, virtualPath string) (webdav.File, error) { @@ -262,9 +262,9 @@ func (c *Connection) handleUploadToNewFile(resolvedPath, filePath, requestPath s maxWriteSize, _ := c.GetMaxWriteSize(quotaResult, false, 0) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, 0, true) + common.TransferUpload, 0, 0, maxWriteSize, true, c.Fs) - return newWebDavFile(baseTransfer, w, nil, maxWriteSize, nil, c.Fs), nil + return newWebDavFile(baseTransfer, w, nil, nil), nil } func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, fileSize int64, @@ -312,9 +312,9 @@ func (c *Connection) handleUploadToExistingFile(resolvedPath, filePath string, f vfs.SetPathPermissions(c.Fs, filePath, c.User.GetUID(), c.User.GetGID()) baseTransfer := common.NewBaseTransfer(file, c.BaseConnection, cancelFn, resolvedPath, requestPath, - common.TransferUpload, 0, initialSize, false) + common.TransferUpload, 0, initialSize, maxWriteSize, false, c.Fs) - return newWebDavFile(baseTransfer, w, nil, maxWriteSize, nil, c.Fs), nil + return newWebDavFile(baseTransfer, w, nil, nil), nil } type objectMapping struct { diff --git a/webdavd/internal_test.go b/webdavd/internal_test.go index 12238847..522c1691 100644 --- a/webdavd/internal_test.go +++ b/webdavd/internal_test.go @@ -428,9 +428,9 @@ func TestContentType(t *testing.T) { testFilePath := filepath.Join(user.HomeDir, testFile) ctx := context.Background() baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferDownload, 0, 0, false) + common.TransferDownload, 0, 0, 0, false, fs) info := vfs.NewFileInfo(testFilePath, true, 0, time.Now()) - davFile := newWebDavFile(baseTransfer, nil, nil, 0, info, fs) + davFile := newWebDavFile(baseTransfer, nil, nil, info) fi, err := davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(webDavFileInfo).ContentType(ctx) @@ -444,7 +444,8 @@ func TestContentType(t *testing.T) { assert.NoError(t, err) fi, err = os.Stat(testFilePath) assert.NoError(t, err) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, fi, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, fi) + davFile.Fs = fs fi, err = davFile.Stat() if assert.NoError(t, err) { ctype, err := fi.(webDavFileInfo).ContentType(ctx) @@ -471,8 +472,8 @@ func TestTransferReadWriteErrors(t *testing.T) { } testFilePath := filepath.Join(user.HomeDir, testFile) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferUpload, 0, 0, false) - davFile := newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + common.TransferUpload, 0, 0, 0, false, fs) + davFile := newWebDavFile(baseTransfer, nil, nil, nil) assert.False(t, davFile.isDir()) p := make([]byte, 1) _, err := davFile.Read(p) @@ -480,9 +481,9 @@ func TestTransferReadWriteErrors(t *testing.T) { r, w, err := pipeat.Pipe() assert.NoError(t, err) - davFile = newWebDavFile(baseTransfer, nil, r, 0, nil, fs) + davFile = newWebDavFile(baseTransfer, nil, r, nil) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) - davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil, 0, nil, fs) + davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil, nil) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) err = r.Close() assert.NoError(t, err) @@ -490,15 +491,15 @@ func TestTransferReadWriteErrors(t *testing.T) { assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferDownload, 0, 0, false) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + common.TransferDownload, 0, 0, 0, false, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, nil) _, err = davFile.Read(p) assert.True(t, os.IsNotExist(err)) _, err = davFile.Stat() assert.True(t, os.IsNotExist(err)) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferDownload, 0, 0, false) + common.TransferDownload, 0, 0, 0, false, fs) err = ioutil.WriteFile(testFilePath, []byte(""), os.ModePerm) assert.NoError(t, err) f, err := os.Open(testFilePath) @@ -506,7 +507,7 @@ func TestTransferReadWriteErrors(t *testing.T) { err = f.Close() assert.NoError(t, err) } - davFile = newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile.reader = f err = davFile.Close() assert.EqualError(t, err, common.ErrGenericFailure.Error()) @@ -520,8 +521,8 @@ func TestTransferReadWriteErrors(t *testing.T) { } baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferDownload, 0, 0, false) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + common.TransferDownload, 0, 0, 0, false, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile.writer = f err = davFile.Close() assert.EqualError(t, err, common.ErrGenericFailure.Error()) @@ -542,16 +543,16 @@ func TestTransferSeek(t *testing.T) { } testFilePath := filepath.Join(user.HomeDir, testFile) baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferUpload, 0, 0, false) - davFile := newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + common.TransferUpload, 0, 0, 0, false, fs) + davFile := newWebDavFile(baseTransfer, nil, nil, nil) _, err := davFile.Seek(0, io.SeekStart) assert.EqualError(t, err, common.ErrOpUnsupported.Error()) err = davFile.Close() assert.NoError(t, err) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferDownload, 0, 0, false) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + common.TransferDownload, 0, 0, 0, false, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, nil) _, err = davFile.Seek(0, io.SeekCurrent) assert.True(t, os.IsNotExist(err)) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) @@ -564,15 +565,15 @@ func TestTransferSeek(t *testing.T) { assert.NoError(t, err) } baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferDownload, 0, 0, false) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + common.TransferDownload, 0, 0, 0, false, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, nil) _, err = davFile.Seek(0, io.SeekStart) assert.Error(t, err) davFile.Connection.RemoveTransfer(davFile.BaseTransfer) baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile, - common.TransferDownload, 0, 0, false) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + common.TransferDownload, 0, 0, 0, false, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, nil) davFile.reader = f res, err := davFile.Seek(0, io.SeekStart) assert.NoError(t, err) @@ -581,26 +582,26 @@ func TestTransferSeek(t *testing.T) { info, err := os.Stat(testFilePath) assert.NoError(t, err) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, info, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, info) davFile.reader = f res, err = davFile.Seek(0, io.SeekEnd) assert.NoError(t, err) assert.Equal(t, int64(7), res) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, info, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, info) davFile.reader = f - davFile.fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir()) + davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir()) res, err = davFile.Seek(2, io.SeekStart) assert.NoError(t, err) assert.Equal(t, int64(2), res) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, info, fs) - davFile.fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir()) + davFile = newWebDavFile(baseTransfer, nil, nil, info) + davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir()) res, err = davFile.Seek(2, io.SeekEnd) assert.NoError(t, err) assert.Equal(t, int64(5), res) - davFile = newWebDavFile(baseTransfer, nil, nil, 0, nil, fs) + davFile = newWebDavFile(baseTransfer, nil, nil, nil) res, err = davFile.Seek(2, io.SeekEnd) assert.EqualError(t, err, "unable to get file size, seek from end not possible") assert.Equal(t, int64(0), res)