replace utils.Contains with slices.Contains

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2024-07-24 18:27:13 +02:00
parent bd5eb03d9c
commit d94f80c8da
51 changed files with 353 additions and 322 deletions

View File

@@ -24,6 +24,7 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"testing"
"time"
@@ -418,7 +419,7 @@ func TestSupportedSSHCommands(t *testing.T) {
assert.Equal(t, len(supportedSSHCommands), len(cmds))
for _, c := range cmds {
assert.True(t, util.Contains(supportedSSHCommands, c))
assert.True(t, slices.Contains(supportedSSHCommands, c))
}
}
@@ -842,7 +843,7 @@ func TestRsyncOptions(t *testing.T) {
}
cmd, err := sshCmd.getSystemCommand()
assert.NoError(t, err)
assert.True(t, util.Contains(cmd.cmd.Args, "--safe-links"),
assert.True(t, slices.Contains(cmd.cmd.Args, "--safe-links"),
"--safe-links must be added if the user has the create symlinks permission")
permissions["/"] = []string{dataprovider.PermDownload, dataprovider.PermUpload, dataprovider.PermCreateDirs,
@@ -859,7 +860,7 @@ func TestRsyncOptions(t *testing.T) {
}
cmd, err = sshCmd.getSystemCommand()
assert.NoError(t, err)
assert.True(t, util.Contains(cmd.cmd.Args, "--munge-links"),
assert.True(t, slices.Contains(cmd.cmd.Args, "--munge-links"),
"--munge-links must be added if the user has the create symlinks permission")
sshCmd.connection.User.VirtualFolders = append(sshCmd.connection.User.VirtualFolders, vfs.VirtualFolder{

View File

@@ -26,6 +26,7 @@ import (
"os"
"path/filepath"
"runtime/debug"
"slices"
"strings"
"sync"
"time"
@@ -263,13 +264,13 @@ func (c *Configuration) getServerConfig() *ssh.ServerConfig {
func (c *Configuration) updateSupportedAuthentications() {
serviceStatus.Authentications = util.RemoveDuplicates(serviceStatus.Authentications, false)
if util.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) &&
util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
if slices.Contains(serviceStatus.Authentications, dataprovider.LoginMethodPassword) &&
slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndPassword)
}
if util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) &&
util.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
if slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyboardInteractive) &&
slices.Contains(serviceStatus.Authentications, dataprovider.SSHLoginMethodPublicKey) {
serviceStatus.Authentications = append(serviceStatus.Authentications, dataprovider.SSHLoginMethodKeyAndKeyboardInt)
}
}
@@ -422,7 +423,7 @@ func (c *Configuration) configureKeyAlgos(serverConfig *ssh.ServerConfig) error
c.HostKeyAlgorithms = util.RemoveDuplicates(c.HostKeyAlgorithms, true)
}
for _, hostKeyAlgo := range c.HostKeyAlgorithms {
if !util.Contains(supportedHostKeyAlgos, hostKeyAlgo) {
if !slices.Contains(supportedHostKeyAlgos, hostKeyAlgo) {
return fmt.Errorf("unsupported host key algorithm %q", hostKeyAlgo)
}
}
@@ -430,7 +431,7 @@ func (c *Configuration) configureKeyAlgos(serverConfig *ssh.ServerConfig) error
if len(c.PublicKeyAlgorithms) > 0 {
c.PublicKeyAlgorithms = util.RemoveDuplicates(c.PublicKeyAlgorithms, true)
for _, algo := range c.PublicKeyAlgorithms {
if !util.Contains(supportedPublicKeyAlgos, algo) {
if !slices.Contains(supportedPublicKeyAlgos, algo) {
return fmt.Errorf("unsupported public key authentication algorithm %q", algo)
}
}
@@ -472,7 +473,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
if kex == keyExchangeCurve25519SHA256LibSSH {
continue
}
if !util.Contains(supportedKexAlgos, kex) {
if !slices.Contains(supportedKexAlgos, kex) {
return fmt.Errorf("unsupported key-exchange algorithm %q", kex)
}
}
@@ -486,7 +487,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
if len(c.Ciphers) > 0 {
c.Ciphers = util.RemoveDuplicates(c.Ciphers, true)
for _, cipher := range c.Ciphers {
if !util.Contains(supportedCiphers, cipher) {
if !slices.Contains(supportedCiphers, cipher) {
return fmt.Errorf("unsupported cipher %q", cipher)
}
}
@@ -499,7 +500,7 @@ func (c *Configuration) configureSecurityOptions(serverConfig *ssh.ServerConfig)
if len(c.MACs) > 0 {
c.MACs = util.RemoveDuplicates(c.MACs, true)
for _, mac := range c.MACs {
if !util.Contains(supportedMACs, mac) {
if !slices.Contains(supportedMACs, mac) {
return fmt.Errorf("unsupported MAC algorithm %q", mac)
}
}
@@ -785,7 +786,7 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
user.Username, user.HomeDir)
return nil, fmt.Errorf("cannot login user with invalid home dir: %q", user.HomeDir)
}
if util.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) {
if slices.Contains(user.Filters.DeniedProtocols, common.ProtocolSSH) {
logger.Info(logSender, connectionID, "cannot login user %q, protocol SSH is not allowed", user.Username)
return nil, fmt.Errorf("protocol SSH is not allowed for user %q", user.Username)
}
@@ -830,14 +831,14 @@ func loginUser(user *dataprovider.User, loginMethod, publicKey string, conn ssh.
}
func (c *Configuration) checkSSHCommands() {
if util.Contains(c.EnabledSSHCommands, "*") {
if slices.Contains(c.EnabledSSHCommands, "*") {
c.EnabledSSHCommands = GetSupportedSSHCommands()
return
}
sshCommands := []string{}
for _, command := range c.EnabledSSHCommands {
command = strings.TrimSpace(command)
if util.Contains(supportedSSHCommands, command) {
if slices.Contains(supportedSSHCommands, command) {
sshCommands = append(sshCommands, command)
} else {
logger.Warn(logSender, "", "unsupported ssh command: %q ignored", command)
@@ -927,7 +928,7 @@ func (c *Configuration) checkHostKeyAutoGeneration(configDir string) error {
func (c *Configuration) getHostKeyAlgorithms(keyFormat string) []string {
var algos []string
for _, algo := range algorithmsForKeyFormat(keyFormat) {
if util.Contains(c.HostKeyAlgorithms, algo) {
if slices.Contains(c.HostKeyAlgorithms, algo) {
algos = append(algos, algo)
}
}
@@ -986,7 +987,7 @@ func (c *Configuration) checkAndLoadHostKeys(configDir string, serverConfig *ssh
var algos []string
for _, algo := range algorithmsForKeyFormat(signer.PublicKey().Type()) {
if underlyingAlgo, ok := certKeyAlgoNames[algo]; ok {
if util.Contains(mas.Algorithms(), underlyingAlgo) {
if slices.Contains(mas.Algorithms(), underlyingAlgo) {
algos = append(algos, algo)
}
}
@@ -1098,12 +1099,12 @@ func (c *Configuration) initializeCertChecker(configDir string) error {
func (c *Configuration) getPartialSuccessError(nextAuthMethods []string) error {
err := &ssh.PartialSuccessError{}
if c.PasswordAuthentication && util.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) {
if c.PasswordAuthentication && slices.Contains(nextAuthMethods, dataprovider.LoginMethodPassword) {
err.Next.PasswordCallback = func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
return c.validatePasswordCredentials(conn, password, dataprovider.SSHLoginMethodKeyAndPassword)
}
}
if c.KeyboardInteractiveAuthentication && util.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) {
if c.KeyboardInteractiveAuthentication && slices.Contains(nextAuthMethods, dataprovider.SSHLoginMethodKeyboardInteractive) {
err.Next.KeyboardInteractiveCallback = func(conn ssh.ConnMetadata, client ssh.KeyboardInteractiveChallenge) (*ssh.Permissions, error) {
return c.validateKeyboardInteractiveCredentials(conn, client, dataprovider.SSHLoginMethodKeyAndKeyboardInt, true)
}

View File

@@ -38,6 +38,7 @@ import (
"path"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
@@ -8639,8 +8640,8 @@ func TestUserAllowedLoginMethods(t *testing.T) {
allowedMethods = user.GetAllowedLoginMethods()
assert.Equal(t, 4, len(allowedMethods))
assert.True(t, util.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt))
assert.True(t, util.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword))
assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndKeyboardInt))
assert.True(t, slices.Contains(allowedMethods, dataprovider.SSHLoginMethodKeyAndPassword))
}
func TestUserPartialAuth(t *testing.T) {

View File

@@ -27,6 +27,7 @@ import (
"os/exec"
"path"
"runtime/debug"
"slices"
"strings"
"sync"
"time"
@@ -91,7 +92,7 @@ func processSSHCommand(payload []byte, connection *Connection, enabledSSHCommand
name, args, err := parseCommandPayload(msg.Command)
connection.Log(logger.LevelDebug, "new ssh command: %q args: %v num args: %d user: %s, error: %v",
name, args, len(args), connection.User.Username, err)
if err == nil && util.Contains(enabledSSHCommands, name) {
if err == nil && slices.Contains(enabledSSHCommands, name) {
connection.command = msg.Command
if name == scpCmdName && len(args) >= 2 {
connection.SetProtocol(common.ProtocolSCP)
@@ -139,9 +140,9 @@ func (c *sshCommand) handle() (err error) {
defer common.Connections.Remove(c.connection.GetID())
c.connection.UpdateLastActivity()
if util.Contains(sshHashCommands, c.command) {
if slices.Contains(sshHashCommands, c.command) {
return c.handleHashCommands()
} else if util.Contains(systemCommands, c.command) {
} else if slices.Contains(systemCommands, c.command) {
command, err := c.getSystemCommand()
if err != nil {
return c.sendErrorResponse(err)
@@ -429,11 +430,11 @@ func (c *sshCommand) getSystemCommand() (systemCommand, error) {
// If the user cannot create symlinks we add the option --munge-links, if it is not
// already set. This should make symlinks unusable (but manually recoverable)
if c.connection.User.HasPerm(dataprovider.PermCreateSymlinks, c.getDestPath()) {
if !util.Contains(args, "--safe-links") {
if !slices.Contains(args, "--safe-links") {
args = append([]string{"--safe-links"}, args...)
}
} else {
if !util.Contains(args, "--munge-links") {
if !slices.Contains(args, "--munge-links") {
args = append([]string{"--munge-links"}, args...)
}
}