mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 14:50:55 +03:00
make connections lookups constant time
Performance improves if there are many active connections. For a few connections there is a small (unnoticeable) performance degradation Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
@@ -117,6 +117,8 @@ func init() {
|
||||
clients: make(map[string]int),
|
||||
}
|
||||
Connections.perUserConns = make(map[string]int)
|
||||
Connections.mapping = make(map[string]int)
|
||||
Connections.sshMapping = make(map[string]int)
|
||||
}
|
||||
|
||||
// errors definitions
|
||||
@@ -752,7 +754,9 @@ type ActiveConnections struct {
|
||||
transfersCheckStatus atomic.Bool
|
||||
sync.RWMutex
|
||||
connections []ActiveConnection
|
||||
mapping map[string]int
|
||||
sshConnections []*SSHConnection
|
||||
sshMapping map[string]int
|
||||
perUserConns map[string]int
|
||||
}
|
||||
|
||||
@@ -800,9 +804,10 @@ func (conns *ActiveConnections) Add(c ActiveConnection) error {
|
||||
}
|
||||
conns.addUserConnection(username)
|
||||
}
|
||||
conns.mapping[c.GetID()] = len(conns.connections)
|
||||
conns.connections = append(conns.connections, c)
|
||||
metric.UpdateActiveConnectionsSize(len(conns.connections))
|
||||
logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v",
|
||||
logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %q, remote address %q, num open connections: %d",
|
||||
c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections))
|
||||
return nil
|
||||
}
|
||||
@@ -815,25 +820,25 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error {
|
||||
conns.Lock()
|
||||
defer conns.Unlock()
|
||||
|
||||
for idx, conn := range conns.connections {
|
||||
if conn.GetID() == c.GetID() {
|
||||
conns.removeUserConnection(conn.GetUsername())
|
||||
if username := c.GetUsername(); username != "" {
|
||||
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
|
||||
if val := conns.perUserConns[username]; val >= maxSessions {
|
||||
conns.addUserConnection(conn.GetUsername())
|
||||
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
|
||||
}
|
||||
if idx, ok := conns.mapping[c.GetID()]; ok {
|
||||
conn := conns.connections[idx]
|
||||
conns.removeUserConnection(conn.GetUsername())
|
||||
if username := c.GetUsername(); username != "" {
|
||||
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
|
||||
if val, ok := conns.perUserConns[username]; ok && val >= maxSessions {
|
||||
conns.addUserConnection(conn.GetUsername())
|
||||
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
|
||||
}
|
||||
conns.addUserConnection(username)
|
||||
}
|
||||
err := conn.CloseFS()
|
||||
conns.connections[idx] = c
|
||||
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
|
||||
conn = nil
|
||||
return nil
|
||||
conns.addUserConnection(username)
|
||||
}
|
||||
err := conn.CloseFS()
|
||||
conns.connections[idx] = c
|
||||
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
|
||||
conn = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("connection to swap not found")
|
||||
}
|
||||
|
||||
@@ -842,49 +847,53 @@ func (conns *ActiveConnections) Remove(connectionID string) {
|
||||
conns.Lock()
|
||||
defer conns.Unlock()
|
||||
|
||||
for idx, conn := range conns.connections {
|
||||
if conn.GetID() == connectionID {
|
||||
err := conn.CloseFS()
|
||||
lastIdx := len(conns.connections) - 1
|
||||
conns.connections[idx] = conns.connections[lastIdx]
|
||||
conns.connections[lastIdx] = nil
|
||||
conns.connections = conns.connections[:lastIdx]
|
||||
conns.removeUserConnection(conn.GetUsername())
|
||||
metric.UpdateActiveConnectionsSize(lastIdx)
|
||||
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
|
||||
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
|
||||
if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" {
|
||||
ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress())
|
||||
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, conn.GetProtocol(),
|
||||
dataprovider.ErrNoAuthTryed.Error())
|
||||
metric.AddNoAuthTryed()
|
||||
AddDefenderEvent(ip, HostEventNoLoginTried)
|
||||
dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTryed, ip,
|
||||
conn.GetProtocol(), dataprovider.ErrNoAuthTryed)
|
||||
}
|
||||
Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(),
|
||||
conn.GetID(), conn.GetConnectionTime())
|
||||
return
|
||||
if idx, ok := conns.mapping[connectionID]; ok {
|
||||
conn := conns.connections[idx]
|
||||
err := conn.CloseFS()
|
||||
lastIdx := len(conns.connections) - 1
|
||||
conns.connections[idx] = conns.connections[lastIdx]
|
||||
conns.connections[lastIdx] = nil
|
||||
conns.connections = conns.connections[:lastIdx]
|
||||
delete(conns.mapping, connectionID)
|
||||
if idx != lastIdx {
|
||||
conns.mapping[conns.connections[idx].GetID()] = idx
|
||||
}
|
||||
conns.removeUserConnection(conn.GetUsername())
|
||||
metric.UpdateActiveConnectionsSize(lastIdx)
|
||||
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
|
||||
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
|
||||
if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" {
|
||||
ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress())
|
||||
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, conn.GetProtocol(),
|
||||
dataprovider.ErrNoAuthTryed.Error())
|
||||
metric.AddNoAuthTryed()
|
||||
AddDefenderEvent(ip, HostEventNoLoginTried)
|
||||
dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTryed, ip,
|
||||
conn.GetProtocol(), dataprovider.ErrNoAuthTryed)
|
||||
}
|
||||
Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(),
|
||||
conn.GetID(), conn.GetConnectionTime())
|
||||
return
|
||||
}
|
||||
logger.Warn(logSender, "", "connection id %#v to remove not found!", connectionID)
|
||||
|
||||
logger.Warn(logSender, "", "connection id %q to remove not found!", connectionID)
|
||||
}
|
||||
|
||||
// Close closes an active connection.
|
||||
// It returns true on success
|
||||
func (conns *ActiveConnections) Close(connectionID string) bool {
|
||||
conns.RLock()
|
||||
result := false
|
||||
|
||||
for _, c := range conns.connections {
|
||||
if c.GetID() == connectionID {
|
||||
defer func(conn ActiveConnection) {
|
||||
err := conn.Disconnect()
|
||||
logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err)
|
||||
}(c)
|
||||
result = true
|
||||
break
|
||||
}
|
||||
var result bool
|
||||
|
||||
if idx, ok := conns.mapping[connectionID]; ok {
|
||||
c := conns.connections[idx]
|
||||
|
||||
defer func(conn ActiveConnection) {
|
||||
err := conn.Disconnect()
|
||||
logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err)
|
||||
}(c)
|
||||
result = true
|
||||
}
|
||||
|
||||
conns.RUnlock()
|
||||
@@ -896,8 +905,9 @@ func (conns *ActiveConnections) AddSSHConnection(c *SSHConnection) {
|
||||
conns.Lock()
|
||||
defer conns.Unlock()
|
||||
|
||||
conns.sshMapping[c.GetID()] = len(conns.sshConnections)
|
||||
conns.sshConnections = append(conns.sshConnections, c)
|
||||
logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %v", len(conns.sshConnections))
|
||||
logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %d", len(conns.sshConnections))
|
||||
}
|
||||
|
||||
// RemoveSSHConnection removes a connection from the active ones
|
||||
@@ -905,17 +915,19 @@ func (conns *ActiveConnections) RemoveSSHConnection(connectionID string) {
|
||||
conns.Lock()
|
||||
defer conns.Unlock()
|
||||
|
||||
for idx, conn := range conns.sshConnections {
|
||||
if conn.GetID() == connectionID {
|
||||
lastIdx := len(conns.sshConnections) - 1
|
||||
conns.sshConnections[idx] = conns.sshConnections[lastIdx]
|
||||
conns.sshConnections[lastIdx] = nil
|
||||
conns.sshConnections = conns.sshConnections[:lastIdx]
|
||||
logger.Debug(logSender, conn.GetID(), "ssh connection removed, num open ssh connections: %v", lastIdx)
|
||||
return
|
||||
if idx, ok := conns.sshMapping[connectionID]; ok {
|
||||
lastIdx := len(conns.sshConnections) - 1
|
||||
conns.sshConnections[idx] = conns.sshConnections[lastIdx]
|
||||
conns.sshConnections[lastIdx] = nil
|
||||
conns.sshConnections = conns.sshConnections[:lastIdx]
|
||||
delete(conns.sshMapping, connectionID)
|
||||
if idx != lastIdx {
|
||||
conns.sshMapping[conns.sshConnections[idx].GetID()] = idx
|
||||
}
|
||||
logger.Debug(logSender, connectionID, "ssh connection removed, num open ssh connections: %d", lastIdx)
|
||||
return
|
||||
}
|
||||
logger.Warn(logSender, "", "ssh connection to remove with id %#v not found!", connectionID)
|
||||
logger.Warn(logSender, "", "ssh connection to remove with id %q not found!", connectionID)
|
||||
}
|
||||
|
||||
func (conns *ActiveConnections) checkIdles() {
|
||||
|
||||
Reference in New Issue
Block a user