mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 23:00:55 +03:00
sftpd: refactor multi-step authentication
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
@@ -271,7 +271,7 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
|
||||
MaxAuthTries: c.MaxAuthTries,
|
||||
PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
sp, err := c.validatePublicKeyCredentials(conn, pubKey)
|
||||
if err == ssh.ErrPartialSuccess {
|
||||
if errors.Is(err, &ssh.PartialSuccessError{}) {
|
||||
return sp, err
|
||||
}
|
||||
if err != nil {
|
||||
@@ -281,26 +281,12 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
|
||||
|
||||
return sp, nil
|
||||
},
|
||||
NextAuthMethodsCallback: func(conn ssh.ConnMetadata) []string {
|
||||
var nextMethods []string
|
||||
user, err := dataprovider.GetUserWithGroupSettings(conn.User(), "")
|
||||
if err == nil {
|
||||
nextMethods = user.GetNextAuthMethods(conn.PartialSuccessMethods(), c.PasswordAuthentication)
|
||||
}
|
||||
return nextMethods
|
||||
},
|
||||
ServerVersion: fmt.Sprintf("SSH-2.0-%s", c.Banner),
|
||||
}
|
||||
|
||||
if c.PasswordAuthentication {
|
||||
serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
|
||||
sp, err := c.validatePasswordCredentials(conn, pass)
|
||||
if err != nil {
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err),
|
||||
dataprovider.SSHLoginMethodPassword)
|
||||
}
|
||||
|
||||
return sp, nil
|
||||
serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
return c.validatePasswordCredentials(conn, password, dataprovider.LoginMethodPassword)
|
||||
}
|
||||
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.LoginMethodPassword)
|
||||
}
|
||||
@@ -544,6 +530,7 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
|
||||
if c.KeyboardInteractiveHook != "" {
|
||||
if !strings.HasPrefix(c.KeyboardInteractiveHook, "http") {
|
||||
if !filepath.IsAbs(c.KeyboardInteractiveHook) {
|
||||
c.KeyboardInteractiveAuthentication = false
|
||||
logger.WarnToConsole("invalid keyboard interactive authentication program: %q must be an absolute path",
|
||||
c.KeyboardInteractiveHook)
|
||||
logger.Warn(logSender, "", "invalid keyboard interactive authentication program: %q must be an absolute path",
|
||||
@@ -552,6 +539,7 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
|
||||
}
|
||||
_, err := os.Stat(c.KeyboardInteractiveHook)
|
||||
if err != nil {
|
||||
c.KeyboardInteractiveAuthentication = false
|
||||
logger.WarnToConsole("invalid keyboard interactive authentication program:: %v", err)
|
||||
logger.Warn(logSender, "", "invalid keyboard interactive authentication program:: %v", err)
|
||||
return
|
||||
@@ -559,13 +547,7 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
|
||||
}
|
||||
}
|
||||
serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
|
||||
if err != nil {
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err),
|
||||
dataprovider.SSHLoginMethodKeyboardInteractive)
|
||||
}
|
||||
|
||||
return sp, nil
|
||||
return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyboardInteractive)
|
||||
}
|
||||
|
||||
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive)
|
||||
@@ -817,7 +799,7 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
|
||||
return nil, fmt.Errorf("too many open sessions: %v", activeSessions)
|
||||
}
|
||||
}
|
||||
if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolSSH, conn.PartialSuccessMethods()) {
|
||||
if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolSSH) {
|
||||
logger.Info(logSender, connectionID, "cannot login user %q, login method %q is not allowed",
|
||||
user.Username, loginMethod)
|
||||
return nil, fmt.Errorf("login method %q is not allowed for user %q", loginMethod, user.Username)
|
||||
@@ -1130,6 +1112,21 @@ func (c *Configuration) initializeCertChecker(configDir string) error {
|
||||
return revokedCertManager.load()
|
||||
}
|
||||
|
||||
func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error {
|
||||
err := &ssh.PartialSuccessError{}
|
||||
if c.PasswordAuthentication && util.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) {
|
||||
err.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
return c.validatePasswordCredentials(conn, password, dataprovider.SSHLoginMethodKeyAndPassword)
|
||||
}
|
||||
}
|
||||
if c.KeyboardInteractiveAuthentication && util.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) {
|
||||
err.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
var err error
|
||||
var user dataprovider.User
|
||||
@@ -1180,9 +1177,9 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
|
||||
keyID = fmt.Sprintf("%s: ID: %s, serial: %v, CA %s %s", certFingerprint,
|
||||
cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
|
||||
}
|
||||
if user.IsPartialAuth(method) {
|
||||
if user.IsPartialAuth() {
|
||||
logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User())
|
||||
return certPerm, ssh.ErrPartialSuccess
|
||||
return certPerm, c.getPartialSuccessError(user.GetNextAuthMethods())
|
||||
}
|
||||
sshPerm, err = loginUser(&user, method, keyID, conn)
|
||||
if err == nil && certPerm != nil {
|
||||
@@ -1201,33 +1198,30 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
|
||||
return sshPerm, err
|
||||
}
|
||||
|
||||
func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
|
||||
func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte, method string) (*ssh.Permissions, error) {
|
||||
var err error
|
||||
var user dataprovider.User
|
||||
var sshPerm *ssh.Permissions
|
||||
|
||||
method := dataprovider.LoginMethodPassword
|
||||
if len(conn.PartialSuccessMethods()) == 1 {
|
||||
method = dataprovider.SSHLoginMethodKeyAndPassword
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
||||
if user, err = dataprovider.CheckUserAndPass(conn.User(), string(pass), ipAddr, common.ProtocolSSH); err == nil {
|
||||
sshPerm, err = loginUser(&user, method, "", conn)
|
||||
}
|
||||
user.Username = conn.User()
|
||||
updateLoginMetrics(&user, ipAddr, method, err)
|
||||
return sshPerm, err
|
||||
if err != nil {
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err), method)
|
||||
}
|
||||
return sshPerm, nil
|
||||
}
|
||||
|
||||
func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
|
||||
func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge,
|
||||
method string,
|
||||
) (*ssh.Permissions, error) {
|
||||
var err error
|
||||
var user dataprovider.User
|
||||
var sshPerm *ssh.Permissions
|
||||
|
||||
method := dataprovider.SSHLoginMethodKeyboardInteractive
|
||||
if len(conn.PartialSuccessMethods()) == 1 {
|
||||
method = dataprovider.SSHLoginMethodKeyAndKeyboardInt
|
||||
}
|
||||
ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String())
|
||||
if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client,
|
||||
ipAddr, common.ProtocolSSH); err == nil {
|
||||
@@ -1235,7 +1229,10 @@ func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMeta
|
||||
}
|
||||
user.Username = conn.User()
|
||||
updateLoginMetrics(&user, ipAddr, method, err)
|
||||
return sshPerm, err
|
||||
if err != nil {
|
||||
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err), method)
|
||||
}
|
||||
return sshPerm, nil
|
||||
}
|
||||
|
||||
func updateLoginMetrics(user *dataprovider.User, ip, method string, err error) {
|
||||
|
||||
Reference in New Issue
Block a user