allow to store temporary sessions within the data provider

so we can persist password reset codes, OIDC auth sessions and tokens.
These features will also work in multi-node setups without sicky
sessions now

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-05-19 19:49:51 +02:00
parent a87aa9b98e
commit 796ea1dde9
68 changed files with 1501 additions and 730 deletions

View File

@@ -94,7 +94,7 @@ func buildURLRelativeToBase(paths ...string) string {
}
// GetToken tries to return a JWT token
func GetToken(username, password string) (string, map[string]interface{}, error) {
func GetToken(username, password string) (string, map[string]any, error) {
req, err := http.NewRequest(http.MethodGet, buildURLRelativeToBase(tokenPath), nil)
if err != nil {
return "", nil, err
@@ -110,7 +110,7 @@ func GetToken(username, password string) (string, map[string]interface{}, error)
if err != nil {
return "", nil, err
}
responseHolder := make(map[string]interface{})
responseHolder := make(map[string]any)
err = render.DecodeJSON(resp.Body, &responseHolder)
if err != nil {
return "", nil, err
@@ -985,8 +985,8 @@ func RemoveDefenderHostByIP(ip string, expectedStatusCode int) ([]byte, error) {
}
// GetBanTime returns the ban time for the given IP address
func GetBanTime(ip string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
var response map[string]interface{}
func GetBanTime(ip string, expectedStatusCode int) (map[string]any, []byte, error) {
var response map[string]any
var body []byte
url, err := url.Parse(buildURLRelativeToBase(defenderBanTime))
if err != nil {
@@ -1010,8 +1010,8 @@ func GetBanTime(ip string, expectedStatusCode int) (map[string]interface{}, []by
}
// GetScore returns the score for the given IP address
func GetScore(ip string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
var response map[string]interface{}
func GetScore(ip string, expectedStatusCode int) (map[string]any, []byte, error) {
var response map[string]any
var body []byte
url, err := url.Parse(buildURLRelativeToBase(defenderScore))
if err != nil {
@@ -1050,8 +1050,8 @@ func UnbanIP(ip string, expectedStatusCode int) error {
// Dumpdata requests a backup to outputFile.
// outputFile is relative to the configured backups_path
func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
var response map[string]interface{}
func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int) (map[string]any, []byte, error) {
var response map[string]any
var body []byte
url, err := url.Parse(buildURLRelativeToBase(dumpDataPath))
if err != nil {
@@ -1083,8 +1083,8 @@ func Dumpdata(outputFile, outputData, indent string, expectedStatusCode int) (ma
}
// Loaddata restores a backup.
func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
var response map[string]interface{}
func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[string]any, []byte, error) {
var response map[string]any
var body []byte
url, err := url.Parse(buildURLRelativeToBase(loadDataPath))
if err != nil {
@@ -1114,8 +1114,8 @@ func Loaddata(inputFile, scanQuota, mode string, expectedStatusCode int) (map[st
}
// LoaddataFromPostBody restores a backup
func LoaddataFromPostBody(data []byte, scanQuota, mode string, expectedStatusCode int) (map[string]interface{}, []byte, error) {
var response map[string]interface{}
func LoaddataFromPostBody(data []byte, scanQuota, mode string, expectedStatusCode int) (map[string]any, []byte, error) {
var response map[string]any
var body []byte
url, err := url.Parse(buildURLRelativeToBase(loadDataPath))
if err != nil {
@@ -1265,7 +1265,7 @@ func checkAdmin(expected, actual *dataprovider.Admin) error {
return errors.New("permissions mismatch")
}
for _, p := range expected.Permissions {
if !util.IsStringInSlice(p, actual.Permissions) {
if !util.Contains(actual.Permissions, p) {
return errors.New("permissions content mismatch")
}
}
@@ -1276,7 +1276,7 @@ func checkAdmin(expected, actual *dataprovider.Admin) error {
return errors.New("allow_api_key_auth mismatch")
}
for _, v := range expected.Filters.AllowList {
if !util.IsStringInSlice(v, actual.Filters.AllowList) {
if !util.Contains(actual.Filters.AllowList, v) {
return errors.New("allow list content mismatch")
}
}
@@ -1350,7 +1350,7 @@ func compareUserPermissions(expected map[string][]string, actual map[string][]st
for dir, perms := range expected {
if actualPerms, ok := actual[dir]; ok {
for _, v := range actualPerms {
if !util.IsStringInSlice(v, perms) {
if !util.Contains(perms, v) {
return errors.New("permissions contents mismatch")
}
}
@@ -1530,7 +1530,7 @@ func compareSFTPFsConfig(expected *vfs.Filesystem, actual *vfs.Filesystem) error
return errors.New("SFTPFs fingerprints mismatch")
}
for _, value := range actual.SFTPConfig.Fingerprints {
if !util.IsStringInSlice(value, expected.SFTPConfig.Fingerprints) {
if !util.Contains(expected.SFTPConfig.Fingerprints, value) {
return errors.New("SFTPFs fingerprints mismatch")
}
}
@@ -1621,27 +1621,27 @@ func checkEncryptedSecret(expected, actual *kms.Secret) error {
func compareUserFilterSubStructs(expected sdk.BaseUserFilters, actual sdk.BaseUserFilters) error {
for _, IPMask := range expected.AllowedIP {
if !util.IsStringInSlice(IPMask, actual.AllowedIP) {
if !util.Contains(actual.AllowedIP, IPMask) {
return errors.New("allowed IP contents mismatch")
}
}
for _, IPMask := range expected.DeniedIP {
if !util.IsStringInSlice(IPMask, actual.DeniedIP) {
if !util.Contains(actual.DeniedIP, IPMask) {
return errors.New("denied IP contents mismatch")
}
}
for _, method := range expected.DeniedLoginMethods {
if !util.IsStringInSlice(method, actual.DeniedLoginMethods) {
if !util.Contains(actual.DeniedLoginMethods, method) {
return errors.New("denied login methods contents mismatch")
}
}
for _, protocol := range expected.DeniedProtocols {
if !util.IsStringInSlice(protocol, actual.DeniedProtocols) {
if !util.Contains(actual.DeniedProtocols, protocol) {
return errors.New("denied protocols contents mismatch")
}
}
for _, options := range expected.WebClient {
if !util.IsStringInSlice(options, actual.WebClient) {
if !util.Contains(actual.WebClient, options) {
return errors.New("web client options contents mismatch")
}
}
@@ -1712,7 +1712,7 @@ func checkFilterMatch(expected []string, actual []string) bool {
return false
}
for _, e := range expected {
if !util.IsStringInSlice(strings.ToLower(e), actual) {
if !util.Contains(actual, strings.ToLower(e)) {
return false
}
}
@@ -1734,7 +1734,7 @@ func compareUserDataTransferLimitFilters(expected sdk.BaseUserFilters, actual sd
return errors.New("data transfer limit total_data_transfer mismatch")
}
for _, source := range actual.DataTransferLimits[idx].Sources {
if !util.IsStringInSlice(source, l.Sources) {
if !util.Contains(l.Sources, source) {
return errors.New("data transfer limit source mismatch")
}
}
@@ -1759,7 +1759,7 @@ func compareUserBandwidthLimitFilters(expected sdk.BaseUserFilters, actual sdk.B
return errors.New("bandwidth filters sources mismatch")
}
for _, source := range actual.BandwidthLimits[idx].Sources {
if !util.IsStringInSlice(source, l.Sources) {
if !util.Contains(l.Sources, source) {
return errors.New("bandwidth filters source mismatch")
}
}