From 615f118fa65b78ffe9342f9bbfdd3cea534261f3 Mon Sep 17 00:00:00 2001 From: Bill Guowei Yang Date: Wed, 10 Jun 2026 11:44:33 -0400 Subject: [PATCH 1/7] Wait for worker query close before session reuse --- duckdbservice/flight_handler.go | 27 ++++ duckdbservice/flight_handler_test.go | 54 ++++++++ duckdbservice/service.go | 35 +++++ server/flightclient/flight_executor.go | 100 +++++++++++--- server/flightclient/flight_executor_test.go | 145 ++++++++++++++++++++ server/wire/worker_proto.go | 6 + server/worker_control.go | 9 +- tests/e2e-mw-dev/README.md | 7 + 8 files changed, 359 insertions(+), 24 deletions(-) diff --git a/duckdbservice/flight_handler.go b/duckdbservice/flight_handler.go index 75635870..5fb6baa0 100644 --- a/duckdbservice/flight_handler.go +++ b/duckdbservice/flight_handler.go @@ -433,6 +433,33 @@ func (h *FlightSQLHandler) doHealthCheck(body []byte, stream flight.FlightServic return sendActionResult(stream, &flight.Result{Body: resp}) } +func (h *FlightSQLHandler) doWaitSessionIdle(body []byte, stream flight.FlightService_DoActionServer) error { + var req server.WorkerWaitSessionIdlePayload + if err := json.Unmarshal(body, &req); err != nil { + return status.Errorf(codes.InvalidArgument, "invalid WaitSessionIdle request: %v", err) + } + if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil { + return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err) + } + session, err := h.sessionFromContext(stream.Context()) + if err != nil { + return err + } + if err := session.waitOperationIdle(stream.Context()); err != nil { + switch { + case errors.Is(err, context.Canceled): + return status.Errorf(codes.Canceled, "wait for session idle: %v", err) + case errors.Is(err, context.DeadlineExceeded): + return status.Errorf(codes.DeadlineExceeded, "wait for session idle: %v", err) + default: + return status.Errorf(codes.Internal, "wait for session idle: %v", err) + } + } + + resp, _ := json.Marshal(map[string]bool{"ok": true}) + return sendActionResult(stream, &flight.Result{Body: resp}) +} + // Flight SQL method implementations func (h *FlightSQLHandler) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, diff --git a/duckdbservice/flight_handler_test.go b/duckdbservice/flight_handler_test.go index 7a54d491..11794a38 100644 --- a/duckdbservice/flight_handler_test.go +++ b/duckdbservice/flight_handler_test.go @@ -473,6 +473,60 @@ func TestCreateSessionRejectsWhileDraining(t *testing.T) { } } +func TestWaitSessionIdleBlocksUntilOperationReleases(t *testing.T) { + session := &Session{ + ID: "session-1", + Username: "alice", + CreatedAt: time.Now(), + queries: make(map[string]*QueryHandle), + txns: make(map[string]*trackedTx), + txnOwner: make(map[string]string), + } + pool := &SessionPool{ + sessions: map[string]*Session{session.ID: session}, + stopRefresh: make(map[string]func()), + warmupDone: make(chan struct{}), + startTime: time.Now(), + } + close(pool.warmupDone) + handler := &FlightSQLHandler{pool: pool, alloc: memory.DefaultAllocator} + + finishOperation, ok := session.beginOperation() + if !ok { + t.Fatal("beginOperation rejected test session") + } + + body, err := json.Marshal(server.WorkerWaitSessionIdlePayload{}) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", session.ID)) + stream := &mockDoActionStream{ctx: ctx} + done := make(chan error, 1) + go func() { + done <- handler.doWaitSessionIdle(body, stream) + }() + + select { + case err := <-done: + t.Fatalf("WaitSessionIdle returned before operation released: %v", err) + case <-time.After(50 * time.Millisecond): + } + + finishOperation() + select { + case err := <-done: + if err != nil { + t.Fatalf("WaitSessionIdle returned error: %v", err) + } + case <-time.After(time.Second): + t.Fatal("WaitSessionIdle did not return after operation released") + } + if len(stream.results) != 1 { + t.Fatalf("expected one action result, got %d", len(stream.results)) + } +} + func TestCreateSessionSendFailureDestroysSession(t *testing.T) { pool := &SessionPool{ sessions: make(map[string]*Session), diff --git a/duckdbservice/service.go b/duckdbservice/service.go index f5bc25f7..bdac8f9d 100644 --- a/duckdbservice/service.go +++ b/duckdbservice/service.go @@ -123,6 +123,7 @@ type Session struct { txnOwner map[string]string closed bool operationOpen bool + operationIdle chan struct{} connWork int connWorkDone *sync.Cond handleCounter atomic.Uint64 @@ -234,6 +235,7 @@ func (s *Session) beginOperation() (func(), bool) { return nil, false } s.operationOpen = true + s.operationIdle = make(chan struct{}) s.mu.Unlock() var once sync.Once @@ -241,6 +243,10 @@ func (s *Session) beginOperation() (func(), bool) { once.Do(func() { s.mu.Lock() s.operationOpen = false + if s.operationIdle != nil { + close(s.operationIdle) + s.operationIdle = nil + } s.mu.Unlock() }) }, true @@ -257,6 +263,7 @@ func (s *Session) beginOperationForTransaction(txnKey string) (func(), bool, boo return nil, true, false } s.operationOpen = true + s.operationIdle = make(chan struct{}) s.mu.Unlock() var once sync.Once @@ -264,11 +271,37 @@ func (s *Session) beginOperationForTransaction(txnKey string) (func(), bool, boo once.Do(func() { s.mu.Lock() s.operationOpen = false + if s.operationIdle != nil { + close(s.operationIdle) + s.operationIdle = nil + } s.mu.Unlock() }) }, true, true } +func (s *Session) waitOperationIdle(ctx context.Context) error { + for { + s.mu.RLock() + if !s.operationOpen { + s.mu.RUnlock() + return nil + } + idle := s.operationIdle + s.mu.RUnlock() + + if idle == nil { + return errors.New("session operation idle signal missing") + } + select { + case <-idle: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } +} + // beginConnWork fences any operation that uses the session connection while a // raw SQL transaction may be open. It intentionally does not mutate queryActive: // conn work includes COPY receive and metadata/planning work, while queryActive @@ -1593,6 +1626,8 @@ func (s *customActionServer) DoAction(cmd *flight.Action, stream flight.FlightSe return s.handler.doDestroySession(cmd.Body, stream) case "HealthCheck": return s.handler.doHealthCheck(cmd.Body, stream) + case "WaitSessionIdle": + return s.handler.doWaitSessionIdle(cmd.Body, stream) default: // Fall through to standard flightsql action router (BeginTransaction, etc.) return s.FlightServer.DoAction(cmd, stream) diff --git a/server/flightclient/flight_executor.go b/server/flightclient/flight_executor.go index fbf6283a..f3bf4f0d 100644 --- a/server/flightclient/flight_executor.go +++ b/server/flightclient/flight_executor.go @@ -3,15 +3,17 @@ package flightclient import ( "context" "encoding/hex" + "encoding/json" "errors" "fmt" + "io" "math/big" "runtime" + "strconv" "strings" "sync" "sync/atomic" "time" - "strconv" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" @@ -21,6 +23,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/posthog/duckgres/duckdbservice/arrowmap" "github.com/posthog/duckgres/server/sqlcore" + "github.com/posthog/duckgres/server/wire" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -30,6 +33,11 @@ import ( // DuckDB query results can easily exceed the default 4MB limit. const MaxGRPCMessageSize = 1 << 30 // 1GB +const ( + waitSessionIdleAction = "WaitSessionIdle" + queryCloseWaitTimeout = 30 * time.Second +) + // ErrWorkerDead is returned when the backing worker process has crashed. var ErrWorkerDead = errors.New("flight worker is dead") @@ -215,9 +223,10 @@ func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args .. success = true return &FlightRowSet{ - reader: reader, - schema: schema, - cancel: cancel, + reader: reader, + schema: schema, + cancel: cancel, + waitForClosed: e.waitForSessionIdle, }, nil } @@ -307,6 +316,52 @@ func (e *FlightExecutor) Close() error { return nil } +func (e *FlightExecutor) waitForSessionIdle() error { + if e.client == nil || e.client.Client == nil { + return nil + } + + payload, err := json.Marshal(wire.WorkerWaitSessionIdlePayload{ + WorkerControlMetadata: wire.WorkerControlMetadata{ + WorkerID: e.workerID, + OwnerEpoch: e.ownerEpoch, + CPInstanceID: e.cpInstanceID, + }, + }) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), queryCloseWaitTimeout) + defer cancel() + if e.ctx != nil { + go func() { + select { + case <-e.ctx.Done(): + cancel() + case <-ctx.Done(): + } + }() + } + + stream, err := e.client.Client.DoAction( + e.withSession(ctx), + &flight.Action{Type: waitSessionIdleAction, Body: payload}, + ) + if err != nil { + return err + } + for { + _, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return err + } + } +} + func (e *FlightExecutor) LastProfilingOutput() string { v := e.lastProfiling.Load() if v == nil { @@ -331,12 +386,14 @@ type FlightRowSet struct { schema *arrow.Schema // Current batch state - currentBatch arrow.RecordBatch - batchRow int // current row index within currentBatch - done bool - err error - closeOnce sync.Once - cancel context.CancelFunc + currentBatch arrow.RecordBatch + batchRow int // current row index within currentBatch + done bool + err error + closeOnce sync.Once + closeErr error + cancel context.CancelFunc + waitForClosed func() error } func (r *FlightRowSet) Columns() ([]string, error) { @@ -422,8 +479,11 @@ func (r *FlightRowSet) Close() error { r.currentBatch = nil } r.reader.Release() + if r.waitForClosed != nil { + r.closeErr = r.waitForClosed() + } }) - return nil + return r.closeErr } func (r *FlightRowSet) Err() error { @@ -433,12 +493,12 @@ func (r *FlightRowSet) Err() error { // emptyRowSet is returned when a query produces no endpoints and no schema. type emptyRowSet struct{} -func (e *emptyRowSet) Columns() ([]string, error) { return nil, nil } +func (e *emptyRowSet) Columns() ([]string, error) { return nil, nil } func (e *emptyRowSet) ColumnTypes() ([]sqlcore.ColumnTyper, error) { return nil, nil } -func (e *emptyRowSet) Next() bool { return false } -func (e *emptyRowSet) Scan(dest ...any) error { return fmt.Errorf("no rows") } -func (e *emptyRowSet) Close() error { return nil } -func (e *emptyRowSet) Err() error { return nil } +func (e *emptyRowSet) Next() bool { return false } +func (e *emptyRowSet) Scan(dest ...any) error { return fmt.Errorf("no rows") } +func (e *emptyRowSet) Close() error { return nil } +func (e *emptyRowSet) Err() error { return nil } // emptySchemaRowSet is returned when a query produces no data rows but does // have schema information (e.g., SELECT ... LIMIT 0). This preserves column @@ -464,10 +524,10 @@ func (e *emptySchemaRowSet) ColumnTypes() ([]sqlcore.ColumnTyper, error) { return types, nil } -func (e *emptySchemaRowSet) Next() bool { return false } +func (e *emptySchemaRowSet) Next() bool { return false } func (e *emptySchemaRowSet) Scan(...any) error { return fmt.Errorf("no rows") } -func (e *emptySchemaRowSet) Close() error { return nil } -func (e *emptySchemaRowSet) Err() error { return nil } +func (e *emptySchemaRowSet) Close() error { return nil } +func (e *emptySchemaRowSet) Err() error { return nil } // flightExecResult implements ExecResult for Flight SQL updates. type flightExecResult struct { @@ -781,7 +841,7 @@ func interpolateArgs(query string, args []any) string { } // scanQuoted returns the index just past a quoted region starting at start -// (query[start] == quote), treating a doubled quote ('' or "") as an escape. +// (query[start] == quote), treating a doubled quote (” or "") as an escape. func scanQuoted(query string, start int, quote byte) int { for i := start + 1; i < len(query); i++ { if query[i] != quote { diff --git a/server/flightclient/flight_executor_test.go b/server/flightclient/flight_executor_test.go index fc2e1705..e377f2c2 100644 --- a/server/flightclient/flight_executor_test.go +++ b/server/flightclient/flight_executor_test.go @@ -2,8 +2,21 @@ package flightclient import ( "context" + "encoding/json" + "net" "testing" + "time" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/flight" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" + pb "github.com/apache/arrow-go/v18/arrow/flight/gen/flight" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/posthog/duckgres/server/wire" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" ) @@ -30,3 +43,135 @@ func TestFlightExecutorWithSessionAddsOwnerEpochHeader(t *testing.T) { t.Fatalf("unexpected cp instance metadata: %#v", got) } } + +type closeWaitFlightServer struct { + pb.UnimplementedFlightServiceServer + + schema *arrow.Schema + + doGetStarted chan struct{} + doGetContextCanceled chan struct{} + allowDoGetReturn chan struct{} + doGetDone chan struct{} +} + +func newCloseWaitFlightServer() *closeWaitFlightServer { + return &closeWaitFlightServer{ + schema: arrow.NewSchema([]arrow.Field{{Name: "x", Type: arrow.PrimitiveTypes.Int64}}, nil), + doGetStarted: make(chan struct{}), + doGetContextCanceled: make(chan struct{}), + allowDoGetReturn: make(chan struct{}), + doGetDone: make(chan struct{}), + } +} + +func (s *closeWaitFlightServer) GetFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.FlightInfo, error) { + return &flight.FlightInfo{ + Schema: flight.SerializeSchema(s.schema, memory.DefaultAllocator), + Endpoint: []*flight.FlightEndpoint{{ + Ticket: &flight.Ticket{Ticket: []byte("query-1")}, + }}, + }, nil +} + +func (s *closeWaitFlightServer) DoGet(_ *flight.Ticket, stream pb.FlightService_DoGetServer) error { + close(s.doGetStarted) + + builder := array.NewInt64Builder(memory.DefaultAllocator) + builder.Append(1) + arr := builder.NewArray() + builder.Release() + defer arr.Release() + + record := array.NewRecordBatch(s.schema, []arrow.Array{arr}, 1) + defer record.Release() + + writer := flight.NewRecordWriter(stream, ipc.WithSchema(s.schema), ipc.WithAllocator(memory.DefaultAllocator)) + if err := writer.Write(record); err != nil { + return err + } + + <-stream.Context().Done() + close(s.doGetContextCanceled) + <-s.allowDoGetReturn + close(s.doGetDone) + return stream.Context().Err() +} + +func (s *closeWaitFlightServer) DoAction(action *flight.Action, stream pb.FlightService_DoActionServer) error { + if action.Type != waitSessionIdleAction { + return s.UnimplementedFlightServiceServer.DoAction(action, stream) + } + var payload wire.WorkerWaitSessionIdlePayload + if err := json.Unmarshal(action.Body, &payload); err != nil { + return err + } + <-s.doGetDone + return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) +} + +func TestFlightRowSetCloseWaitsForWorkerDoGetCleanup(t *testing.T) { + srv := newCloseWaitFlightServer() + + grpcSrv := grpc.NewServer() + pb.RegisterFlightServiceServer(grpcSrv, srv) + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + go func() { _ = grpcSrv.Serve(lis) }() + defer grpcSrv.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + flightCli, err := flight.NewClientWithMiddlewareCtx( + ctx, lis.Addr().String(), nil, nil, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("flight client: %v", err) + } + defer func() { _ = flightCli.Close() }() + + exec := NewFlightExecutorFromClient(&flightsql.Client{Client: flightCli}, "session-1") + defer func() { _ = exec.Close() }() + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("QueryContext: %v", err) + } + select { + case <-srv.doGetStarted: + case <-time.After(time.Second): + t.Fatal("DoGet did not start") + } + + closeReturned := make(chan error, 1) + go func() { + closeReturned <- rows.Close() + }() + + select { + case err := <-closeReturned: + t.Fatalf("Close returned before worker DoGet cleanup completed: %v", err) + case <-srv.doGetContextCanceled: + case <-time.After(time.Second): + t.Fatal("worker DoGet did not observe cancellation") + } + + select { + case err := <-closeReturned: + t.Fatalf("Close returned while worker DoGet cleanup was still blocked: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(srv.allowDoGetReturn) + select { + case err := <-closeReturned: + if err != nil { + t.Fatalf("Close returned error: %v", err) + } + case <-time.After(time.Second): + t.Fatal("Close did not return after worker DoGet cleanup completed") + } +} diff --git a/server/wire/worker_proto.go b/server/wire/worker_proto.go index 164e0c2e..27fe55e8 100644 --- a/server/wire/worker_proto.go +++ b/server/wire/worker_proto.go @@ -45,3 +45,9 @@ type WorkerDestroySessionPayload struct { type WorkerHealthCheckPayload struct { WorkerControlMetadata } + +// WorkerWaitSessionIdlePayload asks a worker to acknowledge that the session's +// current Flight SQL operation has released its worker-side lifecycle state. +type WorkerWaitSessionIdlePayload struct { + WorkerControlMetadata +} diff --git a/server/worker_control.go b/server/worker_control.go index 38543837..dde89cb1 100644 --- a/server/worker_control.go +++ b/server/worker_control.go @@ -9,8 +9,9 @@ import "github.com/posthog/duckgres/server/wire" // wire.X directly. type ( - WorkerControlMetadata = wire.WorkerControlMetadata - WorkerCreateSessionPayload = wire.WorkerCreateSessionPayload - WorkerDestroySessionPayload = wire.WorkerDestroySessionPayload - WorkerHealthCheckPayload = wire.WorkerHealthCheckPayload + WorkerControlMetadata = wire.WorkerControlMetadata + WorkerCreateSessionPayload = wire.WorkerCreateSessionPayload + WorkerDestroySessionPayload = wire.WorkerDestroySessionPayload + WorkerHealthCheckPayload = wire.WorkerHealthCheckPayload + WorkerWaitSessionIdlePayload = wire.WorkerWaitSessionIdlePayload ) diff --git a/tests/e2e-mw-dev/README.md b/tests/e2e-mw-dev/README.md index 48fd1b1f..9dcd1f40 100644 --- a/tests/e2e-mw-dev/README.md +++ b/tests/e2e-mw-dev/README.md @@ -141,6 +141,13 @@ normal `go test ./...` lane. GetFlightInfo-to-DoGet handoffs, and abandoned-continuation cleanup are covered by `duckdbservice` unit tests. The harness still asserts the cluster invariant this protects: one active session owns one worker. +- **Worker DoGet close acknowledgement** — asserting that an early client + close waits for worker-side DoGet cleanup requires pausing the worker exactly + after it observes gRPC cancellation but before it releases the session + operation token. The in-cluster harness only has normal clients, so it cannot + deterministically hold that cleanup window without a bespoke worker/Flight + fault-injection client. Covered by `server/flightclient` and `duckdbservice` + unit tests instead. - **Malformed Bind message validation (#720)** — negative count/length fields in a Bind message must return a clean `08P01` instead of panicking. Every real client (psql, lib/pq, ...) only emits well-formed Bind messages, so From 42582188263ff652599b93549e8e934367e08c61 Mon Sep 17 00:00:00 2001 From: Bill Guowei Yang Date: Wed, 10 Jun 2026 12:39:09 -0400 Subject: [PATCH 2/7] Trim worker close wait path --- server/flightclient/flight_executor.go | 8 +- server/flightclient/flight_executor_test.go | 121 ++++++++++++++++++++ 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/server/flightclient/flight_executor.go b/server/flightclient/flight_executor.go index f3bf4f0d..660ee728 100644 --- a/server/flightclient/flight_executor.go +++ b/server/flightclient/flight_executor.go @@ -316,10 +316,14 @@ func (e *FlightExecutor) Close() error { return nil } -func (e *FlightExecutor) waitForSessionIdle() error { +func (e *FlightExecutor) waitForSessionIdle() (err error) { + if e.dead.Load() { + return nil + } if e.client == nil || e.client.Client == nil { return nil } + defer recoverClientPanic(&err) payload, err := json.Marshal(wire.WorkerWaitSessionIdlePayload{ WorkerControlMetadata: wire.WorkerControlMetadata{ @@ -479,7 +483,7 @@ func (r *FlightRowSet) Close() error { r.currentBatch = nil } r.reader.Release() - if r.waitForClosed != nil { + if r.waitForClosed != nil && (!r.done || r.err != nil) { r.closeErr = r.waitForClosed() } }) diff --git a/server/flightclient/flight_executor_test.go b/server/flightclient/flight_executor_test.go index e377f2c2..32186258 100644 --- a/server/flightclient/flight_executor_test.go +++ b/server/flightclient/flight_executor_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net" + "sync" "testing" "time" @@ -53,6 +54,9 @@ type closeWaitFlightServer struct { doGetContextCanceled chan struct{} allowDoGetReturn chan struct{} doGetDone chan struct{} + doActionCalled chan struct{} + doActionOnce sync.Once + closeAfterFirstBatch bool } func newCloseWaitFlightServer() *closeWaitFlightServer { @@ -62,6 +66,7 @@ func newCloseWaitFlightServer() *closeWaitFlightServer { doGetContextCanceled: make(chan struct{}), allowDoGetReturn: make(chan struct{}), doGetDone: make(chan struct{}), + doActionCalled: make(chan struct{}), } } @@ -90,6 +95,10 @@ func (s *closeWaitFlightServer) DoGet(_ *flight.Ticket, stream pb.FlightService_ if err := writer.Write(record); err != nil { return err } + if s.closeAfterFirstBatch { + close(s.doGetDone) + return nil + } <-stream.Context().Done() close(s.doGetContextCanceled) @@ -102,6 +111,9 @@ func (s *closeWaitFlightServer) DoAction(action *flight.Action, stream pb.Flight if action.Type != waitSessionIdleAction { return s.UnimplementedFlightServiceServer.DoAction(action, stream) } + s.doActionOnce.Do(func() { + close(s.doActionCalled) + }) var payload wire.WorkerWaitSessionIdlePayload if err := json.Unmarshal(action.Body, &payload); err != nil { return err @@ -175,3 +187,112 @@ func TestFlightRowSetCloseWaitsForWorkerDoGetCleanup(t *testing.T) { t.Fatal("Close did not return after worker DoGet cleanup completed") } } + +func TestFlightRowSetCloseSkipsWaitAfterCleanEOF(t *testing.T) { + srv := newCloseWaitFlightServer() + srv.closeAfterFirstBatch = true + + grpcSrv := grpc.NewServer() + pb.RegisterFlightServiceServer(grpcSrv, srv) + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + go func() { _ = grpcSrv.Serve(lis) }() + defer grpcSrv.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + flightCli, err := flight.NewClientWithMiddlewareCtx( + ctx, lis.Addr().String(), nil, nil, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("flight client: %v", err) + } + defer func() { _ = flightCli.Close() }() + + exec := NewFlightExecutorFromClient(&flightsql.Client{Client: flightCli}, "session-1") + defer func() { _ = exec.Close() }() + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("QueryContext: %v", err) + } + if !rows.Next() { + t.Fatalf("expected first row, err=%v", rows.Err()) + } + if rows.Next() { + t.Fatal("expected EOF after first row") + } + if err := rows.Err(); err != nil { + t.Fatalf("rowset error: %v", err) + } + if err := rows.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + select { + case <-srv.doActionCalled: + t.Fatal("Close called WaitSessionIdle after a clean EOF") + case <-time.After(50 * time.Millisecond): + } +} + +func TestFlightRowSetCloseSkipsWaitAfterExecutorMarkedDead(t *testing.T) { + srv := newCloseWaitFlightServer() + + grpcSrv := grpc.NewServer() + pb.RegisterFlightServiceServer(grpcSrv, srv) + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + go func() { _ = grpcSrv.Serve(lis) }() + defer grpcSrv.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + flightCli, err := flight.NewClientWithMiddlewareCtx( + ctx, lis.Addr().String(), nil, nil, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("flight client: %v", err) + } + defer func() { _ = flightCli.Close() }() + + exec := NewFlightExecutorFromClient(&flightsql.Client{Client: flightCli}, "session-1") + defer func() { _ = exec.Close() }() + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("QueryContext: %v", err) + } + select { + case <-srv.doGetStarted: + case <-time.After(time.Second): + t.Fatal("DoGet did not start") + } + + exec.MarkDead() + closeReturned := make(chan error, 1) + go func() { + closeReturned <- rows.Close() + }() + + select { + case err := <-closeReturned: + if err != nil { + t.Fatalf("Close returned error: %v", err) + } + case <-time.After(100 * time.Millisecond): + close(srv.allowDoGetReturn) + t.Fatal("Close waited for worker idle after executor was marked dead") + } + select { + case <-srv.doActionCalled: + t.Fatal("Close called WaitSessionIdle after executor was marked dead") + default: + } +} From f0655926c485caf201b73124fe20790157450259 Mon Sep 17 00:00:00 2001 From: Bill Guowei Yang Date: Wed, 10 Jun 2026 12:48:07 -0400 Subject: [PATCH 3/7] Reduce close wait test setup duplication --- server/flightclient/flight_executor_test.go | 67 +++++---------------- 1 file changed, 14 insertions(+), 53 deletions(-) diff --git a/server/flightclient/flight_executor_test.go b/server/flightclient/flight_executor_test.go index 32186258..e5f4aef1 100644 --- a/server/flightclient/flight_executor_test.go +++ b/server/flightclient/flight_executor_test.go @@ -122,9 +122,8 @@ func (s *closeWaitFlightServer) DoAction(action *flight.Action, stream pb.Flight return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) } -func TestFlightRowSetCloseWaitsForWorkerDoGetCleanup(t *testing.T) { - srv := newCloseWaitFlightServer() - +func newCloseWaitExecutor(t *testing.T, srv *closeWaitFlightServer) (*FlightExecutor, context.Context) { + t.Helper() grpcSrv := grpc.NewServer() pb.RegisterFlightServiceServer(grpcSrv, srv) lis, err := net.Listen("tcp", "127.0.0.1:0") @@ -132,10 +131,10 @@ func TestFlightRowSetCloseWaitsForWorkerDoGetCleanup(t *testing.T) { t.Fatalf("listen: %v", err) } go func() { _ = grpcSrv.Serve(lis) }() - defer grpcSrv.Stop() + t.Cleanup(grpcSrv.Stop) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + t.Cleanup(cancel) flightCli, err := flight.NewClientWithMiddlewareCtx( ctx, lis.Addr().String(), nil, nil, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -143,10 +142,16 @@ func TestFlightRowSetCloseWaitsForWorkerDoGetCleanup(t *testing.T) { if err != nil { t.Fatalf("flight client: %v", err) } - defer func() { _ = flightCli.Close() }() + t.Cleanup(func() { _ = flightCli.Close() }) exec := NewFlightExecutorFromClient(&flightsql.Client{Client: flightCli}, "session-1") - defer func() { _ = exec.Close() }() + t.Cleanup(func() { _ = exec.Close() }) + return exec, ctx +} + +func TestFlightRowSetCloseWaitsForWorkerDoGetCleanup(t *testing.T) { + srv := newCloseWaitFlightServer() + exec, ctx := newCloseWaitExecutor(t, srv) rows, err := exec.QueryContext(ctx, "SELECT 1") if err != nil { @@ -191,29 +196,7 @@ func TestFlightRowSetCloseWaitsForWorkerDoGetCleanup(t *testing.T) { func TestFlightRowSetCloseSkipsWaitAfterCleanEOF(t *testing.T) { srv := newCloseWaitFlightServer() srv.closeAfterFirstBatch = true - - grpcSrv := grpc.NewServer() - pb.RegisterFlightServiceServer(grpcSrv, srv) - lis, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - go func() { _ = grpcSrv.Serve(lis) }() - defer grpcSrv.Stop() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - flightCli, err := flight.NewClientWithMiddlewareCtx( - ctx, lis.Addr().String(), nil, nil, - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil { - t.Fatalf("flight client: %v", err) - } - defer func() { _ = flightCli.Close() }() - - exec := NewFlightExecutorFromClient(&flightsql.Client{Client: flightCli}, "session-1") - defer func() { _ = exec.Close() }() + exec, ctx := newCloseWaitExecutor(t, srv) rows, err := exec.QueryContext(ctx, "SELECT 1") if err != nil { @@ -241,29 +224,7 @@ func TestFlightRowSetCloseSkipsWaitAfterCleanEOF(t *testing.T) { func TestFlightRowSetCloseSkipsWaitAfterExecutorMarkedDead(t *testing.T) { srv := newCloseWaitFlightServer() - - grpcSrv := grpc.NewServer() - pb.RegisterFlightServiceServer(grpcSrv, srv) - lis, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen: %v", err) - } - go func() { _ = grpcSrv.Serve(lis) }() - defer grpcSrv.Stop() - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - flightCli, err := flight.NewClientWithMiddlewareCtx( - ctx, lis.Addr().String(), nil, nil, - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - if err != nil { - t.Fatalf("flight client: %v", err) - } - defer func() { _ = flightCli.Close() }() - - exec := NewFlightExecutorFromClient(&flightsql.Client{Client: flightCli}, "session-1") - defer func() { _ = exec.Close() }() + exec, ctx := newCloseWaitExecutor(t, srv) rows, err := exec.QueryContext(ctx, "SELECT 1") if err != nil { From d0effb2927eac5ad95d970e25155aa47f83d82ca Mon Sep 17 00:00:00 2001 From: Bill Guowei Yang Date: Wed, 10 Jun 2026 13:18:12 -0400 Subject: [PATCH 4/7] Cover cancel reuse close wait edges --- server/flightclient/flight_executor.go | 22 ++++++ server/flightclient/flight_executor_test.go | 38 +++++++++++ tests/e2e-mw-dev/README.md | 16 ++--- tests/e2e-mw-dev/harness.sh | 74 ++++++++++++++++++++- 4 files changed, 140 insertions(+), 10 deletions(-) diff --git a/server/flightclient/flight_executor.go b/server/flightclient/flight_executor.go index 660ee728..73388544 100644 --- a/server/flightclient/flight_executor.go +++ b/server/flightclient/flight_executor.go @@ -25,8 +25,10 @@ import ( "github.com/posthog/duckgres/server/sqlcore" "github.com/posthog/duckgres/server/wire" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) // MaxGRPCMessageSize is the max gRPC message size for Flight SQL communication. @@ -212,6 +214,11 @@ func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args .. reader, err := e.client.DoGet(reqCtx, info.Endpoint[0].Ticket) if err != nil { + // If cancellation lands after Execute has registered a worker-side + // handle but before DoGet succeeds, no RowSet exists to Close and + // acknowledge. That pre-existing split-phase gap is bounded by the + // worker's abandoned-handle reaper; the close wait below covers + // cancellations after DoGet starts streaming. return nil, fmt.Errorf("flight doget: %w", err) } @@ -316,6 +323,15 @@ func (e *FlightExecutor) Close() error { return nil } +func isTerminalSessionIdleWaitError(err error) bool { + switch status.Code(err) { + case codes.Canceled, codes.Unavailable, codes.FailedPrecondition, codes.NotFound: + return true + default: + return false + } +} + func (e *FlightExecutor) waitForSessionIdle() (err error) { if e.dead.Load() { return nil @@ -353,6 +369,9 @@ func (e *FlightExecutor) waitForSessionIdle() (err error) { &flight.Action{Type: waitSessionIdleAction, Body: payload}, ) if err != nil { + if isTerminalSessionIdleWaitError(err) { + return nil + } return err } for { @@ -361,6 +380,9 @@ func (e *FlightExecutor) waitForSessionIdle() (err error) { return nil } if err != nil { + if isTerminalSessionIdleWaitError(err) { + return nil + } return err } } diff --git a/server/flightclient/flight_executor_test.go b/server/flightclient/flight_executor_test.go index e5f4aef1..f0c15e69 100644 --- a/server/flightclient/flight_executor_test.go +++ b/server/flightclient/flight_executor_test.go @@ -17,8 +17,10 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/posthog/duckgres/server/wire" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) func TestFlightExecutorWithSessionAddsOwnerEpochHeader(t *testing.T) { @@ -57,6 +59,7 @@ type closeWaitFlightServer struct { doActionCalled chan struct{} doActionOnce sync.Once closeAfterFirstBatch bool + doActionErr error } func newCloseWaitFlightServer() *closeWaitFlightServer { @@ -118,6 +121,9 @@ func (s *closeWaitFlightServer) DoAction(action *flight.Action, stream pb.Flight if err := json.Unmarshal(action.Body, &payload); err != nil { return err } + if s.doActionErr != nil { + return s.doActionErr + } <-s.doGetDone return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) } @@ -256,4 +262,36 @@ func TestFlightRowSetCloseSkipsWaitAfterExecutorMarkedDead(t *testing.T) { t.Fatal("Close called WaitSessionIdle after executor was marked dead") default: } + close(srv.allowDoGetReturn) +} + +func TestFlightRowSetCloseTreatsTerminalWaitFailureAsBestEffort(t *testing.T) { + srv := newCloseWaitFlightServer() + srv.doActionErr = status.Error(codes.Unavailable, "worker is gone") + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("QueryContext: %v", err) + } + select { + case <-srv.doGetStarted: + case <-time.After(time.Second): + t.Fatal("DoGet did not start") + } + + closeReturned := make(chan error, 1) + go func() { + closeReturned <- rows.Close() + }() + select { + case err := <-closeReturned: + if err != nil { + t.Fatalf("Close returned terminal wait failure: %v", err) + } + case <-time.After(time.Second): + close(srv.allowDoGetReturn) + t.Fatal("Close did not return after terminal WaitSessionIdle failure") + } + close(srv.allowDoGetReturn) } diff --git a/tests/e2e-mw-dev/README.md b/tests/e2e-mw-dev/README.md index 9dcd1f40..fdfe05ba 100644 --- a/tests/e2e-mw-dev/README.md +++ b/tests/e2e-mw-dev/README.md @@ -49,7 +49,8 @@ client-go: statements **discarded until Sync** (the queued INSERT must not execute); the statement after `\syncpipeline` must execute normally. This is why the harness Job image is `postgres:18-alpine` (pipeline meta-commands are - psql 18+). + psql 18+). The same wire lane also asserts that a pgwire CancelRequest leaves + the same session immediately reusable. - **cold-burst absorption** — there is no warm pool, so a burst of cold sessions spawns workers on demand; if it outruns the org/global cap the surplus gets a graceful client-visible hint (`no Duckgres worker … retry in about 45 seconds` @@ -141,13 +142,12 @@ normal `go test ./...` lane. GetFlightInfo-to-DoGet handoffs, and abandoned-continuation cleanup are covered by `duckdbservice` unit tests. The harness still asserts the cluster invariant this protects: one active session owns one worker. -- **Worker DoGet close acknowledgement** — asserting that an early client - close waits for worker-side DoGet cleanup requires pausing the worker exactly - after it observes gRPC cancellation but before it releases the session - operation token. The in-cluster harness only has normal clients, so it cannot - deterministically hold that cleanup window without a bespoke worker/Flight - fault-injection client. Covered by `server/flightclient` and `duckdbservice` - unit tests instead. +- **Worker DoGet close acknowledgement internals** — the harness covers the + black-box pgwire behavior (CancelRequest then immediate same-session reuse), + but not the exact internal wait point. Pausing the worker exactly after it + observes gRPC cancellation but before it releases the session operation token + would need a bespoke worker/Flight fault-injection client. Covered by + `server/flightclient` and `duckdbservice` unit tests instead. - **Malformed Bind message validation (#720)** — negative count/length fields in a Bind message must return a clean `08P01` instead of panicking. Every real client (psql, lib/pq, ...) only emits well-formed Bind messages, so diff --git a/tests/e2e-mw-dev/harness.sh b/tests/e2e-mw-dev/harness.sh index c70babeb..9e872741 100755 --- a/tests/e2e-mw-dev/harness.sh +++ b/tests/e2e-mw-dev/harness.sh @@ -24,7 +24,8 @@ # demand; any surplus gets a graceful retry hint) then served. # jsonb || keeps Postgres concat semantics through # transpilation (#716), and a pipelined extended-query error -# discards queued messages until Sync (#718). +# discards queued messages until Sync (#718). A same pgwire +# session remains usable immediately after CancelRequest. # activation : DuckLake + Iceberg catalogs attach and read/write. The FIRST # session on a cold Iceberg worker must not fail the pg_catalog # compat-view bind (the CP primes the REST catalog's schema list @@ -403,6 +404,74 @@ EOF pg "$1" "$2" ducklake "DROP TABLE $t;" } +# Black-box regression for async cancel cleanup: psql handles SIGINT by sending +# a PostgreSQL CancelRequest and keeping the same connection open. The next query +# is written immediately to that same stdin stream; if the control plane reuses +# the session before the worker has released its cancelled DoGet operation, this +# can fail with "session already has an active operation". +cancel_then_reuse_same_session() { # org password + log "cancel then immediate same-session reuse on $1" + in="$(mktemp)"; out="$(mktemp)" + rm -f "$in" + mkfifo "$in" + + ( trap - INT + export PGPASSWORD="$2" + exec psql \ + "sslmode=require host=$1$SNI_SUFFIX hostaddr=$CP_IP port=5432 user=root dbname=ducklake" \ + -v ON_ERROR_STOP=0 -q -tA <"$in" >"$out" 2>&1 + ) & + psql_pid=$! + + exec 9>"$in" + set +e + printf '%s\n' "$HEAVY_Q" >&9 + write_rc=$? + set -e + [ "$write_rc" = 0 ] || { + exec 9>&- + wait "$psql_pid" 2>/dev/null || true + text="$(tr '\n' ' ' <"$out" | tail -c 400)" + rm -f "$in" "$out" + fail "cancel_then_reuse_same_session: failed to send heavy query to psql: $text" + } + sleep 3 + kill -0 "$psql_pid" 2>/dev/null || { + exec 9>&- + fail "cancel_then_reuse_same_session: psql exited before cancel: $(tr '\n' ' ' <"$out" | tail -c 200)" + } + kill -INT "$psql_pid" 2>/dev/null || true + # Same connection proof: a TEMP table survives only if the post-cancel query + # runs on the original session. + set +e + printf '%s\n' "CREATE TEMP TABLE cancel_reuse_marker(v INT); INSERT INTO cancel_reuse_marker VALUES (42); SELECT v FROM cancel_reuse_marker;" >&9 + reuse_write_rc=$? + printf '\\q\n' >&9 + quit_write_rc=$? + exec 9>&- + set -e + + rc=0 + wait "$psql_pid" || rc=$? + text="$(tr '\n' ' ' <"$out" | tail -c 400)" + rm -f "$in" "$out" + [ "$reuse_write_rc" = 0 ] || fail "cancel_then_reuse_same_session: failed to send reuse query after cancel: $text" + [ "$quit_write_rc" = 0 ] || fail "cancel_then_reuse_same_session: failed to send quit after cancel: $text" + [ "$rc" = 0 ] || fail "cancel_then_reuse_same_session: psql rc=$rc output=$text" + case "$text" in + *"session already has an active operation"*) + fail "cancel_then_reuse_same_session: session reused before worker cancel cleanup completed: $text" ;; + esac + case "$text" in + *"canceling statement"*|*"Query was cancelled"*|*"context canceled"*|*"Cancel request sent"*) ;; + *) fail "cancel_then_reuse_same_session: heavy query did not appear to be cancelled: $text" ;; + esac + case "$text" in + *"42"*) ;; + *) fail "cancel_then_reuse_same_session: same-session marker query did not succeed: $text" ;; + esac +} + # ---- bundled extension forks ---------------------------------------------- # Ported from TestK8sDucklakeExtensionIsBundledFork / TestK8sHttpfsExtensionIsBundledFork. assert_fork_extensions() { # org password @@ -1053,6 +1122,7 @@ lane_cnpg() { # full wire/catalog/concurrency/sizing coverage on the cnpg org cold_burst_absorption "$CNPG" "$cnpg_pw" # early, while this org is mostly cold rw_ducklake "$CNPG" "$cnpg_pw" pipeline_error_recovery "$CNPG" "$cnpg_pw" # after rw_ducklake (table writes proven) + cancel_then_reuse_same_session "$CNPG" "$cnpg_pw" assert_fork_extensions "$CNPG" "$cnpg_pw" # after a DuckLake R/W (httpfs loaded) iceberg_cold_first_connect "$CNPG" "$cnpg_pw" # cold first session must not fail the metadata-init bind rw_iceberg "$CNPG" "$cnpg_pw" @@ -1159,7 +1229,7 @@ main() { # mid-run image bump); it stays covered by the controlplane/ unit tests. log "SKIP version-reaper (needs an in-run image bump; see README)" - log "PASS: admin-no-query-token + wire + malformed-startup-resilience + jsonb-concat + cold-burst-absorption + pipeline-error-recovery + activation(DuckLake/Iceberg) + iceberg-cold-first-connect + ext-forks + worker-pod + concurrency + durability + crash-recovery + graceful-drain + one-session-per-worker + parallel-cold-burst-ramp + worker-sizing(cnpg DuckLake+Iceberg) + org-default-profile(ext) + isolation + lifecycle-teardown, on cnpg & ext (4 parallel lanes)" + log "PASS: admin-no-query-token + wire + malformed-startup-resilience + jsonb-concat + cold-burst-absorption + pipeline-error-recovery + cancel-reuse + activation(DuckLake/Iceberg) + iceberg-cold-first-connect + ext-forks + worker-pod + concurrency + durability + crash-recovery + graceful-drain + one-session-per-worker + parallel-cold-burst-ramp + worker-sizing(cnpg DuckLake+Iceberg) + org-default-profile(ext) + isolation + lifecycle-teardown, on cnpg & ext (4 parallel lanes)" } main "$@" From 4cd9c1fd24a3afa2187dfb511d892f6253bf8a25 Mon Sep 17 00:00:00 2001 From: Bill Guowei Yang Date: Wed, 10 Jun 2026 13:38:36 -0400 Subject: [PATCH 5/7] Fix cancel reuse e2e client --- server/flightclient/flight_executor.go | 12 ++- tests/e2e-mw-dev/harness.sh | 129 +++++++++++++------------ 2 files changed, 78 insertions(+), 63 deletions(-) diff --git a/server/flightclient/flight_executor.go b/server/flightclient/flight_executor.go index 73388544..553f3cc2 100644 --- a/server/flightclient/flight_executor.go +++ b/server/flightclient/flight_executor.go @@ -324,6 +324,9 @@ func (e *FlightExecutor) Close() error { } func isTerminalSessionIdleWaitError(err error) bool { + if strings.Contains(err.Error(), "flight client panic") { + return true + } switch status.Code(err) { case codes.Canceled, codes.Unavailable, codes.FailedPrecondition, codes.NotFound: return true @@ -339,7 +342,12 @@ func (e *FlightExecutor) waitForSessionIdle() (err error) { if e.client == nil || e.client.Client == nil { return nil } - defer recoverClientPanic(&err) + defer func() { + recoverClientPanic(&err) + if err != nil && isTerminalSessionIdleWaitError(err) { + err = nil + } + }() payload, err := json.Marshal(wire.WorkerWaitSessionIdlePayload{ WorkerControlMetadata: wire.WorkerControlMetadata{ @@ -867,7 +875,7 @@ func interpolateArgs(query string, args []any) string { } // scanQuoted returns the index just past a quoted region starting at start -// (query[start] == quote), treating a doubled quote (” or "") as an escape. +// (query[start] == quote), treating a doubled quote ('' or "") as an escape. func scanQuoted(query string, start int, quote byte) int { for i := start + 1; i < len(query); i++ { if query[i] != quote { diff --git a/tests/e2e-mw-dev/harness.sh b/tests/e2e-mw-dev/harness.sh index 9e872741..5f018613 100755 --- a/tests/e2e-mw-dev/harness.sh +++ b/tests/e2e-mw-dev/harness.sh @@ -105,7 +105,7 @@ HEAVY_EXPECT=1500000000 # stdout would land in that capture and make jq choke ("Invalid numeric literal"). log() { echo ">>> $*" >&2; } -apk add --no-cache curl jq postgresql-client openssl >/dev/null 2>&1 || true +apk add --no-cache curl jq postgresql-client openssl python3 py3-psycopg2 >/dev/null 2>&1 || true # kubectl, for the pod-level assertions the Go suite used to make via client-go. # The Job runs as the `duckgres` SA (pods get/list/delete/patch + pods/exec + @@ -404,71 +404,78 @@ EOF pg "$1" "$2" ducklake "DROP TABLE $t;" } -# Black-box regression for async cancel cleanup: psql handles SIGINT by sending -# a PostgreSQL CancelRequest and keeping the same connection open. The next query -# is written immediately to that same stdin stream; if the control plane reuses -# the session before the worker has released its cancelled DoGet operation, this -# can fail with "session already has an active operation". +# Black-box regression for async cancel cleanup: libpq's cancel path sends a +# PostgreSQL CancelRequest while keeping the same connection open. The next +# query runs immediately on that same session; if the control plane reuses the +# session before the worker has released its cancelled DoGet operation, this can +# fail with "session already has an active operation". cancel_then_reuse_same_session() { # org password log "cancel then immediate same-session reuse on $1" - in="$(mktemp)"; out="$(mktemp)" - rm -f "$in" - mkfifo "$in" - - ( trap - INT - export PGPASSWORD="$2" - exec psql \ - "sslmode=require host=$1$SNI_SUFFIX hostaddr=$CP_IP port=5432 user=root dbname=ducklake" \ - -v ON_ERROR_STOP=0 -q -tA <"$in" >"$out" 2>&1 - ) & - psql_pid=$! - - exec 9>"$in" - set +e - printf '%s\n' "$HEAVY_Q" >&9 - write_rc=$? - set -e - [ "$write_rc" = 0 ] || { - exec 9>&- - wait "$psql_pid" 2>/dev/null || true + out="$(mktemp)" + if ! python3 - "$1" "$2" "$CP_IP" "$SNI_SUFFIX" "$HEAVY_Q" >"$out" 2>&1 <<'PY' +import sys +import threading +import time + +import psycopg2 + +org, password, hostaddr, sni_suffix, heavy_q = sys.argv[1:] +conn = psycopg2.connect( + dbname="ducklake", + user="root", + password=password, + host=org + sni_suffix, + hostaddr=hostaddr, + port=5432, + sslmode="require", + connect_timeout=30, +) +conn.autocommit = True +try: + cur = conn.cursor() + timer = threading.Timer(3.0, conn.cancel) + timer.start() + cancelled = False + try: + cur.execute(heavy_q) + print("heavy query completed before cancel") + except Exception as exc: + cancelled = True + msg = str(exc) + if "cancel" not in msg.lower(): + raise SystemExit(f"heavy query failed without cancellation: {msg}") + print(f"cancel_error={msg}") + finally: + timer.cancel() + if not cancelled: + raise SystemExit("heavy query did not get cancelled") + + cur.execute( + "CREATE TEMP TABLE cancel_reuse_marker(v INT); " + "INSERT INTO cancel_reuse_marker VALUES (42); " + "SELECT v FROM cancel_reuse_marker;" + ) + marker = cur.fetchone()[0] + if marker != 42: + raise SystemExit(f"marker query returned {marker}, want 42") + print("marker=42") +finally: + conn.close() +PY + then text="$(tr '\n' ' ' <"$out" | tail -c 400)" - rm -f "$in" "$out" - fail "cancel_then_reuse_same_session: failed to send heavy query to psql: $text" - } - sleep 3 - kill -0 "$psql_pid" 2>/dev/null || { - exec 9>&- - fail "cancel_then_reuse_same_session: psql exited before cancel: $(tr '\n' ' ' <"$out" | tail -c 200)" - } - kill -INT "$psql_pid" 2>/dev/null || true - # Same connection proof: a TEMP table survives only if the post-cancel query - # runs on the original session. - set +e - printf '%s\n' "CREATE TEMP TABLE cancel_reuse_marker(v INT); INSERT INTO cancel_reuse_marker VALUES (42); SELECT v FROM cancel_reuse_marker;" >&9 - reuse_write_rc=$? - printf '\\q\n' >&9 - quit_write_rc=$? - exec 9>&- - set -e - - rc=0 - wait "$psql_pid" || rc=$? + rm -f "$out" + case "$text" in + *"session already has an active operation"*) + fail "cancel_then_reuse_same_session: session reused before worker cancel cleanup completed: $text" ;; + *) fail "cancel_then_reuse_same_session: cancel/reuse client failed: $text" ;; + esac + fi text="$(tr '\n' ' ' <"$out" | tail -c 400)" - rm -f "$in" "$out" - [ "$reuse_write_rc" = 0 ] || fail "cancel_then_reuse_same_session: failed to send reuse query after cancel: $text" - [ "$quit_write_rc" = 0 ] || fail "cancel_then_reuse_same_session: failed to send quit after cancel: $text" - [ "$rc" = 0 ] || fail "cancel_then_reuse_same_session: psql rc=$rc output=$text" - case "$text" in - *"session already has an active operation"*) - fail "cancel_then_reuse_same_session: session reused before worker cancel cleanup completed: $text" ;; - esac - case "$text" in - *"canceling statement"*|*"Query was cancelled"*|*"context canceled"*|*"Cancel request sent"*) ;; - *) fail "cancel_then_reuse_same_session: heavy query did not appear to be cancelled: $text" ;; - esac + rm -f "$out" case "$text" in - *"42"*) ;; - *) fail "cancel_then_reuse_same_session: same-session marker query did not succeed: $text" ;; + *"cancel_error="*"marker=42"*) ;; + *) fail "cancel_then_reuse_same_session: expected cancel and same-session marker success, got: $text" ;; esac } From a9608efba6078e89c61fe23512146271fb605622 Mon Sep 17 00:00:00 2001 From: Bill Guowei Yang Date: Wed, 10 Jun 2026 14:25:58 -0400 Subject: [PATCH 6/7] Release worker query handle when DoGet fails --- duckdbservice/flight_handler.go | 25 ++++++ duckdbservice/flight_handler_test.go | 64 +++++++++++++++ duckdbservice/service.go | 2 + server/flightclient/flight_executor.go | 86 +++++++++++++++++++-- server/flightclient/flight_executor_test.go | 62 ++++++++++++++- server/wire/worker_proto.go | 8 ++ server/worker_control.go | 19 +++-- 7 files changed, 247 insertions(+), 19 deletions(-) diff --git a/duckdbservice/flight_handler.go b/duckdbservice/flight_handler.go index 5fb6baa0..fd09685a 100644 --- a/duckdbservice/flight_handler.go +++ b/duckdbservice/flight_handler.go @@ -460,6 +460,31 @@ func (h *FlightSQLHandler) doWaitSessionIdle(body []byte, stream flight.FlightSe return sendActionResult(stream, &flight.Result{Body: resp}) } +func (h *FlightSQLHandler) doReleaseQueryHandle(body []byte, stream flight.FlightService_DoActionServer) error { + var req server.WorkerReleaseQueryHandlePayload + if err := json.Unmarshal(body, &req); err != nil { + return status.Errorf(codes.InvalidArgument, "invalid ReleaseQueryHandle request: %v", err) + } + if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil { + return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err) + } + session, err := h.sessionFromContext(stream.Context()) + if err != nil { + return err + } + ticket, err := flightsql.GetStatementQueryTicket(&flight.Ticket{Ticket: req.Ticket}) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid statement ticket: %v", err) + } + handleID := string(ticket.GetStatementHandle()) + if handle, ok := popQueryHandle(session, handleID); ok { + releaseQueryHandleValue(handle) + } + + resp, _ := json.Marshal(map[string]bool{"ok": true}) + return sendActionResult(stream, &flight.Result{Body: resp}) +} + // Flight SQL method implementations func (h *FlightSQLHandler) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, diff --git a/duckdbservice/flight_handler_test.go b/duckdbservice/flight_handler_test.go index 11794a38..7171c7e5 100644 --- a/duckdbservice/flight_handler_test.go +++ b/duckdbservice/flight_handler_test.go @@ -527,6 +527,70 @@ func TestWaitSessionIdleBlocksUntilOperationReleases(t *testing.T) { } } +func TestReleaseQueryHandleReleasesAbandonedOperation(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + warmupDone: make(chan struct{}), + startTime: time.Now(), + } + close(pool.warmupDone) + session := &Session{ + ID: "session-1", + Username: "alice", + CreatedAt: time.Now(), + queries: make(map[string]*QueryHandle), + txns: make(map[string]*trackedTx), + txnOwner: make(map[string]string), + } + pool.sessions[session.ID] = session + handler := &FlightSQLHandler{pool: pool, alloc: memory.DefaultAllocator} + + finishOperation, ok := session.beginOperation() + if !ok { + t.Fatal("beginOperation rejected test session") + } + finishDrain, err := pool.beginDrainWork(false) + if err != nil { + t.Fatalf("beginDrainWork: %v", err) + } + session.queries["query-1"] = &QueryHandle{ + Query: "SELECT 1", + createdAt: time.Now(), + finishDrain: finishDrain, + finishOperation: finishOperation, + } + + ticket, err := flightsql.CreateStatementQueryTicket([]byte("query-1")) + if err != nil { + t.Fatalf("create ticket: %v", err) + } + body, err := json.Marshal(server.WorkerReleaseQueryHandlePayload{Ticket: ticket}) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", session.ID)) + stream := &mockDoActionStream{ctx: ctx} + + if err := handler.doReleaseQueryHandle(body, stream); err != nil { + t.Fatalf("ReleaseQueryHandle: %v", err) + } + if _, ok := session.queries["query-1"]; ok { + t.Fatal("query handle was not removed") + } + if got := pool.ActiveDrainWork(); got != 0 { + t.Fatalf("active drain work=%d, want 0", got) + } + finishOperation2, ok := session.beginOperation() + if !ok { + t.Fatal("operation gate was not released") + } + finishOperation2() + if len(stream.results) != 1 { + t.Fatalf("expected one action result, got %d", len(stream.results)) + } +} + func TestCreateSessionSendFailureDestroysSession(t *testing.T) { pool := &SessionPool{ sessions: make(map[string]*Session), diff --git a/duckdbservice/service.go b/duckdbservice/service.go index bdac8f9d..9047d847 100644 --- a/duckdbservice/service.go +++ b/duckdbservice/service.go @@ -1628,6 +1628,8 @@ func (s *customActionServer) DoAction(cmd *flight.Action, stream flight.FlightSe return s.handler.doHealthCheck(cmd.Body, stream) case "WaitSessionIdle": return s.handler.doWaitSessionIdle(cmd.Body, stream) + case "ReleaseQueryHandle": + return s.handler.doReleaseQueryHandle(cmd.Body, stream) default: // Fall through to standard flightsql action router (BeginTransaction, etc.) return s.FlightServer.DoAction(cmd, stream) diff --git a/server/flightclient/flight_executor.go b/server/flightclient/flight_executor.go index 553f3cc2..cfb91ac8 100644 --- a/server/flightclient/flight_executor.go +++ b/server/flightclient/flight_executor.go @@ -36,8 +36,9 @@ import ( const MaxGRPCMessageSize = 1 << 30 // 1GB const ( - waitSessionIdleAction = "WaitSessionIdle" - queryCloseWaitTimeout = 30 * time.Second + waitSessionIdleAction = "WaitSessionIdle" + releaseQueryHandleAction = "ReleaseQueryHandle" + queryCloseWaitTimeout = 30 * time.Second ) // ErrWorkerDead is returned when the backing worker process has crashed. @@ -212,19 +213,28 @@ func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args .. return &emptyRowSet{}, nil } - reader, err := e.client.DoGet(reqCtx, info.Endpoint[0].Ticket) + ticket := info.Endpoint[0].Ticket + if err := reqCtx.Err(); err != nil { + _ = e.releaseQueryHandle(ticket) + _ = e.waitForSessionIdle() + return nil, err + } + + reader, err := e.client.DoGet(reqCtx, ticket) if err != nil { // If cancellation lands after Execute has registered a worker-side - // handle but before DoGet succeeds, no RowSet exists to Close and - // acknowledge. That pre-existing split-phase gap is bounded by the - // worker's abandoned-handle reaper; the close wait below covers - // cancellations after DoGet starts streaming. + // handle but before DoGet returns a RowSet, there is no Close call to + // acknowledge the abandoned split-phase operation. Release the handle + // explicitly, then wait in case DoGet consumed it and is still unwinding. + _ = e.releaseQueryHandle(ticket) + _ = e.waitForSessionIdle() return nil, fmt.Errorf("flight doget: %w", err) } schema, err := flight.DeserializeSchema(info.Schema, e.alloc) if err != nil { reader.Release() + _ = e.waitForSessionIdle() return nil, fmt.Errorf("flight deserialize schema: %w", err) } @@ -396,6 +406,68 @@ func (e *FlightExecutor) waitForSessionIdle() (err error) { } } +func (e *FlightExecutor) releaseQueryHandle(ticket *flight.Ticket) (err error) { + if e.dead.Load() { + return nil + } + if e.client == nil || e.client.Client == nil || ticket == nil || len(ticket.Ticket) == 0 { + return nil + } + defer func() { + recoverClientPanic(&err) + if err != nil && isTerminalSessionIdleWaitError(err) { + err = nil + } + }() + + payload, err := json.Marshal(wire.WorkerReleaseQueryHandlePayload{ + WorkerControlMetadata: wire.WorkerControlMetadata{ + WorkerID: e.workerID, + OwnerEpoch: e.ownerEpoch, + CPInstanceID: e.cpInstanceID, + }, + Ticket: ticket.Ticket, + }) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), queryCloseWaitTimeout) + defer cancel() + if e.ctx != nil { + go func() { + select { + case <-e.ctx.Done(): + cancel() + case <-ctx.Done(): + } + }() + } + + stream, err := e.client.Client.DoAction( + e.withSession(ctx), + &flight.Action{Type: releaseQueryHandleAction, Body: payload}, + ) + if err != nil { + if isTerminalSessionIdleWaitError(err) { + return nil + } + return err + } + for { + _, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + if isTerminalSessionIdleWaitError(err) { + return nil + } + return err + } + } +} + func (e *FlightExecutor) LastProfilingOutput() string { v := e.lastProfiling.Load() if v == nil { diff --git a/server/flightclient/flight_executor_test.go b/server/flightclient/flight_executor_test.go index f0c15e69..db4da807 100644 --- a/server/flightclient/flight_executor_test.go +++ b/server/flightclient/flight_executor_test.go @@ -58,7 +58,11 @@ type closeWaitFlightServer struct { doGetDone chan struct{} doActionCalled chan struct{} doActionOnce sync.Once + releaseActionCalled chan struct{} + releaseActionOnce sync.Once + releaseTicket []byte closeAfterFirstBatch bool + doGetErr error doActionErr error } @@ -70,20 +74,29 @@ func newCloseWaitFlightServer() *closeWaitFlightServer { allowDoGetReturn: make(chan struct{}), doGetDone: make(chan struct{}), doActionCalled: make(chan struct{}), + releaseActionCalled: make(chan struct{}), } } func (s *closeWaitFlightServer) GetFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.FlightInfo, error) { + ticket, err := flightsql.CreateStatementQueryTicket([]byte("query-1")) + if err != nil { + return nil, err + } return &flight.FlightInfo{ Schema: flight.SerializeSchema(s.schema, memory.DefaultAllocator), Endpoint: []*flight.FlightEndpoint{{ - Ticket: &flight.Ticket{Ticket: []byte("query-1")}, + Ticket: &flight.Ticket{Ticket: ticket}, }}, }, nil } func (s *closeWaitFlightServer) DoGet(_ *flight.Ticket, stream pb.FlightService_DoGetServer) error { close(s.doGetStarted) + if s.doGetErr != nil { + close(s.doGetDone) + return s.doGetErr + } builder := array.NewInt64Builder(memory.DefaultAllocator) builder.Append(1) @@ -111,7 +124,19 @@ func (s *closeWaitFlightServer) DoGet(_ *flight.Ticket, stream pb.FlightService_ } func (s *closeWaitFlightServer) DoAction(action *flight.Action, stream pb.FlightService_DoActionServer) error { - if action.Type != waitSessionIdleAction { + switch action.Type { + case releaseQueryHandleAction: + s.releaseActionOnce.Do(func() { + close(s.releaseActionCalled) + }) + var payload wire.WorkerReleaseQueryHandlePayload + if err := json.Unmarshal(action.Body, &payload); err != nil { + return err + } + s.releaseTicket = payload.Ticket + return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) + case waitSessionIdleAction: + default: return s.UnimplementedFlightServiceServer.DoAction(action, stream) } s.doActionOnce.Do(func() { @@ -295,3 +320,36 @@ func TestFlightRowSetCloseTreatsTerminalWaitFailureAsBestEffort(t *testing.T) { } close(srv.allowDoGetReturn) } + +func TestQueryContextReleasesHandleWhenDoGetFails(t *testing.T) { + srv := newCloseWaitFlightServer() + srv.doGetErr = status.Error(codes.Canceled, "context canceled") + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err == nil { + if rows != nil { + _ = rows.Close() + } + t.Fatal("expected QueryContext to return DoGet error") + } + + select { + case <-srv.releaseActionCalled: + case <-time.After(time.Second): + t.Fatal("QueryContext did not release the abandoned query handle") + } + ticket, err := flightsql.GetStatementQueryTicket(&flight.Ticket{Ticket: srv.releaseTicket}) + if err != nil { + t.Fatalf("release ticket did not decode: %v", err) + } + if got := string(ticket.GetStatementHandle()); got != "query-1" { + t.Fatalf("released handle %q, want query-1", got) + } + + select { + case <-srv.doActionCalled: + case <-time.After(time.Second): + t.Fatal("QueryContext did not wait for the session to become idle after handle release") + } +} diff --git a/server/wire/worker_proto.go b/server/wire/worker_proto.go index 27fe55e8..961f0062 100644 --- a/server/wire/worker_proto.go +++ b/server/wire/worker_proto.go @@ -51,3 +51,11 @@ type WorkerHealthCheckPayload struct { type WorkerWaitSessionIdlePayload struct { WorkerControlMetadata } + +// WorkerReleaseQueryHandlePayload asks a worker to release a statement query +// handle that was created by GetFlightInfo but abandoned before DoGet could +// consume it. +type WorkerReleaseQueryHandlePayload struct { + WorkerControlMetadata + Ticket []byte `json:"ticket"` +} diff --git a/server/worker_control.go b/server/worker_control.go index dde89cb1..449c16f1 100644 --- a/server/worker_control.go +++ b/server/worker_control.go @@ -2,16 +2,15 @@ package server import "github.com/posthog/duckgres/server/wire" -// Aliases retained so existing references to server.WorkerControlMetadata, -// server.WorkerCreateSessionPayload, server.WorkerDestroySessionPayload and -// server.WorkerHealthCheckPayload continue to compile after the types -// moved to server/wire. New code should import server/wire and use -// wire.X directly. +// Aliases retained so existing references to server.Worker* payloads continue +// to compile after the types moved to server/wire. New code should import +// server/wire and use wire.X directly. type ( - WorkerControlMetadata = wire.WorkerControlMetadata - WorkerCreateSessionPayload = wire.WorkerCreateSessionPayload - WorkerDestroySessionPayload = wire.WorkerDestroySessionPayload - WorkerHealthCheckPayload = wire.WorkerHealthCheckPayload - WorkerWaitSessionIdlePayload = wire.WorkerWaitSessionIdlePayload + WorkerControlMetadata = wire.WorkerControlMetadata + WorkerCreateSessionPayload = wire.WorkerCreateSessionPayload + WorkerDestroySessionPayload = wire.WorkerDestroySessionPayload + WorkerHealthCheckPayload = wire.WorkerHealthCheckPayload + WorkerWaitSessionIdlePayload = wire.WorkerWaitSessionIdlePayload + WorkerReleaseQueryHandlePayload = wire.WorkerReleaseQueryHandlePayload ) From 646dc3686a556f72abff493b79d4d86ed94ad5e1 Mon Sep 17 00:00:00 2001 From: Bill Guowei Yang Date: Wed, 10 Jun 2026 15:15:42 -0400 Subject: [PATCH 7/7] Cancel DoGet before schema-error wait --- server/flightclient/flight_executor.go | 1 + server/flightclient/flight_executor_test.go | 75 +++++++++++++++++---- 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/server/flightclient/flight_executor.go b/server/flightclient/flight_executor.go index cfb91ac8..f1625942 100644 --- a/server/flightclient/flight_executor.go +++ b/server/flightclient/flight_executor.go @@ -233,6 +233,7 @@ func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args .. schema, err := flight.DeserializeSchema(info.Schema, e.alloc) if err != nil { + cancel() reader.Release() _ = e.waitForSessionIdle() return nil, fmt.Errorf("flight deserialize schema: %w", err) diff --git a/server/flightclient/flight_executor_test.go b/server/flightclient/flight_executor_test.go index db4da807..e90dd6c4 100644 --- a/server/flightclient/flight_executor_test.go +++ b/server/flightclient/flight_executor_test.go @@ -52,18 +52,21 @@ type closeWaitFlightServer struct { schema *arrow.Schema - doGetStarted chan struct{} - doGetContextCanceled chan struct{} - allowDoGetReturn chan struct{} - doGetDone chan struct{} - doActionCalled chan struct{} - doActionOnce sync.Once - releaseActionCalled chan struct{} - releaseActionOnce sync.Once - releaseTicket []byte - closeAfterFirstBatch bool - doGetErr error - doActionErr error + doGetStarted chan struct{} + doGetContextCanceled chan struct{} + allowDoGetReturn chan struct{} + doGetDone chan struct{} + doActionCalled chan struct{} + doActionOnce sync.Once + releaseActionCalled chan struct{} + releaseActionOnce sync.Once + releaseTicket []byte + waitBeforeDoGetCancel chan struct{} + waitBeforeOnce sync.Once + closeAfterFirstBatch bool + invalidInfoSchema bool + doGetErr error + doActionErr error } func newCloseWaitFlightServer() *closeWaitFlightServer { @@ -83,8 +86,12 @@ func (s *closeWaitFlightServer) GetFlightInfo(context.Context, *flight.FlightDes if err != nil { return nil, err } + schema := flight.SerializeSchema(s.schema, memory.DefaultAllocator) + if s.invalidInfoSchema { + schema = []byte("not an arrow schema") + } return &flight.FlightInfo{ - Schema: flight.SerializeSchema(s.schema, memory.DefaultAllocator), + Schema: schema, Endpoint: []*flight.FlightEndpoint{{ Ticket: &flight.Ticket{Ticket: ticket}, }}, @@ -149,6 +156,16 @@ func (s *closeWaitFlightServer) DoAction(action *flight.Action, stream pb.Flight if s.doActionErr != nil { return s.doActionErr } + if s.waitBeforeDoGetCancel != nil { + select { + case <-s.doGetContextCanceled: + case <-time.After(200 * time.Millisecond): + s.waitBeforeOnce.Do(func() { + close(s.waitBeforeDoGetCancel) + }) + return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) + } + } <-s.doGetDone return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) } @@ -353,3 +370,35 @@ func TestQueryContextReleasesHandleWhenDoGetFails(t *testing.T) { t.Fatal("QueryContext did not wait for the session to become idle after handle release") } } + +func TestQueryContextCancelsDoGetBeforeWaitingAfterSchemaError(t *testing.T) { + srv := newCloseWaitFlightServer() + srv.invalidInfoSchema = true + srv.waitBeforeDoGetCancel = make(chan struct{}) + close(srv.allowDoGetReturn) + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err == nil { + if rows != nil { + _ = rows.Close() + } + t.Fatal("expected QueryContext to return schema error") + } + + select { + case <-srv.doGetContextCanceled: + case <-time.After(time.Second): + t.Fatal("DoGet did not observe cancellation") + } + select { + case <-srv.doActionCalled: + case <-time.After(time.Second): + t.Fatal("QueryContext did not wait for the session to become idle") + } + select { + case <-srv.waitBeforeDoGetCancel: + t.Fatal("QueryContext waited for session idle before canceling DoGet") + default: + } +}