mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-08 07:10:56 +03:00
check quota usage between ongoing transfers
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
131
common/common.go
131
common/common.go
@@ -53,9 +53,10 @@ const (
|
||||
operationMkdir = "mkdir"
|
||||
operationRmdir = "rmdir"
|
||||
// SSH command action name
|
||||
OperationSSHCmd = "ssh_cmd"
|
||||
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
|
||||
idleTimeoutCheckInterval = 3 * time.Minute
|
||||
OperationSSHCmd = "ssh_cmd"
|
||||
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
|
||||
idleTimeoutCheckInterval = 3 * time.Minute
|
||||
periodicTimeoutCheckInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
// Stat flags
|
||||
@@ -110,6 +111,7 @@ var (
|
||||
ErrCrtRevoked = errors.New("your certificate has been revoked")
|
||||
ErrNoCredentials = errors.New("no credential provided")
|
||||
ErrInternalFailure = errors.New("internal failure")
|
||||
ErrTransferAborted = errors.New("transfer aborted")
|
||||
errNoTransfer = errors.New("requested transfer not found")
|
||||
errTransferMismatch = errors.New("transfer mismatch")
|
||||
)
|
||||
@@ -120,10 +122,11 @@ var (
|
||||
// Connections is the list of active connections
|
||||
Connections ActiveConnections
|
||||
// QuotaScans is the list of active quota scans
|
||||
QuotaScans ActiveScans
|
||||
idleTimeoutTicker *time.Ticker
|
||||
idleTimeoutTickerDone chan bool
|
||||
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
|
||||
QuotaScans ActiveScans
|
||||
transfersChecker TransfersChecker
|
||||
periodicTimeoutTicker *time.Ticker
|
||||
periodicTimeoutTickerDone chan bool
|
||||
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
|
||||
ProtocolHTTP, ProtocolHTTPShare}
|
||||
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
|
||||
// the map key is the protocol, for each protocol we can have multiple rate limiters
|
||||
@@ -135,9 +138,7 @@ func Initialize(c Configuration) error {
|
||||
Config = c
|
||||
Config.idleLoginTimeout = 2 * time.Minute
|
||||
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
|
||||
if Config.IdleTimeout > 0 {
|
||||
startIdleTimeoutTicker(idleTimeoutCheckInterval)
|
||||
}
|
||||
startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
|
||||
Config.defender = nil
|
||||
rateLimiters = make(map[string][]*rateLimiter)
|
||||
for _, rlCfg := range c.RateLimitersConfig {
|
||||
@@ -176,6 +177,7 @@ func Initialize(c Configuration) error {
|
||||
}
|
||||
vfs.SetTempPath(c.TempPath)
|
||||
dataprovider.SetTempPath(c.TempPath)
|
||||
transfersChecker = getTransfersChecker()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -267,41 +269,52 @@ func AddDefenderEvent(ip string, event HostEvent) {
|
||||
}
|
||||
|
||||
// the ticker cannot be started/stopped from multiple goroutines
|
||||
func startIdleTimeoutTicker(duration time.Duration) {
|
||||
stopIdleTimeoutTicker()
|
||||
idleTimeoutTicker = time.NewTicker(duration)
|
||||
idleTimeoutTickerDone = make(chan bool)
|
||||
func startPeriodicTimeoutTicker(duration time.Duration) {
|
||||
stopPeriodicTimeoutTicker()
|
||||
periodicTimeoutTicker = time.NewTicker(duration)
|
||||
periodicTimeoutTickerDone = make(chan bool)
|
||||
go func() {
|
||||
counter := int64(0)
|
||||
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
|
||||
for {
|
||||
select {
|
||||
case <-idleTimeoutTickerDone:
|
||||
case <-periodicTimeoutTickerDone:
|
||||
return
|
||||
case <-idleTimeoutTicker.C:
|
||||
Connections.checkIdles()
|
||||
case <-periodicTimeoutTicker.C:
|
||||
counter++
|
||||
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
|
||||
counter = 0
|
||||
Connections.checkIdles()
|
||||
}
|
||||
go Connections.checkTransfers()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func stopIdleTimeoutTicker() {
|
||||
if idleTimeoutTicker != nil {
|
||||
idleTimeoutTicker.Stop()
|
||||
idleTimeoutTickerDone <- true
|
||||
idleTimeoutTicker = nil
|
||||
func stopPeriodicTimeoutTicker() {
|
||||
if periodicTimeoutTicker != nil {
|
||||
periodicTimeoutTicker.Stop()
|
||||
periodicTimeoutTickerDone <- true
|
||||
periodicTimeoutTicker = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ActiveTransfer defines the interface for the current active transfers
|
||||
type ActiveTransfer interface {
|
||||
GetID() uint64
|
||||
GetID() int64
|
||||
GetType() int
|
||||
GetSize() int64
|
||||
GetDownloadedSize() int64
|
||||
GetUploadedSize() int64
|
||||
GetVirtualPath() string
|
||||
GetStartTime() time.Time
|
||||
SignalClose()
|
||||
SignalClose(err error)
|
||||
Truncate(fsPath string, size int64) (int64, error)
|
||||
GetRealFsPath(fsPath string) string
|
||||
SetTimes(fsPath string, atime time.Time, mtime time.Time) bool
|
||||
GetTruncatedSize() int64
|
||||
GetMaxAllowedSize() int64
|
||||
}
|
||||
|
||||
// ActiveConnection defines the interface for the current active connections
|
||||
@@ -319,6 +332,7 @@ type ActiveConnection interface {
|
||||
AddTransfer(t ActiveTransfer)
|
||||
RemoveTransfer(t ActiveTransfer)
|
||||
GetTransfers() []ConnectionTransfer
|
||||
SignalTransferClose(transferID int64, err error)
|
||||
CloseFS() error
|
||||
}
|
||||
|
||||
@@ -335,11 +349,14 @@ type StatAttributes struct {
|
||||
|
||||
// ConnectionTransfer defines the trasfer details to expose
|
||||
type ConnectionTransfer struct {
|
||||
ID uint64 `json:"-"`
|
||||
OperationType string `json:"operation_type"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
Size int64 `json:"size"`
|
||||
VirtualPath string `json:"path"`
|
||||
ID int64 `json:"-"`
|
||||
OperationType string `json:"operation_type"`
|
||||
StartTime int64 `json:"start_time"`
|
||||
Size int64 `json:"size"`
|
||||
VirtualPath string `json:"path"`
|
||||
MaxAllowedSize int64 `json:"-"`
|
||||
ULSize int64 `json:"-"`
|
||||
DLSize int64 `json:"-"`
|
||||
}
|
||||
|
||||
func (t *ConnectionTransfer) getConnectionTransferAsString() string {
|
||||
@@ -653,7 +670,8 @@ func (c *SSHConnection) Close() error {
|
||||
type ActiveConnections struct {
|
||||
// clients contains both authenticated and estabilished connections and the ones waiting
|
||||
// for authentication
|
||||
clients clientsMap
|
||||
clients clientsMap
|
||||
transfersCheckStatus int32
|
||||
sync.RWMutex
|
||||
connections []ActiveConnection
|
||||
sshConnections []*SSHConnection
|
||||
@@ -825,6 +843,59 @@ func (conns *ActiveConnections) checkIdles() {
|
||||
conns.RUnlock()
|
||||
}
|
||||
|
||||
func (conns *ActiveConnections) checkTransfers() {
|
||||
if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
|
||||
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
|
||||
return
|
||||
}
|
||||
atomic.StoreInt32(&conns.transfersCheckStatus, 1)
|
||||
defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
logger.Debug(logSender, "", "start concurrent transfers check")
|
||||
conns.RLock()
|
||||
|
||||
// update the current size for transfers to monitors
|
||||
for _, c := range conns.connections {
|
||||
for _, t := range c.GetTransfers() {
|
||||
if t.MaxAllowedSize > 0 {
|
||||
wg.Add(1)
|
||||
|
||||
go func(transfer ConnectionTransfer, connID string) {
|
||||
defer wg.Done()
|
||||
transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID)
|
||||
}(t, c.GetID())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conns.RUnlock()
|
||||
logger.Debug(logSender, "", "waiting for the update of the transfers current size")
|
||||
wg.Wait()
|
||||
|
||||
logger.Debug(logSender, "", "getting overquota transfers")
|
||||
overquotaTransfers := transfersChecker.GetOverquotaTransfers()
|
||||
logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers))
|
||||
if len(overquotaTransfers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
conns.RLock()
|
||||
defer conns.RUnlock()
|
||||
|
||||
for _, c := range conns.connections {
|
||||
for _, overquotaTransfer := range overquotaTransfers {
|
||||
if c.GetID() == overquotaTransfer.ConnID {
|
||||
logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ",
|
||||
c.GetUsername(), overquotaTransfer.TransferID)
|
||||
c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol()))
|
||||
}
|
||||
}
|
||||
}
|
||||
logger.Debug(logSender, "", "transfers check completed")
|
||||
}
|
||||
|
||||
// AddClientConnection stores a new client connection
|
||||
func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
|
||||
conns.clients.add(ipAddr)
|
||||
|
||||
Reference in New Issue
Block a user