diff --git a/common/actions.go b/common/actions.go index f2ae56ea..014246e0 100644 --- a/common/actions.go +++ b/common/actions.go @@ -34,8 +34,29 @@ type ProtocolActions struct { Hook string `json:"hook" mapstructure:"hook"` } -// actionNotification defines a notification for a Protocol Action -type actionNotification struct { +var actionHandler ActionHandler = defaultActionHandler{} + +// InitializeActionHandler lets the user choose an action handler implementation. +// +// Do NOT call this function after application initialization. +func InitializeActionHandler(handler ActionHandler) { + actionHandler = handler +} + +// SSHCommandActionNotification executes the defined action for the specified SSH command. +func SSHCommandActionNotification(user *dataprovider.User, filePath, target, sshCmd string, err error) { + notification := newActionNotification(user, operationSSHCmd, filePath, target, sshCmd, ProtocolSSH, 0, err) + + go actionHandler.Handle(notification) // nolint:errcheck +} + +// ActionHandler handles a notification for a Protocol Action. +type ActionHandler interface { + Handle(notification ActionNotification) error +} + +// ActionNotification defines a notification for a Protocol Action. +type ActionNotification struct { Action string `json:"action"` Username string `json:"username"` Path string `json:"path"` @@ -49,29 +70,29 @@ type actionNotification struct { Protocol string `json:"protocol"` } -// SSHCommandActionNotification executes the defined action for the specified SSH command -func SSHCommandActionNotification(user *dataprovider.User, filePath, target, sshCmd string, err error) { - action := newActionNotification(user, operationSSHCmd, filePath, target, sshCmd, ProtocolSSH, 0, err) - go action.execute() //nolint:errcheck -} - -func newActionNotification(user *dataprovider.User, operation, filePath, target, sshCmd, protocol string, fileSize int64, - err error) actionNotification { - bucket := "" - endpoint := "" +func newActionNotification( + user *dataprovider.User, + operation, filePath, target, sshCmd, protocol string, + fileSize int64, + err error, +) ActionNotification { + var bucket, endpoint string status := 1 + if user.FsConfig.Provider == dataprovider.S3FilesystemProvider { bucket = user.FsConfig.S3Config.Bucket endpoint = user.FsConfig.S3Config.Endpoint } else if user.FsConfig.Provider == dataprovider.GCSFilesystemProvider { bucket = user.FsConfig.GCSConfig.Bucket } + if err == ErrQuotaExceeded { status = 2 } else if err != nil { status = 0 } - return actionNotification{ + + return ActionNotification{ Action: operation, Username: user.Username, Path: filePath, @@ -86,72 +107,92 @@ func newActionNotification(user *dataprovider.User, operation, filePath, target, } } -func (a *actionNotification) asJSON() []byte { - res, _ := json.Marshal(a) - return res -} +type defaultActionHandler struct{} -func (a *actionNotification) asEnvVars() []string { - return []string{fmt.Sprintf("SFTPGO_ACTION=%v", a.Action), - fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", a.Username), - fmt.Sprintf("SFTPGO_ACTION_PATH=%v", a.Path), - fmt.Sprintf("SFTPGO_ACTION_TARGET=%v", a.TargetPath), - fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%v", a.SSHCmd), - fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%v", a.FileSize), - fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%v", a.FsProvider), - fmt.Sprintf("SFTPGO_ACTION_BUCKET=%v", a.Bucket), - fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%v", a.Endpoint), - fmt.Sprintf("SFTPGO_ACTION_STATUS=%v", a.Status), - fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%v", a.Protocol), +func (h defaultActionHandler) Handle(notification ActionNotification) error { + if !utils.IsStringInSlice(notification.Action, Config.Actions.ExecuteOn) { + return errUnconfiguredAction } + + if Config.Actions.Hook == "" { + logger.Warn(notification.Protocol, "", "Unable to send notification, no hook is defined") + + return errNoHook + } + + if strings.HasPrefix(Config.Actions.Hook, "http") { + return h.handleHTTP(notification) + } + + return h.handleCommand(notification) } -func (a *actionNotification) executeNotificationCommand() error { - if !filepath.IsAbs(Config.Actions.Hook) { - err := fmt.Errorf("invalid notification command %#v", Config.Actions.Hook) - logger.Warn(a.Protocol, "", "unable to execute notification command: %v", err) +func (h defaultActionHandler) handleHTTP(notification ActionNotification) error { + u, err := url.Parse(Config.Actions.Hook) + if err != nil { + logger.Warn(notification.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, notification.Action, err) + return err } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - cmd := exec.CommandContext(ctx, Config.Actions.Hook, a.Action, a.Username, a.Path, a.TargetPath, a.SSHCmd) - cmd.Env = append(os.Environ(), a.asEnvVars()...) + startTime := time.Now() - err := cmd.Run() - logger.Debug(a.Protocol, "", "executed command %#v with arguments: %#v, %#v, %#v, %#v, %#v, elapsed: %v, error: %v", - Config.Actions.Hook, a.Action, a.Username, a.Path, a.TargetPath, a.SSHCmd, time.Since(startTime), err) + respCode := 0 + + httpClient := httpclient.GetHTTPClient() + + var b bytes.Buffer + _ = json.NewEncoder(&b).Encode(notification) + + resp, err := httpClient.Post(u.String(), "application/json", &b) + if err == nil { + respCode = resp.StatusCode + resp.Body.Close() + + if respCode != http.StatusOK { + err = errUnexpectedHTTResponse + } + } + + logger.Debug(notification.Protocol, "", "notified operation %#v to URL: %v status code: %v, elapsed: %v err: %v", notification.Action, u.String(), respCode, time.Since(startTime), err) + return err } -func (a *actionNotification) execute() error { - if !utils.IsStringInSlice(a.Action, Config.Actions.ExecuteOn) { - return errUnconfiguredAction - } - if len(Config.Actions.Hook) == 0 { - logger.Warn(a.Protocol, "", "Unable to send notification, no hook is defined") - return errNoHook - } - if strings.HasPrefix(Config.Actions.Hook, "http") { - var url *url.URL - url, err := url.Parse(Config.Actions.Hook) - if err != nil { - logger.Warn(a.Protocol, "", "Invalid hook %#v for operation %#v: %v", Config.Actions.Hook, a.Action, err) - return err - } - startTime := time.Now() - httpClient := httpclient.GetHTTPClient() - resp, err := httpClient.Post(url.String(), "application/json", bytes.NewBuffer(a.asJSON())) - respCode := 0 - if err == nil { - respCode = resp.StatusCode - resp.Body.Close() - if respCode != http.StatusOK { - err = errUnexpectedHTTResponse - } - } - logger.Debug(a.Protocol, "", "notified operation %#v to URL: %v status code: %v, elapsed: %v err: %v", - a.Action, url.String(), respCode, time.Since(startTime), err) +func (h defaultActionHandler) handleCommand(notification ActionNotification) error { + if !filepath.IsAbs(Config.Actions.Hook) { + err := fmt.Errorf("invalid notification command %#v", Config.Actions.Hook) + logger.Warn(notification.Protocol, "", "unable to execute notification command: %v", err) + return err } - return a.executeNotificationCommand() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, Config.Actions.Hook, notification.Action, notification.Username, notification.Path, notification.TargetPath, notification.SSHCmd) + cmd.Env = append(os.Environ(), notificationAsEnvVars(notification)...) + + startTime := time.Now() + err := cmd.Run() + + logger.Debug(notification.Protocol, "", "executed command %#v with arguments: %#v, %#v, %#v, %#v, %#v, elapsed: %v, error: %v", + Config.Actions.Hook, notification.Action, notification.Username, notification.Path, notification.TargetPath, notification.SSHCmd, time.Since(startTime), err) + + return err +} + +func notificationAsEnvVars(notification ActionNotification) []string { + return []string{ + fmt.Sprintf("SFTPGO_ACTION=%v", notification.Action), + fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", notification.Username), + fmt.Sprintf("SFTPGO_ACTION_PATH=%v", notification.Path), + fmt.Sprintf("SFTPGO_ACTION_TARGET=%v", notification.TargetPath), + fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%v", notification.SSHCmd), + fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%v", notification.FileSize), + fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%v", notification.FsProvider), + fmt.Sprintf("SFTPGO_ACTION_BUCKET=%v", notification.Bucket), + fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%v", notification.Endpoint), + fmt.Sprintf("SFTPGO_ACTION_STATUS=%v", notification.Status), + fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%v", notification.Protocol), + } } diff --git a/common/actions_test.go b/common/actions_test.go index be9c63fb..e65c0e87 100644 --- a/common/actions_test.go +++ b/common/actions_test.go @@ -58,15 +58,15 @@ func TestActionHTTP(t *testing.T) { Username: "username", } a := newActionNotification(user, operationDownload, "path", "target", "", ProtocolSFTP, 123, nil) - err := a.execute() + err := actionHandler.Handle(a) assert.NoError(t, err) Config.Actions.Hook = "http://invalid:1234" - err = a.execute() + err = actionHandler.Handle(a) assert.Error(t, err) Config.Actions.Hook = fmt.Sprintf("http://%v/404", httpAddr) - err = a.execute() + err = actionHandler.Handle(a) if assert.Error(t, err) { assert.EqualError(t, err, errUnexpectedHTTResponse.Error()) } @@ -91,7 +91,7 @@ func TestActionCMD(t *testing.T) { Username: "username", } a := newActionNotification(user, operationDownload, "path", "target", "", ProtocolSFTP, 123, nil) - err = a.execute() + err = actionHandler.Handle(a) assert.NoError(t, err) SSHCommandActionNotification(user, "path", "target", "sha1sum", nil) @@ -115,26 +115,26 @@ func TestWrongActions(t *testing.T) { } a := newActionNotification(user, operationUpload, "", "", "", ProtocolSFTP, 123, nil) - err := a.execute() + err := actionHandler.Handle(a) assert.Error(t, err, "action with bad command must fail") a.Action = operationDelete - err = a.execute() + err = actionHandler.Handle(a) assert.EqualError(t, err, errUnconfiguredAction.Error()) Config.Actions.Hook = "http://foo\x7f.com/" a.Action = operationUpload - err = a.execute() + err = actionHandler.Handle(a) assert.Error(t, err, "action with bad url must fail") Config.Actions.Hook = "" - err = a.execute() + err = actionHandler.Handle(a) if assert.Error(t, err) { assert.EqualError(t, err, errNoHook.Error()) } Config.Actions.Hook = "relative path" - err = a.execute() + err = actionHandler.Handle(a) if assert.Error(t, err) { assert.EqualError(t, err, fmt.Sprintf("invalid notification command %#v", Config.Actions.Hook)) } diff --git a/common/connection.go b/common/connection.go index e0d50b5f..12c2da39 100644 --- a/common/connection.go +++ b/common/connection.go @@ -249,7 +249,7 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo } size := info.Size() action := newActionNotification(&c.User, operationPreDelete, fsPath, "", "", c.protocol, size, nil) - actionErr := action.execute() + actionErr := actionHandler.Handle(action) if actionErr == nil { c.Log(logger.LevelDebug, "remove for path %#v handled by pre-delete action", fsPath) } else { @@ -273,7 +273,7 @@ func (c *BaseConnection) RemoveFile(fsPath, virtualPath string, info os.FileInfo } if actionErr != nil { action := newActionNotification(&c.User, operationDelete, fsPath, "", "", c.protocol, size, nil) - go action.execute() //nolint:errcheck + go actionHandler.Handle(action) // nolint:errcheck } return nil } @@ -392,7 +392,7 @@ func (c *BaseConnection) Rename(fsSourcePath, fsTargetPath, virtualSourcePath, v "", "", "", -1) action := newActionNotification(&c.User, operationRename, fsSourcePath, fsTargetPath, "", c.protocol, 0, nil) // the returned error is used in test cases only, we already log the error inside action.execute - go action.execute() //nolint:errcheck + go actionHandler.Handle(action) // nolint:errcheck return nil } diff --git a/common/transfer.go b/common/transfer.go index 2c5f5c2b..1b172cc4 100644 --- a/common/transfer.go +++ b/common/transfer.go @@ -220,7 +220,7 @@ func (t *BaseTransfer) Close() error { t.Connection.ID, t.Connection.protocol) action := newActionNotification(&t.Connection.User, operationDownload, t.fsPath, "", "", t.Connection.protocol, atomic.LoadInt64(&t.BytesSent), t.ErrTransfer) - go action.execute() //nolint:errcheck + go actionHandler.Handle(action) //nolint:errcheck } else { fileSize := atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset info, err := t.Fs.Stat(t.fsPath) @@ -233,7 +233,7 @@ func (t *BaseTransfer) Close() error { t.Connection.ID, t.Connection.protocol) action := newActionNotification(&t.Connection.User, operationUpload, t.fsPath, "", "", t.Connection.protocol, fileSize, t.ErrTransfer) - go action.execute() //nolint:errcheck + go actionHandler.Handle(action) //nolint:errcheck } if t.ErrTransfer != nil { t.Connection.Log(logger.LevelWarn, "transfer error: %v, path: %#v", t.ErrTransfer, t.fsPath)