diff --git a/sftpd/server.go b/sftpd/server.go index a11b9fc6..38b7d5aa 100644 --- a/sftpd/server.go +++ b/sftpd/server.go @@ -360,6 +360,9 @@ func canAcceptConnection(ip string) bool { logger.Log(logger.LevelDebug, common.ProtocolSSH, "", "connection refused, configured limit reached") return false } + if err := common.Config.ExecutePostConnectHook(ip, common.ProtocolSSH); err != nil { + return false + } return true } @@ -378,10 +381,7 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve // Before beginning a handshake must be performed on the incoming net.Conn // we'll set a Deadline for handshake to complete, the default is 2 minutes as OpenSSH conn.SetDeadline(time.Now().Add(handshakeTimeout)) //nolint:errcheck - if err := common.Config.ExecutePostConnectHook(ipAddr, common.ProtocolSSH); err != nil { - conn.Close() - return - } + sconn, chans, reqs, err := ssh.NewServerConn(conn, config) if err != nil { logger.Debug(logSender, "", "failed to accept an incoming connection: %v", err) @@ -471,7 +471,9 @@ func (c *Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Serve ok = processSSHCommand(req.Payload, &connection, c.EnabledSSHCommands) } } - req.Reply(ok, nil) //nolint:errcheck + if req.WantReply { + req.Reply(ok, nil) //nolint:errcheck + } } }(requests, channelCounter) }