From 4230da8e7d9f503210c1116a4685bb3f93975e5b Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Wed, 19 Nov 2025 18:53:27 +0100 Subject: [PATCH] s3: implement multipart downloads without using the S3 Manager Signed-off-by: Nicola Murino --- go.mod | 1 - go.sum | 2 - internal/vfs/azblobfs.go | 14 +-- internal/vfs/s3fs.go | 184 +++++++++++++++++++++++++++------------ internal/vfs/vfs.go | 12 +++ 5 files changed, 139 insertions(+), 74 deletions(-) diff --git a/go.mod b/go.mod index 1a331b01..8a60a2b8 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,6 @@ require ( github.com/aws/aws-sdk-go-v2 v1.39.4 github.com/aws/aws-sdk-go-v2/config v1.31.15 github.com/aws/aws-sdk-go-v2/credentials v1.18.19 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.15 github.com/aws/aws-sdk-go-v2/service/s3 v1.88.7 github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 github.com/bmatcuk/doublestar/v4 v4.9.1 diff --git a/go.sum b/go.sum index d148ffc3..1d7ace42 100644 --- a/go.sum +++ b/go.sum @@ -67,8 +67,6 @@ github.com/aws/aws-sdk-go-v2/credentials v1.18.19 h1:Jc1zzwkSY1QbkEcLujwqRTXOdvW github.com/aws/aws-sdk-go-v2/credentials v1.18.19/go.mod h1:DIfQ9fAk5H0pGtnqfqkbSIzky82qYnGvh06ASQXXg6A= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11 h1:X7X4YKb+c0rkI6d4uJ5tEMxXgCZ+jZ/D6mvkno8c8Uw= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11/go.mod h1:EqM6vPZQsZHYvC4Cai35UDg/f5NCEU+vp0WfbVqVcZc= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.15 h1:OsZ2Sk84YUPJfi6BemhyMQyuR8/5tWu37WBMVUl8lJk= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.15/go.mod h1:CYZDjBMY+MyT+U+QmXw81GBiq+lhgM97kIMdDAJk+hg= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 h1:7AANQZkF3ihM8fbdftpjhken0TP9sBzFbV/Ze/Y4HXA= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11/go.mod h1:NTF4QCGkm6fzVwncpkFQqoquQyOolcyXfbpC98urj+c= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 h1:ShdtWUZT37LCAA4Mw2kJAJtzaszfSHFb5n25sdcv4YE= diff --git a/internal/vfs/azblobfs.go b/internal/vfs/azblobfs.go index 83f55b52..e2129375 100644 --- a/internal/vfs/azblobfs.go +++ b/internal/vfs/azblobfs.go @@ -915,7 +915,7 @@ func (fs *AzureBlobFs) downloadPart(ctx context.Context, blockBlob *blockblob.Cl return err } - _, err = fs.writeAtFull(w, buf, writeOffset, int(count)) + _, err = writeAtFull(w, buf, writeOffset, int(count)) return err } @@ -1105,18 +1105,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read return err } -func (*AzureBlobFs) writeAtFull(w io.WriterAt, buf []byte, offset int64, count int) (int, error) { - written := 0 - for written < count { - n, err := w.WriteAt(buf[written:count], offset+int64(written)) - written += n - if err != nil { - return written, err - } - } - return written, nil -} - func (fs *AzureBlobFs) getCopyOptions(srcInfo os.FileInfo, updateModTime bool) *blob.StartCopyFromURLOptions { copyOptions := &blob.StartCopyFromURLOptions{} if fs.config.AccessTier != "" { diff --git a/internal/vfs/s3fs.go b/internal/vfs/s3fs.go index d330133a..ce4a1248 100644 --- a/internal/vfs/s3fs.go +++ b/internal/vfs/s3fs.go @@ -45,7 +45,6 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -227,53 +226,29 @@ func (fs *S3Fs) Lstat(name string) (os.FileInfo, error) { // Open opens the named file for reading func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error) { + attrs, err := fs.headObject(name) + if err != nil { + return nil, nil, nil, err + } r, w, err := createPipeFn(fs.localTempDir, fs.config.DownloadPartSize*int64(fs.config.DownloadConcurrency)+1) if err != nil { return nil, nil, nil, err } p := NewPipeReader(r) if readMetadata > 0 { - attrs, err := fs.headObject(name) - if err != nil { - r.Close() - w.Close() - return nil, nil, nil, err - } p.setMetadata(attrs.Metadata) } - ctx, cancelFn := context.WithCancel(context.Background()) - downloader := manager.NewDownloader(fs.svc, func(d *manager.Downloader) { - d.Concurrency = fs.config.DownloadConcurrency - d.PartSize = fs.config.DownloadPartSize - if offset == 0 && fs.config.DownloadPartMaxTime > 0 { - d.ClientOptions = append(d.ClientOptions, func(o *s3.Options) { - o.HTTPClient = getAWSHTTPClient(fs.config.DownloadPartMaxTime, 100*time.Millisecond, - fs.config.SkipTLSVerify) - }) - } - }) - - var streamRange *string - if offset > 0 { - streamRange = aws.String(fmt.Sprintf("bytes=%d-", offset)) - } go func() { defer cancelFn() - n, err := downloader.Download(ctx, w, &s3.GetObjectInput{ - Bucket: aws.String(fs.config.Bucket), - Key: aws.String(name), - Range: streamRange, - SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), - SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), - SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), - }) + err := fs.handleDownload(ctx, name, offset, w, attrs) w.CloseWithError(err) //nolint:errcheck - fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %+v", name, n, err) - metric.S3TransferCompleted(n, 1, err) + fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %d, err: %+v", name, w.GetWrittenBytes(), err) + metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err) }() + return nil, p, cancelFn, nil } @@ -673,25 +648,28 @@ func (fs *S3Fs) resolve(name *string, prefix string) (string, bool) { } func (fs *S3Fs) setConfigDefaults() { + const defaultPartSize = 1024 * 1024 * 5 + const defaultConcurrency = 5 + if fs.config.UploadPartSize == 0 { - fs.config.UploadPartSize = manager.DefaultUploadPartSize + fs.config.UploadPartSize = defaultPartSize } else { if fs.config.UploadPartSize < 1024*1024 { fs.config.UploadPartSize *= 1024 * 1024 } } if fs.config.UploadConcurrency == 0 { - fs.config.UploadConcurrency = manager.DefaultUploadConcurrency + fs.config.UploadConcurrency = defaultConcurrency } if fs.config.DownloadPartSize == 0 { - fs.config.DownloadPartSize = manager.DefaultDownloadPartSize + fs.config.DownloadPartSize = defaultPartSize } else { if fs.config.DownloadPartSize < 1024*1024 { fs.config.DownloadPartSize *= 1024 * 1024 } } if fs.config.DownloadConcurrency == 0 { - fs.config.DownloadConcurrency = manager.DefaultDownloadConcurrency + fs.config.DownloadConcurrency = defaultConcurrency } } @@ -815,6 +793,109 @@ func (fs *S3Fs) hasContents(name string) (bool, error) { return false, nil } +func (fs *S3Fs) downloadPart(ctx context.Context, name string, buf []byte, w io.WriterAt, start, count, writeOffset int64) error { + rangeHeader := fmt.Sprintf("bytes=%d-%d", start, start+count) + + resp, err := fs.svc.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + Range: &rangeHeader, + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + if err != nil { + return err + } + defer resp.Body.Close() + + _, err = io.ReadAtLeast(resp.Body, buf, int(count)) + if err != nil { + return err + } + + _, err = writeAtFull(w, buf, writeOffset, int(count)) + return err +} + +func (fs *S3Fs) handleDownload(ctx context.Context, name string, offset int64, writer io.WriterAt, attrs *s3.HeadObjectOutput) error { + contentLength := util.GetIntFromPointer(attrs.ContentLength) + sizeToDownload := contentLength - offset + if sizeToDownload < 0 { + fsLog(fs, logger.LevelError, "invalid multipart download size or offset, size: %d, offset: %d, size to download: %d", + contentLength, offset, sizeToDownload) + return errors.New("the requested offset exceeds the file size") + } + if sizeToDownload == 0 { + fsLog(fs, logger.LevelDebug, "nothing to download, offset %d, content length %d", offset, contentLength) + return nil + } + partSize := fs.config.DownloadPartSize + guard := make(chan struct{}, fs.config.DownloadConcurrency) + var blockCtxTimeout time.Duration + if fs.config.DownloadPartMaxTime > 0 { + blockCtxTimeout = time.Duration(fs.config.DownloadPartSize) * time.Second + } else { + blockCtxTimeout = time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute + } + pool := newBufferAllocator(int(partSize)) + defer pool.free() + + finished := false + var wg sync.WaitGroup + var errOnce sync.Once + var hasError atomic.Bool + var poolError error + + poolCtx, poolCancel := context.WithCancel(ctx) + defer poolCancel() + + for part := 0; !finished; part++ { + start := offset + end := offset + partSize + if end >= contentLength { + end = contentLength + finished = true + } + writeOffset := int64(part) * partSize + offset = end + + guard <- struct{}{} + if hasError.Load() { + fsLog(fs, logger.LevelDebug, "pool error, download for part %d not started", part) + break + } + + buf := pool.getBuffer() + wg.Add(1) + go func(start, end, writeOffset int64, buf []byte) { + defer func() { + pool.releaseBuffer(buf) + <-guard + wg.Done() + }() + + innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout)) + defer cancelFn() + + err := fs.downloadPart(innerCtx, name, buf, writer, start, end-start, writeOffset) + if err != nil { + errOnce.Do(func() { + fsLog(fs, logger.LevelError, "multipart download error: %+v", err) + hasError.Store(true) + poolError = fmt.Errorf("multipart download error: %w", err) + poolCancel() + }) + } + }(start, end, writeOffset, buf) + } + + wg.Wait() + close(guard) + + return poolError +} + func (fs *S3Fs) initiateMultipartUpload(ctx context.Context, name, contentType string) (string, error) { ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout)) defer cancelFn() @@ -1208,31 +1289,18 @@ func (*S3Fs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) { func (fs *S3Fs) downloadToWriter(name string, w PipeWriter) (int64, error) { fsLog(fs, logger.LevelDebug, "starting download before resuming upload, path %q", name) + attrs, err := fs.headObject(name) + if err != nil { + return 0, err + } ctx, cancelFn := context.WithTimeout(context.Background(), preResumeTimeout) defer cancelFn() - downloader := manager.NewDownloader(fs.svc, func(d *manager.Downloader) { - d.Concurrency = fs.config.DownloadConcurrency - d.PartSize = fs.config.DownloadPartSize - if fs.config.DownloadPartMaxTime > 0 { - d.ClientOptions = append(d.ClientOptions, func(o *s3.Options) { - o.HTTPClient = getAWSHTTPClient(fs.config.DownloadPartMaxTime, 100*time.Millisecond, - fs.config.SkipTLSVerify) - }) - } - }) - - n, err := downloader.Download(ctx, w, &s3.GetObjectInput{ - Bucket: aws.String(fs.config.Bucket), - Key: aws.String(name), - SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), - SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), - SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), - }) + err = fs.handleDownload(ctx, name, 0, w, attrs) fsLog(fs, logger.LevelDebug, "download before resuming upload completed, path %q size: %d, err: %+v", - name, n, err) - metric.S3TransferCompleted(n, 1, err) - return n, err + name, w.GetWrittenBytes(), err) + metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err) + return w.GetWrittenBytes(), err } type s3DirLister struct { diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go index add7f41c..3607784f 100644 --- a/internal/vfs/vfs.go +++ b/internal/vfs/vfs.go @@ -1282,6 +1282,18 @@ func readFill(r io.Reader, buf []byte) (n int, err error) { return n, err } +func writeAtFull(w io.WriterAt, buf []byte, offset int64, count int) (int, error) { + written := 0 + for written < count { + n, err := w.WriteAt(buf[written:count], offset+int64(written)) + written += n + if err != nil { + return written, err + } + } + return written, nil +} + type bytesReaderWrapper struct { *bytes.Reader }