mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 23:00:55 +03:00
refactoring of user session counters
Fixes #792 Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
@@ -98,6 +98,7 @@ func init() {
|
||||
Connections.clients = clientsMap{
|
||||
clients: make(map[string]int),
|
||||
}
|
||||
Connections.perUserConns = make(map[string]int)
|
||||
}
|
||||
|
||||
// errors definitions
|
||||
@@ -345,6 +346,7 @@ type ActiveTransfer interface {
|
||||
type ActiveConnection interface {
|
||||
GetID() string
|
||||
GetUsername() string
|
||||
GetMaxSessions() int
|
||||
GetLocalAddress() string
|
||||
GetRemoteAddress() string
|
||||
GetClientVersion() string
|
||||
@@ -733,6 +735,29 @@ type ActiveConnections struct {
|
||||
sync.RWMutex
|
||||
connections []ActiveConnection
|
||||
sshConnections []*SSHConnection
|
||||
perUserConns map[string]int
|
||||
}
|
||||
|
||||
// internal method, must be called within a locked block
|
||||
func (conns *ActiveConnections) addUserConnection(username string) {
|
||||
if username == "" {
|
||||
return
|
||||
}
|
||||
conns.perUserConns[username]++
|
||||
}
|
||||
|
||||
// internal method, must be called within a locked block
|
||||
func (conns *ActiveConnections) removeUserConnection(username string) {
|
||||
if username == "" {
|
||||
return
|
||||
}
|
||||
if val, ok := conns.perUserConns[username]; ok {
|
||||
conns.perUserConns[username]--
|
||||
if val > 1 {
|
||||
return
|
||||
}
|
||||
delete(conns.perUserConns, username)
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveSessions returns the number of active sessions for the given username.
|
||||
@@ -741,24 +766,27 @@ func (conns *ActiveConnections) GetActiveSessions(username string) int {
|
||||
conns.RLock()
|
||||
defer conns.RUnlock()
|
||||
|
||||
numSessions := 0
|
||||
for _, c := range conns.connections {
|
||||
if c.GetUsername() == username {
|
||||
numSessions++
|
||||
}
|
||||
}
|
||||
return numSessions
|
||||
return conns.perUserConns[username]
|
||||
}
|
||||
|
||||
// Add adds a new connection to the active ones
|
||||
func (conns *ActiveConnections) Add(c ActiveConnection) {
|
||||
func (conns *ActiveConnections) Add(c ActiveConnection) error {
|
||||
conns.Lock()
|
||||
defer conns.Unlock()
|
||||
|
||||
if username := c.GetUsername(); username != "" {
|
||||
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
|
||||
if val := conns.perUserConns[username]; val >= maxSessions {
|
||||
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
|
||||
}
|
||||
}
|
||||
conns.addUserConnection(username)
|
||||
}
|
||||
conns.connections = append(conns.connections, c)
|
||||
metric.UpdateActiveConnectionsSize(len(conns.connections))
|
||||
logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v",
|
||||
c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Swap replaces an existing connection with the given one.
|
||||
@@ -771,6 +799,16 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error {
|
||||
|
||||
for idx, conn := range conns.connections {
|
||||
if conn.GetID() == c.GetID() {
|
||||
conns.removeUserConnection(conn.GetUsername())
|
||||
if username := c.GetUsername(); username != "" {
|
||||
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
|
||||
if val := conns.perUserConns[username]; val >= maxSessions {
|
||||
conns.addUserConnection(conn.GetUsername())
|
||||
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
|
||||
}
|
||||
}
|
||||
conns.addUserConnection(username)
|
||||
}
|
||||
err := conn.CloseFS()
|
||||
conns.connections[idx] = c
|
||||
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
|
||||
@@ -793,6 +831,7 @@ func (conns *ActiveConnections) Remove(connectionID string) {
|
||||
conns.connections[idx] = conns.connections[lastIdx]
|
||||
conns.connections[lastIdx] = nil
|
||||
conns.connections = conns.connections[:lastIdx]
|
||||
conns.removeUserConnection(conn.GetUsername())
|
||||
metric.UpdateActiveConnectionsSize(lastIdx)
|
||||
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
|
||||
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
|
||||
|
||||
Reference in New Issue
Block a user