Skip to content

Commit 8c937df

Browse files
VTGate: fix warming reads timeout context (#19674)
Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 9e5141a commit 8c937df

5 files changed

Lines changed: 188 additions & 17 deletions

File tree

go/vt/vtgate/engine/fake_vcursor_test.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ func (t *noopVCursor) CloneForReplicaWarming(ctx context.Context) VCursor {
124124
panic("implement me")
125125
}
126126

127+
func (t *noopVCursor) WarmingReadsContext(ctx context.Context) (context.Context, context.CancelFunc) {
128+
panic("implement me")
129+
}
130+
127131
func (t *noopVCursor) CloneForMirroring(ctx context.Context) VCursor {
128132
panic("implement me")
129133
}
@@ -470,10 +474,11 @@ type loggingVCursor struct {
470474

471475
parser *sqlparser.Parser
472476

473-
onMirrorClonesFn func(context.Context) VCursor
474-
onExecuteMultiShardFn func(context.Context, Primitive, []*srvtopo.ResolvedShard, []*querypb.BoundQuery, bool, bool)
475-
onStreamExecuteMultiFn func(context.Context, Primitive, string, []*srvtopo.ResolvedShard, []map[string]*querypb.BindVariable, bool, bool, func(*sqltypes.Result) error)
476-
onRecordMirrorStatsFn func(time.Duration, time.Duration, error)
477+
onMirrorClonesFn func(context.Context) VCursor
478+
onExecuteMultiShardFn func(context.Context, Primitive, []*srvtopo.ResolvedShard, []*querypb.BoundQuery, bool, bool)
479+
onStreamExecuteMultiFn func(context.Context, Primitive, string, []*srvtopo.ResolvedShard, []map[string]*querypb.BindVariable, bool, bool, func(*sqltypes.Result) error)
480+
onRecordMirrorStatsFn func(time.Duration, time.Duration, error)
481+
onResolveDestinationsFn func(context.Context)
477482

478483
metrics *Metrics
479484
}
@@ -599,6 +604,10 @@ func (f *loggingVCursor) CloneForReplicaWarming(ctx context.Context) VCursor {
599604
return f
600605
}
601606

607+
func (f *loggingVCursor) WarmingReadsContext(ctx context.Context) (context.Context, context.CancelFunc) {
608+
return ctx, func() {}
609+
}
610+
602611
func (f *loggingVCursor) CloneForMirroring(ctx context.Context) VCursor {
603612
if f.onMirrorClonesFn != nil {
604613
return f.onMirrorClonesFn(ctx)
@@ -662,6 +671,9 @@ func (f *loggingVCursor) StreamExecuteMulti(ctx context.Context, primitive Primi
662671
}
663672

664673
func (f *loggingVCursor) ResolveDestinations(ctx context.Context, keyspace string, ids []*querypb.Value, destinations []key.ShardDestination) ([]*srvtopo.ResolvedShard, [][]*querypb.Value, error) {
674+
if f.onResolveDestinationsFn != nil {
675+
f.onResolveDestinationsFn(ctx)
676+
}
665677
f.log = append(f.log, fmt.Sprintf("ResolveDestinations %v %v %v", keyspace, ids, key.DestinationsString(destinations)))
666678
if f.shardErr != nil {
667679
return nil, nil, f.shardErr

go/vt/vtgate/engine/primitive.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,16 @@ type (
129129
// GetWarmingReadsChannel returns the channel for executing warming reads against replicas
130130
GetWarmingReadsChannel() chan bool
131131

132-
// CloneForReplicaWarming clones the VCursor for re-use in warming queries to replicas
132+
// CloneForReplicaWarming clones the VCursor for re-use in warming queries to replicas.
133+
// The clone must be created before launching the warming goroutine to avoid
134+
// concurrent access to the original VCursor. Use WarmingReadsContext on the
135+
// clone to obtain a timeout-bounded context for the warming query.
133136
CloneForReplicaWarming(ctx context.Context) VCursor
134137

138+
// WarmingReadsContext returns a timeout-bounded context for a warming query
139+
// and a cancel function that must be called when the warming query completes.
140+
WarmingReadsContext(ctx context.Context) (context.Context, context.CancelFunc)
141+
135142
// CloneForMirroring clones the VCursor for re-use in mirroring queries to other keyspaces
136143
CloneForMirroring(ctx context.Context) VCursor
137144
//

go/vt/vtgate/engine/route.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,22 +525,25 @@ func (route *Route) executeWarmingReplicaRead(ctx context.Context, vcursor VCurs
525525
}
526526
}
527527

528-
replicaVCursor := vcursor.CloneForReplicaWarming(ctx)
529528
warmingReadsChannel := vcursor.GetWarmingReadsChannel()
530529

531530
select {
532531
// if there's no more room in the channel, drop the warming read
533532
case warmingReadsChannel <- true:
533+
replicaVCursor := vcursor.CloneForReplicaWarming(ctx)
534534
go func(replicaVCursor VCursor) {
535+
warmingCtx, cancel := replicaVCursor.WarmingReadsContext(ctx)
536+
// Defers run LIFO: channel slot is released first, then context is canceled.
537+
defer cancel()
535538
defer func() {
536539
<-warmingReadsChannel
537540
}()
538-
rss, _, err := route.findRoute(ctx, replicaVCursor, bindVars)
541+
rss, _, err := route.findRoute(warmingCtx, replicaVCursor, bindVars)
539542
if err != nil {
540543
return
541544
}
542545

543-
_, errs := replicaVCursor.ExecuteMultiShard(ctx, route, rss, warmingQueries, false /*rollbackOnError*/, false /*canAutocommit*/, route.FetchLastInsertID)
546+
_, errs := replicaVCursor.ExecuteMultiShard(warmingCtx, route, rss, warmingQueries, false /*rollbackOnError*/, false /*canAutocommit*/, route.FetchLastInsertID)
544547
if len(errs) > 0 {
545548
log.Warn(fmt.Sprintf("Failed to execute warming replica read: %v", errs))
546549
} else {

go/vt/vtgate/engine/route_warming_reads_test.go

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type warmingReadsVCursor struct {
3737
*loggingVCursor
3838
warmingReadsPercent int
3939
warmingReadsChannel chan bool
40+
warmingReadsTimeout time.Duration
4041
warmingReadsExecuteFunc func(context.Context, Primitive, []*srvtopo.ResolvedShard, []*querypb.BoundQuery, bool, bool)
4142
}
4243

@@ -49,16 +50,26 @@ func (vc *warmingReadsVCursor) GetWarmingReadsChannel() chan bool {
4950
}
5051

5152
func (vc *warmingReadsVCursor) CloneForReplicaWarming(ctx context.Context) VCursor {
53+
clonedLogging := &loggingVCursor{
54+
shards: vc.shards,
55+
results: vc.results,
56+
onResolveDestinationsFn: vc.onResolveDestinationsFn,
57+
}
5258
clone := &warmingReadsVCursor{
53-
loggingVCursor: vc.loggingVCursor,
59+
loggingVCursor: clonedLogging,
5460
warmingReadsPercent: vc.warmingReadsPercent,
5561
warmingReadsChannel: vc.warmingReadsChannel,
62+
warmingReadsTimeout: vc.warmingReadsTimeout,
5663
warmingReadsExecuteFunc: vc.warmingReadsExecuteFunc,
5764
}
5865
clone.onExecuteMultiShardFn = vc.warmingReadsExecuteFunc
5966
return clone
6067
}
6168

69+
func (vc *warmingReadsVCursor) WarmingReadsContext(ctx context.Context) (context.Context, context.CancelFunc) {
70+
return context.WithTimeout(context.Background(), vc.warmingReadsTimeout)
71+
}
72+
6273
func TestWarmingReadsSkipsForUpdate(t *testing.T) {
6374
vindex, _ := vindexes.CreateVindex("hash", "", nil)
6475
testCases := []struct {
@@ -129,29 +140,160 @@ func TestWarmingReadsSkipsForUpdate(t *testing.T) {
129140

130141
var warmingReadExecuted atomic.Bool
131142
var capturedQuery string
143+
var capturedCtxHasDeadline atomic.Bool
144+
var capturedCtxErr atomic.Pointer[error]
145+
var resolveDestCtxHasDeadline atomic.Bool
146+
// done is closed by the test to unblock the warming read goroutine
147+
// after context assertions have been made.
148+
done := make(chan struct{})
132149
vc := &warmingReadsVCursor{
133150
loggingVCursor: &loggingVCursor{
134151
shards: []string{"-20", "20-"},
135152
results: []*sqltypes.Result{defaultSelectResult},
153+
onResolveDestinationsFn: func(ctx context.Context) {
154+
_, hasDeadline := ctx.Deadline()
155+
resolveDestCtxHasDeadline.Store(hasDeadline)
156+
},
136157
},
137158
warmingReadsPercent: 100,
138159
warmingReadsChannel: make(chan bool, 1),
160+
warmingReadsTimeout: 5 * time.Second,
139161
}
140162
vc.warmingReadsExecuteFunc = func(ctx context.Context, primitive Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, rollbackOnError, canAutocommit bool) {
141163
if len(queries) > 0 {
142164
capturedQuery = queries[0].Sql
143165
}
166+
_, hasDeadline := ctx.Deadline()
167+
capturedCtxHasDeadline.Store(hasDeadline)
168+
ctxErr := ctx.Err()
169+
capturedCtxErr.Store(&ctxErr)
144170
warmingReadExecuted.Store(true)
171+
// Block until the test has checked our context assertions,
172+
// preventing defer cancel() from running.
173+
select {
174+
case <-done:
175+
case <-t.Context().Done():
176+
}
145177
}
146178

147-
_, err := route.TryExecute(t.Context(), vc, map[string]*querypb.BindVariable{}, false)
179+
// Use a cancelable parent context to verify the warming read
180+
// context is independent of the parent request context.
181+
parentCtx, parentCancel := context.WithCancel(t.Context())
182+
_, err := route.TryExecute(parentCtx, vc, map[string]*querypb.BindVariable{}, false)
148183
require.NoError(t, err)
149184

185+
// Cancel the parent context to simulate the primary request completing.
186+
parentCancel()
187+
150188
require.Eventually(t, func() bool {
151189
return warmingReadExecuted.Load()
152190
}, time.Second, 10*time.Millisecond, "warming read should be executed")
153191

154192
require.Equal(t, tc.expectedWarmingQuery, capturedQuery, "warming read query should match expected")
193+
require.True(t, capturedCtxHasDeadline.Load(), "warming read context should have a deadline from the timeout")
194+
195+
// The warming read context should still be active even though the
196+
// parent request context was canceled.
197+
require.NoError(t, *capturedCtxErr.Load(), "warming read context should not be canceled when parent context is canceled")
198+
199+
// Verify findRoute received the warming context (with deadline), not the parent context.
200+
require.True(t, resolveDestCtxHasDeadline.Load(), "ResolveDestinations should receive a context with deadline from the warming timeout")
201+
202+
// Unblock the warming read goroutine.
203+
close(done)
155204
})
156205
}
157206
}
207+
208+
func TestWarmingReadsDroppedWhenChannelFull(t *testing.T) {
209+
vindex, _ := vindexes.CreateVindex("hash", "", nil)
210+
route := NewRoute(
211+
EqualUnique,
212+
&vindexes.Keyspace{
213+
Name: "ks",
214+
Sharded: true,
215+
},
216+
"SELECT * FROM users WHERE id = 1",
217+
"dummy_select_field",
218+
)
219+
parser, _ := sqlparser.NewTestParser().Parse("SELECT * FROM users WHERE id = 1")
220+
route.QueryStatement = parser
221+
route.Vindex = vindex.(vindexes.SingleColumn)
222+
route.Values = []evalengine.Expr{
223+
evalengine.NewLiteralInt(1),
224+
}
225+
226+
var warmingReadExecuted atomic.Bool
227+
vc := &warmingReadsVCursor{
228+
loggingVCursor: &loggingVCursor{
229+
shards: []string{"-20", "20-"},
230+
results: []*sqltypes.Result{defaultSelectResult},
231+
},
232+
warmingReadsPercent: 100,
233+
warmingReadsChannel: make(chan bool, 1),
234+
warmingReadsTimeout: 5 * time.Second,
235+
}
236+
vc.warmingReadsExecuteFunc = func(ctx context.Context, primitive Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, rollbackOnError, canAutocommit bool) {
237+
warmingReadExecuted.Store(true)
238+
}
239+
240+
// Pre-fill the channel to simulate a full pool.
241+
vc.warmingReadsChannel <- true
242+
243+
_, err := route.TryExecute(t.Context(), vc, map[string]*querypb.BindVariable{}, false)
244+
require.NoError(t, err)
245+
246+
// Verify over a short window that no warming read is executed while the channel is full.
247+
require.Never(t, func() bool {
248+
return warmingReadExecuted.Load()
249+
}, 100*time.Millisecond, 5*time.Millisecond, "warming read should not execute when the channel is full")
250+
// Drain the channel.
251+
<-vc.warmingReadsChannel
252+
}
253+
254+
func TestWarmingReadsContextTimeout(t *testing.T) {
255+
vindex, _ := vindexes.CreateVindex("hash", "", nil)
256+
route := NewRoute(
257+
EqualUnique,
258+
&vindexes.Keyspace{
259+
Name: "ks",
260+
Sharded: true,
261+
},
262+
"SELECT * FROM users WHERE id = 1",
263+
"dummy_select_field",
264+
)
265+
parser, _ := sqlparser.NewTestParser().Parse("SELECT * FROM users WHERE id = 1")
266+
route.QueryStatement = parser
267+
route.Vindex = vindex.(vindexes.SingleColumn)
268+
route.Values = []evalengine.Expr{
269+
evalengine.NewLiteralInt(1),
270+
}
271+
272+
var capturedCtxErr atomic.Pointer[error]
273+
var warmingReadExecuted atomic.Bool
274+
vc := &warmingReadsVCursor{
275+
loggingVCursor: &loggingVCursor{
276+
shards: []string{"-20", "20-"},
277+
results: []*sqltypes.Result{defaultSelectResult},
278+
},
279+
warmingReadsPercent: 100,
280+
warmingReadsChannel: make(chan bool, 1),
281+
warmingReadsTimeout: 1 * time.Millisecond,
282+
}
283+
vc.warmingReadsExecuteFunc = func(ctx context.Context, primitive Primitive, rss []*srvtopo.ResolvedShard, queries []*querypb.BoundQuery, rollbackOnError, canAutocommit bool) {
284+
// Block until the warming context times out.
285+
<-ctx.Done()
286+
ctxErr := ctx.Err()
287+
capturedCtxErr.Store(&ctxErr)
288+
warmingReadExecuted.Store(true)
289+
}
290+
291+
_, err := route.TryExecute(t.Context(), vc, map[string]*querypb.BindVariable{}, false)
292+
require.NoError(t, err)
293+
294+
require.Eventually(t, func() bool {
295+
return warmingReadExecuted.Load()
296+
}, time.Second, 10*time.Millisecond, "warming read should have been executed and timed out")
297+
298+
require.ErrorIs(t, *capturedCtxErr.Load(), context.DeadlineExceeded, "warming read context should have timed out")
299+
}

go/vt/vtgate/executorcontext/vcursor_impl.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,6 @@ func (vc *VCursorImpl) CloneForMirroring(ctx context.Context) engine.VCursor {
299299
}
300300

301301
func (vc *VCursorImpl) CloneForReplicaWarming(ctx context.Context) engine.VCursor {
302-
callerId := callerid.EffectiveCallerIDFromContext(ctx)
303-
immediateCallerId := callerid.ImmediateCallerIDFromContext(ctx)
304-
305-
timedCtx, _ := context.WithTimeout(context.Background(), vc.config.WarmingReadsTimeout) //nolint
306-
clonedCtx := callerid.NewContext(timedCtx, callerId, immediateCallerId)
307-
308302
v := &VCursorImpl{
309303
config: vc.config,
310304
SafeSession: NewAutocommitSession(vc.SafeSession.Session),
@@ -315,7 +309,7 @@ func (vc *VCursorImpl) CloneForReplicaWarming(ctx context.Context) engine.VCurso
315309
executor: vc.executor,
316310
resolver: vc.resolver,
317311
topoServer: vc.topoServer,
318-
logStats: &logstats.LogStats{Ctx: clonedCtx},
312+
logStats: &logstats.LogStats{},
319313
metrics: vc.metrics,
320314

321315
ignoreMaxMemoryRows: vc.ignoreMaxMemoryRows,
@@ -332,6 +326,19 @@ func (vc *VCursorImpl) CloneForReplicaWarming(ctx context.Context) engine.VCurso
332326
return v
333327
}
334328

329+
func (vc *VCursorImpl) WarmingReadsContext(ctx context.Context) (context.Context, context.CancelFunc) {
330+
callerId := callerid.EffectiveCallerIDFromContext(ctx)
331+
immediateCallerId := callerid.ImmediateCallerIDFromContext(ctx)
332+
333+
baseCtx := context.WithoutCancel(ctx)
334+
timedCtx, cancel := context.WithTimeout(baseCtx, vc.config.WarmingReadsTimeout)
335+
clonedCtx := callerid.NewContext(timedCtx, callerId, immediateCallerId)
336+
337+
vc.logStats = &logstats.LogStats{Ctx: clonedCtx}
338+
339+
return clonedCtx, cancel
340+
}
341+
335342
func (vc *VCursorImpl) cloneWithAutocommitSession() *VCursorImpl {
336343
safeSession := vc.SafeSession.NewAutocommitSession()
337344
return &VCursorImpl{

0 commit comments

Comments
 (0)