mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-06 14:20:55 +03:00
move IP/Network lists to the data provider
this is a backward incompatible change, all previous file based IP/network lists will not work anymore Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -34,7 +35,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
sqlDatabaseVersion = 26
|
||||
sqlDatabaseVersion = 27
|
||||
defaultSQLQueryTimeout = 10 * time.Second
|
||||
longSQLQueryTimeout = 60 * time.Second
|
||||
)
|
||||
@@ -79,6 +80,7 @@ func sqlReplaceAll(sql string) string {
|
||||
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
|
||||
sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes)
|
||||
sql = strings.ReplaceAll(sql, "{{roles}}", sqlTableRoles)
|
||||
sql = strings.ReplaceAll(sql, "{{ip_lists}}", sqlTableIPLists)
|
||||
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
|
||||
return sql
|
||||
}
|
||||
@@ -538,6 +540,241 @@ func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) {
|
||||
return getAdminsWithGroups(ctx, admins, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonGetIPListEntry(ipOrNet string, listType IPListType, dbHandle sqlQuerier) (IPListEntry, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getIPListEntryQuery()
|
||||
row := dbHandle.QueryRowContext(ctx, q, listType, ipOrNet)
|
||||
return getIPListEntryFromDbRow(row)
|
||||
}
|
||||
|
||||
func sqlCommonDumpIPListEntries(dbHandle *sql.DB) ([]IPListEntry, error) {
|
||||
count, err := sqlCommonCountIPListEntries(0, dbHandle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if count > ipListMemoryLimit {
|
||||
providerLog(logger.LevelInfo, "IP lists excluded from dump, too many entries: %d", count)
|
||||
return nil, nil
|
||||
}
|
||||
entries := make([]IPListEntry, 0, 100)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDumpListEntriesQuery()
|
||||
|
||||
rows, err := dbHandle.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
entry, err := getIPListEntryFromDbRow(rows)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func sqlCommonCountIPListEntries(listType IPListType, dbHandle *sql.DB) (int64, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var q string
|
||||
var args []any
|
||||
if listType == 0 {
|
||||
q = getCountAllIPListEntriesQuery()
|
||||
} else {
|
||||
q = getCountIPListEntriesQuery()
|
||||
args = append(args, listType)
|
||||
}
|
||||
var count int64
|
||||
err := dbHandle.QueryRowContext(ctx, q, args...).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func sqlCommonGetIPListEntries(listType IPListType, filter, from, order string, limit int, dbHandle sqlQuerier) ([]IPListEntry, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getIPListEntriesQuery(filter, from, order, limit)
|
||||
args := []any{listType}
|
||||
if from != "" {
|
||||
args = append(args, from)
|
||||
}
|
||||
if filter != "" {
|
||||
args = append(args, filter+"%")
|
||||
}
|
||||
if limit > 0 {
|
||||
args = append(args, limit)
|
||||
}
|
||||
entries := make([]IPListEntry, 0, limit)
|
||||
rows, err := dbHandle.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
entry, err := getIPListEntryFromDbRow(rows)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func sqlCommonGetRecentlyUpdatedIPListEntries(after int64, dbHandle sqlQuerier) ([]IPListEntry, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getRecentlyUpdatedIPListQuery()
|
||||
entries := make([]IPListEntry, 0, 5)
|
||||
rows, err := dbHandle.QueryContext(ctx, q, after)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
entry, err := getIPListEntryFromDbRow(rows)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func sqlCommonGetListEntriesForIP(ip string, listType IPListType, dbHandle sqlQuerier) ([]IPListEntry, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
entries := make([]IPListEntry, 0, 2)
|
||||
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
|
||||
rows, err = dbHandle.QueryContext(ctx, getIPListEntriesForIPQueryPg(), listType, ip)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
} else {
|
||||
ipAddr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return entries, fmt.Errorf("invalid ip address %s", ip)
|
||||
}
|
||||
var netType int
|
||||
var ipBytes []byte
|
||||
if ipAddr.Is4() || ipAddr.Is4In6() {
|
||||
netType = ipTypeV4
|
||||
as4 := ipAddr.As4()
|
||||
ipBytes = as4[:]
|
||||
} else {
|
||||
netType = ipTypeV6
|
||||
as16 := ipAddr.As16()
|
||||
ipBytes = as16[:]
|
||||
}
|
||||
rows, err = dbHandle.QueryContext(ctx, getIPListEntriesForIPQueryNoPg(), listType, netType, ipBytes)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
entry, err := getIPListEntryFromDbRow(rows)
|
||||
if err != nil {
|
||||
return entries, err
|
||||
}
|
||||
entries = append(entries, entry)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func sqlCommonAddIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error {
|
||||
if err := entry.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
q := getAddIPListEntryQuery()
|
||||
first := entry.getFirst()
|
||||
last := entry.getLast()
|
||||
var netType int
|
||||
if first.Is4() {
|
||||
netType = ipTypeV4
|
||||
} else {
|
||||
netType = ipTypeV6
|
||||
}
|
||||
if config.IsShared == 1 {
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, getRemoveSoftDeletedIPListEntryQuery(), entry.Type, entry.IPOrNet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
|
||||
_, err = tx.ExecContext(ctx, q, entry.Type, entry.IPOrNet, first.String(), last.String(),
|
||||
netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
} else {
|
||||
_, err = tx.ExecContext(ctx, q, entry.Type, entry.IPOrNet, entry.First, entry.Last,
|
||||
netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
}
|
||||
return err
|
||||
})
|
||||
}
|
||||
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
|
||||
_, err = dbHandle.ExecContext(ctx, q, entry.Type, entry.IPOrNet, first.String(), last.String(),
|
||||
netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
} else {
|
||||
_, err = dbHandle.ExecContext(ctx, q, entry.Type, entry.IPOrNet, entry.First, entry.Last,
|
||||
netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()),
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonUpdateIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error {
|
||||
if err := entry.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getUpdateIPListEntryQuery()
|
||||
_, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description,
|
||||
util.GetTimeAsMsSinceEpoch(time.Now()), entry.Type, entry.IPOrNet)
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonDeleteIPListEntry(entry IPListEntry, softDelete bool, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDeleteIPListEntryQuery(softDelete)
|
||||
var args []any
|
||||
if softDelete {
|
||||
ts := util.GetTimeAsMsSinceEpoch(time.Now())
|
||||
args = append(args, ts, ts)
|
||||
}
|
||||
args = append(args, entry.Type, entry.IPOrNet)
|
||||
res, err := dbHandle.ExecContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonRequireRowAffected(res)
|
||||
}
|
||||
|
||||
func sqlCommonGetRoleByName(name string, dbHandle sqlQuerier) (Role, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
@@ -1872,6 +2109,24 @@ func getEventRuleFromDbRow(row sqlScanner) (EventRule, error) {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func getIPListEntryFromDbRow(row sqlScanner) (IPListEntry, error) {
|
||||
var entry IPListEntry
|
||||
var description sql.NullString
|
||||
|
||||
err := row.Scan(&entry.Type, &entry.IPOrNet, &entry.Mode, &entry.Protocols, &description,
|
||||
&entry.CreatedAt, &entry.UpdatedAt, &entry.DeletedAt)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return entry, util.NewRecordNotFoundError(err.Error())
|
||||
}
|
||||
return entry, err
|
||||
}
|
||||
if description.Valid {
|
||||
entry.Description = description.String
|
||||
}
|
||||
return entry, err
|
||||
}
|
||||
|
||||
func getRoleFromDbRow(row sqlScanner) (Role, error) {
|
||||
var role Role
|
||||
var description sql.NullString
|
||||
|
||||
Reference in New Issue
Block a user