s3: implement multipart downloads without using the S3 Manager

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2025-11-19 18:53:27 +01:00
parent 22c875c0a1
commit 4230da8e7d
5 changed files with 139 additions and 74 deletions

View File

@@ -915,7 +915,7 @@ func (fs *AzureBlobFs) downloadPart(ctx context.Context, blockBlob *blockblob.Cl
return err
}
_, err = fs.writeAtFull(w, buf, writeOffset, int(count))
_, err = writeAtFull(w, buf, writeOffset, int(count))
return err
}
@@ -1105,18 +1105,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
return err
}
func (*AzureBlobFs) writeAtFull(w io.WriterAt, buf []byte, offset int64, count int) (int, error) {
written := 0
for written < count {
n, err := w.WriteAt(buf[written:count], offset+int64(written))
written += n
if err != nil {
return written, err
}
}
return written, nil
}
func (fs *AzureBlobFs) getCopyOptions(srcInfo os.FileInfo, updateModTime bool) *blob.StartCopyFromURLOptions {
copyOptions := &blob.StartCopyFromURLOptions{}
if fs.config.AccessTier != "" {

View File

@@ -45,7 +45,6 @@ import (
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/aws-sdk-go-v2/service/sts"
@@ -227,53 +226,29 @@ func (fs *S3Fs) Lstat(name string) (os.FileInfo, error) {
// Open opens the named file for reading
func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error) {
attrs, err := fs.headObject(name)
if err != nil {
return nil, nil, nil, err
}
r, w, err := createPipeFn(fs.localTempDir, fs.config.DownloadPartSize*int64(fs.config.DownloadConcurrency)+1)
if err != nil {
return nil, nil, nil, err
}
p := NewPipeReader(r)
if readMetadata > 0 {
attrs, err := fs.headObject(name)
if err != nil {
r.Close()
w.Close()
return nil, nil, nil, err
}
p.setMetadata(attrs.Metadata)
}
ctx, cancelFn := context.WithCancel(context.Background())
downloader := manager.NewDownloader(fs.svc, func(d *manager.Downloader) {
d.Concurrency = fs.config.DownloadConcurrency
d.PartSize = fs.config.DownloadPartSize
if offset == 0 && fs.config.DownloadPartMaxTime > 0 {
d.ClientOptions = append(d.ClientOptions, func(o *s3.Options) {
o.HTTPClient = getAWSHTTPClient(fs.config.DownloadPartMaxTime, 100*time.Millisecond,
fs.config.SkipTLSVerify)
})
}
})
var streamRange *string
if offset > 0 {
streamRange = aws.String(fmt.Sprintf("bytes=%d-", offset))
}
go func() {
defer cancelFn()
n, err := downloader.Download(ctx, w, &s3.GetObjectInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
Range: streamRange,
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
err := fs.handleDownload(ctx, name, offset, w, attrs)
w.CloseWithError(err) //nolint:errcheck
fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %+v", name, n, err)
metric.S3TransferCompleted(n, 1, err)
fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %d, err: %+v", name, w.GetWrittenBytes(), err)
metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err)
}()
return nil, p, cancelFn, nil
}
@@ -673,25 +648,28 @@ func (fs *S3Fs) resolve(name *string, prefix string) (string, bool) {
}
func (fs *S3Fs) setConfigDefaults() {
const defaultPartSize = 1024 * 1024 * 5
const defaultConcurrency = 5
if fs.config.UploadPartSize == 0 {
fs.config.UploadPartSize = manager.DefaultUploadPartSize
fs.config.UploadPartSize = defaultPartSize
} else {
if fs.config.UploadPartSize < 1024*1024 {
fs.config.UploadPartSize *= 1024 * 1024
}
}
if fs.config.UploadConcurrency == 0 {
fs.config.UploadConcurrency = manager.DefaultUploadConcurrency
fs.config.UploadConcurrency = defaultConcurrency
}
if fs.config.DownloadPartSize == 0 {
fs.config.DownloadPartSize = manager.DefaultDownloadPartSize
fs.config.DownloadPartSize = defaultPartSize
} else {
if fs.config.DownloadPartSize < 1024*1024 {
fs.config.DownloadPartSize *= 1024 * 1024
}
}
if fs.config.DownloadConcurrency == 0 {
fs.config.DownloadConcurrency = manager.DefaultDownloadConcurrency
fs.config.DownloadConcurrency = defaultConcurrency
}
}
@@ -815,6 +793,109 @@ func (fs *S3Fs) hasContents(name string) (bool, error) {
return false, nil
}
func (fs *S3Fs) downloadPart(ctx context.Context, name string, buf []byte, w io.WriterAt, start, count, writeOffset int64) error {
rangeHeader := fmt.Sprintf("bytes=%d-%d", start, start+count)
resp, err := fs.svc.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
Range: &rangeHeader,
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
if err != nil {
return err
}
defer resp.Body.Close()
_, err = io.ReadAtLeast(resp.Body, buf, int(count))
if err != nil {
return err
}
_, err = writeAtFull(w, buf, writeOffset, int(count))
return err
}
func (fs *S3Fs) handleDownload(ctx context.Context, name string, offset int64, writer io.WriterAt, attrs *s3.HeadObjectOutput) error {
contentLength := util.GetIntFromPointer(attrs.ContentLength)
sizeToDownload := contentLength - offset
if sizeToDownload < 0 {
fsLog(fs, logger.LevelError, "invalid multipart download size or offset, size: %d, offset: %d, size to download: %d",
contentLength, offset, sizeToDownload)
return errors.New("the requested offset exceeds the file size")
}
if sizeToDownload == 0 {
fsLog(fs, logger.LevelDebug, "nothing to download, offset %d, content length %d", offset, contentLength)
return nil
}
partSize := fs.config.DownloadPartSize
guard := make(chan struct{}, fs.config.DownloadConcurrency)
var blockCtxTimeout time.Duration
if fs.config.DownloadPartMaxTime > 0 {
blockCtxTimeout = time.Duration(fs.config.DownloadPartSize) * time.Second
} else {
blockCtxTimeout = time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute
}
pool := newBufferAllocator(int(partSize))
defer pool.free()
finished := false
var wg sync.WaitGroup
var errOnce sync.Once
var hasError atomic.Bool
var poolError error
poolCtx, poolCancel := context.WithCancel(ctx)
defer poolCancel()
for part := 0; !finished; part++ {
start := offset
end := offset + partSize
if end >= contentLength {
end = contentLength
finished = true
}
writeOffset := int64(part) * partSize
offset = end
guard <- struct{}{}
if hasError.Load() {
fsLog(fs, logger.LevelDebug, "pool error, download for part %d not started", part)
break
}
buf := pool.getBuffer()
wg.Add(1)
go func(start, end, writeOffset int64, buf []byte) {
defer func() {
pool.releaseBuffer(buf)
<-guard
wg.Done()
}()
innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout))
defer cancelFn()
err := fs.downloadPart(innerCtx, name, buf, writer, start, end-start, writeOffset)
if err != nil {
errOnce.Do(func() {
fsLog(fs, logger.LevelError, "multipart download error: %+v", err)
hasError.Store(true)
poolError = fmt.Errorf("multipart download error: %w", err)
poolCancel()
})
}
}(start, end, writeOffset, buf)
}
wg.Wait()
close(guard)
return poolError
}
func (fs *S3Fs) initiateMultipartUpload(ctx context.Context, name, contentType string) (string, error) {
ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout))
defer cancelFn()
@@ -1208,31 +1289,18 @@ func (*S3Fs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) {
func (fs *S3Fs) downloadToWriter(name string, w PipeWriter) (int64, error) {
fsLog(fs, logger.LevelDebug, "starting download before resuming upload, path %q", name)
attrs, err := fs.headObject(name)
if err != nil {
return 0, err
}
ctx, cancelFn := context.WithTimeout(context.Background(), preResumeTimeout)
defer cancelFn()
downloader := manager.NewDownloader(fs.svc, func(d *manager.Downloader) {
d.Concurrency = fs.config.DownloadConcurrency
d.PartSize = fs.config.DownloadPartSize
if fs.config.DownloadPartMaxTime > 0 {
d.ClientOptions = append(d.ClientOptions, func(o *s3.Options) {
o.HTTPClient = getAWSHTTPClient(fs.config.DownloadPartMaxTime, 100*time.Millisecond,
fs.config.SkipTLSVerify)
})
}
})
n, err := downloader.Download(ctx, w, &s3.GetObjectInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
err = fs.handleDownload(ctx, name, 0, w, attrs)
fsLog(fs, logger.LevelDebug, "download before resuming upload completed, path %q size: %d, err: %+v",
name, n, err)
metric.S3TransferCompleted(n, 1, err)
return n, err
name, w.GetWrittenBytes(), err)
metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err)
return w.GetWrittenBytes(), err
}
type s3DirLister struct {

View File

@@ -1282,6 +1282,18 @@ func readFill(r io.Reader, buf []byte) (n int, err error) {
return n, err
}
func writeAtFull(w io.WriterAt, buf []byte, offset int64, count int) (int, error) {
written := 0
for written < count {
n, err := w.WriteAt(buf[written:count], offset+int64(written))
written += n
if err != nil {
return written, err
}
}
return written, nil
}
type bytesReaderWrapper struct {
*bytes.Reader
}