diff --git a/internal/vfs/azblobfs.go b/internal/vfs/azblobfs.go index 122d6c2a..83f55b52 100644 --- a/internal/vfs/azblobfs.go +++ b/internal/vfs/azblobfs.go @@ -946,6 +946,8 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *b guard := make(chan struct{}, fs.config.DownloadConcurrency) blockCtxTimeout := time.Duration(fs.config.DownloadPartSize/(1024*1024)) * time.Minute pool := newBufferAllocator(int(partSize)) + defer pool.free() + finished := false var wg sync.WaitGroup var errOnce sync.Once @@ -999,7 +1001,6 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *b wg.Wait() close(guard) - pool.free() return poolError } @@ -1014,6 +1015,8 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read // sync.Pool seems to use a lot of memory so prefer our own, very simple, allocator // we only need to recycle few byte slices pool := newBufferAllocator(int(partSize)) + defer pool.free() + finished := false var blocks []string var wg sync.WaitGroup @@ -1027,7 +1030,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read for part := 0; !finished; part++ { buf := pool.getBuffer() - n, err := fs.readFill(reader, buf) + n, err := readFill(reader, buf) if err == io.EOF { // read finished, if n > 0 we need to process the last data chunck if n == 0 { @@ -1037,7 +1040,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read finished = true } else if err != nil { pool.releaseBuffer(buf) - pool.free() return err } @@ -1046,7 +1048,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read generatedUUID, err := uuid.NewRandom() if err != nil { pool.releaseBuffer(buf) - pool.free() return fmt.Errorf("unable to generate block ID: %w", err) } blockID := base64.StdEncoding.EncodeToString([]byte(generatedUUID.String())) @@ -1087,7 +1088,6 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read wg.Wait() close(guard) - pool.free() if poolError != nil { return poolError @@ -1117,16 +1117,6 @@ func (*AzureBlobFs) writeAtFull(w io.WriterAt, buf []byte, offset int64, count i return written, nil } -// copied from rclone -func (*AzureBlobFs) readFill(r io.Reader, buf []byte) (n int, err error) { - var nn int - for n < len(buf) && err == nil { - nn, err = r.Read(buf[n:]) - n += nn - } - return n, err -} - func (fs *AzureBlobFs) getCopyOptions(srcInfo os.FileInfo, updateModTime bool) *blob.StartCopyFromURLOptions { copyOptions := &blob.StartCopyFromURLOptions{} if fs.config.AccessTier != "" { @@ -1187,66 +1177,6 @@ func getAzContainerClientOptions() *container.ClientOptions { } } -type bytesReaderWrapper struct { - *bytes.Reader -} - -func (b *bytesReaderWrapper) Close() error { - return nil -} - -type bufferAllocator struct { - sync.Mutex - available [][]byte - bufferSize int - finalized bool -} - -func newBufferAllocator(size int) *bufferAllocator { - return &bufferAllocator{ - bufferSize: size, - finalized: false, - } -} - -func (b *bufferAllocator) getBuffer() []byte { - b.Lock() - defer b.Unlock() - - if len(b.available) > 0 { - var result []byte - - truncLength := len(b.available) - 1 - result = b.available[truncLength] - - b.available[truncLength] = nil - b.available = b.available[:truncLength] - - return result - } - - return make([]byte, b.bufferSize) -} - -func (b *bufferAllocator) releaseBuffer(buf []byte) { - b.Lock() - defer b.Unlock() - - if b.finalized || len(buf) != b.bufferSize { - return - } - - b.available = append(b.available, buf) -} - -func (b *bufferAllocator) free() { - b.Lock() - defer b.Unlock() - - b.available = nil - b.finalized = true -} - type azureBlobDirLister struct { baseDirLister paginator *runtime.Pager[container.ListBlobsHierarchyResponse] diff --git a/internal/vfs/s3fs.go b/internal/vfs/s3fs.go index 639266c9..dc7118e3 100644 --- a/internal/vfs/s3fs.go +++ b/internal/vfs/s3fs.go @@ -17,6 +17,7 @@ package vfs import ( + "bytes" "context" "crypto/md5" "crypto/sha256" @@ -255,7 +256,7 @@ func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error var streamRange *string if offset > 0 { - streamRange = aws.String(fmt.Sprintf("bytes=%v-", offset)) + streamRange = aws.String(fmt.Sprintf("bytes=%d-", offset)) } go func() { @@ -295,16 +296,6 @@ func (fs *S3Fs) Create(name string, flag, checks int) (File, PipeWriter, func(), p = NewPipeWriter(w) } ctx, cancelFn := context.WithCancel(context.Background()) - uploader := manager.NewUploader(fs.svc, func(u *manager.Uploader) { - u.Concurrency = fs.config.UploadConcurrency - u.PartSize = fs.config.UploadPartSize - if fs.config.UploadPartMaxTime > 0 { - u.ClientOptions = append(u.ClientOptions, func(o *s3.Options) { - o.HTTPClient = getAWSHTTPClient(fs.config.UploadPartMaxTime, 100*time.Millisecond, - fs.config.SkipTLSVerify) - }) - } - }) go func() { defer cancelFn() @@ -315,17 +306,7 @@ func (fs *S3Fs) Create(name string, flag, checks int) (File, PipeWriter, func(), } else { contentType = mime.TypeByExtension(path.Ext(name)) } - _, err := uploader.Upload(ctx, &s3.PutObjectInput{ - Bucket: aws.String(fs.config.Bucket), - Key: aws.String(name), - Body: r, - ACL: types.ObjectCannedACL(fs.config.ACL), - StorageClass: types.StorageClass(fs.config.StorageClass), - ContentType: util.NilIfEmpty(contentType), - SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), - SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), - SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), - }) + err := fs.handleMultipartUpload(ctx, r, name, contentType) r.CloseWithError(err) //nolint:errcheck p.Done(err) fsLog(fs, logger.LevelDebug, "upload completed, path: %q, acl: %q, readed bytes: %d, err: %+v", @@ -834,6 +815,181 @@ func (fs *S3Fs) hasContents(name string) (bool, error) { return false, nil } +func (fs *S3Fs) initiateMultipartUpload(ctx context.Context, name, contentType string) (string, error) { + ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + res, err := fs.svc.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + StorageClass: types.StorageClass(fs.config.StorageClass), + ACL: types.ObjectCannedACL(fs.config.ACL), + ContentType: util.NilIfEmpty(contentType), + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + if err != nil { + return "", fmt.Errorf("unable to create multipart upload request: %w", err) + } + uploadID := util.GetStringFromPointer(res.UploadId) + if uploadID == "" { + return "", errors.New("unable to get multipart upload ID") + } + return uploadID, nil +} + +func (fs *S3Fs) uploadPart(ctx context.Context, name, uploadID string, partNumber int32, data []byte) (*string, error) { + timeout := time.Duration(fs.config.UploadPartSize/(1024*1024)) * time.Minute + if fs.config.UploadPartMaxTime > 0 { + timeout = time.Duration(fs.config.UploadPartMaxTime) + } + ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancelFn() + + resp, err := fs.svc.UploadPart(ctx, &s3.UploadPartInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + PartNumber: &partNumber, + UploadId: aws.String(uploadID), + Body: bytes.NewReader(data), + SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey), + SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo), + SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5), + }) + if err != nil { + return nil, fmt.Errorf("unable to upload part number %d: %w", partNumber, err) + } + return resp.ETag, nil +} + +func (fs *S3Fs) completeMultipartUpload(ctx context.Context, name, uploadID string, completedParts []types.CompletedPart) error { + ctx, cancelFn := context.WithDeadline(ctx, time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + _, err := fs.svc.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + UploadId: aws.String(uploadID), + MultipartUpload: &types.CompletedMultipartUpload{ + Parts: completedParts, + }, + }) + return err +} + +func (fs *S3Fs) abortMultipartUpload(name, uploadID string) error { + ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) + defer cancelFn() + + _, err := fs.svc.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{ + Bucket: aws.String(fs.config.Bucket), + Key: aws.String(name), + UploadId: aws.String(uploadID), + }) + return err +} + +func (fs *S3Fs) handleMultipartUpload(ctx context.Context, reader io.Reader, name, contentType string) error { + uploadID, err := fs.initiateMultipartUpload(ctx, name, contentType) + if err != nil { + return err + } + guard := make(chan struct{}, fs.config.UploadConcurrency) + finished := false + var partMutex sync.Mutex + var completedParts []types.CompletedPart + var wg sync.WaitGroup + var hasError atomic.Bool + var poolErr error + var errOnce sync.Once + var partNumber int32 + + pool := newBufferAllocator(int(fs.config.UploadPartSize)) + defer pool.free() + + poolCtx, poolCancel := context.WithCancel(ctx) + defer poolCancel() + + finalizeFailedUpload := func(err error) { + fsLog(fs, logger.LevelError, "finalize failed multipart upload after error: %v", err) + hasError.Store(true) + poolErr = err + poolCancel() + if abortErr := fs.abortMultipartUpload(name, uploadID); abortErr != nil { + fsLog(fs, logger.LevelError, "unable to abort multipart upload: %+v", abortErr) + } + } + + for partNumber = 1; !finished; partNumber++ { + buf := pool.getBuffer() + + n, err := readFill(reader, buf) + if err == io.EOF { + if n == 0 && partNumber > 1 { + pool.releaseBuffer(buf) + break + } + finished = true + } else if err != nil { + pool.releaseBuffer(buf) + errOnce.Do(func() { + finalizeFailedUpload(err) + }) + return err + } + guard <- struct{}{} + if hasError.Load() { + fsLog(fs, logger.LevelError, "pool error, upload for part %d not started", partNumber) + pool.releaseBuffer(buf) + break + } + + wg.Add(1) + go func(partNum int32, buf []byte, bytesRead int) { + defer func() { + pool.releaseBuffer(buf) + <-guard + wg.Done() + }() + + etag, err := fs.uploadPart(poolCtx, name, uploadID, partNum, buf[:bytesRead]) + if err != nil { + errOnce.Do(func() { + finalizeFailedUpload(err) + }) + return + } + partMutex.Lock() + completedParts = append(completedParts, types.CompletedPart{ + PartNumber: &partNum, + ETag: etag, + }) + partMutex.Unlock() + }(partNumber, buf, n) + } + + wg.Wait() + close(guard) + + if poolErr != nil { + return poolErr + } + + sort.Slice(completedParts, func(i, j int) bool { + getPartNumber := func(number *int32) int32 { + if number == nil { + return 0 + } + return *number + } + + return getPartNumber(completedParts[i].PartNumber) < getPartNumber(completedParts[j].PartNumber) + }) + + return fs.completeMultipartUpload(ctx, name, uploadID, completedParts) +} + func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int64) error { ctx, cancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) defer cancelFn() @@ -921,15 +1077,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int copyError = fmt.Errorf("error copying part number %d: %w", partNum, err) opCancel() - abortCtx, abortCancelFn := context.WithDeadline(context.Background(), time.Now().Add(fs.ctxTimeout)) - defer abortCancelFn() - - _, errAbort := fs.svc.AbortMultipartUpload(abortCtx, &s3.AbortMultipartUploadInput{ - Bucket: aws.String(fs.config.Bucket), - Key: aws.String(target), - UploadId: aws.String(uploadID), - }) - if errAbort != nil { + if errAbort := fs.abortMultipartUpload(target, uploadID); errAbort != nil { fsLog(fs, logger.LevelError, "unable to abort multipart copy: %+v", errAbort) } }) diff --git a/internal/vfs/vfs.go b/internal/vfs/vfs.go index 1fcb01be..add7f41c 100644 --- a/internal/vfs/vfs.go +++ b/internal/vfs/vfs.go @@ -16,6 +16,7 @@ package vfs import ( + "bytes" "errors" "fmt" "io" @@ -1271,6 +1272,76 @@ func doRecursiveRename(fs Fs, source, target string, } } +// copied from rclone +func readFill(r io.Reader, buf []byte) (n int, err error) { + var nn int + for n < len(buf) && err == nil { + nn, err = r.Read(buf[n:]) + n += nn + } + return n, err +} + +type bytesReaderWrapper struct { + *bytes.Reader +} + +func (b *bytesReaderWrapper) Close() error { + return nil +} + +type bufferAllocator struct { + sync.Mutex + available [][]byte + bufferSize int + finalized bool +} + +func newBufferAllocator(size int) *bufferAllocator { + return &bufferAllocator{ + bufferSize: size, + finalized: false, + } +} + +func (b *bufferAllocator) getBuffer() []byte { + b.Lock() + defer b.Unlock() + + if len(b.available) > 0 { + var result []byte + + truncLength := len(b.available) - 1 + result = b.available[truncLength] + + b.available[truncLength] = nil + b.available = b.available[:truncLength] + + return result + } + + return make([]byte, b.bufferSize) +} + +func (b *bufferAllocator) releaseBuffer(buf []byte) { + b.Lock() + defer b.Unlock() + + if b.finalized || len(buf) != b.bufferSize { + return + } + + b.available = append(b.available, buf) +} + +func (b *bufferAllocator) free() { + b.Lock() + defer b.Unlock() + + b.available = nil + b.finalized = true +} + func fsLog(fs Fs, level logger.LogLevel, format string, v ...any) { logger.Log(level, fs.Name(), fs.ConnectionID(), format, v...) }