mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 23:00:55 +03:00
refactoring: add common package
The common package defines the interfaces that a protocol must implement and contain code that can be shared among supported protocols. This way should be easier to support new protocols
This commit is contained in:
@@ -1,123 +1,99 @@
|
||||
package sftpd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/eikenb/pipeat"
|
||||
|
||||
"github.com/drakkan/sftpgo/dataprovider"
|
||||
"github.com/drakkan/sftpgo/logger"
|
||||
"github.com/drakkan/sftpgo/common"
|
||||
"github.com/drakkan/sftpgo/metrics"
|
||||
"github.com/drakkan/sftpgo/vfs"
|
||||
)
|
||||
|
||||
const (
|
||||
transferUpload = iota
|
||||
transferDownload
|
||||
)
|
||||
|
||||
var (
|
||||
errTransferClosed = errors.New("transfer already closed")
|
||||
)
|
||||
|
||||
// Transfer contains the transfer details for an upload or a download.
|
||||
// It implements the io Reader and Writer interface to handle files downloads and uploads
|
||||
type Transfer struct {
|
||||
file *os.File
|
||||
writerAt *vfs.PipeWriter
|
||||
readerAt *pipeat.PipeReaderAt
|
||||
cancelFn func()
|
||||
path string
|
||||
start time.Time
|
||||
bytesSent int64
|
||||
bytesReceived int64
|
||||
user dataprovider.User
|
||||
connectionID string
|
||||
transferType int
|
||||
lastActivity time.Time
|
||||
protocol string
|
||||
transferError error
|
||||
minWriteOffset int64
|
||||
initialSize int64
|
||||
lock *sync.Mutex
|
||||
isNewFile bool
|
||||
isFinished bool
|
||||
requestPath string
|
||||
maxWriteSize int64
|
||||
type writerAtCloser interface {
|
||||
io.WriterAt
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// TransferError is called if there is an unexpected error.
|
||||
// For example network or client issues
|
||||
func (t *Transfer) TransferError(err error) {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
if t.transferError != nil {
|
||||
return
|
||||
type readerAtCloser interface {
|
||||
io.ReaderAt
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// transfer defines the transfer details.
|
||||
// It implements the io.ReaderAt and io.WriterAt interfaces to handle SFTP downloads and uploads
|
||||
type transfer struct {
|
||||
*common.BaseTransfer
|
||||
writerAt writerAtCloser
|
||||
readerAt readerAtCloser
|
||||
isFinished bool
|
||||
maxWriteSize int64
|
||||
}
|
||||
|
||||
func newTranfer(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt,
|
||||
maxWriteSize int64) *transfer {
|
||||
var writer writerAtCloser
|
||||
var reader readerAtCloser
|
||||
if baseTransfer.File != nil {
|
||||
writer = baseTransfer.File
|
||||
reader = baseTransfer.File
|
||||
} else if pipeWriter != nil {
|
||||
writer = pipeWriter
|
||||
} else if pipeReader != nil {
|
||||
reader = pipeReader
|
||||
}
|
||||
t.transferError = err
|
||||
if t.cancelFn != nil {
|
||||
t.cancelFn()
|
||||
return &transfer{
|
||||
BaseTransfer: baseTransfer,
|
||||
writerAt: writer,
|
||||
readerAt: reader,
|
||||
isFinished: false,
|
||||
maxWriteSize: maxWriteSize,
|
||||
}
|
||||
elapsed := time.Since(t.start).Nanoseconds() / 1000000
|
||||
logger.Warn(logSender, t.connectionID, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
|
||||
"bytes received: %v transfer running since %v ms", t.path, t.transferError, t.bytesSent, t.bytesReceived, elapsed)
|
||||
}
|
||||
|
||||
// ReadAt reads len(p) bytes from the File to download starting at byte offset off and updates the bytes sent.
|
||||
// It handles download bandwidth throttling too
|
||||
func (t *Transfer) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
t.lastActivity = time.Now()
|
||||
func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
t.Connection.UpdateLastActivity()
|
||||
var readed int
|
||||
var e error
|
||||
if t.readerAt != nil {
|
||||
readed, e = t.readerAt.ReadAt(p, off)
|
||||
} else {
|
||||
readed, e = t.file.ReadAt(p, off)
|
||||
}
|
||||
t.lock.Lock()
|
||||
t.bytesSent += int64(readed)
|
||||
t.lock.Unlock()
|
||||
|
||||
readed, e = t.readerAt.ReadAt(p, off)
|
||||
atomic.AddInt64(&t.BytesSent, int64(readed))
|
||||
|
||||
if e != nil && e != io.EOF {
|
||||
t.TransferError(e)
|
||||
return readed, e
|
||||
}
|
||||
t.handleThrottle()
|
||||
t.HandleThrottle()
|
||||
return readed, e
|
||||
}
|
||||
|
||||
// WriteAt writes len(p) bytes to the uploaded file starting at byte offset off and updates the bytes received.
|
||||
// It handles upload bandwidth throttling too
|
||||
func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) {
|
||||
t.lastActivity = time.Now()
|
||||
if off < t.minWriteOffset {
|
||||
err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.minWriteOffset)
|
||||
func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) {
|
||||
t.Connection.UpdateLastActivity()
|
||||
if off < t.MinWriteOffset {
|
||||
err := fmt.Errorf("Invalid write offset: %v minimum valid value: %v", off, t.MinWriteOffset)
|
||||
t.TransferError(err)
|
||||
return 0, err
|
||||
}
|
||||
var written int
|
||||
var e error
|
||||
if t.writerAt != nil {
|
||||
written, e = t.writerAt.WriteAt(p, off)
|
||||
} else {
|
||||
written, e = t.file.WriteAt(p, off)
|
||||
|
||||
written, e = t.writerAt.WriteAt(p, off)
|
||||
atomic.AddInt64(&t.BytesReceived, int64(written))
|
||||
|
||||
if t.maxWriteSize > 0 && e == nil && atomic.LoadInt64(&t.BytesReceived) > t.maxWriteSize {
|
||||
e = common.ErrQuotaExceeded
|
||||
}
|
||||
t.lock.Lock()
|
||||
t.bytesReceived += int64(written)
|
||||
if e == nil && t.maxWriteSize > 0 && t.bytesReceived > t.maxWriteSize {
|
||||
e = errQuotaExceeded
|
||||
}
|
||||
t.lock.Unlock()
|
||||
if e != nil {
|
||||
t.TransferError(e)
|
||||
return written, e
|
||||
}
|
||||
t.handleThrottle()
|
||||
t.HandleThrottle()
|
||||
return written, e
|
||||
}
|
||||
|
||||
@@ -126,147 +102,74 @@ func (t *Transfer) WriteAt(p []byte, off int64) (n int, err error) {
|
||||
// and executes any defined action.
|
||||
// If there is an error no action will be executed and, in atomic mode, we try to delete
|
||||
// the temporary file
|
||||
func (t *Transfer) Close() error {
|
||||
t.lock.Lock()
|
||||
defer t.lock.Unlock()
|
||||
if t.isFinished {
|
||||
return errTransferClosed
|
||||
func (t *transfer) Close() error {
|
||||
if err := t.setFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
err := t.closeIO()
|
||||
defer removeTransfer(t) //nolint:errcheck
|
||||
t.isFinished = true
|
||||
numFiles := 0
|
||||
if t.isNewFile {
|
||||
numFiles = 1
|
||||
errBaseClose := t.BaseTransfer.Close()
|
||||
if errBaseClose != nil {
|
||||
err = errBaseClose
|
||||
}
|
||||
metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
|
||||
if t.transferError == errQuotaExceeded && t.file != nil {
|
||||
// if quota is exceeded we try to remove the partial file for uploads to local filesystem
|
||||
err = os.Remove(t.file.Name())
|
||||
if err == nil {
|
||||
numFiles--
|
||||
t.bytesReceived = 0
|
||||
t.minWriteOffset = 0
|
||||
}
|
||||
logger.Warn(logSender, t.connectionID, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v",
|
||||
t.file.Name(), err)
|
||||
} else if t.transferType == transferUpload && t.file != nil && t.file.Name() != t.path {
|
||||
if t.transferError == nil || uploadMode == uploadModeAtomicWithResume {
|
||||
err = os.Rename(t.file.Name(), t.path)
|
||||
logger.Debug(logSender, t.connectionID, "atomic upload completed, rename: %#v -> %#v, error: %v",
|
||||
t.file.Name(), t.path, err)
|
||||
} else {
|
||||
err = os.Remove(t.file.Name())
|
||||
logger.Warn(logSender, t.connectionID, "atomic upload completed with error: \"%v\", delete temporary file: %#v, "+
|
||||
"deletion error: %v", t.transferError, t.file.Name(), err)
|
||||
if err == nil {
|
||||
numFiles--
|
||||
t.bytesReceived = 0
|
||||
t.minWriteOffset = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
elapsed := time.Since(t.start).Nanoseconds() / 1000000
|
||||
if t.transferType == transferDownload {
|
||||
logger.TransferLog(downloadLogSender, t.path, elapsed, t.bytesSent, t.user.Username, t.connectionID, t.protocol)
|
||||
go executeAction(newActionNotification(t.user, operationDownload, t.path, "", "", t.bytesSent, t.transferError)) //nolint:errcheck
|
||||
} else {
|
||||
logger.TransferLog(uploadLogSender, t.path, elapsed, t.bytesReceived, t.user.Username, t.connectionID, t.protocol)
|
||||
go executeAction(newActionNotification(t.user, operationUpload, t.path, "", "", t.bytesReceived+t.minWriteOffset, //nolint:errcheck
|
||||
t.transferError))
|
||||
}
|
||||
if t.transferError != nil {
|
||||
logger.Warn(logSender, t.connectionID, "transfer error: %v, path: %#v", t.transferError, t.path)
|
||||
if err == nil {
|
||||
err = t.transferError
|
||||
}
|
||||
}
|
||||
t.updateQuota(numFiles)
|
||||
return err
|
||||
return t.Connection.GetFsError(err)
|
||||
}
|
||||
|
||||
func (t *Transfer) closeIO() error {
|
||||
func (t *transfer) closeIO() error {
|
||||
var err error
|
||||
if t.writerAt != nil {
|
||||
if t.File != nil {
|
||||
err = t.File.Close()
|
||||
} else if t.writerAt != nil {
|
||||
err = t.writerAt.Close()
|
||||
if err != nil {
|
||||
t.transferError = err
|
||||
t.Lock()
|
||||
// we set ErrTransfer here so quota is not updated, in this case the uploads are atomic
|
||||
if err != nil && t.ErrTransfer == nil {
|
||||
t.ErrTransfer = err
|
||||
}
|
||||
t.Unlock()
|
||||
} else if t.readerAt != nil {
|
||||
err = t.readerAt.Close()
|
||||
} else {
|
||||
err = t.file.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Transfer) updateQuota(numFiles int) bool {
|
||||
// S3 uploads are atomic, if there is an error nothing is uploaded
|
||||
if t.file == nil && t.transferError != nil {
|
||||
return false
|
||||
}
|
||||
if t.transferType == transferUpload && (numFiles != 0 || t.bytesReceived > 0) {
|
||||
vfolder, err := t.user.GetVirtualFolderForPath(path.Dir(t.requestPath))
|
||||
if err == nil {
|
||||
dataprovider.UpdateVirtualFolderQuota(vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
|
||||
t.bytesReceived-t.initialSize, false)
|
||||
if vfolder.IsIncludedInUserQuota() {
|
||||
dataprovider.UpdateUserQuota(t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
|
||||
}
|
||||
} else {
|
||||
dataprovider.UpdateUserQuota(t.user, numFiles, t.bytesReceived-t.initialSize, false) //nolint:errcheck
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *Transfer) handleThrottle() {
|
||||
var wantedBandwidth int64
|
||||
var trasferredBytes int64
|
||||
if t.transferType == transferDownload {
|
||||
wantedBandwidth = t.user.DownloadBandwidth
|
||||
trasferredBytes = t.bytesSent
|
||||
} else {
|
||||
wantedBandwidth = t.user.UploadBandwidth
|
||||
trasferredBytes = t.bytesReceived
|
||||
}
|
||||
if wantedBandwidth > 0 {
|
||||
// real and wanted elapsed as milliseconds, bytes as kilobytes
|
||||
realElapsed := time.Since(t.start).Nanoseconds() / 1000000
|
||||
// trasferredBytes / 1000 = KB/s, we multiply for 1000 to get milliseconds
|
||||
wantedElapsed := 1000 * (trasferredBytes / 1000) / wantedBandwidth
|
||||
if wantedElapsed > realElapsed {
|
||||
toSleep := time.Duration(wantedElapsed - realElapsed)
|
||||
time.Sleep(toSleep * time.Millisecond)
|
||||
}
|
||||
func (t *transfer) setFinished() error {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
if t.isFinished {
|
||||
return common.ErrTransferClosed
|
||||
}
|
||||
t.isFinished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// used for ssh commands.
|
||||
// It reads from src until EOF so it does not treat an EOF from Read as an error to be reported.
|
||||
// EOF from Write is reported as error
|
||||
func (t *Transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, error) {
|
||||
func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, error) {
|
||||
defer t.Connection.RemoveTransfer(t)
|
||||
|
||||
var written int64
|
||||
var err error
|
||||
|
||||
if t.maxWriteSize < 0 {
|
||||
return 0, errQuotaExceeded
|
||||
return 0, common.ErrQuotaExceeded
|
||||
}
|
||||
isDownload := t.GetType() == common.TransferDownload
|
||||
buf := make([]byte, 32768)
|
||||
for {
|
||||
t.lastActivity = time.Now()
|
||||
t.Connection.UpdateLastActivity()
|
||||
nr, er := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, ew := dst.Write(buf[0:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
if t.transferType == transferDownload {
|
||||
t.bytesSent = written
|
||||
if isDownload {
|
||||
atomic.StoreInt64(&t.BytesSent, written)
|
||||
} else {
|
||||
t.bytesReceived = written
|
||||
atomic.StoreInt64(&t.BytesReceived, written)
|
||||
}
|
||||
if t.maxWriteSize > 0 && written > t.maxWriteSize {
|
||||
err = errQuotaExceeded
|
||||
err = common.ErrQuotaExceeded
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -285,11 +188,11 @@ func (t *Transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64,
|
||||
}
|
||||
break
|
||||
}
|
||||
t.handleThrottle()
|
||||
t.HandleThrottle()
|
||||
}
|
||||
t.transferError = err
|
||||
if t.bytesSent > 0 || t.bytesReceived > 0 || err != nil {
|
||||
metrics.TransferCompleted(t.bytesSent, t.bytesReceived, t.transferType, t.transferError)
|
||||
t.ErrTransfer = err
|
||||
if written > 0 || err != nil {
|
||||
metrics.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.GetType(), t.ErrTransfer)
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user