add support for different bandwidth limits based on client IP

This commit is contained in:
Nicola Murino
2021-12-10 18:43:26 +01:00
parent c153330ab8
commit 0bb141960f
18 changed files with 575 additions and 56 deletions

View File

@@ -732,6 +732,41 @@ func getUserPermissionsFromPostFields(r *http.Request) map[string][]string {
return permissions
}
func getBandwidthLimitsFromPostFields(r *http.Request) ([]sdk.BandwidthLimit, error) {
var result []sdk.BandwidthLimit
for k := range r.Form {
if strings.HasPrefix(k, "bandwidth_limit_sources") {
sources := getSliceFromDelimitedValues(r.Form.Get(k), ",")
if len(sources) > 0 {
bwLimit := sdk.BandwidthLimit{
Sources: sources,
}
idx := strings.TrimPrefix(k, "bandwidth_limit_sources")
ul := r.Form.Get(fmt.Sprintf("upload_bandwidth_source%v", idx))
dl := r.Form.Get(fmt.Sprintf("download_bandwidth_source%v", idx))
if ul != "" {
bandwidthUL, err := strconv.ParseInt(ul, 10, 64)
if err != nil {
return result, fmt.Errorf("invalid upload_bandwidth_source%v %#v: %w", idx, ul, err)
}
bwLimit.UploadBandwidth = bandwidthUL
}
if dl != "" {
bandwidthDL, err := strconv.ParseInt(dl, 10, 64)
if err != nil {
return result, fmt.Errorf("invalid download_bandwidth_source%v %#v: %w", idx, ul, err)
}
bwLimit.DownloadBandwidth = bandwidthDL
}
result = append(result, bwLimit)
}
}
}
return result, nil
}
func getFilePatternsFromPostField(r *http.Request) []sdk.PatternsFilter {
var result []sdk.PatternsFilter
@@ -786,8 +821,13 @@ func getFilePatternsFromPostField(r *http.Request) []sdk.PatternsFilter {
return result
}
func getFiltersFromUserPostFields(r *http.Request) sdk.UserFilters {
func getFiltersFromUserPostFields(r *http.Request) (sdk.UserFilters, error) {
var filters sdk.UserFilters
bwLimits, err := getBandwidthLimitsFromPostFields(r)
if err != nil {
return filters, err
}
filters.BandwidthLimits = bwLimits
filters.AllowedIP = getSliceFromDelimitedValues(r.Form.Get("allowed_ip"), ",")
filters.DeniedIP = getSliceFromDelimitedValues(r.Form.Get("denied_ip"), ",")
filters.DeniedLoginMethods = r.Form["ssh_login_methods"]
@@ -807,7 +847,7 @@ func getFiltersFromUserPostFields(r *http.Request) sdk.UserFilters {
}
filters.DisableFsChecks = len(r.Form.Get("disable_fs_checks")) > 0
filters.AllowAPIKeyAuth = len(r.Form.Get("allow_api_key_auth")) > 0
return filters
return filters, nil
}
func getSecretFromFormField(r *http.Request, field string) *kms.Secret {
@@ -1143,6 +1183,10 @@ func getUserFromPostFields(r *http.Request) (dataprovider.User, error) {
if err != nil {
return user, err
}
filters, err := getFiltersFromUserPostFields(r)
if err != nil {
return user, err
}
user = dataprovider.User{
BaseUser: sdk.BaseUser{
Username: r.Form.Get("username"),
@@ -1160,7 +1204,7 @@ func getUserFromPostFields(r *http.Request) (dataprovider.User, error) {
DownloadBandwidth: bandwidthDL,
Status: status,
ExpirationDate: expirationDateMillis,
Filters: getFiltersFromUserPostFields(r),
Filters: filters,
AdditionalInfo: r.Form.Get("additional_info"),
Description: r.Form.Get("description"),
},