S3: add SSE customer key

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2024-08-15 10:09:06 +02:00
parent d783ffc13f
commit 2fbf608895
14 changed files with 264 additions and 75 deletions

View File

@@ -19,7 +19,10 @@ package vfs
import (
"context"
"crypto/md5"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -72,10 +75,13 @@ type S3Fs struct {
connectionID string
localTempDir string
// if not empty this fs is mouted as virtual folder in the specified path
mountPath string
config *S3FsConfig
svc *s3.Client
ctxTimeout time.Duration
mountPath string
config *S3FsConfig
svc *s3.Client
ctxTimeout time.Duration
sseCustomerKey string
sseCustomerKeyMD5 string
sseCustomerAlgo string
}
func init() {
@@ -121,6 +127,23 @@ func NewS3Fs(connectionID, localTempDir, mountPath string, s3Config S3FsConfig)
fs.config.SessionToken),
)
}
if !fs.config.SSECustomerKey.IsEmpty() {
if err := fs.config.SSECustomerKey.TryDecrypt(); err != nil {
return fs, err
}
key := fs.config.SSECustomerKey.GetPayload()
if len(key) == 32 {
md5sumBinary := md5.Sum([]byte(key))
fs.sseCustomerKey = base64.StdEncoding.EncodeToString([]byte(key))
fs.sseCustomerKeyMD5 = base64.StdEncoding.EncodeToString(md5sumBinary[:])
} else {
keyHash := sha256.Sum256([]byte(key))
md5sumBinary := md5.Sum(keyHash[:])
fs.sseCustomerKey = base64.StdEncoding.EncodeToString(keyHash[:])
fs.sseCustomerKeyMD5 = base64.StdEncoding.EncodeToString(md5sumBinary[:])
}
fs.sseCustomerAlgo = "AES256"
}
fs.setConfigDefaults()
@@ -242,9 +265,12 @@ func (fs *S3Fs) Open(name string, offset int64) (File, PipeReader, func(), error
defer cancelFn()
n, err := downloader.Download(ctx, w, &s3.GetObjectInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
Range: streamRange,
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
Range: streamRange,
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
w.CloseWithError(err) //nolint:errcheck
fsLog(fs, logger.LevelDebug, "download completed, path: %q size: %v, err: %+v", name, n, err)
@@ -293,12 +319,15 @@ func (fs *S3Fs) Create(name string, flag, checks int) (File, PipeWriter, func(),
contentType = mime.TypeByExtension(path.Ext(name))
}
_, err := uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
Body: r,
ACL: types.ObjectCannedACL(fs.config.ACL),
StorageClass: types.StorageClass(fs.config.StorageClass),
ContentType: util.NilIfEmpty(contentType),
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
Body: r,
ACL: types.ObjectCannedACL(fs.config.ACL),
StorageClass: types.StorageClass(fs.config.StorageClass),
ContentType: util.NilIfEmpty(contentType),
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
r.CloseWithError(err) //nolint:errcheck
p.Done(err)
@@ -703,12 +732,18 @@ func (fs *S3Fs) copyFileInternal(source, target string, srcInfo os.FileInfo) err
defer cancelFn()
copyObject := &s3.CopyObjectInput{
Bucket: aws.String(fs.config.Bucket),
CopySource: aws.String(copySource),
Key: aws.String(target),
StorageClass: types.StorageClass(fs.config.StorageClass),
ACL: types.ObjectCannedACL(fs.config.ACL),
ContentType: util.NilIfEmpty(contentType),
Bucket: aws.String(fs.config.Bucket),
CopySource: aws.String(copySource),
Key: aws.String(target),
StorageClass: types.StorageClass(fs.config.StorageClass),
ACL: types.ObjectCannedACL(fs.config.ACL),
ContentType: util.NilIfEmpty(contentType),
CopySourceSSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
CopySourceSSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
CopySourceSSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
}
metadata := getMetadata(srcInfo)
@@ -812,11 +847,14 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
defer cancelFn()
res, err := fs.svc.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(target),
StorageClass: types.StorageClass(fs.config.StorageClass),
ACL: types.ObjectCannedACL(fs.config.ACL),
ContentType: util.NilIfEmpty(contentType),
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(target),
StorageClass: types.StorageClass(fs.config.StorageClass),
ACL: types.ObjectCannedACL(fs.config.ACL),
ContentType: util.NilIfEmpty(contentType),
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
if err != nil {
return fmt.Errorf("unable to create multipart copy request: %w", err)
@@ -871,12 +909,18 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int
defer innerCancelFn()
partResp, err := fs.svc.UploadPartCopy(innerCtx, &s3.UploadPartCopyInput{
Bucket: aws.String(fs.config.Bucket),
CopySource: aws.String(source),
Key: aws.String(target),
PartNumber: &partNum,
UploadId: aws.String(uploadID),
CopySourceRange: aws.String(fmt.Sprintf("bytes=%d-%d", partStart, partEnd-1)),
Bucket: aws.String(fs.config.Bucket),
CopySource: aws.String(source),
Key: aws.String(target),
PartNumber: &partNum,
UploadId: aws.String(uploadID),
CopySourceRange: aws.String(fmt.Sprintf("bytes=%d-%d", partStart, partEnd-1)),
CopySourceSSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
CopySourceSSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
CopySourceSSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
if err != nil {
errOnce.Do(func() {
@@ -959,8 +1003,11 @@ func (fs *S3Fs) headObject(name string) (*s3.HeadObjectOutput, error) {
defer cancelFn()
obj, err := fs.svc.HeadObject(ctx, &s3.HeadObjectInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
metric.S3HeadObjectCompleted(err)
return obj, err
@@ -1002,8 +1049,11 @@ func (fs *S3Fs) downloadToWriter(name string, w PipeWriter) (int64, error) {
})
n, err := downloader.Download(ctx, w, &s3.GetObjectInput{
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
Bucket: aws.String(fs.config.Bucket),
Key: aws.String(name),
SSECustomerKey: util.NilIfEmpty(fs.sseCustomerKey),
SSECustomerAlgorithm: util.NilIfEmpty(fs.sseCustomerAlgo),
SSECustomerKeyMD5: util.NilIfEmpty(fs.sseCustomerKeyMD5),
})
fsLog(fs, logger.LevelDebug, "download before resuming upload completed, path %q size: %d, err: %+v",
name, n, err)