Skip to content

Commit 84a92b1

Browse files
authored
feat: implement context-based timeout handling (#95)
- Add context package import - Replace timeout channel with context-based timeout - Improve error message to include context timeout error - Update test to match new error message format - Add new test for command timeout functionality Signed-off-by: Bo-Yi Wu <appleboy.tw@gmail.com>
1 parent dc56456 commit 84a92b1

2 files changed

Lines changed: 23 additions & 5 deletions

File tree

easyssh.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package easyssh
66

77
import (
88
"bufio"
9+
"context"
910
"errors"
1011
"fmt"
1112
"io"
@@ -357,7 +358,8 @@ func (ssh_conf *MakeConfig) Stream(command string, timeout ...time.Duration) (<-
357358
if len(timeout) > 0 {
358359
executeTimeout = timeout[0]
359360
}
360-
timeoutChan := time.After(executeTimeout)
361+
ctxTimeout, cancel := context.WithTimeout(context.Background(), executeTimeout)
362+
defer cancel()
361363
res := make(chan struct{}, 1)
362364
var resWg sync.WaitGroup
363365
resWg.Add(2)
@@ -398,8 +400,8 @@ func (ssh_conf *MakeConfig) Stream(command string, timeout ...time.Duration) (<-
398400
case <-res:
399401
errChan <- session.Wait()
400402
doneChan <- true
401-
case <-timeoutChan:
402-
errChan <- fmt.Errorf("Run Command Timeout")
403+
case <-ctxTimeout.Done():
404+
errChan <- fmt.Errorf("Run Command Timeout: %v", ctxTimeout.Err())
403405
doneChan <- false
404406
}
405407
}(stdoutScanner, stderrScanner, stdoutChan, stderrChan, doneChan, errChan)

easyssh_test.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package easyssh
22

33
import (
4+
"context"
45
"os"
56
"os/user"
67
"path"
@@ -20,7 +21,6 @@ func getHostPublicKeyFile(keypath string) (ssh.PublicKey, error) {
2021
}
2122

2223
pubkey, _, _, _, err = ssh.ParseAuthorizedKey(buf)
23-
2424
if err != nil {
2525
return nil, err
2626
}
@@ -169,7 +169,7 @@ func TestRunCommand(t *testing.T) {
169169
assert.Equal(t, "", errStr)
170170
assert.False(t, isTimeout)
171171
assert.Error(t, err)
172-
assert.Equal(t, "Run Command Timeout", err.Error())
172+
assert.Equal(t, "Run Command Timeout: "+context.DeadlineExceeded.Error(), err.Error())
173173

174174
// test exit code
175175
outStr, errStr, isTimeout, err = ssh.Run("exit 1")
@@ -496,3 +496,19 @@ func TestSudoCommand(t *testing.T) {
496496
assert.True(t, isTimeout)
497497
assert.NoError(t, err)
498498
}
499+
500+
func TestCommandTimeout(t *testing.T) {
501+
ssh := &MakeConfig{
502+
Server: "localhost",
503+
User: "root",
504+
Port: "22",
505+
KeyPath: "./tests/.ssh/id_rsa",
506+
}
507+
508+
outStr, errStr, isTimeout, err := ssh.Run("whoami; sleep 2", 1*time.Second)
509+
assert.Equal(t, "root\n", outStr)
510+
assert.Equal(t, "", errStr)
511+
assert.False(t, isTimeout)
512+
assert.NotNil(t, err)
513+
assert.Equal(t, "Run Command Timeout: "+context.DeadlineExceeded.Error(), err.Error())
514+
}

0 commit comments

Comments
 (0)