data provider: add config options for certs validation/authentication

Fixes #682

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-01-30 18:04:03 +01:00
parent 1df1225eed
commit fb2d59ec92
6 changed files with 101 additions and 17 deletions

View File

@@ -5,15 +5,16 @@ package dataprovider
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"errors"
"fmt"
"os"
"strings"
"time"
// we import go-sql-driver/mysql here to be able to disable MySQL support using a build tag
_ "github.com/go-sql-driver/mysql"
"github.com/go-sql-driver/mysql"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/version"
@@ -114,10 +115,18 @@ func init() {
func initializeMySQLProvider() error {
var err error
dbHandle, err := sql.Open("mysql", getMySQLConnectionString(false))
connString, err := getMySQLConnectionString(false)
if err != nil {
return err
}
redactedConnString, err := getMySQLConnectionString(true)
if err != nil {
return err
}
dbHandle, err := sql.Open("mysql", connString)
if err == nil {
providerLog(logger.LevelDebug, "mysql database handle created, connection string: %#v, pool size: %v",
getMySQLConnectionString(true), config.PoolSize)
redactedConnString, config.PoolSize)
dbHandle.SetMaxOpenConns(config.PoolSize)
if config.PoolSize > 0 {
dbHandle.SetMaxIdleConns(config.PoolSize)
@@ -128,23 +137,58 @@ func initializeMySQLProvider() error {
provider = &MySQLProvider{dbHandle: dbHandle}
} else {
providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v",
getMySQLConnectionString(true), err)
redactedConnString, err)
}
return err
}
func getMySQLConnectionString(redactedPwd bool) string {
func getMySQLConnectionString(redactedPwd bool) (string, error) {
var connectionString string
if config.ConnectionString == "" {
password := config.Password
if redactedPwd {
password = "[redacted]"
}
sslMode := getSSLMode()
if sslMode == "custom" && !redactedPwd {
tlsConfig := &tls.Config{}
if config.RootCert != "" {
rootCAs, err := x509.SystemCertPool()
if err != nil {
rootCAs = x509.NewCertPool()
}
rootCrt, err := os.ReadFile(config.RootCert)
if err != nil {
return "", fmt.Errorf("unable to load root certificate %#v: %v", config.RootCert, err)
}
if !rootCAs.AppendCertsFromPEM(rootCrt) {
return "", fmt.Errorf("unable to parse root certificate %#v", config.RootCert)
}
tlsConfig.RootCAs = rootCAs
}
if config.ClientCert != "" && config.ClientKey != "" {
clientCert := make([]tls.Certificate, 0, 1)
tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey)
if err != nil {
return "", fmt.Errorf("unable to load key pair %#v, %#v: %v", config.ClientCert, config.ClientKey, err)
}
clientCert = append(clientCert, tlsCert)
tlsConfig.Certificates = clientCert
}
if config.SSLMode == 2 {
tlsConfig.InsecureSkipVerify = true
}
providerLog(logger.LevelInfo, "registering custom TLS config, root cert %#v, client cert %#v, client key %#v",
config.RootCert, config.ClientCert, config.ClientKey)
if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
return "", fmt.Errorf("unable to register tls config: %v", err)
}
}
connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=10s&readTimeout=10s",
config.Username, password, config.Host, config.Port, config.Name, getSSLMode())
config.Username, password, config.Host, config.Port, config.Name, sslMode)
} else {
connectionString = config.ConnectionString
}
return connectionString
return connectionString, nil
}
func (p *MySQLProvider) checkAvailability() error {