move IP/Network lists to the data provider

this is a backward incompatible change, all previous file based IP/network
lists will not work anymore

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2023-02-09 09:33:33 +01:00
parent 2412a0a369
commit 1b1745b7f7
103 changed files with 4958 additions and 1284 deletions

View File

@@ -21,6 +21,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"strings"
"time"
@@ -34,7 +35,7 @@ import (
)
const (
sqlDatabaseVersion = 26
sqlDatabaseVersion = 27
defaultSQLQueryTimeout = 10 * time.Second
longSQLQueryTimeout = 60 * time.Second
)
@@ -79,6 +80,7 @@ func sqlReplaceAll(sql string) string {
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
sql = strings.ReplaceAll(sql, "{{nodes}}", sqlTableNodes)
sql = strings.ReplaceAll(sql, "{{roles}}", sqlTableRoles)
sql = strings.ReplaceAll(sql, "{{ip_lists}}", sqlTableIPLists)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sql
}
@@ -538,6 +540,241 @@ func sqlCommonDumpAdmins(dbHandle sqlQuerier) ([]Admin, error) {
return getAdminsWithGroups(ctx, admins, dbHandle)
}
func sqlCommonGetIPListEntry(ipOrNet string, listType IPListType, dbHandle sqlQuerier) (IPListEntry, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getIPListEntryQuery()
row := dbHandle.QueryRowContext(ctx, q, listType, ipOrNet)
return getIPListEntryFromDbRow(row)
}
func sqlCommonDumpIPListEntries(dbHandle *sql.DB) ([]IPListEntry, error) {
count, err := sqlCommonCountIPListEntries(0, dbHandle)
if err != nil {
return nil, err
}
if count > ipListMemoryLimit {
providerLog(logger.LevelInfo, "IP lists excluded from dump, too many entries: %d", count)
return nil, nil
}
entries := make([]IPListEntry, 0, 100)
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
q := getDumpListEntriesQuery()
rows, err := dbHandle.QueryContext(ctx, q)
if err != nil {
return entries, err
}
defer rows.Close()
for rows.Next() {
entry, err := getIPListEntryFromDbRow(rows)
if err != nil {
return entries, err
}
entries = append(entries, entry)
}
return entries, rows.Err()
}
func sqlCommonCountIPListEntries(listType IPListType, dbHandle *sql.DB) (int64, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
var q string
var args []any
if listType == 0 {
q = getCountAllIPListEntriesQuery()
} else {
q = getCountIPListEntriesQuery()
args = append(args, listType)
}
var count int64
err := dbHandle.QueryRowContext(ctx, q, args...).Scan(&count)
return count, err
}
func sqlCommonGetIPListEntries(listType IPListType, filter, from, order string, limit int, dbHandle sqlQuerier) ([]IPListEntry, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getIPListEntriesQuery(filter, from, order, limit)
args := []any{listType}
if from != "" {
args = append(args, from)
}
if filter != "" {
args = append(args, filter+"%")
}
if limit > 0 {
args = append(args, limit)
}
entries := make([]IPListEntry, 0, limit)
rows, err := dbHandle.QueryContext(ctx, q, args...)
if err != nil {
return entries, err
}
defer rows.Close()
for rows.Next() {
entry, err := getIPListEntryFromDbRow(rows)
if err != nil {
return entries, err
}
entries = append(entries, entry)
}
return entries, rows.Err()
}
func sqlCommonGetRecentlyUpdatedIPListEntries(after int64, dbHandle sqlQuerier) ([]IPListEntry, error) {
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
q := getRecentlyUpdatedIPListQuery()
entries := make([]IPListEntry, 0, 5)
rows, err := dbHandle.QueryContext(ctx, q, after)
if err != nil {
return entries, err
}
defer rows.Close()
for rows.Next() {
entry, err := getIPListEntryFromDbRow(rows)
if err != nil {
return entries, err
}
entries = append(entries, entry)
}
return entries, rows.Err()
}
func sqlCommonGetListEntriesForIP(ip string, listType IPListType, dbHandle sqlQuerier) ([]IPListEntry, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
var rows *sql.Rows
var err error
entries := make([]IPListEntry, 0, 2)
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
rows, err = dbHandle.QueryContext(ctx, getIPListEntriesForIPQueryPg(), listType, ip)
if err != nil {
return entries, err
}
} else {
ipAddr, err := netip.ParseAddr(ip)
if err != nil {
return entries, fmt.Errorf("invalid ip address %s", ip)
}
var netType int
var ipBytes []byte
if ipAddr.Is4() || ipAddr.Is4In6() {
netType = ipTypeV4
as4 := ipAddr.As4()
ipBytes = as4[:]
} else {
netType = ipTypeV6
as16 := ipAddr.As16()
ipBytes = as16[:]
}
rows, err = dbHandle.QueryContext(ctx, getIPListEntriesForIPQueryNoPg(), listType, netType, ipBytes)
if err != nil {
return entries, err
}
}
defer rows.Close()
for rows.Next() {
entry, err := getIPListEntryFromDbRow(rows)
if err != nil {
return entries, err
}
entries = append(entries, entry)
}
return entries, rows.Err()
}
func sqlCommonAddIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error {
if err := entry.validate(); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
var err error
q := getAddIPListEntryQuery()
first := entry.getFirst()
last := entry.getLast()
var netType int
if first.Is4() {
netType = ipTypeV4
} else {
netType = ipTypeV6
}
if config.IsShared == 1 {
return sqlCommonExecuteTx(ctx, dbHandle, func(tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, getRemoveSoftDeletedIPListEntryQuery(), entry.Type, entry.IPOrNet)
if err != nil {
return err
}
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
_, err = tx.ExecContext(ctx, q, entry.Type, entry.IPOrNet, first.String(), last.String(),
netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()))
} else {
_, err = tx.ExecContext(ctx, q, entry.Type, entry.IPOrNet, entry.First, entry.Last,
netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()))
}
return err
})
}
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
_, err = dbHandle.ExecContext(ctx, q, entry.Type, entry.IPOrNet, first.String(), last.String(),
netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()))
} else {
_, err = dbHandle.ExecContext(ctx, q, entry.Type, entry.IPOrNet, entry.First, entry.Last,
netType, entry.Protocols, entry.Description, entry.Mode, util.GetTimeAsMsSinceEpoch(time.Now()),
util.GetTimeAsMsSinceEpoch(time.Now()))
}
return err
}
func sqlCommonUpdateIPListEntry(entry *IPListEntry, dbHandle *sql.DB) error {
if err := entry.validate(); err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getUpdateIPListEntryQuery()
_, err := dbHandle.ExecContext(ctx, q, entry.Mode, entry.Protocols, entry.Description,
util.GetTimeAsMsSinceEpoch(time.Now()), entry.Type, entry.IPOrNet)
return err
}
func sqlCommonDeleteIPListEntry(entry IPListEntry, softDelete bool, dbHandle *sql.DB) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
q := getDeleteIPListEntryQuery(softDelete)
var args []any
if softDelete {
ts := util.GetTimeAsMsSinceEpoch(time.Now())
args = append(args, ts, ts)
}
args = append(args, entry.Type, entry.IPOrNet)
res, err := dbHandle.ExecContext(ctx, q, args...)
if err != nil {
return err
}
return sqlCommonRequireRowAffected(res)
}
func sqlCommonGetRoleByName(name string, dbHandle sqlQuerier) (Role, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultSQLQueryTimeout)
defer cancel()
@@ -1872,6 +2109,24 @@ func getEventRuleFromDbRow(row sqlScanner) (EventRule, error) {
return rule, nil
}
func getIPListEntryFromDbRow(row sqlScanner) (IPListEntry, error) {
var entry IPListEntry
var description sql.NullString
err := row.Scan(&entry.Type, &entry.IPOrNet, &entry.Mode, &entry.Protocols, &description,
&entry.CreatedAt, &entry.UpdatedAt, &entry.DeletedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return entry, util.NewRecordNotFoundError(err.Error())
}
return entry, err
}
if description.Valid {
entry.Description = description.String
}
return entry, err
}
func getRoleFromDbRow(row sqlScanner) (Role, error) {
var role Role
var description sql.NullString