Skip to content

Commit b638274

Browse files
authored
Ignore host key update mechanism requests (#5)
Fixes: #4
1 parent 669ff9c commit b638274

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

internal/sshproxy/server.go

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"net"
9+
"slices"
910
"strconv"
1011
"strings"
1112
"sync"
@@ -28,6 +29,20 @@ const (
2829
upstreamDialTimeout = time.Second * 10
2930
)
3031

32+
var blacklistedGlobalRequests = []string{
33+
// Host key update mechanism for SSH: https://www.ietf.org/archive/id/draft-miller-sshm-hostkey-update-02.html
34+
// Reasons to blacklist:
35+
// 1. Signature check always fail as the signed data contains session identifier, which is not the same on client
36+
// and upstream side, since they don't talk directly but through sshproxy (there are two SSH transport sessions
37+
// with their own unique identifiers).
38+
// 2. Even if it worked somehow, we don't want to inflate user's known_hosts file with garbage records,
39+
// since container host keys are ephemeral -- they are generated on dstack-runner startup (= unique for each job).
40+
"hostkeys",
41+
"hostkeys-00@openssh.com",
42+
"hostkeys-prove",
43+
"hostkeys-prove-00@openssh.com",
44+
}
45+
3146
type direction string
3247

3348
var (
@@ -386,15 +401,22 @@ func bridgeGlobalRequests(ctx context.Context, dir direction, inReqs <-chan *ssh
386401
logger := log.GetLogger(ctx).WithField("dir", dir)
387402
for req := range inReqs {
388403
logger := logger.WithField("type", req.Type)
389-
logger.Trace("global request")
390404

391-
reply, payload, err := outConn.SendRequest(req.Type, req.WantReply, req.Payload)
392-
if req.WantReply {
393-
_ = req.Reply(reply, payload)
394-
}
405+
if slices.Contains(blacklistedGlobalRequests, req.Type) {
406+
logger.Trace("blacklisted global request, ignoring")
407+
if req.WantReply {
408+
_ = req.Reply(false, nil)
409+
}
410+
} else {
411+
logger.Trace("global request")
412+
ok, payload, err := outConn.SendRequest(req.Type, req.WantReply, req.Payload)
413+
if req.WantReply {
414+
_ = req.Reply(ok, payload)
415+
}
395416

396-
if err != nil && !isClosedError(err) {
397-
logger.WithError(err).Error("failed to forward global request")
417+
if err != nil && !isClosedError(err) {
418+
logger.WithError(err).Error("failed to forward global request")
419+
}
398420
}
399421
}
400422
}
@@ -488,9 +510,9 @@ func bridgeChannelRequests(ctx context.Context, dir direction, inReqs <-chan *ss
488510
logger := logger.WithField("type", req.Type)
489511
logger.Trace("request")
490512

491-
reply, err := outConn.SendRequest(req.Type, req.WantReply, req.Payload)
513+
ok, err := outConn.SendRequest(req.Type, req.WantReply, req.Payload)
492514
if req.WantReply {
493-
_ = req.Reply(reply, nil)
515+
_ = req.Reply(ok, nil)
494516
}
495517

496518
if err != nil && !isClosedError(err) {

0 commit comments

Comments
 (0)