Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 180 additions & 6 deletions memcache/memcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,38 @@ type Client struct {
// DialTimeout specifies the timeout for establishing new connections,
// including the TLS handshake if TLS is enabled.
// If zero, Timeout is used (for backward compatibility).
//
// DialTimeout is a single budget covering both the TCP connection and the
// TLS handshake. Because the TLS handshake can be slow when the server is
// under CPU pressure, a single tight budget can abort handshakes mid-flight.
// Prefer ConnectTimeout + HandshakeTimeout (below) to budget them
// separately; DialTimeout is retained for backward compatibility and is used
// only when the split timeouts are zero.
DialTimeout time.Duration

// ConnectTimeout, when non-zero, bounds only the TCP connection
// establishment (not the TLS handshake). Pair with HandshakeTimeout to
// budget connect and handshake independently. If zero, the dialer falls back
// to DialTimeout for the whole operation.
ConnectTimeout time.Duration

// HandshakeTimeout, when non-zero, bounds only the TLS handshake (separate
// from the TCP connect). Has no effect for non-TLS clients. If zero, the TLS
// handshake shares the DialTimeout/ConnectTimeout budget (legacy behavior).
HandshakeTimeout time.Duration

// OnDial, when non-nil, is invoked after every connection attempt (success
// or failure) with timing and outcome details. It enables callers to emit
// dial-count, connect/handshake-duration, and TLS-resumption metrics without
// the library taking a metrics dependency.
//
// OnDial is invoked on a separate goroutine so it can never add latency to
// the dial path; consequently it may run after the connection is already in
// use, and the relative ordering of callbacks is not guaranteed. A panic in
// the hook is recovered and discarded. The DialInfo is passed by value, so
// the hook owns its copy.
OnDial func(DialInfo)

// MaxIdleConns specifies the maximum number of idle connections that will
// be maintained per address. If less than one, DefaultMaxIdleConns will be
// used.
Expand All @@ -156,6 +186,24 @@ type Client struct {
TlsConfig *tls.Config
}

// DialInfo reports the outcome of a single connection attempt to OnDial.
type DialInfo struct {
// Addr is the server address that was dialed.
Addr net.Addr
// TLS is true if the attempt included a TLS handshake.
TLS bool
// ConnectDuration is the time spent establishing the TCP connection.
ConnectDuration time.Duration
// HandshakeDuration is the time spent on the TLS handshake. Zero for
// non-TLS attempts or when the connect failed before the handshake began.
HandshakeDuration time.Duration
// Resumed is true if the TLS handshake resumed a previous session
// (abbreviated handshake). Only meaningful when TLS is true and Err is nil.
Resumed bool
// Err is the failure, or nil on success.
Err error
}

// Item is an item to be got or stored in a memcached server.
type Item struct {
// Key is the Item's key (250 bytes maximum).
Expand Down Expand Up @@ -268,12 +316,42 @@ func (cte *ConnectTimeoutError) Error() string {
return "memcache: connect timeout to " + cte.Addr.String()
}

// useSplitTimeouts reports whether the caller configured independent connect
// and handshake budgets. When false, dial uses the legacy single-budget path.
func (c *Client) useSplitTimeouts() bool {
return c.ConnectTimeout != 0 || c.HandshakeTimeout != 0
}

// connectTimeout is the TCP-connect budget for the split-timeout path. It falls
// back to dialTimeout() when ConnectTimeout is unset so a caller can configure
// only HandshakeTimeout and still get a sane connect bound.
func (c *Client) connectTimeout() time.Duration {
if c.ConnectTimeout != 0 {
return c.ConnectTimeout
}
return c.dialTimeout()
}

// handshakeTimeout is the TLS-handshake budget for the split-timeout path. It
// falls back to dialTimeout() when HandshakeTimeout is unset.
func (c *Client) handshakeTimeout() time.Duration {
if c.HandshakeTimeout != 0 {
return c.HandshakeTimeout
}
return c.dialTimeout()
}

func (c *Client) dial(addr net.Addr) (net.Conn, error) {
type connError struct {
cn net.Conn
err error
if c.useSplitTimeouts() {
return c.dialSplit(addr)
}
return c.dialLegacy(addr)
}

// dialLegacy preserves the original behavior: a single dialer Timeout covers
// both the TCP connection and (for TLS) the handshake.
func (c *Client) dialLegacy(addr net.Addr) (net.Conn, error) {
start := time.Now()
var (
nc net.Conn
err error
Expand All @@ -284,13 +362,109 @@ func (c *Client) dial(addr net.Addr) (net.Conn, error) {
} else {
nc, err = nd.Dial(addr.Network(), addr.String())
}
err = c.normalizeDialErr(addr, err)
c.reportDial(DialInfo{
Addr: addr,
TLS: c.TlsConfig != nil,
ConnectDuration: time.Since(start),
Resumed: resumedFrom(nc, err),
Err: err,
})
if err != nil {
return nil, err
}
return nc, nil
}

// dialSplit budgets the TCP connect and the TLS handshake independently, so a
// slow handshake (e.g. a CPU-bound server) does not consume the connect budget
// and a tight connect budget can fail a dead host fast without aborting healthy
// but slow handshakes.
func (c *Client) dialSplit(addr net.Addr) (net.Conn, error) {
connectStart := time.Now()
nd := net.Dialer{Timeout: c.connectTimeout()}
rawConn, err := nd.Dial(addr.Network(), addr.String())
connectDur := time.Since(connectStart)
if err != nil {
err = c.normalizeDialErr(addr, err)
c.reportDial(DialInfo{Addr: addr, TLS: c.TlsConfig != nil, ConnectDuration: connectDur, Err: err})
return nil, err
}

// Non-TLS: nothing more to do.
if c.TlsConfig == nil {
c.reportDial(DialInfo{Addr: addr, ConnectDuration: connectDur})
return rawConn, nil
}

// TLS: run the handshake under its own deadline.
handshakeStart := time.Now()
tlsConn := tls.Client(rawConn, c.TlsConfig)
if dl := c.handshakeTimeout(); dl > 0 {
_ = tlsConn.SetDeadline(time.Now().Add(dl))
}
herr := tlsConn.Handshake()
handshakeDur := time.Since(handshakeStart)
// Clear the handshake deadline; per-op deadlines are managed elsewhere.
_ = tlsConn.SetDeadline(time.Time{})
if herr != nil {
_ = tlsConn.Close()
herr = c.normalizeDialErr(addr, herr)
c.reportDial(DialInfo{
Addr: addr,
TLS: true,
ConnectDuration: connectDur,
HandshakeDuration: handshakeDur,
Err: herr,
})
return nil, herr
}
c.reportDial(DialInfo{
Addr: addr,
TLS: true,
ConnectDuration: connectDur,
HandshakeDuration: handshakeDur,
Resumed: tlsConn.ConnectionState().DidResume,
})
return tlsConn, nil
}

// normalizeDialErr maps a timeout to ConnectTimeoutError, matching legacy
// behavior, and passes other errors through unchanged.
func (c *Client) normalizeDialErr(addr net.Addr, err error) error {
if err == nil {
return nc, nil
return nil
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
return nil, &ConnectTimeoutError{addr}
return &ConnectTimeoutError{addr}
}
return nil, err
return err
}

// reportDial invokes the OnDial hook (if configured) on a separate goroutine so
// a slow hook cannot add latency to the dial path. A panic in the hook is
// recovered so it cannot crash the spawned goroutine (and thus the process).
func (c *Client) reportDial(info DialInfo) {
hook := c.OnDial
if hook == nil {
return
}
go func() {
defer func() { _ = recover() }()
hook(info)
}()
}

// resumedFrom reports whether a successful TLS connection resumed a session.
// Returns false for failures or non-TLS connections.
func resumedFrom(nc net.Conn, err error) bool {
if err != nil || nc == nil {
return false
}
if tc, ok := nc.(*tls.Conn); ok {
return tc.ConnectionState().DidResume
}
return false
}

func (c *Client) getConn(addr net.Addr) (*conn, error) {
Expand Down
171 changes: 171 additions & 0 deletions memcache/memcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,174 @@ func TestSeparateTimeouts(t *testing.T) {
t.Error("DialTimeout and Timeout should be independent")
}
}

// TestUseSplitTimeouts verifies the dialer takes the split-timeout path only
// when ConnectTimeout or HandshakeTimeout is configured.
func TestUseSplitTimeouts(t *testing.T) {
tests := []struct {
name string
connect time.Duration
handshk time.Duration
expected bool
}{
{"neither set -> legacy", 0, 0, false},
{"connect set -> split", 50 * time.Millisecond, 0, true},
{"handshake set -> split", 0, 500 * time.Millisecond, true},
{"both set -> split", 50 * time.Millisecond, 500 * time.Millisecond, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{ConnectTimeout: tt.connect, HandshakeTimeout: tt.handshk}
if got := c.useSplitTimeouts(); got != tt.expected {
t.Errorf("useSplitTimeouts() = %v, want %v", got, tt.expected)
}
})
}
}

// TestSplitTimeoutFallbacks verifies connectTimeout()/handshakeTimeout() fall
// back to dialTimeout() (and thence Timeout/DefaultTimeout) when unset.
func TestSplitTimeoutFallbacks(t *testing.T) {
// Only handshake configured: connect should fall back to dialTimeout().
c := &Client{HandshakeTimeout: 800 * time.Millisecond, DialTimeout: 175 * time.Millisecond}
if got := c.connectTimeout(); got != 175*time.Millisecond {
t.Errorf("connectTimeout() fallback = %v, want 175ms", got)
}
if got := c.handshakeTimeout(); got != 800*time.Millisecond {
t.Errorf("handshakeTimeout() = %v, want 800ms", got)
}

// Only connect configured: handshake should fall back to dialTimeout().
c2 := &Client{ConnectTimeout: 50 * time.Millisecond, Timeout: 90 * time.Millisecond}
if got := c2.connectTimeout(); got != 50*time.Millisecond {
t.Errorf("connectTimeout() = %v, want 50ms", got)
}
// dialTimeout() with no DialTimeout falls back to netTimeout() == Timeout.
if got := c2.handshakeTimeout(); got != 90*time.Millisecond {
t.Errorf("handshakeTimeout() fallback = %v, want 90ms", got)
}
}

// TestOnDialHookSuccess verifies OnDial fires with timing on a successful
// (non-TLS) connect via the split-timeout path.
func TestOnDialHookSuccess(t *testing.T) {
fakeServer, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal("Could not open fake server: ", err)
}
defer fakeServer.Close()
go func() {
for {
conn, err := fakeServer.Accept()
if err != nil {
return
}
go func() { io.Copy(ioutil.Discard, conn) }()
}
}()

addr := fakeServer.Addr()
// OnDial fires on a separate goroutine, so receive its result over a channel.
dialCh := make(chan DialInfo, 1)
c := New(addr.String())
c.ConnectTimeout = 500 * time.Millisecond // force split path
c.OnDial = func(info DialInfo) { dialCh <- info }

if _, err := c.getConn(addr); err != nil {
t.Fatalf("failed to connect to fake server: %v", err)
}

var got DialInfo
select {
case got = <-dialCh:
case <-time.After(2 * time.Second):
t.Fatal("OnDial was not invoked within 2s")
}
if got.Err != nil {
t.Errorf("OnDial reported error on success: %v", got.Err)
}
if got.TLS {
t.Error("OnDial reported TLS for a non-TLS connection")
}
if got.ConnectDuration <= 0 {
t.Error("OnDial ConnectDuration should be positive")
}
}

// TestOnDialHookConnectFailure verifies OnDial fires with a ConnectTimeoutError
// when the TCP connect fails, on the split-timeout path.
func TestOnDialHookConnectFailure(t *testing.T) {
// 192.0.2.0/24 (TEST-NET-1) is reserved and non-routable: connect will time out.
const blackhole = "192.0.2.1:11211"
dialCh := make(chan DialInfo, 1)
c := New(blackhole)
c.ConnectTimeout = 50 * time.Millisecond // short, forces split path + fast timeout
c.OnDial = func(info DialInfo) { dialCh <- info }

addr, err := c.selector.PickServer("anykey")
if err != nil {
t.Fatalf("PickServer: %v", err)
}
_, derr := c.dial(addr)
if derr == nil {
t.Fatal("expected a connect failure, got nil")
}
if _, ok := derr.(*ConnectTimeoutError); !ok {
// Non-timeout connect errors are also acceptable on some networks, but a
// timeout is expected against a blackhole address.
t.Logf("got non-timeout dial error (acceptable): %v", derr)
}

var got DialInfo
select {
case got = <-dialCh:
case <-time.After(2 * time.Second):
t.Fatal("OnDial was not invoked within 2s")
}
if got.Err == nil {
t.Error("OnDial should have reported the failure error")
}
if got.ConnectDuration <= 0 {
t.Error("OnDial ConnectDuration should be positive even on failure")
}
}

// TestOnDialHookPanicRecovered verifies a panicking hook does not crash the
// caller (the panic is recovered on the hook's goroutine).
func TestOnDialHookPanicRecovered(t *testing.T) {
fakeServer, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal("Could not open fake server: ", err)
}
defer fakeServer.Close()
go func() {
for {
conn, err := fakeServer.Accept()
if err != nil {
return
}
go func() { io.Copy(ioutil.Discard, conn) }()
}
}()

addr := fakeServer.Addr()
done := make(chan struct{})
c := New(addr.String())
c.ConnectTimeout = 500 * time.Millisecond
c.OnDial = func(info DialInfo) {
defer close(done)
panic("boom")
}

// The dial must succeed even though the hook panics.
if _, err := c.getConn(addr); err != nil {
t.Fatalf("dial failed: %v", err)
}
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("hook was not invoked within 2s")
}
// Give the recover() a moment; if the panic were not recovered, the test
// binary would have crashed by now.
}