mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 23:00:55 +03:00
@@ -19,7 +19,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
sqlDatabaseVersion = 14
|
||||
sqlDatabaseVersion = 15
|
||||
defaultSQLQueryTimeout = 10 * time.Second
|
||||
longSQLQueryTimeout = 60 * time.Second
|
||||
)
|
||||
@@ -971,6 +971,267 @@ func sqlCommonGetUsers(limit int, offset int, order string, dbHandle sqlQuerier)
|
||||
return getUsersWithVirtualFolders(ctx, users, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonGetDefenderHosts(from int64, limit int, dbHandle sqlQuerier) ([]*DefenderEntry, error) {
|
||||
hosts := make([]*DefenderEntry, 0, 100)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderHostsQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
rows, err := stmt.QueryContext(ctx, from, limit)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to get defender hosts: %v", err)
|
||||
return hosts, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var idForScores []int64
|
||||
|
||||
for rows.Next() {
|
||||
var banTime sql.NullInt64
|
||||
host := DefenderEntry{}
|
||||
err = rows.Scan(&host.ID, &host.IP, &banTime)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to scan defender host row: %v", err)
|
||||
return hosts, err
|
||||
}
|
||||
var hostBanTime time.Time
|
||||
if banTime.Valid && banTime.Int64 > 0 {
|
||||
hostBanTime = util.GetTimeFromMsecSinceEpoch(banTime.Int64)
|
||||
}
|
||||
if hostBanTime.IsZero() || hostBanTime.Before(time.Now()) {
|
||||
idForScores = append(idForScores, host.ID)
|
||||
} else {
|
||||
host.BanTime = hostBanTime
|
||||
}
|
||||
hosts = append(hosts, &host)
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to iterate over defender host rows: %v", err)
|
||||
return hosts, err
|
||||
}
|
||||
|
||||
return getDefenderHostsWithScores(ctx, hosts, from, idForScores, dbHandle)
|
||||
}
|
||||
|
||||
func sqlCommonIsDefenderHostBanned(ip string, dbHandle sqlQuerier) (*DefenderEntry, error) {
|
||||
var host DefenderEntry
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderIsHostBannedQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := stmt.QueryRowContext(ctx, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
err = row.Scan(&host.ID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, util.NewRecordNotFoundError("host not found")
|
||||
}
|
||||
providerLog(logger.LevelError, "unable to check ban status for host %#v: %v", ip, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &host, nil
|
||||
}
|
||||
|
||||
func sqlCommonGetDefenderHostByIP(ip string, from int64, dbHandle sqlQuerier) (*DefenderEntry, error) {
|
||||
var host DefenderEntry
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderHostQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
row := stmt.QueryRowContext(ctx, ip, from)
|
||||
var banTime sql.NullInt64
|
||||
err = row.Scan(&host.ID, &host.IP, &banTime)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, util.NewRecordNotFoundError("host not found")
|
||||
}
|
||||
providerLog(logger.LevelError, "unable to get host for ip %#v: %v", ip, err)
|
||||
return nil, err
|
||||
}
|
||||
if banTime.Valid && banTime.Int64 > 0 {
|
||||
hostBanTime := util.GetTimeFromMsecSinceEpoch(banTime.Int64)
|
||||
if !hostBanTime.IsZero() && hostBanTime.After(time.Now()) {
|
||||
host.BanTime = hostBanTime
|
||||
return &host, nil
|
||||
}
|
||||
}
|
||||
|
||||
hosts, err := getDefenderHostsWithScores(ctx, []*DefenderEntry{&host}, from, []int64{host.ID}, dbHandle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(hosts) == 0 {
|
||||
return nil, util.NewRecordNotFoundError("host not found")
|
||||
}
|
||||
|
||||
return hosts[0], nil
|
||||
}
|
||||
|
||||
func sqlCommonDefenderIncrementBanTime(ip string, minutesToAdd int, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDefenderIncrementBanTimeQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, minutesToAdd*60000, ip)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "ban time updated for ip %#v, increment (minutes): %v",
|
||||
ip, minutesToAdd)
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error updating ban time for ip %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonSetDefenderBanTime(ip string, banTime int64, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDefenderSetBanTimeQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, banTime, ip)
|
||||
if err == nil {
|
||||
providerLog(logger.LevelDebug, "ip %#v banned until %v", ip, util.GetTimeFromMsecSinceEpoch(banTime))
|
||||
} else {
|
||||
providerLog(logger.LevelError, "error setting ban time for ip %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonDeleteDefenderHost(ip string, dbHandle sqlQuerier) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
q := getDeleteDefenderHostQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, ip)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to delete defender host %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonAddDefenderHostAndEvent(ip string, score int, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
|
||||
if err := sqlCommonAddDefenderHost(ctx, ip, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonAddDefenderEvent(ctx, ip, score, tx)
|
||||
})
|
||||
}
|
||||
|
||||
func sqlCommonDefenderCleanup(from int64, dbHandler *sql.DB) error {
|
||||
if err := sqlCommonCleanupDefenderEvents(from, dbHandler); err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlCommonCleanupDefenderHosts(from, dbHandler)
|
||||
}
|
||||
|
||||
func sqlCommonAddDefenderHost(ctx context.Context, ip string, tx *sql.Tx) error {
|
||||
q := getAddDefenderHostQuery()
|
||||
stmt, err := tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, ip, util.GetTimeAsMsSinceEpoch(time.Now()))
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to add defender host %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonAddDefenderEvent(ctx context.Context, ip string, score int, tx *sql.Tx) error {
|
||||
q := getAddDefenderEventQuery()
|
||||
stmt, err := tx.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), score, ip)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to add defender event for %#v: %v", ip, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonCleanupDefenderHosts(from int64, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderHostsCleanupQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, util.GetTimeAsMsSinceEpoch(time.Now()), from)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to cleanup defender hosts: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func sqlCommonCleanupDefenderEvents(from int64, dbHandle *sql.DB) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
|
||||
defer cancel()
|
||||
|
||||
q := getDefenderEventsCleanupQuery()
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
_, err = stmt.ExecContext(ctx, from)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to cleanup defender events: %v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func getShareFromDbRow(row sqlScanner) (Share, error) {
|
||||
var share Share
|
||||
var description, password, allowFrom, paths sql.NullString
|
||||
@@ -1449,6 +1710,61 @@ func getUserWithVirtualFolders(ctx context.Context, user User, dbHandle sqlQueri
|
||||
return users[0], err
|
||||
}
|
||||
|
||||
func getDefenderHostsWithScores(ctx context.Context, hosts []*DefenderEntry, from int64, idForScores []int64,
|
||||
dbHandle sqlQuerier) (
|
||||
[]*DefenderEntry,
|
||||
error,
|
||||
) {
|
||||
if len(idForScores) == 0 {
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
hostsWithScores := make(map[int64]int)
|
||||
q := getDefenderEventsQuery(idForScores)
|
||||
stmt, err := dbHandle.PrepareContext(ctx, q)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error preparing database query %#v: %v", q, err)
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, from)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "unable to get score for hosts with id %+v: %v", idForScores, err)
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var hostID int64
|
||||
var score int
|
||||
err = rows.Scan(&hostID, &score)
|
||||
if err != nil {
|
||||
providerLog(logger.LevelError, "error scanning host score row: %v", err)
|
||||
return hosts, err
|
||||
}
|
||||
if score > 0 {
|
||||
hostsWithScores[hostID] = score
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return hosts, err
|
||||
}
|
||||
|
||||
result := make([]*DefenderEntry, 0, len(hosts))
|
||||
|
||||
for idx := range hosts {
|
||||
hosts[idx].Score = hostsWithScores[hosts[idx].ID]
|
||||
if hosts[idx].Score > 0 || !hosts[idx].BanTime.IsZero() {
|
||||
result = append(result, hosts[idx])
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func getUsersWithVirtualFolders(ctx context.Context, users []User, dbHandle sqlQuerier) ([]User, error) {
|
||||
if len(users) == 0 {
|
||||
return users, nil
|
||||
|
||||
Reference in New Issue
Block a user