oidc: allow login if the password method is disabled

isLoggedInWithOIDC returns false before login so we need to add
a specific check

Fixes #1879

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2025-03-29 20:28:17 +01:00
parent cf573fc743
commit d95d773570
7 changed files with 18 additions and 17 deletions

View File

@@ -49,7 +49,7 @@ func getUserConnection(w http.ResponseWriter, r *http.Request) (*Connection, err
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil {
sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return nil, err
}

View File

@@ -732,7 +732,7 @@ func updateLoginMetrics(user *dataprovider.User, loginMethod, ip string, err err
dataprovider.ExecutePostLoginHook(user, loginMethod, ip, protocol, err)
}
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions bool) error {
func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID string, checkSessions, isOIDCLogin bool) error {
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolHTTP) {
logger.Info(logSender, connectionID, "cannot login user %q, protocol HTTP is not allowed", user.Username)
return util.NewI18nError(
@@ -740,7 +740,7 @@ func checkHTTPClientUser(user *dataprovider.User, r *http.Request, connectionID
util.I18nErrorProtocolForbidden,
)
}
if !isLoggedInWithOIDC(r) && !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) {
if !isLoggedInWithOIDC(r) && !isOIDCLogin && !user.IsLoginMethodAllowed(dataprovider.LoginMethodPassword, common.ProtocolHTTP) {
logger.Info(logSender, connectionID, "cannot login user %q, password login method is not allowed", user.Username)
return util.NewI18nError(
fmt.Errorf("login method password is not allowed for user %q", user.Username),
@@ -784,7 +784,7 @@ func getActiveUser(username string, r *http.Request) (dataprovider.User, error)
if err := user.CheckLoginConditions(); err != nil {
return user, util.NewRecordNotFoundError(fmt.Sprintf("user %q cannot login: %v", username, err))
}
if err := checkHTTPClientUser(&user, r, xid.New().String(), false); err != nil {
if err := checkHTTPClientUser(&user, r, xid.New().String(), false, false); err != nil {
return user, util.NewRecordNotFoundError(fmt.Sprintf("user %q cannot login: %v", username, err))
}
return user, nil

View File

@@ -551,7 +551,7 @@ func authenticateUserWithAPIKey(username, keyID string, tokenAuth *jwtauth.JWTAu
return err
}
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r)
return err
}

View File

@@ -396,7 +396,7 @@ func (t *oidcToken) refreshUser(r *http.Request) error {
if err := user.CheckLoginConditions(); err != nil {
return err
}
if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil {
if err := checkHTTPClientUser(&user, r, xid.New().String(), true, false); err != nil {
return err
}
t.Permissions = user.Filters.WebClient
@@ -460,7 +460,7 @@ func (t *oidcToken) getUser(r *http.Request) error {
return err
}
connectionID := fmt.Sprintf("%s_%s", common.ProtocolOIDC, xid.New().String())
if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
if err := checkHTTPClientUser(user, r, connectionID, true, true); err != nil {
updateLoginMetrics(user, dataprovider.LoginMethodIDP, ipAddr, err, r)
return err
}

View File

@@ -906,7 +906,8 @@ func TestOIDCToken(t *testing.T) {
},
Filters: dataprovider.UserFilters{
BaseUserFilters: sdk.BaseUserFilters{
DeniedProtocols: []string{common.ProtocolHTTP},
DeniedProtocols: []string{common.ProtocolHTTP},
DeniedLoginMethods: []string{dataprovider.LoginMethodPassword},
},
},
}

View File

@@ -273,7 +273,7 @@ func (s *httpdServer) handleWebClientLoginPost(w http.ResponseWriter, r *http.Re
return
}
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r)
s.renderClientLoginPage(w, r, util.NewI18nError(err, util.I18nError403Message))
return
@@ -312,7 +312,7 @@ func (s *httpdServer) handleWebClientPasswordResetPost(w http.ResponseWriter, r
return
}
connectionID := fmt.Sprintf("%v_%v", getProtocolFromRequest(r), xid.New().String())
if err := checkHTTPClientUser(user, r, connectionID, true); err != nil {
if err := checkHTTPClientUser(user, r, connectionID, true, false); err != nil {
s.renderClientResetPwdPage(w, r, util.NewI18nError(err, util.I18nErrorLoginAfterReset))
return
}
@@ -862,7 +862,7 @@ func (s *httpdServer) getUserToken(w http.ResponseWriter, r *http.Request) {
return
}
connectionID := fmt.Sprintf("%v_%v", protocol, xid.New().String())
if err := checkHTTPClientUser(&user, r, connectionID, true); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, true, false); err != nil {
updateLoginMetrics(&user, dataprovider.LoginMethodPassword, ipAddr, err, r)
sendAPIResponse(w, r, err, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
@@ -1039,7 +1039,7 @@ func (s *httpdServer) refreshClientToken(w http.ResponseWriter, r *http.Request,
logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err)
return
}
if err := checkHTTPClientUser(&user, r, xid.New().String(), true); err != nil {
if err := checkHTTPClientUser(&user, r, xid.New().String(), true, false); err != nil {
logger.Debug(logSender, "", "unable to refresh cookie for user %q: %v", user.Username, err)
return
}

View File

@@ -902,7 +902,7 @@ func (s *httpdServer) handleWebClientDownloadZip(w http.ResponseWriter, r *http.
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil {
s.renderClientForbiddenPage(w, r, err)
return
}
@@ -1192,7 +1192,7 @@ func (s *httpdServer) handleClientGetDirContents(w http.ResponseWriter, r *http.
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%s_%s", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil {
sendAPIResponse(w, r, err, getI18NErrorString(err, util.I18nErrorDirList403), http.StatusForbidden)
return
}
@@ -1281,7 +1281,7 @@ func (s *httpdServer) handleClientGetFiles(w http.ResponseWriter, r *http.Reques
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil {
s.renderClientForbiddenPage(w, r, err)
return
}
@@ -1342,7 +1342,7 @@ func (s *httpdServer) handleClientEditFile(w http.ResponseWriter, r *http.Reques
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil {
s.renderClientForbiddenPage(w, r, err)
return
}
@@ -1838,7 +1838,7 @@ func (s *httpdServer) handleClientGetPDF(w http.ResponseWriter, r *http.Request)
connID := xid.New().String()
protocol := getProtocolFromRequest(r)
connectionID := fmt.Sprintf("%v_%v", protocol, connID)
if err := checkHTTPClientUser(&user, r, connectionID, false); err != nil {
if err := checkHTTPClientUser(&user, r, connectionID, false, false); err != nil {
s.renderClientForbiddenPage(w, r, err)
return
}