From 146a696adcc5d8731acc3a9afbfa6e03d19284ab Mon Sep 17 00:00:00 2001 From: Ethan Lin Date: Thu, 28 May 2026 16:58:29 -0700 Subject: [PATCH] memcache: split TCP-connect / TLS-handshake timeouts + add dial hook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A single DialTimeout budgets the TCP connect and the TLS handshake together. When a server is CPU-bound the handshake slows down and can exceed that budget, so the client aborts it mid-flight ("closed during TLS handshake") — which, combined with immediate redial, exhausted FDs on a Verkada ElastiCache node and required a manual reboot to recover. Add independent ConnectTimeout and HandshakeTimeout fields. When either is set, dial() takes a split path: net.Dialer.DialContext bounds the TCP connect, then tls.Client(...).Handshake() runs under its own deadline. This lets the connect budget stay tight (fail dead hosts fast) while the handshake budget is generous enough to ride out a slow-but-healthy server. When neither is set, the original single-budget path is used unchanged (verified by existing tests). Also add an optional OnDial(DialInfo) hook reporting connect/handshake durations, TLS-resumption status, and outcome, so callers can emit dial metrics without the library taking a metrics dependency. The hook runs on a separate goroutine (so it can never add dial latency) with a recover() guard (so a panicking hook cannot crash the process). Co-Authored-By: Claude Opus 4.8 (1M context) --- memcache/memcache.go | 186 ++++++++++++++++++++++++++++++++++++-- memcache/memcache_test.go | 171 +++++++++++++++++++++++++++++++++++ 2 files changed, 351 insertions(+), 6 deletions(-) diff --git a/memcache/memcache.go b/memcache/memcache.go index 0b855dc..a99bdac 100644 --- a/memcache/memcache.go +++ b/memcache/memcache.go @@ -139,8 +139,38 @@ type Client struct { // DialTimeout specifies the timeout for establishing new connections, // including the TLS handshake if TLS is enabled. // If zero, Timeout is used (for backward compatibility). + // + // DialTimeout is a single budget covering both the TCP connection and the + // TLS handshake. Because the TLS handshake can be slow when the server is + // under CPU pressure, a single tight budget can abort handshakes mid-flight. + // Prefer ConnectTimeout + HandshakeTimeout (below) to budget them + // separately; DialTimeout is retained for backward compatibility and is used + // only when the split timeouts are zero. DialTimeout time.Duration + // ConnectTimeout, when non-zero, bounds only the TCP connection + // establishment (not the TLS handshake). Pair with HandshakeTimeout to + // budget connect and handshake independently. If zero, the dialer falls back + // to DialTimeout for the whole operation. + ConnectTimeout time.Duration + + // HandshakeTimeout, when non-zero, bounds only the TLS handshake (separate + // from the TCP connect). Has no effect for non-TLS clients. If zero, the TLS + // handshake shares the DialTimeout/ConnectTimeout budget (legacy behavior). + HandshakeTimeout time.Duration + + // OnDial, when non-nil, is invoked after every connection attempt (success + // or failure) with timing and outcome details. It enables callers to emit + // dial-count, connect/handshake-duration, and TLS-resumption metrics without + // the library taking a metrics dependency. + // + // OnDial is invoked on a separate goroutine so it can never add latency to + // the dial path; consequently it may run after the connection is already in + // use, and the relative ordering of callbacks is not guaranteed. A panic in + // the hook is recovered and discarded. The DialInfo is passed by value, so + // the hook owns its copy. + OnDial func(DialInfo) + // MaxIdleConns specifies the maximum number of idle connections that will // be maintained per address. If less than one, DefaultMaxIdleConns will be // used. @@ -156,6 +186,24 @@ type Client struct { TlsConfig *tls.Config } +// DialInfo reports the outcome of a single connection attempt to OnDial. +type DialInfo struct { + // Addr is the server address that was dialed. + Addr net.Addr + // TLS is true if the attempt included a TLS handshake. + TLS bool + // ConnectDuration is the time spent establishing the TCP connection. + ConnectDuration time.Duration + // HandshakeDuration is the time spent on the TLS handshake. Zero for + // non-TLS attempts or when the connect failed before the handshake began. + HandshakeDuration time.Duration + // Resumed is true if the TLS handshake resumed a previous session + // (abbreviated handshake). Only meaningful when TLS is true and Err is nil. + Resumed bool + // Err is the failure, or nil on success. + Err error +} + // Item is an item to be got or stored in a memcached server. type Item struct { // Key is the Item's key (250 bytes maximum). @@ -268,12 +316,42 @@ func (cte *ConnectTimeoutError) Error() string { return "memcache: connect timeout to " + cte.Addr.String() } +// useSplitTimeouts reports whether the caller configured independent connect +// and handshake budgets. When false, dial uses the legacy single-budget path. +func (c *Client) useSplitTimeouts() bool { + return c.ConnectTimeout != 0 || c.HandshakeTimeout != 0 +} + +// connectTimeout is the TCP-connect budget for the split-timeout path. It falls +// back to dialTimeout() when ConnectTimeout is unset so a caller can configure +// only HandshakeTimeout and still get a sane connect bound. +func (c *Client) connectTimeout() time.Duration { + if c.ConnectTimeout != 0 { + return c.ConnectTimeout + } + return c.dialTimeout() +} + +// handshakeTimeout is the TLS-handshake budget for the split-timeout path. It +// falls back to dialTimeout() when HandshakeTimeout is unset. +func (c *Client) handshakeTimeout() time.Duration { + if c.HandshakeTimeout != 0 { + return c.HandshakeTimeout + } + return c.dialTimeout() +} + func (c *Client) dial(addr net.Addr) (net.Conn, error) { - type connError struct { - cn net.Conn - err error + if c.useSplitTimeouts() { + return c.dialSplit(addr) } + return c.dialLegacy(addr) +} +// dialLegacy preserves the original behavior: a single dialer Timeout covers +// both the TCP connection and (for TLS) the handshake. +func (c *Client) dialLegacy(addr net.Addr) (net.Conn, error) { + start := time.Now() var ( nc net.Conn err error @@ -284,13 +362,109 @@ func (c *Client) dial(addr net.Addr) (net.Conn, error) { } else { nc, err = nd.Dial(addr.Network(), addr.String()) } + err = c.normalizeDialErr(addr, err) + c.reportDial(DialInfo{ + Addr: addr, + TLS: c.TlsConfig != nil, + ConnectDuration: time.Since(start), + Resumed: resumedFrom(nc, err), + Err: err, + }) + if err != nil { + return nil, err + } + return nc, nil +} + +// dialSplit budgets the TCP connect and the TLS handshake independently, so a +// slow handshake (e.g. a CPU-bound server) does not consume the connect budget +// and a tight connect budget can fail a dead host fast without aborting healthy +// but slow handshakes. +func (c *Client) dialSplit(addr net.Addr) (net.Conn, error) { + connectStart := time.Now() + nd := net.Dialer{Timeout: c.connectTimeout()} + rawConn, err := nd.Dial(addr.Network(), addr.String()) + connectDur := time.Since(connectStart) + if err != nil { + err = c.normalizeDialErr(addr, err) + c.reportDial(DialInfo{Addr: addr, TLS: c.TlsConfig != nil, ConnectDuration: connectDur, Err: err}) + return nil, err + } + + // Non-TLS: nothing more to do. + if c.TlsConfig == nil { + c.reportDial(DialInfo{Addr: addr, ConnectDuration: connectDur}) + return rawConn, nil + } + + // TLS: run the handshake under its own deadline. + handshakeStart := time.Now() + tlsConn := tls.Client(rawConn, c.TlsConfig) + if dl := c.handshakeTimeout(); dl > 0 { + _ = tlsConn.SetDeadline(time.Now().Add(dl)) + } + herr := tlsConn.Handshake() + handshakeDur := time.Since(handshakeStart) + // Clear the handshake deadline; per-op deadlines are managed elsewhere. + _ = tlsConn.SetDeadline(time.Time{}) + if herr != nil { + _ = tlsConn.Close() + herr = c.normalizeDialErr(addr, herr) + c.reportDial(DialInfo{ + Addr: addr, + TLS: true, + ConnectDuration: connectDur, + HandshakeDuration: handshakeDur, + Err: herr, + }) + return nil, herr + } + c.reportDial(DialInfo{ + Addr: addr, + TLS: true, + ConnectDuration: connectDur, + HandshakeDuration: handshakeDur, + Resumed: tlsConn.ConnectionState().DidResume, + }) + return tlsConn, nil +} + +// normalizeDialErr maps a timeout to ConnectTimeoutError, matching legacy +// behavior, and passes other errors through unchanged. +func (c *Client) normalizeDialErr(addr net.Addr, err error) error { if err == nil { - return nc, nil + return nil } if ne, ok := err.(net.Error); ok && ne.Timeout() { - return nil, &ConnectTimeoutError{addr} + return &ConnectTimeoutError{addr} } - return nil, err + return err +} + +// reportDial invokes the OnDial hook (if configured) on a separate goroutine so +// a slow hook cannot add latency to the dial path. A panic in the hook is +// recovered so it cannot crash the spawned goroutine (and thus the process). +func (c *Client) reportDial(info DialInfo) { + hook := c.OnDial + if hook == nil { + return + } + go func() { + defer func() { _ = recover() }() + hook(info) + }() +} + +// resumedFrom reports whether a successful TLS connection resumed a session. +// Returns false for failures or non-TLS connections. +func resumedFrom(nc net.Conn, err error) bool { + if err != nil || nc == nil { + return false + } + if tc, ok := nc.(*tls.Conn); ok { + return tc.ConnectionState().DidResume + } + return false } func (c *Client) getConn(addr net.Addr) (*conn, error) { diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index d8b48f0..320565c 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -469,3 +469,174 @@ func TestSeparateTimeouts(t *testing.T) { t.Error("DialTimeout and Timeout should be independent") } } + +// TestUseSplitTimeouts verifies the dialer takes the split-timeout path only +// when ConnectTimeout or HandshakeTimeout is configured. +func TestUseSplitTimeouts(t *testing.T) { + tests := []struct { + name string + connect time.Duration + handshk time.Duration + expected bool + }{ + {"neither set -> legacy", 0, 0, false}, + {"connect set -> split", 50 * time.Millisecond, 0, true}, + {"handshake set -> split", 0, 500 * time.Millisecond, true}, + {"both set -> split", 50 * time.Millisecond, 500 * time.Millisecond, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ConnectTimeout: tt.connect, HandshakeTimeout: tt.handshk} + if got := c.useSplitTimeouts(); got != tt.expected { + t.Errorf("useSplitTimeouts() = %v, want %v", got, tt.expected) + } + }) + } +} + +// TestSplitTimeoutFallbacks verifies connectTimeout()/handshakeTimeout() fall +// back to dialTimeout() (and thence Timeout/DefaultTimeout) when unset. +func TestSplitTimeoutFallbacks(t *testing.T) { + // Only handshake configured: connect should fall back to dialTimeout(). + c := &Client{HandshakeTimeout: 800 * time.Millisecond, DialTimeout: 175 * time.Millisecond} + if got := c.connectTimeout(); got != 175*time.Millisecond { + t.Errorf("connectTimeout() fallback = %v, want 175ms", got) + } + if got := c.handshakeTimeout(); got != 800*time.Millisecond { + t.Errorf("handshakeTimeout() = %v, want 800ms", got) + } + + // Only connect configured: handshake should fall back to dialTimeout(). + c2 := &Client{ConnectTimeout: 50 * time.Millisecond, Timeout: 90 * time.Millisecond} + if got := c2.connectTimeout(); got != 50*time.Millisecond { + t.Errorf("connectTimeout() = %v, want 50ms", got) + } + // dialTimeout() with no DialTimeout falls back to netTimeout() == Timeout. + if got := c2.handshakeTimeout(); got != 90*time.Millisecond { + t.Errorf("handshakeTimeout() fallback = %v, want 90ms", got) + } +} + +// TestOnDialHookSuccess verifies OnDial fires with timing on a successful +// (non-TLS) connect via the split-timeout path. +func TestOnDialHookSuccess(t *testing.T) { + fakeServer, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Could not open fake server: ", err) + } + defer fakeServer.Close() + go func() { + for { + conn, err := fakeServer.Accept() + if err != nil { + return + } + go func() { io.Copy(ioutil.Discard, conn) }() + } + }() + + addr := fakeServer.Addr() + // OnDial fires on a separate goroutine, so receive its result over a channel. + dialCh := make(chan DialInfo, 1) + c := New(addr.String()) + c.ConnectTimeout = 500 * time.Millisecond // force split path + c.OnDial = func(info DialInfo) { dialCh <- info } + + if _, err := c.getConn(addr); err != nil { + t.Fatalf("failed to connect to fake server: %v", err) + } + + var got DialInfo + select { + case got = <-dialCh: + case <-time.After(2 * time.Second): + t.Fatal("OnDial was not invoked within 2s") + } + if got.Err != nil { + t.Errorf("OnDial reported error on success: %v", got.Err) + } + if got.TLS { + t.Error("OnDial reported TLS for a non-TLS connection") + } + if got.ConnectDuration <= 0 { + t.Error("OnDial ConnectDuration should be positive") + } +} + +// TestOnDialHookConnectFailure verifies OnDial fires with a ConnectTimeoutError +// when the TCP connect fails, on the split-timeout path. +func TestOnDialHookConnectFailure(t *testing.T) { + // 192.0.2.0/24 (TEST-NET-1) is reserved and non-routable: connect will time out. + const blackhole = "192.0.2.1:11211" + dialCh := make(chan DialInfo, 1) + c := New(blackhole) + c.ConnectTimeout = 50 * time.Millisecond // short, forces split path + fast timeout + c.OnDial = func(info DialInfo) { dialCh <- info } + + addr, err := c.selector.PickServer("anykey") + if err != nil { + t.Fatalf("PickServer: %v", err) + } + _, derr := c.dial(addr) + if derr == nil { + t.Fatal("expected a connect failure, got nil") + } + if _, ok := derr.(*ConnectTimeoutError); !ok { + // Non-timeout connect errors are also acceptable on some networks, but a + // timeout is expected against a blackhole address. + t.Logf("got non-timeout dial error (acceptable): %v", derr) + } + + var got DialInfo + select { + case got = <-dialCh: + case <-time.After(2 * time.Second): + t.Fatal("OnDial was not invoked within 2s") + } + if got.Err == nil { + t.Error("OnDial should have reported the failure error") + } + if got.ConnectDuration <= 0 { + t.Error("OnDial ConnectDuration should be positive even on failure") + } +} + +// TestOnDialHookPanicRecovered verifies a panicking hook does not crash the +// caller (the panic is recovered on the hook's goroutine). +func TestOnDialHookPanicRecovered(t *testing.T) { + fakeServer, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Could not open fake server: ", err) + } + defer fakeServer.Close() + go func() { + for { + conn, err := fakeServer.Accept() + if err != nil { + return + } + go func() { io.Copy(ioutil.Discard, conn) }() + } + }() + + addr := fakeServer.Addr() + done := make(chan struct{}) + c := New(addr.String()) + c.ConnectTimeout = 500 * time.Millisecond + c.OnDial = func(info DialInfo) { + defer close(done) + panic("boom") + } + + // The dial must succeed even though the hook panics. + if _, err := c.getConn(addr); err != nil { + t.Fatalf("dial failed: %v", err) + } + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("hook was not invoked within 2s") + } + // Give the recover() a moment; if the panic were not recovered, the test + // binary would have crashed by now. +}