diff --git a/sftpd/internal_test.go b/sftpd/internal_test.go index 3a603dd7..1ef0a9e8 100644 --- a/sftpd/internal_test.go +++ b/sftpd/internal_test.go @@ -316,7 +316,7 @@ func TestReadWriteErrors(t *testing.T) { assert.NoError(t, err) transfer = Transfer{ readerAt: nil, - writerAt: w, + writerAt: vfs.NewPipeWriter(w), start: time.Now(), bytesSent: 0, bytesReceived: 0, @@ -334,8 +334,13 @@ func TestReadWriteErrors(t *testing.T) { } err = r.Close() assert.NoError(t, err) + errFake := fmt.Errorf("fake upload error") + go func() { + time.Sleep(100 * time.Millisecond) + transfer.writerAt.Done(errFake) + }() err = transfer.closeIO() - assert.NoError(t, err) + assert.EqualError(t, err, errFake.Error()) _, err = transfer.WriteAt([]byte("test"), 0) assert.Error(t, err, "writing to closed pipe must fail") diff --git a/sftpd/transfer.go b/sftpd/transfer.go index 80e6708f..3467c6d5 100644 --- a/sftpd/transfer.go +++ b/sftpd/transfer.go @@ -13,6 +13,7 @@ import ( "github.com/drakkan/sftpgo/dataprovider" "github.com/drakkan/sftpgo/logger" "github.com/drakkan/sftpgo/metrics" + "github.com/drakkan/sftpgo/vfs" ) const ( @@ -28,7 +29,7 @@ var ( // It implements the io Reader and Writer interface to handle files downloads and uploads type Transfer struct { file *os.File - writerAt *pipeat.PipeWriterAt + writerAt *vfs.PipeWriter readerAt *pipeat.PipeReaderAt cancelFn func() path string @@ -173,6 +174,9 @@ func (t *Transfer) closeIO() error { var err error if t.writerAt != nil { err = t.writerAt.Close() + if err != nil { + t.transferError = err + } } else if t.readerAt != nil { err = t.readerAt.Close() } else { diff --git a/vfs/gcsfs.go b/vfs/gcsfs.go index 1f0cbf4e..cd4d22ad 100644 --- a/vfs/gcsfs.go +++ b/vfs/gcsfs.go @@ -168,11 +168,12 @@ func (fs GCSFs) Open(name string) (*os.File, *pipeat.PipeReaderAt, func(), error } // Create creates or opens the named file for writing -func (fs GCSFs) Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, func(), error) { +func (fs GCSFs) Create(name string, flag int) (*os.File, *PipeWriter, func(), error) { r, w, err := pipeat.PipeInDir(fs.localTempDir) if err != nil { return nil, nil, nil, err } + p := NewPipeWriter(w) bkt := fs.svc.Bucket(fs.config.Bucket) obj := bkt.Object(name) ctx, cancelFn := context.WithCancel(context.Background()) @@ -185,10 +186,11 @@ func (fs GCSFs) Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, f defer objectWriter.Close() n, err := io.Copy(objectWriter, r) r.CloseWithError(err) //nolint:errcheck // the returned error is always null + p.Done(GetSFTPError(fs, err)) fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, readed bytes: %v, err: %v", name, n, err) metrics.GCSTransferCompleted(n, 0, err) }() - return nil, w, cancelFn, nil + return nil, p, cancelFn, nil } // Rename renames (moves) source to target. diff --git a/vfs/osfs.go b/vfs/osfs.go index be6834ad..6eed9862 100644 --- a/vfs/osfs.go +++ b/vfs/osfs.go @@ -65,7 +65,7 @@ func (OsFs) Open(name string) (*os.File, *pipeat.PipeReaderAt, func(), error) { } // Create creates or opens the named file for writing -func (OsFs) Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, func(), error) { +func (OsFs) Create(name string, flag int) (*os.File, *PipeWriter, func(), error) { var err error var f *os.File if flag == 0 { diff --git a/vfs/s3fs.go b/vfs/s3fs.go index f63eea82..63cf2bb1 100644 --- a/vfs/s3fs.go +++ b/vfs/s3fs.go @@ -204,11 +204,12 @@ func (fs S3Fs) Open(name string) (*os.File, *pipeat.PipeReaderAt, func(), error) } // Create creates or opens the named file for writing -func (fs S3Fs) Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, func(), error) { +func (fs S3Fs) Create(name string, flag int) (*os.File, *PipeWriter, func(), error) { r, w, err := pipeat.PipeInDir(fs.localTempDir) if err != nil { return nil, nil, nil, err } + p := NewPipeWriter(w) ctx, cancelFn := context.WithCancel(context.Background()) uploader := s3manager.NewUploaderWithClient(fs.svc) go func() { @@ -224,11 +225,12 @@ func (fs S3Fs) Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, fu u.PartSize = fs.config.UploadPartSize }) r.CloseWithError(err) //nolint:errcheck // the returned error is always null + p.Done(GetSFTPError(fs, err)) fsLog(fs, logger.LevelDebug, "upload completed, path: %#v, response: %v, readed bytes: %v, err: %+v", name, response, r.GetReadedBytes(), err) metrics.S3TransferCompleted(r.GetReadedBytes(), 0, err) }() - return nil, w, cancelFn, nil + return nil, p, cancelFn, nil } // Rename renames (moves) source to target. diff --git a/vfs/vfs.go b/vfs/vfs.go index 59f62f75..96a76b02 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -23,7 +23,7 @@ type Fs interface { Stat(name string) (os.FileInfo, error) Lstat(name string) (os.FileInfo, error) Open(name string) (*os.File, *pipeat.PipeReaderAt, func(), error) - Create(name string, flag int) (*os.File, *pipeat.PipeWriterAt, func(), error) + Create(name string, flag int) (*os.File, *PipeWriter, func(), error) Rename(source, target string) error Remove(name string, isDir bool) error Mkdir(name string) error @@ -44,6 +44,41 @@ type Fs interface { Join(elem ...string) string } +// PipeWriter defines a wrapper for pipeat.PipeWriterAt. +type PipeWriter struct { + writer *pipeat.PipeWriterAt + err error + done chan bool +} + +// NewPipeWriter initializes a new PipeWriter +func NewPipeWriter(w *pipeat.PipeWriterAt) *PipeWriter { + return &PipeWriter{ + writer: w, + err: nil, + done: make(chan bool), + } +} + +// Close waits for the upload to end, closes the pipeat.PipeWriterAt and returns an error if any. +func (p *PipeWriter) Close() error { + p.writer.Close() //nolint:errcheck // the returned error is always null + <-p.done + return p.err +} + +// Done unlocks other goroutines waiting on Close(). +// It must be called when the upload ends +func (p *PipeWriter) Done(err error) { + p.err = err + p.done <- true +} + +// WriteAt is a wrapper for pipeat WriteAt +func (p *PipeWriter) WriteAt(data []byte, off int64) (int, error) { + return p.writer.WriteAt(data, off) +} + // VirtualFolder defines a mapping between a SFTP/SCP virtual path and a // filesystem path outside the user home directory. // The specified paths must be absolute and the virtual path cannot be "/",