kms: improve modularity

This commit is contained in:
Nicola Murino
2021-07-13 21:17:21 +02:00
parent e1a2451c22
commit 776dffcf12
22 changed files with 394 additions and 357 deletions

View File

@@ -7,8 +7,8 @@ import (
"os"
"strings"
"sync"
"time"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
@@ -26,8 +26,13 @@ type SecretProvider interface {
SetKey(string)
SetAdditionalData(string)
SetStatus(SecretStatus)
Clone() SecretProvider
}
const (
logSender = "kms"
)
// SecretStatus defines the statuses of a Secret object
type SecretStatus = string
@@ -51,12 +56,16 @@ const (
SecretStatusRedacted SecretStatus = "Redacted"
)
// Scheme defines the supported URL scheme
type Scheme = string
// supported URL schemes
const (
localProviderName = "Local"
builtinProviderName = "Builtin"
awsProviderName = "AWS"
gcpProviderName = "GCP"
vaultProviderName = "VaultTransit"
SchemeLocal Scheme = "local://"
SchemeBuiltin Scheme = "builtin://"
SchemeAWS Scheme = "awskms://"
SchemeGCP Scheme = "gcpkms://"
SchemeVaultTransit Scheme = "hashivault://"
)
// Configuration defines the KMS configuration
@@ -71,16 +80,32 @@ type Secrets struct {
masterKey string
}
type registeredSecretProvider struct {
encryptedStatus SecretStatus
newFn func(base BaseSecret, url, masterKey string) SecretProvider
}
var (
errWrongSecretStatus = errors.New("wrong secret status")
// ErrWrongSecretStatus defines the error to return if the secret status is not appropriate
// for the request operation
ErrWrongSecretStatus = errors.New("wrong secret status")
// ErrInvalidSecret defines the error to return if a secret is not valid
ErrInvalidSecret = errors.New("invalid secret")
errMalformedCiphertext = errors.New("malformed ciphertext")
errInvalidSecret = errors.New("invalid secret")
validSecretStatuses = []string{SecretStatusPlain, SecretStatusAES256GCM, SecretStatusSecretBox,
SecretStatusVaultTransit, SecretStatusAWS, SecretStatusGCP, SecretStatusRedacted}
config Configuration
defaultTimeout = 10 * time.Second
config Configuration
secretProviders = make(map[string]registeredSecretProvider)
)
// RegisterSecretProvider register a new secret provider
func RegisterSecretProvider(scheme string, encryptedStatus SecretStatus, fn func(base BaseSecret, url, masterKey string) SecretProvider) {
secretProviders[scheme] = registeredSecretProvider{
encryptedStatus: encryptedStatus,
newFn: fn,
}
}
// NewSecret builds a new Secret using the provided arguments
func NewSecret(status SecretStatus, payload, key, data string) *Secret {
return config.newSecret(status, payload, key, data)
@@ -115,11 +140,18 @@ func (c *Configuration) Initialize() error {
c.Secrets.masterKey = strings.TrimSpace(string(mKey))
}
config = *c
if config.Secrets.URL == "" {
config.Secrets.URL = "local://"
}
for k, v := range secretProviders {
logger.Debug(logSender, "", "secret provider registered for scheme: %#v, encrypted status: %#v",
k, v.encryptedStatus)
}
return nil
}
func (c *Configuration) newSecret(status SecretStatus, payload, key, data string) *Secret {
base := baseSecret{
base := BaseSecret{
Status: status,
Key: key,
Payload: payload,
@@ -130,17 +162,13 @@ func (c *Configuration) newSecret(status SecretStatus, payload, key, data string
}
}
func (c *Configuration) getSecretProvider(base baseSecret) SecretProvider {
if strings.HasPrefix(c.Secrets.URL, "hashivault://") {
return newVaultSecret(base, c.Secrets.URL, c.Secrets.masterKey)
func (c *Configuration) getSecretProvider(base BaseSecret) SecretProvider {
for k, v := range secretProviders {
if strings.HasPrefix(c.Secrets.URL, k) {
return v.newFn(base, c.Secrets.URL, c.Secrets.masterKey)
}
}
if strings.HasPrefix(c.Secrets.URL, "awskms://") {
return newAWSSecret(base, c.Secrets.URL, c.Secrets.masterKey)
}
if strings.HasPrefix(c.Secrets.URL, "gcpkms://") {
return newGCPSecret(base, c.Secrets.URL, c.Secrets.masterKey)
}
return newLocalSecret(base, c.Secrets.masterKey)
return NewLocalSecret(base, c.Secrets.URL, c.Secrets.masterKey)
}
// Secret defines the struct used to store confidential data
@@ -154,7 +182,7 @@ func (s *Secret) MarshalJSON() ([]byte, error) {
s.RLock()
defer s.RUnlock()
return json.Marshal(&baseSecret{
return json.Marshal(&BaseSecret{
Status: s.provider.GetStatus(),
Payload: s.provider.GetPayload(),
Key: s.provider.GetKey(),
@@ -169,7 +197,7 @@ func (s *Secret) UnmarshalJSON(data []byte) error {
s.Lock()
defer s.Unlock()
baseSecret := baseSecret{}
baseSecret := BaseSecret{}
err := json.Unmarshal(data, &baseSecret)
if err != nil {
return err
@@ -178,23 +206,21 @@ func (s *Secret) UnmarshalJSON(data []byte) error {
s.provider = config.getSecretProvider(baseSecret)
return nil
}
switch baseSecret.Status {
case SecretStatusAES256GCM:
s.provider = newBuiltinSecret(baseSecret)
case SecretStatusSecretBox:
s.provider = newLocalSecret(baseSecret, config.Secrets.masterKey)
case SecretStatusVaultTransit:
s.provider = newVaultSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
case SecretStatusAWS:
s.provider = newAWSSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
case SecretStatusGCP:
s.provider = newGCPSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
case SecretStatusPlain, SecretStatusRedacted:
if baseSecret.Status == SecretStatusPlain || baseSecret.Status == SecretStatusRedacted {
s.provider = config.getSecretProvider(baseSecret)
default:
return errInvalidSecret
return nil
}
return nil
for _, v := range secretProviders {
if v.encryptedStatus == baseSecret.Status {
s.provider = v.newFn(baseSecret, config.Secrets.URL, config.Secrets.masterKey)
return nil
}
}
logger.Debug(logSender, "", "no provider registered for status %#v", baseSecret.Status)
return ErrInvalidSecret
}
// IsEqual returns true if all the secrets fields are equal
@@ -222,36 +248,9 @@ func (s *Secret) Clone() *Secret {
s.RLock()
defer s.RUnlock()
baseSecret := baseSecret{
Status: s.provider.GetStatus(),
Payload: s.provider.GetPayload(),
Key: s.provider.GetKey(),
AdditionalData: s.provider.GetAdditionalData(),
Mode: s.provider.GetMode(),
return &Secret{
provider: s.provider.Clone(),
}
switch s.provider.Name() {
case builtinProviderName:
return &Secret{
provider: newBuiltinSecret(baseSecret),
}
case awsProviderName:
return &Secret{
provider: newAWSSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
}
case gcpProviderName:
return &Secret{
provider: newGCPSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
}
case localProviderName:
return &Secret{
provider: newLocalSecret(baseSecret, config.Secrets.masterKey),
}
case vaultProviderName:
return &Secret{
provider: newVaultSecret(baseSecret, config.Secrets.URL, config.Secrets.masterKey),
}
}
return NewSecret(s.GetStatus(), s.GetPayload(), s.GetKey(), s.GetAdditionalData())
}
// IsEncrypted returns true if the secret is encrypted