mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-08 15:28:05 +03:00
sftpd: ensure to always close idle connections
after the last commit this wasn't the case anymore Completly fixes #169
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pires/go-proxyproto"
|
"github.com/pires/go-proxyproto"
|
||||||
@@ -336,10 +337,48 @@ func (c *Configuration) ExecutePostConnectHook(remoteAddr, protocol string) erro
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SSHConnection defines an ssh connection.
|
||||||
|
// Each SSH connection can open several channels for SFTP or SSH commands
|
||||||
|
type SSHConnection struct {
|
||||||
|
id string
|
||||||
|
conn net.Conn
|
||||||
|
lastActivity int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSHConnection returns a new SSHConnection
|
||||||
|
func NewSSHConnection(id string, conn net.Conn) *SSHConnection {
|
||||||
|
return &SSHConnection{
|
||||||
|
id: id,
|
||||||
|
conn: conn,
|
||||||
|
lastActivity: time.Now().UnixNano(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the ID for this SSHConnection
|
||||||
|
func (c *SSHConnection) GetID() string {
|
||||||
|
return c.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateLastActivity updates last activity for this connection
|
||||||
|
func (c *SSHConnection) UpdateLastActivity() {
|
||||||
|
atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLastActivity returns the last connection activity
|
||||||
|
func (c *SSHConnection) GetLastActivity() time.Time {
|
||||||
|
return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying network connection
|
||||||
|
func (c *SSHConnection) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// ActiveConnections holds the currect active connections with the associated transfers
|
// ActiveConnections holds the currect active connections with the associated transfers
|
||||||
type ActiveConnections struct {
|
type ActiveConnections struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
connections []ActiveConnection
|
connections []ActiveConnection
|
||||||
|
sshConnections []*SSHConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveSessions returns the number of active sessions for the given username.
|
// GetActiveSessions returns the number of active sessions for the given username.
|
||||||
@@ -431,9 +470,64 @@ func (conns *ActiveConnections) Close(connectionID string) bool {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddSSHConnection adds a new ssh connection to the active ones
|
||||||
|
func (conns *ActiveConnections) AddSSHConnection(c *SSHConnection) {
|
||||||
|
conns.Lock()
|
||||||
|
defer conns.Unlock()
|
||||||
|
|
||||||
|
conns.sshConnections = append(conns.sshConnections, c)
|
||||||
|
logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %v", len(conns.sshConnections))
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveSSHConnection removes a connection from the active ones
|
||||||
|
func (conns *ActiveConnections) RemoveSSHConnection(connectionID string) {
|
||||||
|
conns.Lock()
|
||||||
|
defer conns.Unlock()
|
||||||
|
|
||||||
|
var c *SSHConnection
|
||||||
|
indexToRemove := -1
|
||||||
|
for i, conn := range conns.sshConnections {
|
||||||
|
if conn.GetID() == connectionID {
|
||||||
|
indexToRemove = i
|
||||||
|
c = conn
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if indexToRemove >= 0 {
|
||||||
|
conns.sshConnections[indexToRemove] = conns.sshConnections[len(conns.sshConnections)-1]
|
||||||
|
conns.sshConnections[len(conns.sshConnections)-1] = nil
|
||||||
|
conns.sshConnections = conns.sshConnections[:len(conns.sshConnections)-1]
|
||||||
|
logger.Debug(logSender, c.GetID(), "ssh connection removed, num open ssh connections: %v", len(conns.sshConnections))
|
||||||
|
} else {
|
||||||
|
logger.Warn(logSender, "", "ssh connection to remove with id %#v not found!", connectionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (conns *ActiveConnections) checkIdleConnections() {
|
func (conns *ActiveConnections) checkIdleConnections() {
|
||||||
conns.RLock()
|
conns.RLock()
|
||||||
|
|
||||||
|
for _, sshConn := range conns.sshConnections {
|
||||||
|
idleTime := time.Since(sshConn.GetLastActivity())
|
||||||
|
if idleTime > Config.idleTimeoutAsDuration {
|
||||||
|
// we close the an ssh connection if it has no active connections associated
|
||||||
|
idToMatch := fmt.Sprintf("_%v_", sshConn.GetID())
|
||||||
|
toClose := true
|
||||||
|
for _, conn := range conns.connections {
|
||||||
|
if strings.Contains(conn.GetID(), idToMatch) {
|
||||||
|
toClose = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toClose {
|
||||||
|
defer func(c *SSHConnection) {
|
||||||
|
err := c.Close()
|
||||||
|
logger.Debug(logSender, c.GetID(), "close idle SSH connection, idle time: %v, close err: %v",
|
||||||
|
time.Since(c.GetLastActivity()), err)
|
||||||
|
}(sshConn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, c := range conns.connections {
|
for _, c := range conns.connections {
|
||||||
idleTime := time.Since(c.GetLastActivity())
|
idleTime := time.Since(c.GetLastActivity())
|
||||||
isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && len(c.GetUsername()) == 0)
|
isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && len(c.GetUsername()) == 0)
|
||||||
@@ -442,7 +536,7 @@ func (conns *ActiveConnections) checkIdleConnections() {
|
|||||||
defer func(conn ActiveConnection, isFTPNoAuth bool) {
|
defer func(conn ActiveConnection, isFTPNoAuth bool) {
|
||||||
err := conn.Disconnect()
|
err := conn.Disconnect()
|
||||||
logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %v, username: %#v close err: %v",
|
logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %v, username: %#v close err: %v",
|
||||||
idleTime, conn.GetUsername(), err)
|
time.Since(conn.GetLastActivity()), conn.GetUsername(), err)
|
||||||
if isFTPNoAuth {
|
if isFTPNoAuth {
|
||||||
ip := utils.GetIPFromRemoteAddress(c.GetRemoteAddress())
|
ip := utils.GetIPFromRemoteAddress(c.GetRemoteAddress())
|
||||||
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, c.GetProtocol(), "client idle")
|
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, c.GetProtocol(), "client idle")
|
||||||
|
|||||||
@@ -68,6 +68,18 @@ func (c *fakeConnection) GetRemoteAddress() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type customNetConn struct {
|
||||||
|
net.Conn
|
||||||
|
id string
|
||||||
|
isClosed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *customNetConn) Close() error {
|
||||||
|
Connections.RemoveSSHConnection(c.id)
|
||||||
|
c.isClosed = true
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
logfilePath := "common_test.log"
|
logfilePath := "common_test.log"
|
||||||
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
|
logger.InitLogger(logfilePath, 5, 1, 28, false, zerolog.DebugLevel)
|
||||||
@@ -168,40 +180,112 @@ func closeDataprovider() error {
|
|||||||
return dataprovider.Close()
|
return dataprovider.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSSHConnections(t *testing.T) {
|
||||||
|
conn1, conn2 := net.Pipe()
|
||||||
|
now := time.Now()
|
||||||
|
sshConn1 := NewSSHConnection("id1", conn1)
|
||||||
|
sshConn2 := NewSSHConnection("id2", conn2)
|
||||||
|
assert.Equal(t, "id1", sshConn1.GetID())
|
||||||
|
assert.Equal(t, "id2", sshConn2.GetID())
|
||||||
|
sshConn1.UpdateLastActivity()
|
||||||
|
assert.GreaterOrEqual(t, sshConn1.GetLastActivity().UnixNano(), now.UnixNano())
|
||||||
|
Connections.AddSSHConnection(sshConn1)
|
||||||
|
Connections.AddSSHConnection(sshConn2)
|
||||||
|
Connections.RLock()
|
||||||
|
assert.Len(t, Connections.sshConnections, 2)
|
||||||
|
Connections.RUnlock()
|
||||||
|
Connections.RemoveSSHConnection(sshConn1.id)
|
||||||
|
Connections.RLock()
|
||||||
|
assert.Len(t, Connections.sshConnections, 1)
|
||||||
|
Connections.RUnlock()
|
||||||
|
Connections.RemoveSSHConnection(sshConn1.id)
|
||||||
|
Connections.RLock()
|
||||||
|
assert.Len(t, Connections.sshConnections, 1)
|
||||||
|
Connections.RUnlock()
|
||||||
|
Connections.RemoveSSHConnection(sshConn2.id)
|
||||||
|
Connections.RLock()
|
||||||
|
assert.Len(t, Connections.sshConnections, 0)
|
||||||
|
Connections.RUnlock()
|
||||||
|
assert.NoError(t, sshConn1.Close())
|
||||||
|
assert.NoError(t, sshConn2.Close())
|
||||||
|
}
|
||||||
|
|
||||||
func TestIdleConnections(t *testing.T) {
|
func TestIdleConnections(t *testing.T) {
|
||||||
configCopy := Config
|
configCopy := Config
|
||||||
|
|
||||||
Config.IdleTimeout = 1
|
Config.IdleTimeout = 1
|
||||||
Initialize(Config)
|
Initialize(Config)
|
||||||
|
|
||||||
|
conn1, conn2 := net.Pipe()
|
||||||
|
customConn1 := &customNetConn{
|
||||||
|
Conn: conn1,
|
||||||
|
id: "id1",
|
||||||
|
}
|
||||||
|
customConn2 := &customNetConn{
|
||||||
|
Conn: conn2,
|
||||||
|
id: "id2",
|
||||||
|
}
|
||||||
|
sshConn1 := NewSSHConnection(customConn1.id, customConn1)
|
||||||
|
sshConn2 := NewSSHConnection(customConn2.id, customConn2)
|
||||||
|
|
||||||
username := "test_user"
|
username := "test_user"
|
||||||
user := dataprovider.User{
|
user := dataprovider.User{
|
||||||
Username: username,
|
Username: username,
|
||||||
}
|
}
|
||||||
c := NewBaseConnection("id1", ProtocolSFTP, user, nil)
|
c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, user, nil)
|
||||||
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||||
fakeConn := &fakeConnection{
|
fakeConn := &fakeConnection{
|
||||||
BaseConnection: c,
|
BaseConnection: c,
|
||||||
}
|
}
|
||||||
|
// both ssh connections are expired but they should get removed only
|
||||||
|
// if there is no associated connection
|
||||||
|
sshConn1.lastActivity = c.lastActivity
|
||||||
|
sshConn2.lastActivity = c.lastActivity
|
||||||
|
Connections.AddSSHConnection(sshConn1)
|
||||||
Connections.Add(fakeConn)
|
Connections.Add(fakeConn)
|
||||||
assert.Equal(t, Connections.GetActiveSessions(username), 1)
|
assert.Equal(t, Connections.GetActiveSessions(username), 1)
|
||||||
c = NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil)
|
c = NewBaseConnection(sshConn2.id+"_1", ProtocolSSH, user, nil)
|
||||||
c.lastActivity = time.Now().UnixNano()
|
|
||||||
fakeConn = &fakeConnection{
|
fakeConn = &fakeConnection{
|
||||||
BaseConnection: c,
|
BaseConnection: c,
|
||||||
}
|
}
|
||||||
|
Connections.AddSSHConnection(sshConn2)
|
||||||
Connections.Add(fakeConn)
|
Connections.Add(fakeConn)
|
||||||
assert.Equal(t, Connections.GetActiveSessions(username), 1)
|
assert.Equal(t, Connections.GetActiveSessions(username), 2)
|
||||||
assert.Len(t, Connections.GetStats(), 2)
|
|
||||||
|
cFTP := NewBaseConnection("id2", ProtocolFTP, dataprovider.User{}, nil)
|
||||||
|
cFTP.lastActivity = time.Now().UnixNano()
|
||||||
|
fakeConn = &fakeConnection{
|
||||||
|
BaseConnection: cFTP,
|
||||||
|
}
|
||||||
|
Connections.Add(fakeConn)
|
||||||
|
assert.Equal(t, Connections.GetActiveSessions(username), 2)
|
||||||
|
assert.Len(t, Connections.GetStats(), 3)
|
||||||
|
Connections.RLock()
|
||||||
|
assert.Len(t, Connections.sshConnections, 2)
|
||||||
|
Connections.RUnlock()
|
||||||
|
|
||||||
startIdleTimeoutTicker(100 * time.Millisecond)
|
startIdleTimeoutTicker(100 * time.Millisecond)
|
||||||
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 0 }, 1*time.Second, 200*time.Millisecond)
|
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
Connections.RLock()
|
||||||
|
defer Connections.RUnlock()
|
||||||
|
return len(Connections.sshConnections) == 1
|
||||||
|
}, 1*time.Second, 200*time.Millisecond)
|
||||||
stopIdleTimeoutTicker()
|
stopIdleTimeoutTicker()
|
||||||
assert.Len(t, Connections.GetStats(), 1)
|
assert.Len(t, Connections.GetStats(), 2)
|
||||||
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||||
|
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
|
||||||
|
sshConn2.lastActivity = c.lastActivity
|
||||||
startIdleTimeoutTicker(100 * time.Millisecond)
|
startIdleTimeoutTicker(100 * time.Millisecond)
|
||||||
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
|
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
Connections.RLock()
|
||||||
|
defer Connections.RUnlock()
|
||||||
|
return len(Connections.sshConnections) == 0
|
||||||
|
}, 1*time.Second, 200*time.Millisecond)
|
||||||
stopIdleTimeoutTicker()
|
stopIdleTimeoutTicker()
|
||||||
|
assert.True(t, customConn1.isClosed)
|
||||||
|
assert.True(t, customConn2.isClosed)
|
||||||
|
|
||||||
Config = configCopy
|
Config = configCopy
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -311,9 +311,14 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
|
user.ID, loginType, user.Username, user.HomeDir, remoteAddr.String())
|
||||||
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
dataprovider.UpdateLastLogin(user) //nolint:errcheck
|
||||||
|
|
||||||
|
sshConnection := common.NewSSHConnection(connectionID, conn)
|
||||||
|
common.Connections.AddSSHConnection(sshConnection)
|
||||||
|
|
||||||
|
defer common.Connections.RemoveSSHConnection(connectionID)
|
||||||
|
|
||||||
go ssh.DiscardRequests(reqs)
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
channelCounter := 0
|
channelCounter := int64(0)
|
||||||
for newChannel := range chans {
|
for newChannel := range chans {
|
||||||
// If its not a session channel we just move on because its not something we
|
// If its not a session channel we just move on because its not something we
|
||||||
// know how to handle at this point.
|
// know how to handle at this point.
|
||||||
@@ -331,9 +336,10 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
}
|
}
|
||||||
|
|
||||||
channelCounter++
|
channelCounter++
|
||||||
|
sshConnection.UpdateLastActivity()
|
||||||
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
|
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
|
||||||
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
|
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
|
||||||
go func(in <-chan *ssh.Request, counter int) {
|
go func(in <-chan *ssh.Request, counter int64) {
|
||||||
for req := range in {
|
for req := range in {
|
||||||
ok := false
|
ok := false
|
||||||
connID := fmt.Sprintf("%v_%v", connectionID, counter)
|
connID := fmt.Sprintf("%v_%v", connectionID, counter)
|
||||||
@@ -353,7 +359,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
case "exec":
|
case "exec":
|
||||||
// protocol will be set later inside processSSHCommand it could be SSH or SCP
|
// protocol will be set later inside processSSHCommand it could be SSH or SCP
|
||||||
connection := Connection{
|
connection := Connection{
|
||||||
BaseConnection: common.NewBaseConnection(connID, "sshd", user, fs),
|
BaseConnection: common.NewBaseConnection(connID, "sshd_exec", user, fs),
|
||||||
ClientVersion: string(sconn.ClientVersion()),
|
ClientVersion: string(sconn.ClientVersion()),
|
||||||
RemoteAddr: remoteAddr,
|
RemoteAddr: remoteAddr,
|
||||||
channel: channel,
|
channel: channel,
|
||||||
|
|||||||
Reference in New Issue
Block a user