mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 14:50:55 +03:00
postgres provider: add support for load balancing
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
@@ -23,11 +23,13 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
// we import pgx here to be able to disable PostgreSQL support using a build tag
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
|
||||
"github.com/drakkan/sftpgo/v2/internal/logger"
|
||||
"github.com/drakkan/sftpgo/v2/internal/version"
|
||||
@@ -233,25 +235,61 @@ func init() {
|
||||
}
|
||||
|
||||
func initializePGSQLProvider() error {
|
||||
var err error
|
||||
dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false))
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
|
||||
getPGSQLConnectionString(true), config.PoolSize)
|
||||
dbHandle.SetMaxOpenConns(config.PoolSize)
|
||||
if config.PoolSize > 0 {
|
||||
dbHandle.SetMaxIdleConns(config.PoolSize)
|
||||
} else {
|
||||
dbHandle.SetMaxIdleConns(2)
|
||||
var dbHandle *sql.DB
|
||||
if config.TargetSessionAttrs == "any" {
|
||||
pgxConfig, err := pgx.ParseConfig(getPGSQLConnectionString(false))
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error parsing postgres configuration, connection string: %q, error: %v",
|
||||
getPGSQLConnectionString(true), err)
|
||||
return err
|
||||
}
|
||||
dbHandle.SetConnMaxLifetime(240 * time.Second)
|
||||
dbHandle.SetConnMaxIdleTime(120 * time.Second)
|
||||
provider = &PGSQLProvider{dbHandle: dbHandle}
|
||||
dbHandle = stdlib.OpenDB(*pgxConfig, stdlib.OptionBeforeConnect(stdlib.RandomizeHostOrderFunc))
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
|
||||
getPGSQLConnectionString(true), err)
|
||||
var err error
|
||||
dbHandle, err = sql.Open("pgx", getPGSQLConnectionString(false))
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %q, error: %v",
|
||||
getPGSQLConnectionString(true), err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %q, pool size: %d",
|
||||
getPGSQLConnectionString(true), config.PoolSize)
|
||||
dbHandle.SetMaxOpenConns(config.PoolSize)
|
||||
if config.PoolSize > 0 {
|
||||
dbHandle.SetMaxIdleConns(config.PoolSize)
|
||||
} else {
|
||||
dbHandle.SetMaxIdleConns(2)
|
||||
}
|
||||
dbHandle.SetConnMaxLifetime(240 * time.Second)
|
||||
dbHandle.SetConnMaxIdleTime(120 * time.Second)
|
||||
provider = &PGSQLProvider{dbHandle: dbHandle}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPGSQLHostsAndPorts(configHost string, configPort int) (string, string) {
|
||||
var hosts, ports []string
|
||||
defaultPort := strconv.Itoa(configPort)
|
||||
if defaultPort == "0" {
|
||||
defaultPort = "5432"
|
||||
}
|
||||
|
||||
for _, hostport := range strings.Split(configHost, ",") {
|
||||
hostport = strings.TrimSpace(hostport)
|
||||
if hostport == "" {
|
||||
continue
|
||||
}
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err == nil {
|
||||
hosts = append(hosts, host)
|
||||
ports = append(ports, port)
|
||||
} else {
|
||||
hosts = append(hosts, hostport)
|
||||
ports = append(ports, defaultPort)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(hosts, ","), strings.Join(ports, ",")
|
||||
}
|
||||
|
||||
func getPGSQLConnectionString(redactedPwd bool) string {
|
||||
@@ -261,8 +299,9 @@ func getPGSQLConnectionString(redactedPwd bool) string {
|
||||
if redactedPwd && password != "" {
|
||||
password = "[redacted]"
|
||||
}
|
||||
connectionString = fmt.Sprintf("host='%s' port=%d dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
|
||||
config.Host, config.Port, config.Name, config.Username, password, getSSLMode())
|
||||
host, port := getPGSQLHostsAndPorts(config.Host, config.Port)
|
||||
connectionString = fmt.Sprintf("host='%s' port='%s' dbname='%s' user='%s' password='%s' sslmode=%s connect_timeout=10",
|
||||
host, port, config.Name, config.Username, password, getSSLMode())
|
||||
if config.RootCert != "" {
|
||||
connectionString += fmt.Sprintf(" sslrootcert='%s'", config.RootCert)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user