sftpd: refactor multi-step authentication

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2023-08-04 20:56:23 +02:00
parent c03bcb3a8a
commit af0d7b48ad
9 changed files with 195 additions and 231 deletions

View File

@@ -271,7 +271,7 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
MaxAuthTries: c.MaxAuthTries,
PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
sp, err := c.validatePublicKeyCredentials(conn, pubKey)
if err == ssh.ErrPartialSuccess {
if errors.Is(err, &ssh.PartialSuccessError{}) {
return sp, err
}
if err != nil {
@@ -281,26 +281,12 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
return sp, nil
},
NextAuthMethodsCallback: func(conn ssh.ConnMetadata) []string {
var nextMethods []string
user, err := dataprovider.GetUserWithGroupSettings(conn.User(), "")
if err == nil {
nextMethods = user.GetNextAuthMethods(conn.PartialSuccessMethods(), c.PasswordAuthentication)
}
return nextMethods
},
ServerVersion: fmt.Sprintf("SSH-2.0-%s", c.Banner),
}
if c.PasswordAuthentication {
serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
sp, err := c.validatePasswordCredentials(conn, pass)
if err != nil {
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err),
dataprovider.SSHLoginMethodPassword)
}
return sp, nil
serverConfig.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
return c.validatePasswordCredentials(conn, password, dataprovider.LoginMethodPassword)
}
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.LoginMethodPassword)
}
@@ -544,6 +530,7 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
if c.KeyboardInteractiveHook != "" {
if !strings.HasPrefix(c.KeyboardInteractiveHook, "http") {
if !filepath.IsAbs(c.KeyboardInteractiveHook) {
c.KeyboardInteractiveAuthentication = false
logger.WarnToConsole("invalid keyboard interactive authentication program: %q must be an absolute path",
c.KeyboardInteractiveHook)
logger.Warn(logSender, "", "invalid keyboard interactive authentication program: %q must be an absolute path",
@@ -552,6 +539,7 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
}
_, err := os.Stat(c.KeyboardInteractiveHook)
if err != nil {
c.KeyboardInteractiveAuthentication = false
logger.WarnToConsole("invalid keyboard interactive authentication program:: %v", err)
logger.Warn(logSender, "", "invalid keyboard interactive authentication program:: %v", err)
return
@@ -559,13 +547,7 @@ func (c *Configuration) configureKeyboardInteractiveAuth(serverConfig *ssh.Serve
}
}
serverConfig.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
sp, err := c.validateKeyboardInteractiveCredentials(conn, client)
if err != nil {
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err),
dataprovider.SSHLoginMethodKeyboardInteractive)
}
return sp, nil
return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyboardInteractive)
}
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive)
@@ -817,7 +799,7 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
return nil, fmt.Errorf("too many open sessions: %v", activeSessions)
}
}
if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolSSH, conn.PartialSuccessMethods()) {
if !user.IsLoginMethodAllowed(loginMethod, common.ProtocolSSH) {
logger.Info(logSender, connectionID, "cannot login user %q, login method %q is not allowed",
user.Username, loginMethod)
return nil, fmt.Errorf("login method %q is not allowed for user %q", loginMethod, user.Username)
@@ -1130,6 +1112,21 @@ func (c *Configuration) initializeCertChecker(configDir string) error {
return revokedCertManager.load()
}
func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error {
err := &ssh.PartialSuccessError{}
if c.PasswordAuthentication && util.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) {
err.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
return c.validatePasswordCredentials(conn, password, dataprovider.SSHLoginMethodKeyAndPassword)
}
}
if c.KeyboardInteractiveAuthentication && util.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) {
err.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt)
}
}
return err
}
func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
var err error
var user dataprovider.User
@@ -1180,9 +1177,9 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
keyID = fmt.Sprintf("%s: ID: %s, serial: %v, CA %s %s", certFingerprint,
cert.KeyId, cert.Serial, cert.Type(), ssh.FingerprintSHA256(cert.SignatureKey))
}
if user.IsPartialAuth(method) {
if user.IsPartialAuth() {
logger.Debug(logSender, connectionID, "user %q authenticated with partial success", conn.User())
return certPerm, ssh.ErrPartialSuccess
return certPerm, c.getPartialSuccessError(user.GetNextAuthMethods())
}
sshPerm, err = loginUser(&user, method, keyID, conn)
if err == nil && certPerm != nil {
@@ -1201,33 +1198,30 @@ func (c *Configuration) validatePublicKeyCredentials(conn ssh.ConnMetadata, pubK
return sshPerm, err
}
func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
func (c *Configuration) validatePasswordCredentials(conn ssh.ConnMetadata, pass []byte, method string) (*ssh.Permissions, error) {
var err error
var user dataprovider.User
var sshPerm *ssh.Permissions
method := dataprovider.LoginMethodPassword
if len(conn.PartialSuccessMethods()) == 1 {
method = dataprovider.SSHLoginMethodKeyAndPassword
}
ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String())
if user, err = dataprovider.CheckUserAndPass(conn.User(), string(pass), ipAddr, common.ProtocolSSH); err == nil {
sshPerm, err = loginUser(&user, method, "", conn)
}
user.Username = conn.User()
updateLoginMetrics(&user, ipAddr, method, err)
return sshPerm, err
if err != nil {
return nil, newAuthenticationError(fmt.Errorf("could not validate password credentials: %w", err), method)
}
return sshPerm, nil
}
func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge,
method string,
) (*ssh.Permissions, error) {
var err error
var user dataprovider.User
var sshPerm *ssh.Permissions
method := dataprovider.SSHLoginMethodKeyboardInteractive
if len(conn.PartialSuccessMethods()) == 1 {
method = dataprovider.SSHLoginMethodKeyAndKeyboardInt
}
ipAddr := util.GetIPFromRemoteAddress(conn.RemoteAddr().String())
if user, err = dataprovider.CheckKeyboardInteractiveAuth(conn.User(), c.KeyboardInteractiveHook, client,
ipAddr, common.ProtocolSSH); err == nil {
@@ -1235,7 +1229,10 @@ func (c *Configuration) validateKeyboardInteractiveCredentials(conn ssh.ConnMeta
}
user.Username = conn.User()
updateLoginMetrics(&user, ipAddr, method, err)
return sshPerm, err
if err != nil {
return nil, newAuthenticationError(fmt.Errorf("could not validate keyboard interactive credentials: %w", err), method)
}
return sshPerm, nil
}
func updateLoginMetrics(user *dataprovider.User, ip, method string, err error) {