mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-06 14:20:55 +03:00
s3: implement multipart downloads without using the S3 Manager
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
1
go.mod
1
go.mod
@@ -13,7 +13,6 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2 v1.39.4
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.15
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.19
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.15
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.88.7
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.38.9
|
||||
github.com/bmatcuk/doublestar/v4 v4.9.1
|
||||
|
||||
2
go.sum
2
go.sum
@@ -67,8 +67,6 @@ github.com/aws/aws-sdk-go-v2/credentials v1.18.19 h1:Jc1zzwkSY1QbkEcLujwqRTXOdvW
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.19/go.mod h1:DIfQ9fAk5H0pGtnqfqkbSIzky82qYnGvh06ASQXXg6A=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11 h1:X7X4YKb+c0rkI6d4uJ5tEMxXgCZ+jZ/D6mvkno8c8Uw=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11/go.mod h1:EqM6vPZQsZHYvC4Cai35UDg/f5NCEU+vp0WfbVqVcZc=
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.15 h1:OsZ2Sk84YUPJfi6BemhyMQyuR8/5tWu37WBMVUl8lJk=
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.19.15/go.mod h1:CYZDjBMY+MyT+U+QmXw81GBiq+lhgM97kIMdDAJk+hg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 h1:7AANQZkF3ihM8fbdftpjhken0TP9sBzFbV/Ze/Y4HXA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11/go.mod h1:NTF4QCGkm6fzVwncpkFQqoquQyOolcyXfbpC98urj+c=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 h1:ShdtWUZT37LCAA4Mw2kJAJtzaszfSHFb5n25sdcv4YE=
|
||||
|
||||
@@ -915,7 +915,7 @@ func (fs *AzureBlobFs) downloadPart(ctx context.Context, blockBlob *blockblob.Cl
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = fs.writeAtFull(w, buf, writeOffset, int(count))
|
||||
_, err = writeAtFull(w, buf, writeOffset, int(count))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1105,18 +1105,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read
|
||||
return err
|
||||
}
|
||||
|
||||
func (*AzureBlobFs) writeAtFull(w io.WriterAt, buf []byte, offset int64, count int) (int, error) {
|
||||
written := 0
|
||||
for written < count {
|
||||
n, err := w.WriteAt(buf[written:count], offset+int64(written))
|
||||
written += n
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (fs *AzureBlobFs) getCopyOptions(srcInfo os.FileInfo, updateModTime bool) *blob.StartCopyFromURLOptions {
|
||||
copyOptions := &blob.StartCopyFromURLOptions{}
|
||||
if fs.config.AccessTier != "" {
|
||||
|
||||
@@ -45,7 +45,6 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sts"
|
||||
@@ -227,53 +226,29 @@ func (fs *S3Fs) Lstat(name string) (os.FileInfo, error) {
|
||||
|
||||
// Open opens the named file for reading
|
||||
func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error) {
|
||||
attrs, err := fs.headObject(name)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
r, w, err := createPipeFn(fs.localTempDir, fs.config.DownloadPartSize*int64(fs.config.DownloadConcurrency)+1)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p := NewPipeReader(r)
|
||||
if readMetadata > 0 {
|
||||
attrs, err := fs.headObject(name)
|
||||
if err != nil {
|
||||
r.Close()
|
||||
w.Close()
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
p.setMetadata(attrs.Metadata)
|
||||
}
|
||||
|
||||
ctx, cancelFn := context.WithCancel(context.Background())
|
||||
downloader := manager.NewDownloader(fs.svc, func(d *manager.Downloader) {
|
||||
d.Concurrency = fs.config.DownloadConcurrency
|
||||
d.PartSize = fs.config.DownloadPartSize
|
||||
if offset == 0 && fs.config.DownloadPartMaxTime > 0 {
|
||||
d.ClientOptions = append(d.ClientOptions, func(o *s3.Options) {
|
||||
o.HTTPClient = getAWSHTTPClient(fs.config.DownloadPartMaxTime, 100*time.Millisecond,
|
||||
fs.config.SkipTLSVerify)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
var streamRange *string
|
||||
if offset > 0 {
|
||||
streamRange = aws.String(fmt.Sprintf("bytes=%d-", offset))
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer cancelFn()
|
||||
|
||||
n, err := downloader.Download(ctx, w, &s3.GetObjectInput{
|
||||
Bucket: aws.String(fs.config.Bucket),
|
||||
Key: aws.String(name),
|
||||
Range: streamRange,
|
||||
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
|
||||
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
|
||||
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
|
||||
})
|
||||
err := fs.handleDownload(ctx, name, offset, w, attrs)
|
||||
w.CloseWithError(err) //nolint:errcheck
|
||||
fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %+v", name, n, err)
|
||||
metric.S3TransferCompleted(n, 1, err)
|
||||
fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %d, err: %+v", name, w.GetWrittenBytes(), err)
|
||||
metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err)
|
||||
}()
|
||||
|
||||
return nil, p, cancelFn, nil
|
||||
}
|
||||
|
||||
@@ -673,25 +648,28 @@ func (fs *S3Fs) resolve(name *string, prefix string) (string, bool) {
|
||||
}
|
||||
|
||||
func (fs *S3Fs) setConfigDefaults() {
|
||||
const defaultPartSize = 1024 * 1024 * 5
|
||||
const defaultConcurrency = 5
|
||||
|
||||
if fs.config.UploadPartSize == 0 {
|
||||
fs.config.UploadPartSize = manager.DefaultUploadPartSize
|
||||
fs.config.UploadPartSize = defaultPartSize
|
||||
} else {
|
||||
if fs.config.UploadPartSize < 1024*1024 {
|
||||
fs.config.UploadPartSize *= 1024 * 1024
|
||||
}
|
||||
}
|
||||
if fs.config.UploadConcurrency == 0 {
|
||||
fs.config.UploadConcurrency = manager.DefaultUploadConcurrency
|
||||
fs.config.UploadConcurrency = defaultConcurrency
|
||||
}
|
||||
if fs.config.DownloadPartSize == 0 {
|
||||
fs.config.DownloadPartSize = manager.DefaultDownloadPartSize
|
||||
fs.config.DownloadPartSize = defaultPartSize
|
||||
} else {
|
||||
if fs.config.DownloadPartSize < 1024*1024 {
|
||||
fs.config.DownloadPartSize *= 1024 * 1024
|
||||
}
|
||||
}
|
||||
if fs.config.DownloadConcurrency == 0 {
|
||||
fs.config.DownloadConcurrency = manager.DefaultDownloadConcurrency
|
||||
fs.config.DownloadConcurrency = defaultConcurrency
|
||||
}
|
||||
}
|
||||
|
||||
@@ -815,6 +793,109 @@ func (fs *S3Fs) hasContents(name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (fs *S3Fs) downloadPart(ctx context.Context, name string, buf []byte, w io.WriterAt, start, count, writeOffset int64) error {
|
||||
rangeHeader := fmt.Sprintf("bytes=%d-%d", start, start+count)
|
||||
|
||||
resp, err := fs.svc.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(fs.config.Bucket),
|
||||
Key: aws.String(name),
|
||||
Range: &rangeHeader,
|
||||
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
|
||||
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
|
||||
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
_, err = io.ReadAtLeast(resp.Body, buf, int(count))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = writeAtFull(w, buf, writeOffset, int(count))
|
||||
return err
|
||||
}
|
||||
|
||||
func (fs *S3Fs) handleDownload(ctx context.Context, name string, offset int64, writer io.WriterAt, attrs *s3.HeadObjectOutput) error {
|
||||
contentLength := util.GetIntFromPointer(attrs.ContentLength)
|
||||
sizeToDownload := contentLength - offset
|
||||
if sizeToDownload < 0 {
|
||||
fsLog(fs, logger.LevelError, "invalid multipart download size or offset, size: %d, offset: %d, size to download: %d",
|
||||
contentLength, offset, sizeToDownload)
|
||||
return errors.New("the requested offset exceeds the file size")
|
||||
}
|
||||
if sizeToDownload == 0 {
|
||||
fsLog(fs, logger.LevelDebug, "nothing to download, offset %d, content length %d", offset, contentLength)
|
||||
return nil
|
||||
}
|
||||
partSize := fs.config.DownloadPartSize
|
||||
guard := make(chan struct{}, fs.config.DownloadConcurrency)
|
||||
var blockCtxTimeout time.Duration
|
||||
if fs.config.DownloadPartMaxTime > 0 {
|
||||
blockCtxTimeout = time.Duration(fs.config.DownloadPartSize) * time.Second
|
||||
} else {
|
||||
blockCtxTimeout = time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute
|
||||
}
|
||||
pool := newBufferAllocator(int(partSize))
|
||||
defer pool.free()
|
||||
|
||||
finished := false
|
||||
var wg sync.WaitGroup
|
||||
var errOnce sync.Once
|
||||
var hasError atomic.Bool
|
||||
var poolError error
|
||||
|
||||
poolCtx, poolCancel := context.WithCancel(ctx)
|
||||
defer poolCancel()
|
||||
|
||||
for part := 0; !finished; part++ {
|
||||
start := offset
|
||||
end := offset + partSize
|
||||
if end >= contentLength {
|
||||
end = contentLength
|
||||
finished = true
|
||||
}
|
||||
writeOffset := int64(part) * partSize
|
||||
offset = end
|
||||
|
||||
guard <- struct{}{}
|
||||
if hasError.Load() {
|
||||
fsLog(fs, logger.LevelDebug, "pool error, download for part %d not started", part)
|
||||
break
|
||||
}
|
||||
|
||||
buf := pool.getBuffer()
|
||||
wg.Add(1)
|
||||
go func(start, end, writeOffset int64, buf []byte) {
|
||||
defer func() {
|
||||
pool.releaseBuffer(buf)
|
||||
<-guard
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
innerCtx, cancelFn := context.WithDeadline(poolCtx, time.Now().Add(blockCtxTimeout))
|
||||
defer cancelFn()
|
||||
|
||||
err := fs.downloadPart(innerCtx, name, buf, writer, start, end-start, writeOffset)
|
||||
if err != nil {
|
||||
errOnce.Do(func() {
|
||||
fsLog(fs, logger.LevelError, "multipart download error: %+v", err)
|
||||
hasError.Store(true)
|
||||
poolError = fmt.Errorf("multipart download error: %w", err)
|
||||
poolCancel()
|
||||
})
|
||||
}
|
||||
}(start, end, writeOffset, buf)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(guard)
|
||||
|
||||
return poolError
|
||||
}
|
||||
|
||||
func (fs *S3Fs) initiateMultipartUpload(ctx context.Context, name, contentType string) (string, error) {
|
||||
ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout))
|
||||
defer cancelFn()
|
||||
@@ -1208,31 +1289,18 @@ func (*S3Fs) GetAvailableDiskSize(_ string) (*sftp.StatVFS, error) {
|
||||
|
||||
func (fs *S3Fs) downloadToWriter(name string, w PipeWriter) (int64, error) {
|
||||
fsLog(fs, logger.LevelDebug, "starting download before resuming upload, path %q", name)
|
||||
attrs, err := fs.headObject(name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ctx, cancelFn := context.WithTimeout(context.Background(), preResumeTimeout)
|
||||
defer cancelFn()
|
||||
|
||||
downloader := manager.NewDownloader(fs.svc, func(d *manager.Downloader) {
|
||||
d.Concurrency = fs.config.DownloadConcurrency
|
||||
d.PartSize = fs.config.DownloadPartSize
|
||||
if fs.config.DownloadPartMaxTime > 0 {
|
||||
d.ClientOptions = append(d.ClientOptions, func(o *s3.Options) {
|
||||
o.HTTPClient = getAWSHTTPClient(fs.config.DownloadPartMaxTime, 100*time.Millisecond,
|
||||
fs.config.SkipTLSVerify)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
n, err := downloader.Download(ctx, w, &s3.GetObjectInput{
|
||||
Bucket: aws.String(fs.config.Bucket),
|
||||
Key: aws.String(name),
|
||||
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
|
||||
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
|
||||
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
|
||||
})
|
||||
err = fs.handleDownload(ctx, name, 0, w, attrs)
|
||||
fsLog(fs, logger.LevelDebug, "download before resuming upload completed, path %q size: %d, err: %+v",
|
||||
name, n, err)
|
||||
metric.S3TransferCompleted(n, 1, err)
|
||||
return n, err
|
||||
name, w.GetWrittenBytes(), err)
|
||||
metric.S3TransferCompleted(w.GetWrittenBytes(), 1, err)
|
||||
return w.GetWrittenBytes(), err
|
||||
}
|
||||
|
||||
type s3DirLister struct {
|
||||
|
||||
@@ -1282,6 +1282,18 @@ func readFill(r io.Reader, buf []byte) (n int, err error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
func writeAtFull(w io.WriterAt, buf []byte, offset int64, count int) (int, error) {
|
||||
written := 0
|
||||
for written < count {
|
||||
n, err := w.WriteAt(buf[written:count], offset+int64(written))
|
||||
written += n
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
type bytesReaderWrapper struct {
|
||||
*bytes.Reader
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user