diff --git a/internal/common/clientsmap.go b/internal/common/clientsmap.go index 8955c873..c042d2e7 100644 --- a/internal/common/clientsmap.go +++ b/internal/common/clientsmap.go @@ -23,13 +23,13 @@ import ( // clienstMap is a struct containing the map of the connected clients type clientsMap struct { - totalConnections int32 + totalConnections atomic.Int32 mu sync.RWMutex clients map[string]int } func (c *clientsMap) add(source string) { - atomic.AddInt32(&c.totalConnections, 1) + c.totalConnections.Add(1) c.mu.Lock() defer c.mu.Unlock() @@ -42,7 +42,7 @@ func (c *clientsMap) remove(source string) { defer c.mu.Unlock() if val, ok := c.clients[source]; ok { - atomic.AddInt32(&c.totalConnections, -1) + c.totalConnections.Add(-1) c.clients[source]-- if val > 1 { return @@ -54,7 +54,7 @@ func (c *clientsMap) remove(source string) { } func (c *clientsMap) getTotal() int32 { - return atomic.LoadInt32(&c.totalConnections) + return c.totalConnections.Load() } func (c *clientsMap) getTotalFrom(source string) int { diff --git a/internal/common/common.go b/internal/common/common.go index 78216bca..acdd1375 100644 --- a/internal/common/common.go +++ b/internal/common/common.go @@ -704,16 +704,17 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error { type SSHConnection struct { id string conn net.Conn - lastActivity int64 + lastActivity atomic.Int64 } // NewSSHConnection returns a new SSHConnection func NewSSHConnection(id string, conn net.Conn) *SSHConnection { - return &SSHConnection{ - id: id, - conn: conn, - lastActivity: time.Now().UnixNano(), + c := &SSHConnection{ + id: id, + conn: conn, } + c.lastActivity.Store(time.Now().UnixNano()) + return c } // GetID returns the ID for this SSHConnection @@ -723,12 +724,12 @@ func (c *SSHConnection) GetID() string { // UpdateLastActivity updates last activity for this connection func (c *SSHConnection) UpdateLastActivity() { - atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano()) + c.lastActivity.Store(time.Now().UnixNano()) } // GetLastActivity returns the last connection activity func (c *SSHConnection) GetLastActivity() time.Time { - return time.Unix(0, atomic.LoadInt64(&c.lastActivity)) + return time.Unix(0, c.lastActivity.Load()) } // Close closes the underlying network connection @@ -741,7 +742,7 @@ type ActiveConnections struct { // clients contains both authenticated and estabilished connections and the ones waiting // for authentication clients clientsMap - transfersCheckStatus int32 + transfersCheckStatus atomic.Bool sync.RWMutex connections []ActiveConnection sshConnections []*SSHConnection @@ -953,12 +954,12 @@ func (conns *ActiveConnections) checkIdles() { } func (conns *ActiveConnections) checkTransfers() { - if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 { + if conns.transfersCheckStatus.Load() { logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution") return } - atomic.StoreInt32(&conns.transfersCheckStatus, 1) - defer atomic.StoreInt32(&conns.transfersCheckStatus, 0) + conns.transfersCheckStatus.Store(true) + defer conns.transfersCheckStatus.Store(false) conns.RLock() diff --git a/internal/common/common_test.go b/internal/common/common_test.go index 743e3cce..839b41df 100644 --- a/internal/common/common_test.go +++ b/internal/common/common_test.go @@ -24,7 +24,6 @@ import ( "path/filepath" "runtime" "strings" - "sync/atomic" "testing" "time" @@ -498,14 +497,14 @@ func TestIdleConnections(t *testing.T) { }, } c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user) - c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() + c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) fakeConn := &fakeConnection{ BaseConnection: c, } // both ssh connections are expired but they should get removed only // if there is no associated connection - sshConn1.lastActivity = c.lastActivity - sshConn2.lastActivity = c.lastActivity + sshConn1.lastActivity.Store(c.lastActivity.Load()) + sshConn2.lastActivity.Store(c.lastActivity.Load()) Connections.AddSSHConnection(sshConn1) err = Connections.Add(fakeConn) assert.NoError(t, err) @@ -520,7 +519,7 @@ func TestIdleConnections(t *testing.T) { assert.Equal(t, Connections.GetActiveSessions(username), 2) cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{}) - cFTP.lastActivity = time.Now().UnixNano() + cFTP.lastActivity.Store(time.Now().UnixNano()) fakeConn = &fakeConnection{ BaseConnection: cFTP, } @@ -541,9 +540,9 @@ func TestIdleConnections(t *testing.T) { }, 1*time.Second, 200*time.Millisecond) stopEventScheduler() assert.Len(t, Connections.GetStats(), 2) - c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() - cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano() - sshConn2.lastActivity = c.lastActivity + c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) + cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) + sshConn2.lastActivity.Store(c.lastActivity.Load()) startPeriodicChecks(100 * time.Millisecond) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 2*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { @@ -646,9 +645,9 @@ func TestConnectionStatus(t *testing.T) { BaseConnection: c1, } t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) - t1.BytesReceived = 123 + t1.BytesReceived.Store(123) t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) - t2.BytesSent = 456 + t2.BytesSent.Store(456) c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user) fakeConn2 := &fakeConnection{ BaseConnection: c2, @@ -698,7 +697,7 @@ func TestConnectionStatus(t *testing.T) { err = fakeConn3.SignalTransfersAbort() assert.NoError(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer)) + assert.True(t, t3.AbortTransfer.Load()) err = t3.Close() assert.NoError(t, err) err = fakeConn3.SignalTransfersAbort() diff --git a/internal/common/connection.go b/internal/common/connection.go index 2b13cce9..41610780 100644 --- a/internal/common/connection.go +++ b/internal/common/connection.go @@ -38,12 +38,12 @@ import ( type BaseConnection struct { // last activity for this connection. // Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment - lastActivity int64 + lastActivity atomic.Int64 uploadDone atomic.Bool downloadDone atomic.Bool // unique ID for a transfer. // This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment - transferID int64 + transferID atomic.Int64 // Unique identifier for the connection ID string // user associated with this connection if any @@ -64,16 +64,18 @@ func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprov connID = fmt.Sprintf("%s_%s", protocol, id) } user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID) - return &BaseConnection{ - ID: connID, - User: user, - startTime: time.Now(), - protocol: protocol, - localAddr: localAddr, - remoteAddr: remoteAddr, - lastActivity: time.Now().UnixNano(), - transferID: 0, + c := &BaseConnection{ + ID: connID, + User: user, + startTime: time.Now(), + protocol: protocol, + localAddr: localAddr, + remoteAddr: remoteAddr, } + c.transferID.Store(0) + c.lastActivity.Store(time.Now().UnixNano()) + + return c } // Log outputs a log entry to the configured logger @@ -83,7 +85,7 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) { // GetTransferID returns an unique transfer ID for this connection func (c *BaseConnection) GetTransferID() int64 { - return atomic.AddInt64(&c.transferID, 1) + return c.transferID.Add(1) } // GetID returns the connection ID @@ -126,12 +128,12 @@ func (c *BaseConnection) GetConnectionTime() time.Time { // UpdateLastActivity updates last activity for this connection func (c *BaseConnection) UpdateLastActivity() { - atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano()) + c.lastActivity.Store(time.Now().UnixNano()) } // GetLastActivity returns the last connection activity func (c *BaseConnection) GetLastActivity() time.Time { - return time.Unix(0, atomic.LoadInt64(&c.lastActivity)) + return time.Unix(0, c.lastActivity.Load()) } // CloseFS closes the underlying fs diff --git a/internal/common/eventmanager.go b/internal/common/eventmanager.go index 50b434d0..e13200ea 100644 --- a/internal/common/eventmanager.go +++ b/internal/common/eventmanager.go @@ -82,7 +82,7 @@ func HandleCertificateEvent(params EventParams) { // eventRulesContainer stores event rules by trigger type eventRulesContainer struct { sync.RWMutex - lastLoad int64 + lastLoad atomic.Int64 FsEvents []dataprovider.EventRule ProviderEvents []dataprovider.EventRule Schedules []dataprovider.EventRule @@ -101,11 +101,11 @@ func (r *eventRulesContainer) removeAsyncTask() { } func (r *eventRulesContainer) getLastLoadTime() int64 { - return atomic.LoadInt64(&r.lastLoad) + return r.lastLoad.Load() } func (r *eventRulesContainer) setLastLoadTime(modTime int64) { - atomic.StoreInt64(&r.lastLoad, modTime) + r.lastLoad.Store(modTime) } // RemoveRule deletes the rule with the specified name diff --git a/internal/common/ratelimiter.go b/internal/common/ratelimiter.go index 2dd74f80..78912f37 100644 --- a/internal/common/ratelimiter.go +++ b/internal/common/ratelimiter.go @@ -186,16 +186,16 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) { } type sourceRateLimiter struct { - lastActivity int64 + lastActivity *atomic.Int64 bucket *rate.Limiter } func (s *sourceRateLimiter) updateLastActivity() { - atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano()) + s.lastActivity.Store(time.Now().UnixNano()) } func (s *sourceRateLimiter) getLastActivity() int64 { - return atomic.LoadInt64(&s.lastActivity) + return s.lastActivity.Load() } type sourceBuckets struct { @@ -224,7 +224,8 @@ func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Rese b.cleanup() src := sourceRateLimiter{ - bucket: r, + lastActivity: new(atomic.Int64), + bucket: r, } src.updateLastActivity() b.buckets[source] = src diff --git a/internal/common/transfer.go b/internal/common/transfer.go index af66e993..aca8e34d 100644 --- a/internal/common/transfer.go +++ b/internal/common/transfer.go @@ -35,8 +35,8 @@ var ( // BaseTransfer contains protocols common transfer details for an upload or a download. type BaseTransfer struct { //nolint:maligned ID int64 - BytesSent int64 - BytesReceived int64 + BytesSent atomic.Int64 + BytesReceived atomic.Int64 Fs vfs.Fs File vfs.File Connection *BaseConnection @@ -52,7 +52,7 @@ type BaseTransfer struct { //nolint:maligned truncatedSize int64 isNewFile bool transferType int - AbortTransfer int32 + AbortTransfer atomic.Bool aTime time.Time mTime time.Time transferQuota dataprovider.TransferQuota @@ -79,14 +79,14 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat InitialSize: initialSize, isNewFile: isNewFile, requestPath: requestPath, - BytesSent: 0, - BytesReceived: 0, MaxWriteSize: maxWriteSize, - AbortTransfer: 0, truncatedSize: truncatedSize, transferQuota: transferQuota, Fs: fs, } + t.AbortTransfer.Store(false) + t.BytesSent.Store(0) + t.BytesReceived.Store(0) conn.AddTransfer(t) return t @@ -115,19 +115,19 @@ func (t *BaseTransfer) GetType() int { // GetSize returns the transferred size func (t *BaseTransfer) GetSize() int64 { if t.transferType == TransferDownload { - return atomic.LoadInt64(&t.BytesSent) + return t.BytesSent.Load() } - return atomic.LoadInt64(&t.BytesReceived) + return t.BytesReceived.Load() } // GetDownloadedSize returns the transferred size func (t *BaseTransfer) GetDownloadedSize() int64 { - return atomic.LoadInt64(&t.BytesSent) + return t.BytesSent.Load() } // GetUploadedSize returns the transferred size func (t *BaseTransfer) GetUploadedSize() int64 { - return atomic.LoadInt64(&t.BytesReceived) + return t.BytesReceived.Load() } // GetStartTime returns the start time @@ -153,7 +153,7 @@ func (t *BaseTransfer) SignalClose(err error) { t.Lock() t.errAbort = err t.Unlock() - atomic.StoreInt32(&(t.AbortTransfer), 1) + t.AbortTransfer.Store(true) } // GetTruncatedSize returns the truncated sized if this is an upload overwriting @@ -217,11 +217,11 @@ func (t *BaseTransfer) CheckRead() error { return nil } if t.transferQuota.AllowedTotalSize > 0 { - if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize { + if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize { return t.Connection.GetReadQuotaExceededError() } } else if t.transferQuota.AllowedDLSize > 0 { - if atomic.LoadInt64(&t.BytesSent) > t.transferQuota.AllowedDLSize { + if t.BytesSent.Load() > t.transferQuota.AllowedDLSize { return t.Connection.GetReadQuotaExceededError() } } @@ -230,18 +230,18 @@ func (t *BaseTransfer) CheckRead() error { // CheckWrite returns an error if write if not allowed func (t *BaseTransfer) CheckWrite() error { - if t.MaxWriteSize > 0 && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize { + if t.MaxWriteSize > 0 && t.BytesReceived.Load() > t.MaxWriteSize { return t.Connection.GetQuotaExceededError() } if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 { return nil } if t.transferQuota.AllowedTotalSize > 0 { - if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize { + if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize { return t.Connection.GetQuotaExceededError() } } else if t.transferQuota.AllowedULSize > 0 { - if atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedULSize { + if t.BytesReceived.Load() > t.transferQuota.AllowedULSize { return t.Connection.GetQuotaExceededError() } } @@ -261,14 +261,14 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) { if t.MaxWriteSize > 0 { sizeDiff := initialSize - size t.MaxWriteSize += sizeDiff - metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), + metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs)) if t.transferQuota.HasSizeLimits() { go func(ulSize, dlSize int64, user dataprovider.User) { dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck - }(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User) + }(t.BytesReceived.Load(), t.BytesSent.Load(), t.Connection.User) } - atomic.StoreInt64(&t.BytesReceived, 0) + t.BytesReceived.Store(0) } t.Unlock() } @@ -276,7 +276,7 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) { fsPath, size, t.MaxWriteSize, t.InitialSize, err) return initialSize, err } - if size == 0 && atomic.LoadInt64(&t.BytesSent) == 0 { + if size == 0 && t.BytesSent.Load() == 0 { // for cloud providers the file is always truncated to zero, we don't support append/resume for uploads // for buffered SFTP we can have buffered bytes so we returns an error if !vfs.IsBufferedSFTPFs(t.Fs) { @@ -302,8 +302,8 @@ func (t *BaseTransfer) TransferError(err error) { } elapsed := time.Since(t.start).Nanoseconds() / 1000000 t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+ - "bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, atomic.LoadInt64(&t.BytesSent), - atomic.LoadInt64(&t.BytesReceived), elapsed) + "bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, t.BytesSent.Load(), + t.BytesReceived.Load(), elapsed) } func (t *BaseTransfer) getUploadFileSize() (int64, error) { @@ -333,7 +333,7 @@ func (t *BaseTransfer) checkUploadOutsideHomeDir(err error) int { t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %#v, deletion error: %v", t.effectiveFsPath, err) // the file is outside the home dir so don't update the quota - atomic.StoreInt64(&t.BytesReceived, 0) + t.BytesReceived.Store(0) t.MinWriteOffset = 0 return 1 } @@ -351,18 +351,18 @@ func (t *BaseTransfer) Close() error { if t.isNewFile { numFiles = 1 } - metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), + metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs)) if t.transferQuota.HasSizeLimits() { - dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck - atomic.LoadInt64(&t.BytesSent), false) + dataprovider.UpdateUserTransferQuota(&t.Connection.User, t.BytesReceived.Load(), //nolint:errcheck + t.BytesSent.Load(), false) } if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) { // if quota is exceeded we try to remove the partial file for uploads to local filesystem err = t.Fs.Remove(t.File.Name(), false) if err == nil { numFiles-- - atomic.StoreInt64(&t.BytesReceived, 0) + t.BytesReceived.Store(0) t.MinWriteOffset = 0 } t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v", @@ -380,7 +380,7 @@ func (t *BaseTransfer) Close() error { t.ErrTransfer, t.effectiveFsPath, err) if err == nil { numFiles-- - atomic.StoreInt64(&t.BytesReceived, 0) + t.BytesReceived.Store(0) t.MinWriteOffset = 0 } } @@ -388,12 +388,12 @@ func (t *BaseTransfer) Close() error { elapsed := time.Since(t.start).Nanoseconds() / 1000000 var uploadFileSize int64 if t.transferType == TransferDownload { - logger.TransferLog(downloadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesSent), t.Connection.User.Username, + logger.TransferLog(downloadLogSender, t.fsPath, elapsed, t.BytesSent.Load(), t.Connection.User.Username, t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode) ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "", //nolint:errcheck - atomic.LoadInt64(&t.BytesSent), t.ErrTransfer) + t.BytesSent.Load(), t.ErrTransfer) } else { - uploadFileSize = atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset + uploadFileSize = t.BytesReceived.Load() + t.MinWriteOffset if statSize, errStat := t.getUploadFileSize(); errStat == nil { uploadFileSize = statSize } @@ -401,7 +401,7 @@ func (t *BaseTransfer) Close() error { numFiles, uploadFileSize = t.executeUploadHook(numFiles, uploadFileSize) t.updateQuota(numFiles, uploadFileSize) t.updateTimes() - logger.TransferLog(uploadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesReceived), t.Connection.User.Username, + logger.TransferLog(uploadLogSender, t.fsPath, elapsed, t.BytesReceived.Load(), t.Connection.User.Username, t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode) } if t.ErrTransfer != nil { @@ -428,11 +428,11 @@ func (t *BaseTransfer) updateTransferTimestamps(uploadFileSize int64) { } return } - if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && atomic.LoadInt64(&t.BytesSent) > 0 { + if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && t.BytesSent.Load() > 0 { if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, false); err == nil { t.Connection.downloadDone.Store(true) ExecuteActionNotification(t.Connection, operationFirstDownload, t.fsPath, t.requestPath, "", //nolint:errcheck - "", "", atomic.LoadInt64(&t.BytesSent), t.ErrTransfer) + "", "", t.BytesSent.Load(), t.ErrTransfer) } } } @@ -449,7 +449,7 @@ func (t *BaseTransfer) executeUploadHook(numFiles int, fileSize int64) (int, int if err == nil { numFiles-- fileSize = 0 - atomic.StoreInt64(&t.BytesReceived, 0) + t.BytesReceived.Store(0) t.MinWriteOffset = 0 } else { t.Connection.Log(logger.LevelWarn, "unable to remove path %q after upload hook failure: %v", t.fsPath, err) @@ -494,10 +494,10 @@ func (t *BaseTransfer) HandleThrottle() { var trasferredBytes int64 if t.transferType == TransferDownload { wantedBandwidth = t.Connection.User.DownloadBandwidth - trasferredBytes = atomic.LoadInt64(&t.BytesSent) + trasferredBytes = t.BytesSent.Load() } else { wantedBandwidth = t.Connection.User.UploadBandwidth - trasferredBytes = atomic.LoadInt64(&t.BytesReceived) + trasferredBytes = t.BytesReceived.Load() } if wantedBandwidth > 0 { // real and wanted elapsed as milliseconds, bytes as kilobytes diff --git a/internal/common/transfer_test.go b/internal/common/transfer_test.go index 1106133f..91a0e347 100644 --- a/internal/common/transfer_test.go +++ b/internal/common/transfer_test.go @@ -33,11 +33,11 @@ import ( func TestTransferUpdateQuota(t *testing.T) { conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{}) transfer := BaseTransfer{ - Connection: conn, - transferType: TransferUpload, - BytesReceived: 123, - Fs: vfs.NewOsFs("", os.TempDir(), ""), + Connection: conn, + transferType: TransferUpload, + Fs: vfs.NewOsFs("", os.TempDir(), ""), } + transfer.BytesReceived.Store(123) errFake := errors.New("fake error") transfer.TransferError(errFake) assert.False(t, transfer.updateQuota(1, 0)) @@ -56,7 +56,7 @@ func TestTransferUpdateQuota(t *testing.T) { QuotaSize: -1, }) transfer.ErrTransfer = nil - transfer.BytesReceived = 1 + transfer.BytesReceived.Store(1) transfer.requestPath = "/vdir/file" assert.True(t, transfer.updateQuota(1, 0)) err = transfer.Close() @@ -80,7 +80,7 @@ func TestTransferThrottling(t *testing.T) { wantedDownloadElapsed -= wantedDownloadElapsed / 10 conn := NewBaseConnection("id", ProtocolSCP, "", "", u) transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) - transfer.BytesReceived = testFileSize + transfer.BytesReceived.Store(testFileSize) transfer.Connection.UpdateLastActivity() startTime := transfer.Connection.GetLastActivity() transfer.HandleThrottle() @@ -90,7 +90,7 @@ func TestTransferThrottling(t *testing.T) { assert.NoError(t, err) transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) - transfer.BytesSent = testFileSize + transfer.BytesSent.Store(testFileSize) transfer.Connection.UpdateLastActivity() startTime = transfer.Connection.GetLastActivity() @@ -226,7 +226,7 @@ func TestTransferErrors(t *testing.T) { assert.Equal(t, testFile, transfer.GetFsPath()) transfer.SetCancelFn(cancelFn) errFake := errors.New("err fake") - transfer.BytesReceived = 9 + transfer.BytesReceived.Store(9) transfer.TransferError(ErrQuotaExceeded) assert.True(t, isCancelled) transfer.TransferError(errFake) @@ -249,7 +249,7 @@ func TestTransferErrors(t *testing.T) { fsPath := filepath.Join(os.TempDir(), "test_file") transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) - transfer.BytesReceived = 9 + transfer.BytesReceived.Store(9) transfer.TransferError(errFake) assert.Error(t, transfer.ErrTransfer, errFake.Error()) // the file is closed from the embedding struct before to call close @@ -269,7 +269,7 @@ func TestTransferErrors(t *testing.T) { } transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) - transfer.BytesReceived = 9 + transfer.BytesReceived.Store(9) // the file is closed from the embedding struct before to call close err = file.Close() assert.NoError(t, err) @@ -310,11 +310,11 @@ func TestRemovePartialCryptoFile(t *testing.T) { func TestFTPMode(t *testing.T) { conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{}) transfer := BaseTransfer{ - Connection: conn, - transferType: TransferUpload, - BytesReceived: 123, - Fs: vfs.NewOsFs("", os.TempDir(), ""), + Connection: conn, + transferType: TransferUpload, + Fs: vfs.NewOsFs("", os.TempDir(), ""), } + transfer.BytesReceived.Store(123) assert.Empty(t, transfer.ftpMode) transfer.SetFtpMode("active") assert.Equal(t, "active", transfer.ftpMode) @@ -399,14 +399,14 @@ func TestTransferQuota(t *testing.T) { transfer.transferQuota = dataprovider.TransferQuota{ AllowedTotalSize: 10, } - transfer.BytesReceived = 5 - transfer.BytesSent = 4 + transfer.BytesReceived.Store(5) + transfer.BytesSent.Store(4) err = transfer.CheckRead() assert.NoError(t, err) err = transfer.CheckWrite() assert.NoError(t, err) - transfer.BytesSent = 6 + transfer.BytesSent.Store(6) err = transfer.CheckRead() if assert.Error(t, err) { assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) @@ -428,7 +428,7 @@ func TestTransferQuota(t *testing.T) { err = transfer.CheckWrite() assert.NoError(t, err) - transfer.BytesReceived = 11 + transfer.BytesReceived.Store(11) err = transfer.CheckRead() if assert.Error(t, err) { assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) @@ -442,11 +442,11 @@ func TestUploadOutsideHomeRenameError(t *testing.T) { conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{}) transfer := BaseTransfer{ - Connection: conn, - transferType: TransferUpload, - BytesReceived: 123, - Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""), + Connection: conn, + transferType: TransferUpload, + Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""), } + transfer.BytesReceived.Store(123) fileName := filepath.Join(os.TempDir(), "_temp") err := os.WriteFile(fileName, []byte(`data`), 0644) @@ -459,10 +459,10 @@ func TestUploadOutsideHomeRenameError(t *testing.T) { Config.TempPath = filepath.Clean(os.TempDir()) res = transfer.checkUploadOutsideHomeDir(nil) assert.Equal(t, 0, res) - assert.Greater(t, transfer.BytesReceived, int64(0)) + assert.Greater(t, transfer.BytesReceived.Load(), int64(0)) res = transfer.checkUploadOutsideHomeDir(os.ErrPermission) assert.Equal(t, 1, res) - assert.Equal(t, int64(0), transfer.BytesReceived) + assert.Equal(t, int64(0), transfer.BytesReceived.Load()) assert.NoFileExists(t, fileName) Config.TempPath = oldTempPath diff --git a/internal/common/transferschecker_test.go b/internal/common/transferschecker_test.go index 2f9c236d..30656d30 100644 --- a/internal/common/transferschecker_test.go +++ b/internal/common/transferschecker_test.go @@ -21,7 +21,6 @@ import ( "path/filepath" "strconv" "strings" - "sync/atomic" "testing" "time" @@ -96,7 +95,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { } transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) - transfer1.BytesReceived = 150 + transfer1.BytesReceived.Store(150) err = Connections.Add(fakeConn1) assert.NoError(t, err) // the transferschecker will do nothing if there is only one ongoing transfer @@ -110,8 +109,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { } transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), "/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{}) - transfer1.BytesReceived = 50 - transfer2.BytesReceived = 60 + transfer1.BytesReceived.Store(50) + transfer2.BytesReceived.Store(60) err = Connections.Add(fakeConn2) assert.NoError(t, err) @@ -122,7 +121,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { } transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"), "/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) - transfer3.BytesReceived = 60 // this value will be ignored, this is a download + transfer3.BytesReceived.Store(60) // this value will be ignored, this is a download err = Connections.Add(fakeConn3) assert.NoError(t, err) @@ -132,20 +131,20 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) - transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota + transfer1.BytesReceived.Store(80) // truncated size will be subtracted, we are not overquota Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) - transfer1.BytesReceived = 120 + transfer1.BytesReceived.Store(120) // we are now overquota // if another check is in progress nothing is done - atomic.StoreInt32(&Connections.transfersCheckStatus, 1) + Connections.transfersCheckStatus.Store(true) Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) - atomic.StoreInt32(&Connections.transfersCheckStatus, 0) + Connections.transfersCheckStatus.Store(false) Connections.checkTransfers() assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort) @@ -172,8 +171,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer3.errAbort) // now check a public folder - transfer1.BytesReceived = 0 - transfer2.BytesReceived = 0 + transfer1.BytesReceived.Store(0) + transfer2.BytesReceived.Store(0) connID4 := xid.New().String() fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4) assert.NoError(t, err) @@ -197,12 +196,12 @@ func TestTransfersCheckerDiskQuota(t *testing.T) { err = Connections.Add(fakeConn5) assert.NoError(t, err) - transfer4.BytesReceived = 50 - transfer5.BytesReceived = 40 + transfer4.BytesReceived.Store(50) + transfer5.BytesReceived.Store(40) Connections.checkTransfers() assert.Nil(t, transfer4.errAbort) assert.Nil(t, transfer5.errAbort) - transfer5.BytesReceived = 60 + transfer5.BytesReceived.Store(60) Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) @@ -286,7 +285,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) { } transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) - transfer1.BytesReceived = 150 + transfer1.BytesReceived.Store(150) err = Connections.Add(fakeConn1) assert.NoError(t, err) // the transferschecker will do nothing if there is only one ongoing transfer @@ -300,26 +299,26 @@ func TestTransferCheckerTransferQuota(t *testing.T) { } transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), "/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) - transfer2.BytesReceived = 150 + transfer2.BytesReceived.Store(150) err = Connections.Add(fakeConn2) assert.NoError(t, err) Connections.checkTransfers() assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer2.errAbort) // now test overquota - transfer1.BytesReceived = 1024*1024 + 1 - transfer2.BytesReceived = 0 + transfer1.BytesReceived.Store(1024*1024 + 1) + transfer2.BytesReceived.Store(0) Connections.checkTransfers() assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) assert.Nil(t, transfer2.errAbort) transfer1.errAbort = nil - transfer1.BytesReceived = 1024*1024 + 1 - transfer2.BytesReceived = 1024 + transfer1.BytesReceived.Store(1024*1024 + 1) + transfer2.BytesReceived.Store(1024) Connections.checkTransfers() assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort)) - transfer1.BytesReceived = 0 - transfer2.BytesReceived = 0 + transfer1.BytesReceived.Store(0) + transfer2.BytesReceived.Store(0) transfer1.errAbort = nil transfer2.errAbort = nil @@ -337,7 +336,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) { } transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), "/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) - transfer3.BytesSent = 150 + transfer3.BytesSent.Store(150) err = Connections.Add(fakeConn3) assert.NoError(t, err) @@ -348,15 +347,15 @@ func TestTransferCheckerTransferQuota(t *testing.T) { } transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), "/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) - transfer4.BytesSent = 150 + transfer4.BytesSent.Store(150) err = Connections.Add(fakeConn4) assert.NoError(t, err) Connections.checkTransfers() assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer4.errAbort) - transfer3.BytesSent = 512 * 1024 - transfer4.BytesSent = 512*1024 + 1 + transfer3.BytesSent.Store(512 * 1024) + transfer4.BytesSent.Store(512*1024 + 1) Connections.checkTransfers() if assert.Error(t, transfer3.errAbort) { assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error()) diff --git a/internal/dataprovider/dataprovider.go b/internal/dataprovider/dataprovider.go index ef81a86d..53d2f13a 100644 --- a/internal/dataprovider/dataprovider.go +++ b/internal/dataprovider/dataprovider.go @@ -155,7 +155,7 @@ var ( ErrInvalidCredentials = errors.New("invalid credentials") // ErrLoginNotAllowedFromIP defines the error to return if login is denied from the current IP ErrLoginNotAllowedFromIP = errors.New("login is not allowed from this IP") - isAdminCreated = int32(0) + isAdminCreated atomic.Bool validTLSUsernames = []string{string(sdk.TLSUsernameNone), string(sdk.TLSUsernameCN)} config Config provider Provider @@ -844,7 +844,7 @@ func Initialize(cnf Config, basePath string, checkAdmins bool) error { if err != nil { return err } - atomic.StoreInt32(&isAdminCreated, int32(len(admins))) + isAdminCreated.Store(len(admins) > 0) delayedQuotaUpdater.start() return startScheduler() } @@ -1722,7 +1722,7 @@ func UpdateTaskTimestamp(name string) error { // HasAdmin returns true if the first admin has been created // and so SFTPGo is ready to be used func HasAdmin() bool { - return atomic.LoadInt32(&isAdminCreated) > 0 + return isAdminCreated.Load() } // AddAdmin adds a new SFTPGo admin @@ -1734,7 +1734,7 @@ func AddAdmin(admin *Admin, executor, ipAddress string) error { admin.Username = config.convertName(admin.Username) err := provider.addAdmin(admin) if err == nil { - atomic.StoreInt32(&isAdminCreated, 1) + isAdminCreated.Store(true) executeAction(operationAdd, executor, ipAddress, actionObjectAdmin, admin.Username, admin) } return err diff --git a/internal/dataprovider/scheduler.go b/internal/dataprovider/scheduler.go index 2b9d7a94..e690dd2d 100644 --- a/internal/dataprovider/scheduler.go +++ b/internal/dataprovider/scheduler.go @@ -28,11 +28,11 @@ import ( var ( scheduler *cron.Cron - lastUserCacheUpdate int64 + lastUserCacheUpdate atomic.Int64 // used for bolt and memory providers, so we avoid iterating all users/rules // to find recently modified ones - lastUserUpdate int64 - lastRuleUpdate int64 + lastUserUpdate atomic.Int64 + lastRuleUpdate atomic.Int64 ) func stopScheduler() { @@ -62,7 +62,7 @@ func startScheduler() error { } func addScheduledCacheUpdates() error { - lastUserCacheUpdate = util.GetTimeAsMsSinceEpoch(time.Now()) + lastUserCacheUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) _, err := scheduler.AddFunc("@every 10m", checkCacheUpdates) if err != nil { return fmt.Errorf("unable to schedule cache updates: %w", err) @@ -79,9 +79,9 @@ func checkDataprovider() { } func checkCacheUpdates() { - providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate)) + providerLog(logger.LevelDebug, "start user cache check, update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load())) checkTime := util.GetTimeAsMsSinceEpoch(time.Now()) - users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate) + users, err := provider.getRecentlyUpdatedUsers(lastUserCacheUpdate.Load()) if err != nil { providerLog(logger.LevelError, "unable to get recently updated users: %v", err) return @@ -102,22 +102,22 @@ func checkCacheUpdates() { cachedPasswords.Remove(user.Username) } - lastUserCacheUpdate = checkTime - providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate)) + lastUserCacheUpdate.Store(checkTime) + providerLog(logger.LevelDebug, "end user cache check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastUserCacheUpdate.Load())) } func setLastUserUpdate() { - atomic.StoreInt64(&lastUserUpdate, util.GetTimeAsMsSinceEpoch(time.Now())) + lastUserUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) } func getLastUserUpdate() int64 { - return atomic.LoadInt64(&lastUserUpdate) + return lastUserUpdate.Load() } func setLastRuleUpdate() { - atomic.StoreInt64(&lastRuleUpdate, util.GetTimeAsMsSinceEpoch(time.Now())) + lastRuleUpdate.Store(util.GetTimeAsMsSinceEpoch(time.Now())) } func getLastRuleUpdate() int64 { - return atomic.LoadInt64(&lastRuleUpdate) + return lastRuleUpdate.Load() } diff --git a/internal/ftpd/transfer.go b/internal/ftpd/transfer.go index b6bfa289..8bb76a60 100644 --- a/internal/ftpd/transfer.go +++ b/internal/ftpd/transfer.go @@ -17,7 +17,6 @@ package ftpd import ( "errors" "io" - "sync/atomic" "github.com/eikenb/pipeat" @@ -61,7 +60,7 @@ func (t *transfer) Read(p []byte) (n int, err error) { t.Connection.UpdateLastActivity() n, err = t.reader.Read(p) - atomic.AddInt64(&t.BytesSent, int64(n)) + t.BytesSent.Add(int64(n)) if err == nil { err = t.CheckRead() @@ -79,7 +78,7 @@ func (t *transfer) Write(p []byte) (n int, err error) { t.Connection.UpdateLastActivity() n, err = t.writer.Write(p) - atomic.AddInt64(&t.BytesReceived, int64(n)) + t.BytesReceived.Add(int64(n)) if err == nil { err = t.CheckWrite() diff --git a/internal/httpd/file.go b/internal/httpd/file.go index f79abf37..245ddc1a 100644 --- a/internal/httpd/file.go +++ b/internal/httpd/file.go @@ -16,7 +16,6 @@ package httpd import ( "io" - "sync/atomic" "github.com/eikenb/pipeat" @@ -52,7 +51,7 @@ func newHTTPDFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, // Read reads the contents to downloads. func (f *httpdFile) Read(p []byte) (n int, err error) { - if atomic.LoadInt32(&f.AbortTransfer) == 1 { + if f.AbortTransfer.Load() { err := f.GetAbortError() f.TransferError(err) return 0, err @@ -61,7 +60,7 @@ func (f *httpdFile) Read(p []byte) (n int, err error) { f.Connection.UpdateLastActivity() n, err = f.reader.Read(p) - atomic.AddInt64(&f.BytesSent, int64(n)) + f.BytesSent.Add(int64(n)) if err == nil { err = f.CheckRead() @@ -76,7 +75,7 @@ func (f *httpdFile) Read(p []byte) (n int, err error) { // Write writes the contents to upload func (f *httpdFile) Write(p []byte) (n int, err error) { - if atomic.LoadInt32(&f.AbortTransfer) == 1 { + if f.AbortTransfer.Load() { err := f.GetAbortError() f.TransferError(err) return 0, err @@ -85,7 +84,7 @@ func (f *httpdFile) Write(p []byte) (n int, err error) { f.Connection.UpdateLastActivity() n, err = f.writer.Write(p) - atomic.AddInt64(&f.BytesReceived, int64(n)) + f.BytesReceived.Add(int64(n)) if err == nil { err = f.CheckWrite() diff --git a/internal/httpd/handler.go b/internal/httpd/handler.go index 85bdda8d..48fd27e1 100644 --- a/internal/httpd/handler.go +++ b/internal/httpd/handler.go @@ -238,24 +238,24 @@ func (c *Connection) handleUploadFile(fs vfs.Fs, resolvedPath, filePath, request func newThrottledReader(r io.ReadCloser, limit int64, conn *Connection) *throttledReader { t := &throttledReader{ - bytesRead: 0, - id: conn.GetTransferID(), - limit: limit, - r: r, - abortTransfer: 0, - start: time.Now(), - conn: conn, + id: conn.GetTransferID(), + limit: limit, + r: r, + start: time.Now(), + conn: conn, } + t.bytesRead.Store(0) + t.abortTransfer.Store(false) conn.AddTransfer(t) return t } type throttledReader struct { - bytesRead int64 + bytesRead atomic.Int64 id int64 limit int64 r io.ReadCloser - abortTransfer int32 + abortTransfer atomic.Bool start time.Time conn *Connection mu sync.Mutex @@ -271,7 +271,7 @@ func (t *throttledReader) GetType() int { } func (t *throttledReader) GetSize() int64 { - return atomic.LoadInt64(&t.bytesRead) + return t.bytesRead.Load() } func (t *throttledReader) GetDownloadedSize() int64 { @@ -279,7 +279,7 @@ func (t *throttledReader) GetDownloadedSize() int64 { } func (t *throttledReader) GetUploadedSize() int64 { - return atomic.LoadInt64(&t.bytesRead) + return t.bytesRead.Load() } func (t *throttledReader) GetVirtualPath() string { @@ -304,7 +304,7 @@ func (t *throttledReader) SignalClose(err error) { t.mu.Lock() t.errAbort = err t.mu.Unlock() - atomic.StoreInt32(&(t.abortTransfer), 1) + t.abortTransfer.Store(true) } func (t *throttledReader) GetTruncatedSize() int64 { @@ -328,15 +328,15 @@ func (t *throttledReader) SetTimes(fsPath string, atime time.Time, mtime time.Ti } func (t *throttledReader) Read(p []byte) (n int, err error) { - if atomic.LoadInt32(&t.abortTransfer) == 1 { + if t.abortTransfer.Load() { return 0, t.GetAbortError() } t.conn.UpdateLastActivity() n, err = t.r.Read(p) if t.limit > 0 { - atomic.AddInt64(&t.bytesRead, int64(n)) - trasferredBytes := atomic.LoadInt64(&t.bytesRead) + t.bytesRead.Add(int64(n)) + trasferredBytes := t.bytesRead.Load() elapsed := time.Since(t.start).Nanoseconds() / 1000000 wantedElapsed := 1000 * (trasferredBytes / 1024) / t.limit if wantedElapsed > elapsed { diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go index 323979fb..6fd1b529 100644 --- a/internal/plugin/plugin.go +++ b/internal/plugin/plugin.go @@ -93,7 +93,7 @@ func (c *Config) newKMSPluginSecretProvider(base kms.BaseSecret, url, masterKey // Manager handles enabled plugins type Manager struct { - closed int32 + closed atomic.Bool done chan bool // List of configured plugins Configs []Config `json:"plugins" mapstructure:"plugins"` @@ -124,10 +124,10 @@ func Initialize(configs []Config, logLevel string) error { Handler = Manager{ Configs: configs, done: make(chan bool), - closed: 0, authScopes: -1, concurrencyGuard: make(chan struct{}, 250), } + Handler.closed.Store(false) setLogLevel(logLevel) if len(configs) == 0 { return nil @@ -604,7 +604,7 @@ func (m *Manager) checkCrashedPlugins() { } func (m *Manager) restartNotifierPlugin(config Config, idx int) { - if atomic.LoadInt32(&m.closed) == 1 { + if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed notifier plugin %#v, idx: %v", config.Cmd, idx) @@ -622,7 +622,7 @@ func (m *Manager) restartNotifierPlugin(config Config, idx int) { } func (m *Manager) restartKMSPlugin(config Config, idx int) { - if atomic.LoadInt32(&m.closed) == 1 { + if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed kms plugin %#v, idx: %v", config.Cmd, idx) @@ -638,7 +638,7 @@ func (m *Manager) restartKMSPlugin(config Config, idx int) { } func (m *Manager) restartAuthPlugin(config Config, idx int) { - if atomic.LoadInt32(&m.closed) == 1 { + if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed auth plugin %#v, idx: %v", config.Cmd, idx) @@ -654,7 +654,7 @@ func (m *Manager) restartAuthPlugin(config Config, idx int) { } func (m *Manager) restartSearcherPlugin(config Config) { - if atomic.LoadInt32(&m.closed) == 1 { + if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed searcher plugin %#v", config.Cmd) @@ -670,7 +670,7 @@ func (m *Manager) restartSearcherPlugin(config Config) { } func (m *Manager) restartMetadaterPlugin(config Config) { - if atomic.LoadInt32(&m.closed) == 1 { + if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed metadater plugin %#v", config.Cmd) @@ -686,7 +686,7 @@ func (m *Manager) restartMetadaterPlugin(config Config) { } func (m *Manager) restartIPFilterPlugin(config Config) { - if atomic.LoadInt32(&m.closed) == 1 { + if m.closed.Load() { return } logger.Info(logSender, "", "try to restart crashed IP filter plugin %#v", config.Cmd) @@ -712,7 +712,7 @@ func (m *Manager) removeTask() { // Cleanup releases all the active plugins func (m *Manager) Cleanup() { logger.Debug(logSender, "", "cleanup") - atomic.StoreInt32(&m.closed, 1) + m.closed.Store(true) close(m.done) m.notifLock.Lock() for _, n := range m.notifiers { diff --git a/internal/sftpd/internal_test.go b/internal/sftpd/internal_test.go index ae973c44..80a54f28 100644 --- a/internal/sftpd/internal_test.go +++ b/internal/sftpd/internal_test.go @@ -1785,7 +1785,7 @@ func TestUploadError(t *testing.T) { if assert.Error(t, transfer.ErrTransfer) { assert.EqualError(t, transfer.ErrTransfer, errFake.Error()) } - assert.Equal(t, int64(0), transfer.BytesReceived) + assert.Equal(t, int64(0), transfer.BytesReceived.Load()) assert.NoFileExists(t, testfile) assert.NoFileExists(t, fileTempName) diff --git a/internal/sftpd/sftpd_test.go b/internal/sftpd/sftpd_test.go index 6aa82ce1..1987967d 100644 --- a/internal/sftpd/sftpd_test.go +++ b/internal/sftpd/sftpd_test.go @@ -1114,12 +1114,12 @@ func TestConcurrency(t *testing.T) { err = createTestFile(testFilePath, testFileSize) assert.NoError(t, err) - closedConns := int32(0) + var closedConns atomic.Int32 for i := 0; i < numLogins; i++ { wg.Add(1) go func(counter int) { defer wg.Done() - defer atomic.AddInt32(&closedConns, 1) + defer closedConns.Add(1) conn, client, err := getSftpClient(user, usePubKey) if assert.NoError(t, err) { @@ -1139,7 +1139,7 @@ func TestConcurrency(t *testing.T) { maxConns := 0 maxSessions := 0 for { - servedReqs := atomic.LoadInt32(&closedConns) + servedReqs := closedConns.Load() if servedReqs > 0 { stats := common.Connections.GetStats() if len(stats) > maxConns { diff --git a/internal/sftpd/transfer.go b/internal/sftpd/transfer.go index ce0e0178..710f3d00 100644 --- a/internal/sftpd/transfer.go +++ b/internal/sftpd/transfer.go @@ -17,7 +17,6 @@ package sftpd import ( "fmt" "io" - "sync/atomic" "github.com/eikenb/pipeat" @@ -107,7 +106,7 @@ func (t *transfer) ReadAt(p []byte, off int64) (n int, err error) { t.Connection.UpdateLastActivity() n, err = t.readerAt.ReadAt(p, off) - atomic.AddInt64(&t.BytesSent, int64(n)) + t.BytesSent.Add(int64(n)) if err == nil { err = t.CheckRead() @@ -133,7 +132,7 @@ func (t *transfer) WriteAt(p []byte, off int64) (n int, err error) { } n, err = t.writerAt.WriteAt(p, off) - atomic.AddInt64(&t.BytesReceived, int64(n)) + t.BytesReceived.Add(int64(n)) if err == nil { err = t.CheckWrite() @@ -213,13 +212,13 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, if nw > 0 { written += int64(nw) if isDownload { - atomic.StoreInt64(&t.BytesSent, written) + t.BytesSent.Store(written) if errCheck := t.CheckRead(); errCheck != nil { err = errCheck break } } else { - atomic.StoreInt64(&t.BytesReceived, written) + t.BytesReceived.Store(written) if errCheck := t.CheckWrite(); errCheck != nil { err = errCheck break @@ -245,7 +244,7 @@ func (t *transfer) copyFromReaderToWriter(dst io.Writer, src io.Reader) (int64, } t.ErrTransfer = err if written > 0 || err != nil { - metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.GetType(), + metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), t.GetType(), t.ErrTransfer, vfs.IsSFTPFs(t.Fs)) } return written, err diff --git a/internal/util/timeoutlistener.go b/internal/util/timeoutlistener.go index afb67220..81d17929 100644 --- a/internal/util/timeoutlistener.go +++ b/internal/util/timeoutlistener.go @@ -32,14 +32,14 @@ func (l *listener) Accept() (net.Conn, error) { return nil, err } tc := &Conn{ - Conn: c, - ReadTimeout: l.ReadTimeout, - WriteTimeout: l.WriteTimeout, - ReadThreshold: int32((l.ReadTimeout * 1024) / time.Second), - WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second), - BytesReadFromDeadline: 0, - BytesWrittenFromDeadline: 0, + Conn: c, + ReadTimeout: l.ReadTimeout, + WriteTimeout: l.WriteTimeout, + ReadThreshold: int32((l.ReadTimeout * 1024) / time.Second), + WriteThreshold: int32((l.WriteTimeout * 1024) / time.Second), } + tc.BytesReadFromDeadline.Store(0) + tc.BytesWrittenFromDeadline.Store(0) return tc, nil } @@ -51,13 +51,13 @@ type Conn struct { WriteTimeout time.Duration ReadThreshold int32 WriteThreshold int32 - BytesReadFromDeadline int32 - BytesWrittenFromDeadline int32 + BytesReadFromDeadline atomic.Int32 + BytesWrittenFromDeadline atomic.Int32 } func (c *Conn) Read(b []byte) (n int, err error) { - if atomic.LoadInt32(&c.BytesReadFromDeadline) > c.ReadThreshold { - atomic.StoreInt32(&c.BytesReadFromDeadline, 0) + if c.BytesReadFromDeadline.Load() > c.ReadThreshold { + c.BytesReadFromDeadline.Store(0) // we set both read and write deadlines here otherwise after the request // is read writing the response fails with an i/o timeout error err = c.Conn.SetDeadline(time.Now().Add(c.ReadTimeout)) @@ -66,13 +66,13 @@ func (c *Conn) Read(b []byte) (n int, err error) { } } n, err = c.Conn.Read(b) - atomic.AddInt32(&c.BytesReadFromDeadline, int32(n)) + c.BytesReadFromDeadline.Add(int32(n)) return } func (c *Conn) Write(b []byte) (n int, err error) { - if atomic.LoadInt32(&c.BytesWrittenFromDeadline) > c.WriteThreshold { - atomic.StoreInt32(&c.BytesWrittenFromDeadline, 0) + if c.BytesWrittenFromDeadline.Load() > c.WriteThreshold { + c.BytesWrittenFromDeadline.Store(0) // we extend the read deadline too, not sure it's necessary, // but it doesn't hurt err = c.Conn.SetDeadline(time.Now().Add(c.WriteTimeout)) @@ -81,7 +81,7 @@ func (c *Conn) Write(b []byte) (n int, err error) { } } n, err = c.Conn.Write(b) - atomic.AddInt32(&c.BytesWrittenFromDeadline, int32(n)) + c.BytesWrittenFromDeadline.Add(int32(n)) return } diff --git a/internal/vfs/azblobfs.go b/internal/vfs/azblobfs.go index 6051701f..03223787 100644 --- a/internal/vfs/azblobfs.go +++ b/internal/vfs/azblobfs.go @@ -902,7 +902,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a finished := false var wg sync.WaitGroup var errOnce sync.Once - var hasError int32 + var hasError atomic.Bool var poolError error poolCtx, poolCancel := context.WithCancel(ctx) @@ -919,7 +919,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a offset = end guard <- struct{}{} - if atomic.LoadInt32(&hasError) == 1 { + if hasError.Load() { fsLog(fs, logger.LevelDebug, "pool error, download for part %v not started", part) break } @@ -941,7 +941,7 @@ func (fs *AzureBlobFs) handleMultipartDownload(ctx context.Context, blockBlob *a if err != nil { errOnce.Do(func() { fsLog(fs, logger.LevelError, "multipart download error: %+v", err) - atomic.StoreInt32(&hasError, 1) + hasError.Store(true) poolError = fmt.Errorf("multipart download error: %w", err) poolCancel() }) @@ -971,7 +971,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read var blocks []string var wg sync.WaitGroup var errOnce sync.Once - var hasError int32 + var hasError atomic.Bool var poolError error poolCtx, poolCancel := context.WithCancel(ctx) @@ -999,7 +999,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read blocks = append(blocks, blockID) guard <- struct{}{} - if atomic.LoadInt32(&hasError) == 1 { + if hasError.Load() { fsLog(fs, logger.LevelError, "pool error, upload for part %v not started", part) pool.releaseBuffer(buf) break @@ -1023,7 +1023,7 @@ func (fs *AzureBlobFs) handleMultipartUpload(ctx context.Context, reader io.Read if err != nil { errOnce.Do(func() { fsLog(fs, logger.LevelDebug, "multipart upload error: %+v", err) - atomic.StoreInt32(&hasError, 1) + hasError.Store(true) poolError = fmt.Errorf("multipart upload error: %w", err) poolCancel() }) diff --git a/internal/vfs/s3fs.go b/internal/vfs/s3fs.go index 7a306c41..2d3b7eb0 100644 --- a/internal/vfs/s3fs.go +++ b/internal/vfs/s3fs.go @@ -835,7 +835,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int var completedParts []types.CompletedPart var partMutex sync.Mutex var wg sync.WaitGroup - var hasError int32 + var hasError atomic.Bool var errOnce sync.Once var copyError error var partNumber int32 @@ -854,7 +854,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int offset = end guard <- struct{}{} - if atomic.LoadInt32(&hasError) == 1 { + if hasError.Load() { fsLog(fs, logger.LevelDebug, "previous multipart copy error, copy for part %d not started", partNumber) break } @@ -880,7 +880,7 @@ func (fs *S3Fs) doMultipartCopy(source, target, contentType string, fileSize int if err != nil { errOnce.Do(func() { fsLog(fs, logger.LevelError, "unable to copy part number %d: %+v", partNum, err) - atomic.StoreInt32(&hasError, 1) + hasError.Store(true) copyError = fmt.Errorf("error copying part number %d: %w", partNum, err) opCancel() diff --git a/internal/webdavd/file.go b/internal/webdavd/file.go index bd619773..50af68b0 100644 --- a/internal/webdavd/file.go +++ b/internal/webdavd/file.go @@ -42,7 +42,7 @@ type webDavFile struct { info os.FileInfo startOffset int64 isFinished bool - readTryed int32 + readTryed atomic.Bool } func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter, pipeReader *pipeat.PipeReaderAt) *webDavFile { @@ -56,15 +56,16 @@ func newWebDavFile(baseTransfer *common.BaseTransfer, pipeWriter *vfs.PipeWriter } else if pipeReader != nil { reader = pipeReader } - return &webDavFile{ + f := &webDavFile{ BaseTransfer: baseTransfer, writer: writer, reader: reader, isFinished: false, startOffset: 0, info: nil, - readTryed: 0, } + f.readTryed.Store(false) + return f } type webDavFileInfo struct { @@ -124,7 +125,7 @@ func (f *webDavFile) Stat() (os.FileInfo, error) { f.Unlock() if f.GetType() == common.TransferUpload && errUpload == nil { info := &webDavFileInfo{ - FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, atomic.LoadInt64(&f.BytesReceived), time.Unix(0, 0), false), + FileInfo: vfs.NewFileInfo(f.GetFsPath(), false, f.BytesReceived.Load(), time.Unix(0, 0), false), Fs: f.Fs, virtualPath: f.GetVirtualPath(), fsPath: f.GetFsPath(), @@ -149,10 +150,10 @@ func (f *webDavFile) Stat() (os.FileInfo, error) { // Read reads the contents to downloads. func (f *webDavFile) Read(p []byte) (n int, err error) { - if atomic.LoadInt32(&f.AbortTransfer) == 1 { + if f.AbortTransfer.Load() { return 0, errTransferAborted } - if atomic.LoadInt32(&f.readTryed) == 0 { + if !f.readTryed.Load() { if !f.Connection.User.HasPerm(dataprovider.PermDownload, path.Dir(f.GetVirtualPath())) { return 0, f.Connection.GetPermissionDeniedError() } @@ -171,7 +172,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) { f.Connection.Log(logger.LevelDebug, "download for file %#v denied by pre action: %v", f.GetVirtualPath(), err) return 0, f.Connection.GetPermissionDeniedError() } - atomic.StoreInt32(&f.readTryed, 1) + f.readTryed.Store(true) } f.Connection.UpdateLastActivity() @@ -198,7 +199,7 @@ func (f *webDavFile) Read(p []byte) (n int, err error) { } n, err = f.reader.Read(p) - atomic.AddInt64(&f.BytesSent, int64(n)) + f.BytesSent.Add(int64(n)) if err == nil { err = f.CheckRead() } @@ -212,14 +213,14 @@ func (f *webDavFile) Read(p []byte) (n int, err error) { // Write writes the uploaded contents. func (f *webDavFile) Write(p []byte) (n int, err error) { - if atomic.LoadInt32(&f.AbortTransfer) == 1 { + if f.AbortTransfer.Load() { return 0, errTransferAborted } f.Connection.UpdateLastActivity() n, err = f.writer.Write(p) - atomic.AddInt64(&f.BytesReceived, int64(n)) + f.BytesReceived.Add(int64(n)) if err == nil { err = f.CheckWrite() @@ -252,7 +253,7 @@ func (f *webDavFile) updateTransferQuotaOnSeek() { if transferQuota.HasSizeLimits() { go func(ulSize, dlSize int64, user dataprovider.User) { dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck - }(atomic.LoadInt64(&f.BytesReceived), atomic.LoadInt64(&f.BytesSent), f.Connection.User) + }(f.BytesReceived.Load(), f.BytesSent.Load(), f.Connection.User) } } @@ -270,7 +271,7 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) { return ret, err } if f.GetType() == common.TransferDownload { - readOffset := f.startOffset + atomic.LoadInt64(&f.BytesSent) + readOffset := f.startOffset + f.BytesSent.Load() if offset == 0 && readOffset == 0 { if whence == io.SeekStart { return 0, nil @@ -288,8 +289,8 @@ func (f *webDavFile) Seek(offset int64, whence int) (int64, error) { f.reader = nil } startByte := int64(0) - atomic.StoreInt64(&f.BytesReceived, 0) - atomic.StoreInt64(&f.BytesSent, 0) + f.BytesReceived.Store(0) + f.BytesSent.Store(0) f.updateTransferQuotaOnSeek() switch whence { @@ -369,7 +370,7 @@ func (f *webDavFile) setFinished() error { func (f *webDavFile) isTransfer() bool { if f.GetType() == common.TransferDownload { - return atomic.LoadInt32(&f.readTryed) > 0 + return f.readTryed.Load() } return true }