mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-06 22:30:56 +03:00
kms: improve modularity
This commit is contained in:
133
kms/kms.go
133
kms/kms.go
@@ -7,8 +7,8 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/logger"
|
||||
"github.com/drakkan/sftpgo/v2/util"
|
||||
)
|
||||
|
||||
@@ -26,8 +26,13 @@ type SecretProvider interface {
|
||||
SetKey(string)
|
||||
SetAdditionalData(string)
|
||||
SetStatus(SecretStatus)
|
||||
Clone() SecretProvider
|
||||
}
|
||||
|
||||
const (
|
||||
logSender = "kms"
|
||||
)
|
||||
|
||||
// SecretStatus defines the statuses of a Secret object
|
||||
type SecretStatus = string
|
||||
|
||||
@@ -51,12 +56,16 @@ const (
|
||||
SecretStatusRedacted SecretStatus = "Redacted"
|
||||
)
|
||||
|
||||
// Scheme defines the supported URL scheme
|
||||
type Scheme = string
|
||||
|
||||
// supported URL schemes
|
||||
const (
|
||||
localProviderName = "Local"
|
||||
builtinProviderName = "Builtin"
|
||||
awsProviderName = "AWS"
|
||||
gcpProviderName = "GCP"
|
||||
vaultProviderName = "VaultTransit"
|
||||
SchemeLocal Scheme = "local://"
|
||||
SchemeBuiltin Scheme = "builtin://"
|
||||
SchemeAWS Scheme = "awskms://"
|
||||
SchemeGCP Scheme = "gcpkms://"
|
||||
SchemeVaultTransit Scheme = "hashivault://"
|
||||
)
|
||||
|
||||
// Configuration defines the KMS configuration
|
||||
@@ -71,16 +80,32 @@ type Secrets struct {
|
||||
masterKey string
|
||||
}
|
||||
|
||||
type registeredSecretProvider struct {
|
||||
encryptedStatus SecretStatus
|
||||
newFn func(base BaseSecret, url, masterKey string) SecretProvider
|
||||
}
|
||||
|
||||
var (
|
||||
errWrongSecretStatus = errors.New("wrong secret status")
|
||||
// ErrWrongSecretStatus defines the error to return if the secret status is not appropriate
|
||||
// for the request operation
|
||||
ErrWrongSecretStatus = errors.New("wrong secret status")
|
||||
// ErrInvalidSecret defines the error to return if a secret is not valid
|
||||
ErrInvalidSecret = errors.New("invalid secret")
|
||||
errMalformedCiphertext = errors.New("malformed ciphertext")
|
||||
errInvalidSecret = errors.New("invalid secret")
|
||||
validSecretStatuses = []string{SecretStatusPlain, SecretStatusAES256GCM, SecretStatusSecretBox,
|
||||
SecretStatusVaultTransit, SecretStatusAWS, SecretStatusGCP, SecretStatusRedacted}
|
||||
config Configuration
|
||||
defaultTimeout = 10 * time.Second
|
||||
config Configuration
|
||||
secretProviders = make(map[string]registeredSecretProvider)
|
||||
)
|
||||
|
||||
// RegisterSecretProvider register a new secret provider
|
||||
func RegisterSecretProvider(scheme string, encryptedStatus SecretStatus, fn func(base BaseSecret, url, masterKey string) SecretProvider) {
|
||||
secretProviders[scheme] = registeredSecretProvider{
|
||||
encryptedStatus: encryptedStatus,
|
||||
newFn: fn,
|
||||
}
|
||||
}
|
||||
|
||||
// NewSecret builds a new Secret using the provided arguments
|
||||
func NewSecret(status SecretStatus, payload, key, data string) *Secret {
|
||||
return config.newSecret(status, payload, key, data)
|
||||
@@ -115,11 +140,18 @@ func (c *Configuration) Initialize() error {
|
||||
c.Secrets.masterKey = strings.TrimSpace(string(mKey))
|
||||
}
|
||||
config = *c
|
||||
if config.Secrets.URL == "" {
|
||||
config.Secrets.URL = "local://"
|
||||
}
|
||||
for k, v := range secretProviders {
|
||||
logger.Debug(logSender, "", "secret provider registered for scheme: %#v, encrypted status: %#v",
|
||||
k, v.encryptedStatus)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Configuration) newSecret(status SecretStatus, payload, key, data string) *Secret {
|
||||
base := baseSecret{
|
||||
base := BaseSecret{
|
||||
Status: status,
|
||||
Key: key,
|
||||
Payload: payload,
|
||||
@@ -130,17 +162,13 @@ func (c *Configuration) newSecret(status SecretStatus, payload, key, data string
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Configuration) getSecretProvider(base baseSecret) SecretProvider {
|
||||
if strings.HasPrefix(c.Secrets.URL, "hashivault://") {
|
||||
return newVaultSecret(base, c.Secrets.URL, c.Secrets.masterKey)
|
||||
func (c *Configuration) getSecretProvider(base BaseSecret) SecretProvider {
|
||||
for k, v := range secretProviders {
|
||||
if strings.HasPrefix(c.Secrets.URL, k) {
|
||||
return v.newFn(base, c.Secrets.URL, c.Secrets.masterKey)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Secrets.URL, "awskms://") {
|
||||
return newAWSSecret(base, c.Secrets.URL, c.Secrets.masterKey)
|
||||
}
|
||||
if strings.HasPrefix(c.Secrets.URL, "gcpkms://") {
|
||||
return newGCPSecret(base, c.Secrets.URL, c.Secrets.masterKey)
|
||||
}
|
||||
return newLocalSecret(base, c.Secrets.masterKey)
|
||||
return NewLocalSecret(base, c.Secrets.URL, c.Secrets.masterKey)
|
||||
}
|
||||
|
||||
// Secret defines the struct used to store confidential data
|
||||
@@ -154,7 +182,7 @@ func (s *Secret) MarshalJSON() ([]byte, error) {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
return json.Marshal(&baseSecret{
|
||||
return json.Marshal(&BaseSecret{
|
||||
Status: s.provider.GetStatus(),
|
||||
Payload: s.provider.GetPayload(),
|
||||
Key: s.provider.GetKey(),
|
||||
@@ -169,7 +197,7 @@ func (s *Secret) UnmarshalJSON(data []byte) error {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
baseSecret := baseSecret{}
|
||||
baseSecret := BaseSecret{}
|
||||
err := json.Unmarshal(data, &baseSecret)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -178,23 +206,21 @@ func (s *Secret) UnmarshalJSON(data []byte) error {
|
||||
s.provider = config.getSecretProvider(baseSecret)
|
||||
return nil
|
||||
}
|
||||
switch baseSecret.Status {
|
||||
case SecretStatusAES256GCM:
|
||||
s.provider = newBuiltinSecret(baseSecret)
|
||||
case SecretStatusSecretBox:
|
||||
s.provider = newLocalSecret(baseSecret, config.Secrets.masterKey)
|
||||
case SecretStatusVaultTransit:
|
||||
s.provider = newVaultSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
|
||||
case SecretStatusAWS:
|
||||
s.provider = newAWSSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
|
||||
case SecretStatusGCP:
|
||||
s.provider = newGCPSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
|
||||
case SecretStatusPlain, SecretStatusRedacted:
|
||||
|
||||
if baseSecret.Status == SecretStatusPlain || baseSecret.Status == SecretStatusRedacted {
|
||||
s.provider = config.getSecretProvider(baseSecret)
|
||||
default:
|
||||
return errInvalidSecret
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
for _, v := range secretProviders {
|
||||
if v.encryptedStatus == baseSecret.Status {
|
||||
s.provider = v.newFn(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
logger.Debug(logSender, "", "no provider registered for status %#v", baseSecret.Status)
|
||||
|
||||
return ErrInvalidSecret
|
||||
}
|
||||
|
||||
// IsEqual returns true if all the secrets fields are equal
|
||||
@@ -222,36 +248,9 @@ func (s *Secret) Clone() *Secret {
|
||||
s.RLock()
|
||||
defer s.RUnlock()
|
||||
|
||||
baseSecret := baseSecret{
|
||||
Status: s.provider.GetStatus(),
|
||||
Payload: s.provider.GetPayload(),
|
||||
Key: s.provider.GetKey(),
|
||||
AdditionalData: s.provider.GetAdditionalData(),
|
||||
Mode: s.provider.GetMode(),
|
||||
return &Secret{
|
||||
provider: s.provider.Clone(),
|
||||
}
|
||||
switch s.provider.Name() {
|
||||
case builtinProviderName:
|
||||
return &Secret{
|
||||
provider: newBuiltinSecret(baseSecret),
|
||||
}
|
||||
case awsProviderName:
|
||||
return &Secret{
|
||||
provider: newAWSSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
|
||||
}
|
||||
case gcpProviderName:
|
||||
return &Secret{
|
||||
provider: newGCPSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
|
||||
}
|
||||
case localProviderName:
|
||||
return &Secret{
|
||||
provider: newLocalSecret(baseSecret, config.Secrets.masterKey),
|
||||
}
|
||||
case vaultProviderName:
|
||||
return &Secret{
|
||||
provider: newVaultSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
|
||||
}
|
||||
}
|
||||
return NewSecret(s.GetStatus(), s.GetPayload(), s.GetKey(), s.GetAdditionalData())
|
||||
}
|
||||
|
||||
// IsEncrypted returns true if the secret is encrypted
|
||||
|
||||
Reference in New Issue
Block a user