check quota usage between ongoing transfers

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
This commit is contained in:
Nicola Murino
2022-01-20 18:19:20 +01:00
parent d73be7aee5
commit d2a4178846
30 changed files with 1228 additions and 158 deletions

View File

@@ -53,9 +53,10 @@ const (
operationMkdir = "mkdir"
operationRmdir = "rmdir"
// SSH command action name
OperationSSHCmd = "ssh_cmd"
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
idleTimeoutCheckInterval = 3 * time.Minute
OperationSSHCmd = "ssh_cmd"
chtimesFormat = "2006-01-02T15:04:05" // YYYY-MM-DDTHH:MM:SS
idleTimeoutCheckInterval = 3 * time.Minute
periodicTimeoutCheckInterval = 1 * time.Minute
)
// Stat flags
@@ -110,6 +111,7 @@ var (
ErrCrtRevoked = errors.New("your certificate has been revoked")
ErrNoCredentials = errors.New("no credential provided")
ErrInternalFailure = errors.New("internal failure")
ErrTransferAborted = errors.New("transfer aborted")
errNoTransfer = errors.New("requested transfer not found")
errTransferMismatch = errors.New("transfer mismatch")
)
@@ -120,10 +122,11 @@ var (
// Connections is the list of active connections
Connections ActiveConnections
// QuotaScans is the list of active quota scans
QuotaScans ActiveScans
idleTimeoutTicker *time.Ticker
idleTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
QuotaScans ActiveScans
transfersChecker TransfersChecker
periodicTimeoutTicker *time.Ticker
periodicTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
ProtocolHTTP, ProtocolHTTPShare}
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
// the map key is the protocol, for each protocol we can have multiple rate limiters
@@ -135,9 +138,7 @@ func Initialize(c Configuration) error {
Config = c
Config.idleLoginTimeout = 2 * time.Minute
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
if Config.IdleTimeout > 0 {
startIdleTimeoutTicker(idleTimeoutCheckInterval)
}
startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
Config.defender = nil
rateLimiters = make(map[string][]*rateLimiter)
for _, rlCfg := range c.RateLimitersConfig {
@@ -176,6 +177,7 @@ func Initialize(c Configuration) error {
}
vfs.SetTempPath(c.TempPath)
dataprovider.SetTempPath(c.TempPath)
transfersChecker = getTransfersChecker()
return nil
}
@@ -267,41 +269,52 @@ func AddDefenderEvent(ip string, event HostEvent) {
}
// the ticker cannot be started/stopped from multiple goroutines
func startIdleTimeoutTicker(duration time.Duration) {
stopIdleTimeoutTicker()
idleTimeoutTicker = time.NewTicker(duration)
idleTimeoutTickerDone = make(chan bool)
func startPeriodicTimeoutTicker(duration time.Duration) {
stopPeriodicTimeoutTicker()
periodicTimeoutTicker = time.NewTicker(duration)
periodicTimeoutTickerDone = make(chan bool)
go func() {
counter := int64(0)
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
for {
select {
case <-idleTimeoutTickerDone:
case <-periodicTimeoutTickerDone:
return
case <-idleTimeoutTicker.C:
Connections.checkIdles()
case <-periodicTimeoutTicker.C:
counter++
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
counter = 0
Connections.checkIdles()
}
go Connections.checkTransfers()
}
}
}()
}
func stopIdleTimeoutTicker() {
if idleTimeoutTicker != nil {
idleTimeoutTicker.Stop()
idleTimeoutTickerDone <- true
idleTimeoutTicker = nil
func stopPeriodicTimeoutTicker() {
if periodicTimeoutTicker != nil {
periodicTimeoutTicker.Stop()
periodicTimeoutTickerDone <- true
periodicTimeoutTicker = nil
}
}
// ActiveTransfer defines the interface for the current active transfers
type ActiveTransfer interface {
GetID() uint64
GetID() int64
GetType() int
GetSize() int64
GetDownloadedSize() int64
GetUploadedSize() int64
GetVirtualPath() string
GetStartTime() time.Time
SignalClose()
SignalClose(err error)
Truncate(fsPath string, size int64) (int64, error)
GetRealFsPath(fsPath string) string
SetTimes(fsPath string, atime time.Time, mtime time.Time) bool
GetTruncatedSize() int64
GetMaxAllowedSize() int64
}
// ActiveConnection defines the interface for the current active connections
@@ -319,6 +332,7 @@ type ActiveConnection interface {
AddTransfer(t ActiveTransfer)
RemoveTransfer(t ActiveTransfer)
GetTransfers() []ConnectionTransfer
SignalTransferClose(transferID int64, err error)
CloseFS() error
}
@@ -335,11 +349,14 @@ type StatAttributes struct {
// ConnectionTransfer defines the trasfer details to expose
type ConnectionTransfer struct {
ID uint64 `json:"-"`
OperationType string `json:"operation_type"`
StartTime int64 `json:"start_time"`
Size int64 `json:"size"`
VirtualPath string `json:"path"`
ID int64 `json:"-"`
OperationType string `json:"operation_type"`
StartTime int64 `json:"start_time"`
Size int64 `json:"size"`
VirtualPath string `json:"path"`
MaxAllowedSize int64 `json:"-"`
ULSize int64 `json:"-"`
DLSize int64 `json:"-"`
}
func (t *ConnectionTransfer) getConnectionTransferAsString() string {
@@ -653,7 +670,8 @@ func (c *SSHConnection) Close() error {
type ActiveConnections struct {
// clients contains both authenticated and estabilished connections and the ones waiting
// for authentication
clients clientsMap
clients clientsMap
transfersCheckStatus int32
sync.RWMutex
connections []ActiveConnection
sshConnections []*SSHConnection
@@ -825,6 +843,59 @@ func (conns *ActiveConnections) checkIdles() {
conns.RUnlock()
}
func (conns *ActiveConnections) checkTransfers() {
if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
return
}
atomic.StoreInt32(&conns.transfersCheckStatus, 1)
defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
var wg sync.WaitGroup
logger.Debug(logSender, "", "start concurrent transfers check")
conns.RLock()
// update the current size for transfers to monitors
for _, c := range conns.connections {
for _, t := range c.GetTransfers() {
if t.MaxAllowedSize > 0 {
wg.Add(1)
go func(transfer ConnectionTransfer, connID string) {
defer wg.Done()
transfersChecker.UpdateTransferCurrentSize(transfer.ULSize, transfer.DLSize, transfer.ID, connID)
}(t, c.GetID())
}
}
}
conns.RUnlock()
logger.Debug(logSender, "", "waiting for the update of the transfers current size")
wg.Wait()
logger.Debug(logSender, "", "getting overquota transfers")
overquotaTransfers := transfersChecker.GetOverquotaTransfers()
logger.Debug(logSender, "", "number of overquota transfers: %v", len(overquotaTransfers))
if len(overquotaTransfers) == 0 {
return
}
conns.RLock()
defer conns.RUnlock()
for _, c := range conns.connections {
for _, overquotaTransfer := range overquotaTransfers {
if c.GetID() == overquotaTransfer.ConnID {
logger.Info(logSender, c.GetID(), "user %#v is overquota, try to close transfer id %v ",
c.GetUsername(), overquotaTransfer.TransferID)
c.SignalTransferClose(overquotaTransfer.TransferID, getQuotaExceededError(c.GetProtocol()))
}
}
}
logger.Debug(logSender, "", "transfers check completed")
}
// AddClientConnection stores a new client connection
func (conns *ActiveConnections) AddClientConnection(ipAddr string) {
conns.clients.add(ipAddr)