mirror of
https://github.com/drakkan/sftpgo.git
synced 2025-12-07 14:50:55 +03:00
connections: close the ssh channel before the network connection
This way if pkg/sftp is stuck in Serve() method should be unlocked.
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/drakkan/sftpgo/utils"
|
"github.com/drakkan/sftpgo/utils"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/drakkan/sftpgo/dataprovider"
|
"github.com/drakkan/sftpgo/dataprovider"
|
||||||
"github.com/drakkan/sftpgo/logger"
|
"github.com/drakkan/sftpgo/logger"
|
||||||
@@ -37,6 +38,7 @@ type Connection struct {
|
|||||||
protocol string
|
protocol string
|
||||||
lock *sync.Mutex
|
lock *sync.Mutex
|
||||||
netConn net.Conn
|
netConn net.Conn
|
||||||
|
channel ssh.Channel
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log outputs a log entry to the configured logger
|
// Log outputs a log entry to the configured logger
|
||||||
@@ -580,6 +582,10 @@ func (c Connection) createMissingDirs(filePath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c Connection) close() error {
|
func (c Connection) close() error {
|
||||||
|
if c.channel != nil {
|
||||||
|
err := c.channel.Close()
|
||||||
|
c.Log(logger.LevelInfo, logSender, "channel close, err: %v", err)
|
||||||
|
}
|
||||||
return c.netConn.Close()
|
return c.netConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -252,7 +252,6 @@ func TestSCPGetNonExistingDirContent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSCPParseUploadMessage(t *testing.T) {
|
func TestSCPParseUploadMessage(t *testing.T) {
|
||||||
connection := Connection{}
|
|
||||||
buf := make([]byte, 65535)
|
buf := make([]byte, 65535)
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
mockSSHChannel := MockChannel{
|
mockSSHChannel := MockChannel{
|
||||||
@@ -260,10 +259,12 @@ func TestSCPParseUploadMessage(t *testing.T) {
|
|||||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||||
ReadError: nil,
|
ReadError: nil,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-t", "/tmp"},
|
args: []string{"-t", "/tmp"},
|
||||||
channel: &mockSSHChannel,
|
|
||||||
}
|
}
|
||||||
_, _, err := scpCommand.parseUploadMessage("invalid")
|
_, _, err := scpCommand.parseUploadMessage("invalid")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -284,7 +285,6 @@ func TestSCPParseUploadMessage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSCPProtocolMessages(t *testing.T) {
|
func TestSCPProtocolMessages(t *testing.T) {
|
||||||
connection := Connection{}
|
|
||||||
buf := make([]byte, 65535)
|
buf := make([]byte, 65535)
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
readErr := fmt.Errorf("test read error")
|
readErr := fmt.Errorf("test read error")
|
||||||
@@ -295,10 +295,12 @@ func TestSCPProtocolMessages(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-t", "/tmp"},
|
args: []string{"-t", "/tmp"},
|
||||||
channel: &mockSSHChannel,
|
|
||||||
}
|
}
|
||||||
_, err := scpCommand.readProtocolMessage()
|
_, err := scpCommand.readProtocolMessage()
|
||||||
if err == nil || err != readErr {
|
if err == nil || err != readErr {
|
||||||
@@ -322,7 +324,7 @@ func TestSCPProtocolMessages(t *testing.T) {
|
|||||||
ReadError: nil,
|
ReadError: nil,
|
||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannel
|
scpCommand.connection.channel = &mockSSHChannel
|
||||||
_, err = scpCommand.getNextUploadProtocolMessage()
|
_, err = scpCommand.getNextUploadProtocolMessage()
|
||||||
if err == nil || err != writeErr {
|
if err == nil || err != writeErr {
|
||||||
t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
|
t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
|
||||||
@@ -337,7 +339,7 @@ func TestSCPProtocolMessages(t *testing.T) {
|
|||||||
ReadError: nil,
|
ReadError: nil,
|
||||||
WriteError: nil,
|
WriteError: nil,
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannel
|
scpCommand.connection.channel = &mockSSHChannel
|
||||||
err = scpCommand.readConfirmationMessage()
|
err = scpCommand.readConfirmationMessage()
|
||||||
if err == nil || err.Error() != protocolErrorMsg {
|
if err == nil || err.Error() != protocolErrorMsg {
|
||||||
t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
|
t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
|
||||||
@@ -345,7 +347,6 @@ func TestSCPProtocolMessages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSCPTestDownloadProtocolMessages(t *testing.T) {
|
func TestSCPTestDownloadProtocolMessages(t *testing.T) {
|
||||||
connection := Connection{}
|
|
||||||
buf := make([]byte, 65535)
|
buf := make([]byte, 65535)
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
readErr := fmt.Errorf("test read error")
|
readErr := fmt.Errorf("test read error")
|
||||||
@@ -356,10 +357,12 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-f", "-p", "/tmp"},
|
args: []string{"-f", "-p", "/tmp"},
|
||||||
channel: &mockSSHChannel,
|
|
||||||
}
|
}
|
||||||
path := "testDir"
|
path := "testDir"
|
||||||
os.Mkdir(path, 0777)
|
os.Mkdir(path, 0777)
|
||||||
@@ -388,7 +391,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
|
|||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
scpCommand.args = []string{"-f", "/tmp"}
|
scpCommand.args = []string{"-f", "/tmp"}
|
||||||
scpCommand.channel = &mockSSHChannel
|
scpCommand.connection.channel = &mockSSHChannel
|
||||||
err = scpCommand.sendDownloadProtocolMessages(path, stat)
|
err = scpCommand.sendDownloadProtocolMessages(path, stat)
|
||||||
if err != writeErr {
|
if err != writeErr {
|
||||||
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
|
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
|
||||||
@@ -400,7 +403,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: nil,
|
WriteError: nil,
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannel
|
scpCommand.connection.channel = &mockSSHChannel
|
||||||
err = scpCommand.sendDownloadProtocolMessages(path, stat)
|
err = scpCommand.sendDownloadProtocolMessages(path, stat)
|
||||||
if err != readErr {
|
if err != readErr {
|
||||||
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
|
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
|
||||||
@@ -409,7 +412,6 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSCPCommandHandleErrors(t *testing.T) {
|
func TestSCPCommandHandleErrors(t *testing.T) {
|
||||||
connection := Connection{}
|
|
||||||
buf := make([]byte, 65535)
|
buf := make([]byte, 65535)
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
readErr := fmt.Errorf("test read error")
|
readErr := fmt.Errorf("test read error")
|
||||||
@@ -420,10 +422,12 @@ func TestSCPCommandHandleErrors(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-f", "/tmp"},
|
args: []string{"-f", "/tmp"},
|
||||||
channel: &mockSSHChannel,
|
|
||||||
}
|
}
|
||||||
err := scpCommand.handle()
|
err := scpCommand.handle()
|
||||||
if err == nil || err != readErr {
|
if err == nil || err != readErr {
|
||||||
@@ -437,7 +441,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSCPRecursiveDownloadErrors(t *testing.T) {
|
func TestSCPRecursiveDownloadErrors(t *testing.T) {
|
||||||
connection := Connection{}
|
|
||||||
buf := make([]byte, 65535)
|
buf := make([]byte, 65535)
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
readErr := fmt.Errorf("test read error")
|
readErr := fmt.Errorf("test read error")
|
||||||
@@ -448,10 +451,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-r", "-f", "/tmp"},
|
args: []string{"-r", "-f", "/tmp"},
|
||||||
channel: &mockSSHChannel,
|
|
||||||
}
|
}
|
||||||
path := "testDir"
|
path := "testDir"
|
||||||
os.Mkdir(path, 0777)
|
os.Mkdir(path, 0777)
|
||||||
@@ -466,7 +471,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
|
|||||||
ReadError: nil,
|
ReadError: nil,
|
||||||
WriteError: nil,
|
WriteError: nil,
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannel
|
scpCommand.connection.channel = &mockSSHChannel
|
||||||
err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
|
err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("recursive upload download must fail for a non existing dir")
|
t.Errorf("recursive upload download must fail for a non existing dir")
|
||||||
@@ -476,7 +481,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSCPRecursiveUploadErrors(t *testing.T) {
|
func TestSCPRecursiveUploadErrors(t *testing.T) {
|
||||||
connection := Connection{}
|
|
||||||
buf := make([]byte, 65535)
|
buf := make([]byte, 65535)
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
readErr := fmt.Errorf("test read error")
|
readErr := fmt.Errorf("test read error")
|
||||||
@@ -487,10 +491,12 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-r", "-t", "/tmp"},
|
args: []string{"-r", "-t", "/tmp"},
|
||||||
channel: &mockSSHChannel,
|
|
||||||
}
|
}
|
||||||
err := scpCommand.handleRecursiveUpload()
|
err := scpCommand.handleRecursiveUpload()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -502,7 +508,7 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: nil,
|
WriteError: nil,
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannel
|
scpCommand.connection.channel = &mockSSHChannel
|
||||||
err = scpCommand.handleRecursiveUpload()
|
err = scpCommand.handleRecursiveUpload()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("recursive upload must fail, we send a fake error message")
|
t.Errorf("recursive upload must fail, we send a fake error message")
|
||||||
@@ -516,19 +522,19 @@ func TestSCPCreateDirs(t *testing.T) {
|
|||||||
u.HomeDir = "home_rel_path"
|
u.HomeDir = "home_rel_path"
|
||||||
u.Username = "test"
|
u.Username = "test"
|
||||||
u.Permissions = []string{"*"}
|
u.Permissions = []string{"*"}
|
||||||
connection := Connection{
|
|
||||||
User: u,
|
|
||||||
}
|
|
||||||
mockSSHChannel := MockChannel{
|
mockSSHChannel := MockChannel{
|
||||||
Buffer: bytes.NewBuffer(buf),
|
Buffer: bytes.NewBuffer(buf),
|
||||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||||
ReadError: nil,
|
ReadError: nil,
|
||||||
WriteError: nil,
|
WriteError: nil,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
User: u,
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-r", "-t", "/tmp"},
|
args: []string{"-r", "-t", "/tmp"},
|
||||||
channel: &mockSSHChannel,
|
|
||||||
}
|
}
|
||||||
err := scpCommand.handleCreateDir("invalid_dir")
|
err := scpCommand.handleCreateDir("invalid_dir")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -542,7 +548,6 @@ func TestSCPDownloadFileData(t *testing.T) {
|
|||||||
readErr := fmt.Errorf("test read error")
|
readErr := fmt.Errorf("test read error")
|
||||||
writeErr := fmt.Errorf("test write error")
|
writeErr := fmt.Errorf("test write error")
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
connection := Connection{}
|
|
||||||
mockSSHChannelReadErr := MockChannel{
|
mockSSHChannelReadErr := MockChannel{
|
||||||
Buffer: bytes.NewBuffer(buf),
|
Buffer: bytes.NewBuffer(buf),
|
||||||
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
|
||||||
@@ -555,10 +560,12 @@ func TestSCPDownloadFileData(t *testing.T) {
|
|||||||
ReadError: nil,
|
ReadError: nil,
|
||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
channel: &mockSSHChannelReadErr,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-r", "-f", "/tmp"},
|
args: []string{"-r", "-f", "/tmp"},
|
||||||
channel: &mockSSHChannelReadErr,
|
|
||||||
}
|
}
|
||||||
ioutil.WriteFile(testfile, []byte("test"), 0666)
|
ioutil.WriteFile(testfile, []byte("test"), 0666)
|
||||||
stat, _ := os.Stat(testfile)
|
stat, _ := os.Stat(testfile)
|
||||||
@@ -566,7 +573,7 @@ func TestSCPDownloadFileData(t *testing.T) {
|
|||||||
if err != readErr {
|
if err != readErr {
|
||||||
t.Errorf("send download file data must fail with the expected error: %v", err)
|
t.Errorf("send download file data must fail with the expected error: %v", err)
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannelWriteErr
|
scpCommand.connection.channel = &mockSSHChannelWriteErr
|
||||||
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
|
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
|
||||||
if err != writeErr {
|
if err != writeErr {
|
||||||
t.Errorf("send download file data must fail with the expected error: %v", err)
|
t.Errorf("send download file data must fail with the expected error: %v", err)
|
||||||
@@ -576,7 +583,7 @@ func TestSCPDownloadFileData(t *testing.T) {
|
|||||||
if err != writeErr {
|
if err != writeErr {
|
||||||
t.Errorf("send download file data must fail with the expected error: %v", err)
|
t.Errorf("send download file data must fail with the expected error: %v", err)
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannelReadErr
|
scpCommand.connection.channel = &mockSSHChannelReadErr
|
||||||
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
|
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
|
||||||
if err != readErr {
|
if err != readErr {
|
||||||
t.Errorf("send download file data must fail with the expected error: %v", err)
|
t.Errorf("send download file data must fail with the expected error: %v", err)
|
||||||
@@ -586,12 +593,6 @@ func TestSCPDownloadFileData(t *testing.T) {
|
|||||||
|
|
||||||
func TestSCPUploadFiledata(t *testing.T) {
|
func TestSCPUploadFiledata(t *testing.T) {
|
||||||
testfile := "testfile"
|
testfile := "testfile"
|
||||||
connection := Connection{
|
|
||||||
User: dataprovider.User{
|
|
||||||
Username: "testuser",
|
|
||||||
},
|
|
||||||
protocol: protocolSCP,
|
|
||||||
}
|
|
||||||
buf := make([]byte, 65535)
|
buf := make([]byte, 65535)
|
||||||
stdErrBuf := make([]byte, 65535)
|
stdErrBuf := make([]byte, 65535)
|
||||||
readErr := fmt.Errorf("test read error")
|
readErr := fmt.Errorf("test read error")
|
||||||
@@ -602,10 +603,16 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: writeErr,
|
WriteError: writeErr,
|
||||||
}
|
}
|
||||||
|
connection := Connection{
|
||||||
|
User: dataprovider.User{
|
||||||
|
Username: "testuser",
|
||||||
|
},
|
||||||
|
protocol: protocolSCP,
|
||||||
|
channel: &mockSSHChannel,
|
||||||
|
}
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: []string{"-r", "-t", "/tmp"},
|
args: []string{"-r", "-t", "/tmp"},
|
||||||
channel: &mockSSHChannel,
|
|
||||||
}
|
}
|
||||||
file, _ := os.Create(testfile)
|
file, _ := os.Create(testfile)
|
||||||
transfer := Transfer{
|
transfer := Transfer{
|
||||||
@@ -634,7 +641,7 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|||||||
ReadError: readErr,
|
ReadError: readErr,
|
||||||
WriteError: nil,
|
WriteError: nil,
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannel
|
scpCommand.connection.channel = &mockSSHChannel
|
||||||
file, _ = os.Create(testfile)
|
file, _ = os.Create(testfile)
|
||||||
transfer.file = file
|
transfer.file = file
|
||||||
addTransfer(&transfer)
|
addTransfer(&transfer)
|
||||||
@@ -651,7 +658,7 @@ func TestSCPUploadFiledata(t *testing.T) {
|
|||||||
ReadError: nil,
|
ReadError: nil,
|
||||||
WriteError: nil,
|
WriteError: nil,
|
||||||
}
|
}
|
||||||
scpCommand.channel = &mockSSHChannel
|
scpCommand.connection.channel = &mockSSHChannel
|
||||||
file, _ = os.Create(testfile)
|
file, _ = os.Create(testfile)
|
||||||
transfer.file = file
|
transfer.file = file
|
||||||
addTransfer(&transfer)
|
addTransfer(&transfer)
|
||||||
|
|||||||
37
sftpd/scp.go
37
sftpd/scp.go
@@ -35,7 +35,6 @@ type exitStatusMsg struct {
|
|||||||
type scpCommand struct {
|
type scpCommand struct {
|
||||||
connection Connection
|
connection Connection
|
||||||
args []string
|
args []string
|
||||||
channel ssh.Channel
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *scpCommand) handle() error {
|
func (c *scpCommand) handle() error {
|
||||||
@@ -160,7 +159,7 @@ func (c *scpCommand) getUploadFileData(sizeToRead int64, transfer *Transfer) err
|
|||||||
remaining := sizeToRead
|
remaining := sizeToRead
|
||||||
buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
|
buf := make([]byte, int64(math.Min(32768, float64(sizeToRead))))
|
||||||
for {
|
for {
|
||||||
n, err := c.channel.Read(buf)
|
n, err := c.connection.channel.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendErrorMessage(err.Error())
|
c.sendErrorMessage(err.Error())
|
||||||
transfer.TransferError(err)
|
transfer.TransferError(err)
|
||||||
@@ -403,7 +402,7 @@ func (c *scpCommand) sendDownloadFileData(filePath string, stat os.FileInfo, tra
|
|||||||
n, err := transfer.ReadAt(buf, readed)
|
n, err := transfer.ReadAt(buf, readed)
|
||||||
if err == nil || err == io.EOF {
|
if err == nil || err == io.EOF {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
_, err = c.channel.Write(buf[:n])
|
_, err = c.connection.channel.Write(buf[:n])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
readed += int64(n)
|
readed += int64(n)
|
||||||
@@ -517,15 +516,15 @@ func (c *scpCommand) isRecursive() bool {
|
|||||||
func (c *scpCommand) readConfirmationMessage() error {
|
func (c *scpCommand) readConfirmationMessage() error {
|
||||||
var msg strings.Builder
|
var msg strings.Builder
|
||||||
buf := make([]byte, 1)
|
buf := make([]byte, 1)
|
||||||
n, err := c.channel.Read(buf)
|
n, err := c.connection.channel.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.channel.Close()
|
c.connection.channel.Close()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
|
if n == 1 && (buf[0] == warnMsg[0] || buf[0] == errMsg[0]) {
|
||||||
isError := buf[0] == errMsg[0]
|
isError := buf[0] == errMsg[0]
|
||||||
for {
|
for {
|
||||||
n, err = c.channel.Read(buf)
|
n, err = c.connection.channel.Read(buf)
|
||||||
readed := buf[:n]
|
readed := buf[:n]
|
||||||
if err != nil || (n == 1 && readed[0] == newLine[0]) {
|
if err != nil || (n == 1 && readed[0] == newLine[0]) {
|
||||||
break
|
break
|
||||||
@@ -536,7 +535,7 @@ func (c *scpCommand) readConfirmationMessage() error {
|
|||||||
}
|
}
|
||||||
c.connection.Log(logger.LevelInfo, logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError)
|
c.connection.Log(logger.LevelInfo, logSenderSCP, "scp error message received: %v is error: %v", msg.String(), isError)
|
||||||
err = fmt.Errorf("%v", msg.String())
|
err = fmt.Errorf("%v", msg.String())
|
||||||
c.channel.Close()
|
c.connection.channel.Close()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -548,7 +547,7 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
|
|||||||
buf := make([]byte, 1)
|
buf := make([]byte, 1)
|
||||||
for {
|
for {
|
||||||
var n int
|
var n int
|
||||||
n, err = c.channel.Read(buf)
|
n, err = c.connection.channel.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -561,34 +560,34 @@ func (c *scpCommand) readProtocolMessage() (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.channel.Close()
|
c.connection.channel.Close()
|
||||||
}
|
}
|
||||||
return command.String(), err
|
return command.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// send an error message and close the channel
|
// send an error message and close the channel
|
||||||
func (c *scpCommand) sendErrorMessage(error string) {
|
func (c *scpCommand) sendErrorMessage(error string) {
|
||||||
c.channel.Write(errMsg)
|
c.connection.channel.Write(errMsg)
|
||||||
c.channel.Write([]byte(error))
|
c.connection.channel.Write([]byte(error))
|
||||||
c.channel.Write(newLine)
|
c.connection.channel.Write(newLine)
|
||||||
c.channel.Close()
|
c.connection.channel.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// send scp confirmation message and close the channel if an error happen
|
// send scp confirmation message and close the channel if an error happen
|
||||||
func (c *scpCommand) sendConfirmationMessage() error {
|
func (c *scpCommand) sendConfirmationMessage() error {
|
||||||
_, err := c.channel.Write(okMsg)
|
_, err := c.connection.channel.Write(okMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.channel.Close()
|
c.connection.channel.Close()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// sends a protocol message and close the channel on error
|
// sends a protocol message and close the channel on error
|
||||||
func (c *scpCommand) sendProtocolMessage(message string) error {
|
func (c *scpCommand) sendProtocolMessage(message string) error {
|
||||||
_, err := c.channel.Write([]byte(message))
|
_, err := c.connection.channel.Write([]byte(message))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.connection.Log(logger.LevelWarn, logSenderSCP, "error sending protocol message: %v, err: %v", message, err)
|
c.connection.Log(logger.LevelWarn, logSenderSCP, "error sending protocol message: %v, err: %v", message, err)
|
||||||
c.channel.Close()
|
c.connection.channel.Close()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -604,8 +603,8 @@ func (c *scpCommand) sendExitStatus(err error) {
|
|||||||
}
|
}
|
||||||
c.connection.Log(logger.LevelDebug, logSenderSCP, "send exit status for command with args: %v user: %v err: %v",
|
c.connection.Log(logger.LevelDebug, logSenderSCP, "send exit status for command with args: %v user: %v err: %v",
|
||||||
c.args, c.connection.User.Username, err)
|
c.args, c.connection.User.Username, err)
|
||||||
c.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
|
c.connection.channel.SendRequest("exit-status", false, ssh.Marshal(&ex))
|
||||||
c.channel.Close()
|
c.connection.channel.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the next upload protocol message ignoring T command if any
|
// get the next upload protocol message ignoring T command if any
|
||||||
|
|||||||
@@ -229,6 +229,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
lastActivity: time.Now(),
|
lastActivity: time.Now(),
|
||||||
lock: new(sync.Mutex),
|
lock: new(sync.Mutex),
|
||||||
netConn: conn,
|
netConn: conn,
|
||||||
|
channel: nil,
|
||||||
}
|
}
|
||||||
connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v",
|
connection.Log(logger.LevelInfo, logSender, "User id: %d, logged in with: %#v, username: %#v, home_dir: %#v",
|
||||||
user.ID, loginType, user.Username, user.HomeDir)
|
user.ID, loginType, user.Username, user.HomeDir)
|
||||||
@@ -261,6 +262,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
if string(req.Payload[4:]) == "sftp" {
|
if string(req.Payload[4:]) == "sftp" {
|
||||||
ok = true
|
ok = true
|
||||||
connection.protocol = protocolSFTP
|
connection.protocol = protocolSFTP
|
||||||
|
connection.channel = channel
|
||||||
go c.handleSftpConnection(channel, connection)
|
go c.handleSftpConnection(channel, connection)
|
||||||
}
|
}
|
||||||
case "exec":
|
case "exec":
|
||||||
@@ -274,10 +276,10 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
if err == nil && name == "scp" && len(scpArgs) >= 2 {
|
if err == nil && name == "scp" && len(scpArgs) >= 2 {
|
||||||
ok = true
|
ok = true
|
||||||
connection.protocol = protocolSCP
|
connection.protocol = protocolSCP
|
||||||
|
connection.channel = channel
|
||||||
scpCommand := scpCommand{
|
scpCommand := scpCommand{
|
||||||
connection: connection,
|
connection: connection,
|
||||||
args: scpArgs,
|
args: scpArgs,
|
||||||
channel: channel,
|
|
||||||
}
|
}
|
||||||
go scpCommand.handle()
|
go scpCommand.handle()
|
||||||
}
|
}
|
||||||
@@ -290,7 +292,7 @@ func (c Configuration) AcceptInboundConnection(conn net.Conn, config *ssh.Server
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Configuration) handleSftpConnection(channel io.ReadWriteCloser, connection Connection) {
|
func (c Configuration) handleSftpConnection(channel ssh.Channel, connection Connection) {
|
||||||
addConnection(connection.ID, connection)
|
addConnection(connection.ID, connection)
|
||||||
// Create a new handler for the currently logged in user's server.
|
// Create a new handler for the currently logged in user's server.
|
||||||
handler := c.createHandler(connection)
|
handler := c.createHandler(connection)
|
||||||
|
|||||||
Reference in New Issue
Block a user