postgres provider: add support for load balancing

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2023-03-25 09:29:13 +01:00
parent 354fc9b3d6
commit e17068a76f
8 changed files with 90 additions and 51 deletions

View File

@@ -23,11 +23,13 @@ import (
"database/sql"
"errors"
"fmt"
"net"
"strconv"
"strings"
"time"
// we import pgx here to be able to disable PostgreSQL support using a build tag
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/version"
@@ -233,25 +235,61 @@ func init() {
}
func initializePGSQLProvider() error {
var err error
dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false))
if err == nil {
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
getPGSQLConnectionString(true), config.PoolSize)
dbHandle.SetMaxOpenConns(config.PoolSize)
if config.PoolSize > 0 {
dbHandle.SetMaxIdleConns(config.PoolSize)
} else {
dbHandle.SetMaxIdleConns(2)
var dbHandle *sql.DB
if config.TargetSessionAttrs == "any" {
pgxConfig, err := pgx.ParseConfig(getPGSQLConnectionString(false))
if err != nil {
providerLog(logger.LevelError, "error parsing postgres configuration, connection string: %q, error: %v",
getPGSQLConnectionString(true), err)
return err
}
dbHandle.SetConnMaxLifetime(240 * time.Second)
dbHandle.SetConnMaxIdleTime(120 * time.Second)
provider = &PGSQLProvider{dbHandle: dbHandle}
dbHandle = stdlib.OpenDB(*pgxConfig, stdlib.OptionBeforeConnect(stdlib.RandomizeHostOrderFunc))
} else {
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
getPGSQLConnectionString(true), err)
var err error
dbHandle, err = sql.Open("pgx", getPGSQLConnectionString(false))
if err != nil {
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
getPGSQLConnectionString(true), err)
return err
}
}
return err
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
getPGSQLConnectionString(true), config.PoolSize)
dbHandle.SetMaxOpenConns(config.PoolSize)
if config.PoolSize > 0 {
dbHandle.SetMaxIdleConns(config.PoolSize)
} else {
dbHandle.SetMaxIdleConns(2)
}
dbHandle.SetConnMaxLifetime(240 * time.Second)
dbHandle.SetConnMaxIdleTime(120 * time.Second)
provider = &PGSQLProvider{dbHandle: dbHandle}
return nil
}
func getPGSQLHostsAndPorts(configHost string, configPort int) (string, string) {
var hosts, ports []string
defaultPort := strconv.Itoa(configPort)
if defaultPort == "0" {
defaultPort = "5432"
}
for _, hostport := range strings.Split(configHost, ",") {
hostport = strings.TrimSpace(hostport)
if hostport == "" {
continue
}
host, port, err := net.SplitHostPort(hostport)
if err == nil {
hosts = append(hosts, host)
ports = append(ports, port)
} else {
hosts = append(hosts, hostport)
ports = append(ports, defaultPort)
}
}
return strings.Join(hosts, ","), strings.Join(ports, ",")
}
func getPGSQLConnectionString(redactedPwd bool) string {
@@ -261,8 +299,9 @@ func getPGSQLConnectionString(redactedPwd bool) string {
if redactedPwd && password != "" {
password = "[redacted]"
}
connectionString = fmt.Sprintf("host='%s' port=%d dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
config.Host, config.Port, config.Name, config.Username, password, getSSLMode())
host, port := getPGSQLHostsAndPorts(config.Host, config.Port)
connectionString = fmt.Sprintf("host='%s' port='%s' dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
host, port, config.Name, config.Username, password, getSSLMode())
if config.RootCert != "" {
connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert)
}