connections: close the ssh channel before the network connection

This way if pkg/sftp is stuck in Serve() method should be unlocked.
This commit is contained in:
Nicola Murino
2019-09-11 16:29:56 +02:00
parent 9794ca7ee0
commit 3d13fe15c3
4 changed files with 70 additions and 56 deletions

View File

@@ -252,7 +252,6 @@ func TestSCPGetNonExistingDirContent(t *testing.T) {
}
func TestSCPParseUploadMessage(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
mockSSHChannel := MockChannel{
@@ -260,10 +259,12 @@ func TestSCPParseUploadMessage(t *testing.T) {
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
}
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-t", "/tmp"},
channel: &mockSSHChannel,
}
_, _, err := scpCommand.parseUploadMessage("invalid")
if err == nil {
@@ -284,7 +285,6 @@ func TestSCPParseUploadMessage(t *testing.T) {
}
func TestSCPProtocolMessages(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
@@ -295,10 +295,12 @@ func TestSCPProtocolMessages(t *testing.T) {
ReadError: readErr,
WriteError: writeErr,
}
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-t", "/tmp"},
channel: &mockSSHChannel,
}
_, err := scpCommand.readProtocolMessage()
if err == nil || err != readErr {
@@ -322,7 +324,7 @@ func TestSCPProtocolMessages(t *testing.T) {
ReadError: nil,
WriteError: writeErr,
}
scpCommand.channel = &mockSSHChannel
scpCommand.connection.channel = &mockSSHChannel
_, err = scpCommand.getNextUploadProtocolMessage()
if err == nil || err != writeErr {
t.Errorf("read next upload protocol message must fail, we are sending a fake write error")
@@ -337,7 +339,7 @@ func TestSCPProtocolMessages(t *testing.T) {
ReadError: nil,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.readConfirmationMessage()
if err == nil || err.Error() != protocolErrorMsg {
t.Errorf("read confirmation message must return the expected protocol error, actual err: %v", err)
@@ -345,7 +347,6 @@ func TestSCPProtocolMessages(t *testing.T) {
}
func TestSCPTestDownloadProtocolMessages(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
@@ -356,10 +357,12 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
ReadError: readErr,
WriteError: writeErr,
}
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-f", "-p", "/tmp"},
channel: &mockSSHChannel,
}
path := "testDir"
os.Mkdir(path, 0777)
@@ -388,7 +391,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
WriteError: writeErr,
}
scpCommand.args = []string{"-f", "/tmp"}
scpCommand.channel = &mockSSHChannel
scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.sendDownloadProtocolMessages(path, stat)
if err != writeErr {
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
@@ -400,7 +403,7 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
ReadError: readErr,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.sendDownloadProtocolMessages(path, stat)
if err != readErr {
t.Errorf("sendDownloadProtocolMessages must return the expected error: %v", err)
@@ -409,7 +412,6 @@ func TestSCPTestDownloadProtocolMessages(t *testing.T) {
}
func TestSCPCommandHandleErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
@@ -420,10 +422,12 @@ func TestSCPCommandHandleErrors(t *testing.T) {
ReadError: readErr,
WriteError: writeErr,
}
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-f", "/tmp"},
channel: &mockSSHChannel,
}
err := scpCommand.handle()
if err == nil || err != readErr {
@@ -437,7 +441,6 @@ func TestSCPCommandHandleErrors(t *testing.T) {
}
func TestSCPRecursiveDownloadErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
@@ -448,10 +451,12 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
ReadError: readErr,
WriteError: writeErr,
}
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-f", "/tmp"},
channel: &mockSSHChannel,
}
path := "testDir"
os.Mkdir(path, 0777)
@@ -466,7 +471,7 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
ReadError: nil,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.handleRecursiveDownload("invalid_dir", stat)
if err == nil {
t.Errorf("recursive upload download must fail for a non existing dir")
@@ -476,7 +481,6 @@ func TestSCPRecursiveDownloadErrors(t *testing.T) {
}
func TestSCPRecursiveUploadErrors(t *testing.T) {
connection := Connection{}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
@@ -487,10 +491,12 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
ReadError: readErr,
WriteError: writeErr,
}
connection := Connection{
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
}
err := scpCommand.handleRecursiveUpload()
if err == nil {
@@ -502,7 +508,7 @@ func TestSCPRecursiveUploadErrors(t *testing.T) {
ReadError: readErr,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
scpCommand.connection.channel = &mockSSHChannel
err = scpCommand.handleRecursiveUpload()
if err == nil {
t.Errorf("recursive upload must fail, we send a fake error message")
@@ -516,19 +522,19 @@ func TestSCPCreateDirs(t *testing.T) {
u.HomeDir = "home_rel_path"
u.Username = "test"
u.Permissions = []string{"*"}
connection := Connection{
User: u,
}
mockSSHChannel := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
ReadError: nil,
WriteError: nil,
}
connection := Connection{
User: u,
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
}
err := scpCommand.handleCreateDir("invalid_dir")
if err == nil {
@@ -542,7 +548,6 @@ func TestSCPDownloadFileData(t *testing.T) {
readErr := fmt.Errorf("test read error")
writeErr := fmt.Errorf("test write error")
stdErrBuf := make([]byte, 65535)
connection := Connection{}
mockSSHChannelReadErr := MockChannel{
Buffer: bytes.NewBuffer(buf),
StdErrBuffer: bytes.NewBuffer(stdErrBuf),
@@ -555,10 +560,12 @@ func TestSCPDownloadFileData(t *testing.T) {
ReadError: nil,
WriteError: writeErr,
}
connection := Connection{
channel: &mockSSHChannelReadErr,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-f", "/tmp"},
channel: &mockSSHChannelReadErr,
}
ioutil.WriteFile(testfile, []byte("test"), 0666)
stat, _ := os.Stat(testfile)
@@ -566,7 +573,7 @@ func TestSCPDownloadFileData(t *testing.T) {
if err != readErr {
t.Errorf("send download file data must fail with the expected error: %v", err)
}
scpCommand.channel = &mockSSHChannelWriteErr
scpCommand.connection.channel = &mockSSHChannelWriteErr
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
if err != writeErr {
t.Errorf("send download file data must fail with the expected error: %v", err)
@@ -576,7 +583,7 @@ func TestSCPDownloadFileData(t *testing.T) {
if err != writeErr {
t.Errorf("send download file data must fail with the expected error: %v", err)
}
scpCommand.channel = &mockSSHChannelReadErr
scpCommand.connection.channel = &mockSSHChannelReadErr
err = scpCommand.sendDownloadFileData(testfile, stat, nil)
if err != readErr {
t.Errorf("send download file data must fail with the expected error: %v", err)
@@ -586,12 +593,6 @@ func TestSCPDownloadFileData(t *testing.T) {
func TestSCPUploadFiledata(t *testing.T) {
testfile := "testfile"
connection := Connection{
User: dataprovider.User{
Username: "testuser",
},
protocol: protocolSCP,
}
buf := make([]byte, 65535)
stdErrBuf := make([]byte, 65535)
readErr := fmt.Errorf("test read error")
@@ -602,10 +603,16 @@ func TestSCPUploadFiledata(t *testing.T) {
ReadError: readErr,
WriteError: writeErr,
}
connection := Connection{
User: dataprovider.User{
Username: "testuser",
},
protocol: protocolSCP,
channel: &mockSSHChannel,
}
scpCommand := scpCommand{
connection: connection,
args: []string{"-r", "-t", "/tmp"},
channel: &mockSSHChannel,
}
file, _ := os.Create(testfile)
transfer := Transfer{
@@ -634,7 +641,7 @@ func TestSCPUploadFiledata(t *testing.T) {
ReadError: readErr,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
scpCommand.connection.channel = &mockSSHChannel
file, _ = os.Create(testfile)
transfer.file = file
addTransfer(&transfer)
@@ -651,7 +658,7 @@ func TestSCPUploadFiledata(t *testing.T) {
ReadError: nil,
WriteError: nil,
}
scpCommand.channel = &mockSSHChannel
scpCommand.connection.channel = &mockSSHChannel
file, _ = os.Create(testfile)
transfer.file = file
addTransfer(&transfer)