diff --git a/memcache/memcache.go b/memcache/memcache.go index 0b855dc..a99bdac 100644 --- a/memcache/memcache.go +++ b/memcache/memcache.go @@ -139,8 +139,38 @@ type Client struct { // DialTimeout specifies the timeout for establishing new connections, // including the TLS handshake if TLS is enabled. // If zero, Timeout is used (for backward compatibility). + // + // DialTimeout is a single budget covering both the TCP connection and the + // TLS handshake. Because the TLS handshake can be slow when the server is + // under CPU pressure, a single tight budget can abort handshakes mid-flight. + // Prefer ConnectTimeout + HandshakeTimeout (below) to budget them + // separately; DialTimeout is retained for backward compatibility and is used + // only when the split timeouts are zero. DialTimeout time.Duration + // ConnectTimeout, when non-zero, bounds only the TCP connection + // establishment (not the TLS handshake). Pair with HandshakeTimeout to + // budget connect and handshake independently. If zero, the dialer falls back + // to DialTimeout for the whole operation. + ConnectTimeout time.Duration + + // HandshakeTimeout, when non-zero, bounds only the TLS handshake (separate + // from the TCP connect). Has no effect for non-TLS clients. If zero, the TLS + // handshake shares the DialTimeout/ConnectTimeout budget (legacy behavior). + HandshakeTimeout time.Duration + + // OnDial, when non-nil, is invoked after every connection attempt (success + // or failure) with timing and outcome details. It enables callers to emit + // dial-count, connect/handshake-duration, and TLS-resumption metrics without + // the library taking a metrics dependency. + // + // OnDial is invoked on a separate goroutine so it can never add latency to + // the dial path; consequently it may run after the connection is already in + // use, and the relative ordering of callbacks is not guaranteed. A panic in + // the hook is recovered and discarded. The DialInfo is passed by value, so + // the hook owns its copy. + OnDial func(DialInfo) + // MaxIdleConns specifies the maximum number of idle connections that will // be maintained per address. If less than one, DefaultMaxIdleConns will be // used. @@ -156,6 +186,24 @@ type Client struct { TlsConfig *tls.Config } +// DialInfo reports the outcome of a single connection attempt to OnDial. +type DialInfo struct { + // Addr is the server address that was dialed. + Addr net.Addr + // TLS is true if the attempt included a TLS handshake. + TLS bool + // ConnectDuration is the time spent establishing the TCP connection. + ConnectDuration time.Duration + // HandshakeDuration is the time spent on the TLS handshake. Zero for + // non-TLS attempts or when the connect failed before the handshake began. + HandshakeDuration time.Duration + // Resumed is true if the TLS handshake resumed a previous session + // (abbreviated handshake). Only meaningful when TLS is true and Err is nil. + Resumed bool + // Err is the failure, or nil on success. + Err error +} + // Item is an item to be got or stored in a memcached server. type Item struct { // Key is the Item's key (250 bytes maximum). @@ -268,12 +316,42 @@ func (cte *ConnectTimeoutError) Error() string { return "memcache: connect timeout to " + cte.Addr.String() } +// useSplitTimeouts reports whether the caller configured independent connect +// and handshake budgets. When false, dial uses the legacy single-budget path. +func (c *Client) useSplitTimeouts() bool { + return c.ConnectTimeout != 0 || c.HandshakeTimeout != 0 +} + +// connectTimeout is the TCP-connect budget for the split-timeout path. It falls +// back to dialTimeout() when ConnectTimeout is unset so a caller can configure +// only HandshakeTimeout and still get a sane connect bound. +func (c *Client) connectTimeout() time.Duration { + if c.ConnectTimeout != 0 { + return c.ConnectTimeout + } + return c.dialTimeout() +} + +// handshakeTimeout is the TLS-handshake budget for the split-timeout path. It +// falls back to dialTimeout() when HandshakeTimeout is unset. +func (c *Client) handshakeTimeout() time.Duration { + if c.HandshakeTimeout != 0 { + return c.HandshakeTimeout + } + return c.dialTimeout() +} + func (c *Client) dial(addr net.Addr) (net.Conn, error) { - type connError struct { - cn net.Conn - err error + if c.useSplitTimeouts() { + return c.dialSplit(addr) } + return c.dialLegacy(addr) +} +// dialLegacy preserves the original behavior: a single dialer Timeout covers +// both the TCP connection and (for TLS) the handshake. +func (c *Client) dialLegacy(addr net.Addr) (net.Conn, error) { + start := time.Now() var ( nc net.Conn err error @@ -284,13 +362,109 @@ func (c *Client) dial(addr net.Addr) (net.Conn, error) { } else { nc, err = nd.Dial(addr.Network(), addr.String()) } + err = c.normalizeDialErr(addr, err) + c.reportDial(DialInfo{ + Addr: addr, + TLS: c.TlsConfig != nil, + ConnectDuration: time.Since(start), + Resumed: resumedFrom(nc, err), + Err: err, + }) + if err != nil { + return nil, err + } + return nc, nil +} + +// dialSplit budgets the TCP connect and the TLS handshake independently, so a +// slow handshake (e.g. a CPU-bound server) does not consume the connect budget +// and a tight connect budget can fail a dead host fast without aborting healthy +// but slow handshakes. +func (c *Client) dialSplit(addr net.Addr) (net.Conn, error) { + connectStart := time.Now() + nd := net.Dialer{Timeout: c.connectTimeout()} + rawConn, err := nd.Dial(addr.Network(), addr.String()) + connectDur := time.Since(connectStart) + if err != nil { + err = c.normalizeDialErr(addr, err) + c.reportDial(DialInfo{Addr: addr, TLS: c.TlsConfig != nil, ConnectDuration: connectDur, Err: err}) + return nil, err + } + + // Non-TLS: nothing more to do. + if c.TlsConfig == nil { + c.reportDial(DialInfo{Addr: addr, ConnectDuration: connectDur}) + return rawConn, nil + } + + // TLS: run the handshake under its own deadline. + handshakeStart := time.Now() + tlsConn := tls.Client(rawConn, c.TlsConfig) + if dl := c.handshakeTimeout(); dl > 0 { + _ = tlsConn.SetDeadline(time.Now().Add(dl)) + } + herr := tlsConn.Handshake() + handshakeDur := time.Since(handshakeStart) + // Clear the handshake deadline; per-op deadlines are managed elsewhere. + _ = tlsConn.SetDeadline(time.Time{}) + if herr != nil { + _ = tlsConn.Close() + herr = c.normalizeDialErr(addr, herr) + c.reportDial(DialInfo{ + Addr: addr, + TLS: true, + ConnectDuration: connectDur, + HandshakeDuration: handshakeDur, + Err: herr, + }) + return nil, herr + } + c.reportDial(DialInfo{ + Addr: addr, + TLS: true, + ConnectDuration: connectDur, + HandshakeDuration: handshakeDur, + Resumed: tlsConn.ConnectionState().DidResume, + }) + return tlsConn, nil +} + +// normalizeDialErr maps a timeout to ConnectTimeoutError, matching legacy +// behavior, and passes other errors through unchanged. +func (c *Client) normalizeDialErr(addr net.Addr, err error) error { if err == nil { - return nc, nil + return nil } if ne, ok := err.(net.Error); ok && ne.Timeout() { - return nil, &ConnectTimeoutError{addr} + return &ConnectTimeoutError{addr} } - return nil, err + return err +} + +// reportDial invokes the OnDial hook (if configured) on a separate goroutine so +// a slow hook cannot add latency to the dial path. A panic in the hook is +// recovered so it cannot crash the spawned goroutine (and thus the process). +func (c *Client) reportDial(info DialInfo) { + hook := c.OnDial + if hook == nil { + return + } + go func() { + defer func() { _ = recover() }() + hook(info) + }() +} + +// resumedFrom reports whether a successful TLS connection resumed a session. +// Returns false for failures or non-TLS connections. +func resumedFrom(nc net.Conn, err error) bool { + if err != nil || nc == nil { + return false + } + if tc, ok := nc.(*tls.Conn); ok { + return tc.ConnectionState().DidResume + } + return false } func (c *Client) getConn(addr net.Addr) (*conn, error) { diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index d8b48f0..320565c 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -469,3 +469,174 @@ func TestSeparateTimeouts(t *testing.T) { t.Error("DialTimeout and Timeout should be independent") } } + +// TestUseSplitTimeouts verifies the dialer takes the split-timeout path only +// when ConnectTimeout or HandshakeTimeout is configured. +func TestUseSplitTimeouts(t *testing.T) { + tests := []struct { + name string + connect time.Duration + handshk time.Duration + expected bool + }{ + {"neither set -> legacy", 0, 0, false}, + {"connect set -> split", 50 * time.Millisecond, 0, true}, + {"handshake set -> split", 0, 500 * time.Millisecond, true}, + {"both set -> split", 50 * time.Millisecond, 500 * time.Millisecond, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ConnectTimeout: tt.connect, HandshakeTimeout: tt.handshk} + if got := c.useSplitTimeouts(); got != tt.expected { + t.Errorf("useSplitTimeouts() = %v, want %v", got, tt.expected) + } + }) + } +} + +// TestSplitTimeoutFallbacks verifies connectTimeout()/handshakeTimeout() fall +// back to dialTimeout() (and thence Timeout/DefaultTimeout) when unset. +func TestSplitTimeoutFallbacks(t *testing.T) { + // Only handshake configured: connect should fall back to dialTimeout(). + c := &Client{HandshakeTimeout: 800 * time.Millisecond, DialTimeout: 175 * time.Millisecond} + if got := c.connectTimeout(); got != 175*time.Millisecond { + t.Errorf("connectTimeout() fallback = %v, want 175ms", got) + } + if got := c.handshakeTimeout(); got != 800*time.Millisecond { + t.Errorf("handshakeTimeout() = %v, want 800ms", got) + } + + // Only connect configured: handshake should fall back to dialTimeout(). + c2 := &Client{ConnectTimeout: 50 * time.Millisecond, Timeout: 90 * time.Millisecond} + if got := c2.connectTimeout(); got != 50*time.Millisecond { + t.Errorf("connectTimeout() = %v, want 50ms", got) + } + // dialTimeout() with no DialTimeout falls back to netTimeout() == Timeout. + if got := c2.handshakeTimeout(); got != 90*time.Millisecond { + t.Errorf("handshakeTimeout() fallback = %v, want 90ms", got) + } +} + +// TestOnDialHookSuccess verifies OnDial fires with timing on a successful +// (non-TLS) connect via the split-timeout path. +func TestOnDialHookSuccess(t *testing.T) { + fakeServer, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Could not open fake server: ", err) + } + defer fakeServer.Close() + go func() { + for { + conn, err := fakeServer.Accept() + if err != nil { + return + } + go func() { io.Copy(ioutil.Discard, conn) }() + } + }() + + addr := fakeServer.Addr() + // OnDial fires on a separate goroutine, so receive its result over a channel. + dialCh := make(chan DialInfo, 1) + c := New(addr.String()) + c.ConnectTimeout = 500 * time.Millisecond // force split path + c.OnDial = func(info DialInfo) { dialCh <- info } + + if _, err := c.getConn(addr); err != nil { + t.Fatalf("failed to connect to fake server: %v", err) + } + + var got DialInfo + select { + case got = <-dialCh: + case <-time.After(2 * time.Second): + t.Fatal("OnDial was not invoked within 2s") + } + if got.Err != nil { + t.Errorf("OnDial reported error on success: %v", got.Err) + } + if got.TLS { + t.Error("OnDial reported TLS for a non-TLS connection") + } + if got.ConnectDuration <= 0 { + t.Error("OnDial ConnectDuration should be positive") + } +} + +// TestOnDialHookConnectFailure verifies OnDial fires with a ConnectTimeoutError +// when the TCP connect fails, on the split-timeout path. +func TestOnDialHookConnectFailure(t *testing.T) { + // 192.0.2.0/24 (TEST-NET-1) is reserved and non-routable: connect will time out. + const blackhole = "192.0.2.1:11211" + dialCh := make(chan DialInfo, 1) + c := New(blackhole) + c.ConnectTimeout = 50 * time.Millisecond // short, forces split path + fast timeout + c.OnDial = func(info DialInfo) { dialCh <- info } + + addr, err := c.selector.PickServer("anykey") + if err != nil { + t.Fatalf("PickServer: %v", err) + } + _, derr := c.dial(addr) + if derr == nil { + t.Fatal("expected a connect failure, got nil") + } + if _, ok := derr.(*ConnectTimeoutError); !ok { + // Non-timeout connect errors are also acceptable on some networks, but a + // timeout is expected against a blackhole address. + t.Logf("got non-timeout dial error (acceptable): %v", derr) + } + + var got DialInfo + select { + case got = <-dialCh: + case <-time.After(2 * time.Second): + t.Fatal("OnDial was not invoked within 2s") + } + if got.Err == nil { + t.Error("OnDial should have reported the failure error") + } + if got.ConnectDuration <= 0 { + t.Error("OnDial ConnectDuration should be positive even on failure") + } +} + +// TestOnDialHookPanicRecovered verifies a panicking hook does not crash the +// caller (the panic is recovered on the hook's goroutine). +func TestOnDialHookPanicRecovered(t *testing.T) { + fakeServer, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal("Could not open fake server: ", err) + } + defer fakeServer.Close() + go func() { + for { + conn, err := fakeServer.Accept() + if err != nil { + return + } + go func() { io.Copy(ioutil.Discard, conn) }() + } + }() + + addr := fakeServer.Addr() + done := make(chan struct{}) + c := New(addr.String()) + c.ConnectTimeout = 500 * time.Millisecond + c.OnDial = func(info DialInfo) { + defer close(done) + panic("boom") + } + + // The dial must succeed even though the hook panics. + if _, err := c.getConn(addr); err != nil { + t.Fatalf("dial failed: %v", err) + } + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("hook was not invoked within 2s") + } + // Give the recover() a moment; if the panic were not recovered, the test + // binary would have crashed by now. +}