diff --git a/duckdbservice/flight_handler.go b/duckdbservice/flight_handler.go index 2eda0a68..957227a8 100644 --- a/duckdbservice/flight_handler.go +++ b/duckdbservice/flight_handler.go @@ -441,6 +441,58 @@ func (h *FlightSQLHandler) doHealthCheck(body []byte, stream flight.FlightServic return sendActionResult(stream, &flight.Result{Body: resp}) } +func (h *FlightSQLHandler) doWaitSessionIdle(body []byte, stream flight.FlightService_DoActionServer) error { + var req server.WorkerWaitSessionIdlePayload + if err := json.Unmarshal(body, &req); err != nil { + return status.Errorf(codes.InvalidArgument, "invalid WaitSessionIdle request: %v", err) + } + if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil { + return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err) + } + session, err := h.sessionFromContext(stream.Context()) + if err != nil { + return err + } + if err := session.waitOperationIdle(stream.Context()); err != nil { + switch { + case errors.Is(err, context.Canceled): + return status.Errorf(codes.Canceled, "wait for session idle: %v", err) + case errors.Is(err, context.DeadlineExceeded): + return status.Errorf(codes.DeadlineExceeded, "wait for session idle: %v", err) + default: + return status.Errorf(codes.Internal, "wait for session idle: %v", err) + } + } + + resp, _ := json.Marshal(map[string]bool{"ok": true}) + return sendActionResult(stream, &flight.Result{Body: resp}) +} + +func (h *FlightSQLHandler) doReleaseQueryHandle(body []byte, stream flight.FlightService_DoActionServer) error { + var req server.WorkerReleaseQueryHandlePayload + if err := json.Unmarshal(body, &req); err != nil { + return status.Errorf(codes.InvalidArgument, "invalid ReleaseQueryHandle request: %v", err) + } + if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil { + return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err) + } + session, err := h.sessionFromContext(stream.Context()) + if err != nil { + return err + } + ticket, err := flightsql.GetStatementQueryTicket(&flight.Ticket{Ticket: req.Ticket}) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid statement ticket: %v", err) + } + handleID := string(ticket.GetStatementHandle()) + if handle, ok := popQueryHandle(session, handleID); ok { + releaseQueryHandleValue(handle) + } + + resp, _ := json.Marshal(map[string]bool{"ok": true}) + return sendActionResult(stream, &flight.Result{Body: resp}) +} + // Flight SQL method implementations func (h *FlightSQLHandler) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery, diff --git a/duckdbservice/flight_handler_test.go b/duckdbservice/flight_handler_test.go index 7a54d491..7171c7e5 100644 --- a/duckdbservice/flight_handler_test.go +++ b/duckdbservice/flight_handler_test.go @@ -473,6 +473,124 @@ func TestCreateSessionRejectsWhileDraining(t *testing.T) { } } +func TestWaitSessionIdleBlocksUntilOperationReleases(t *testing.T) { + session := &Session{ + ID: "session-1", + Username: "alice", + CreatedAt: time.Now(), + queries: make(map[string]*QueryHandle), + txns: make(map[string]*trackedTx), + txnOwner: make(map[string]string), + } + pool := &SessionPool{ + sessions: map[string]*Session{session.ID: session}, + stopRefresh: make(map[string]func()), + warmupDone: make(chan struct{}), + startTime: time.Now(), + } + close(pool.warmupDone) + handler := &FlightSQLHandler{pool: pool, alloc: memory.DefaultAllocator} + + finishOperation, ok := session.beginOperation() + if !ok { + t.Fatal("beginOperation rejected test session") + } + + body, err := json.Marshal(server.WorkerWaitSessionIdlePayload{}) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", session.ID)) + stream := &mockDoActionStream{ctx: ctx} + done := make(chan error, 1) + go func() { + done <- handler.doWaitSessionIdle(body, stream) + }() + + select { + case err := <-done: + t.Fatalf("WaitSessionIdle returned before operation released: %v", err) + case <-time.After(50 * time.Millisecond): + } + + finishOperation() + select { + case err := <-done: + if err != nil { + t.Fatalf("WaitSessionIdle returned error: %v", err) + } + case <-time.After(time.Second): + t.Fatal("WaitSessionIdle did not return after operation released") + } + if len(stream.results) != 1 { + t.Fatalf("expected one action result, got %d", len(stream.results)) + } +} + +func TestReleaseQueryHandleReleasesAbandonedOperation(t *testing.T) { + pool := &SessionPool{ + sessions: make(map[string]*Session), + stopRefresh: make(map[string]func()), + warmupDone: make(chan struct{}), + startTime: time.Now(), + } + close(pool.warmupDone) + session := &Session{ + ID: "session-1", + Username: "alice", + CreatedAt: time.Now(), + queries: make(map[string]*QueryHandle), + txns: make(map[string]*trackedTx), + txnOwner: make(map[string]string), + } + pool.sessions[session.ID] = session + handler := &FlightSQLHandler{pool: pool, alloc: memory.DefaultAllocator} + + finishOperation, ok := session.beginOperation() + if !ok { + t.Fatal("beginOperation rejected test session") + } + finishDrain, err := pool.beginDrainWork(false) + if err != nil { + t.Fatalf("beginDrainWork: %v", err) + } + session.queries["query-1"] = &QueryHandle{ + Query: "SELECT 1", + createdAt: time.Now(), + finishDrain: finishDrain, + finishOperation: finishOperation, + } + + ticket, err := flightsql.CreateStatementQueryTicket([]byte("query-1")) + if err != nil { + t.Fatalf("create ticket: %v", err) + } + body, err := json.Marshal(server.WorkerReleaseQueryHandlePayload{Ticket: ticket}) + if err != nil { + t.Fatalf("marshal request: %v", err) + } + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", session.ID)) + stream := &mockDoActionStream{ctx: ctx} + + if err := handler.doReleaseQueryHandle(body, stream); err != nil { + t.Fatalf("ReleaseQueryHandle: %v", err) + } + if _, ok := session.queries["query-1"]; ok { + t.Fatal("query handle was not removed") + } + if got := pool.ActiveDrainWork(); got != 0 { + t.Fatalf("active drain work=%d, want 0", got) + } + finishOperation2, ok := session.beginOperation() + if !ok { + t.Fatal("operation gate was not released") + } + finishOperation2() + if len(stream.results) != 1 { + t.Fatalf("expected one action result, got %d", len(stream.results)) + } +} + func TestCreateSessionSendFailureDestroysSession(t *testing.T) { pool := &SessionPool{ sessions: make(map[string]*Session), diff --git a/duckdbservice/service.go b/duckdbservice/service.go index c7ec6eca..5a45d8d9 100644 --- a/duckdbservice/service.go +++ b/duckdbservice/service.go @@ -123,6 +123,7 @@ type Session struct { txnOwner map[string]string closed bool operationOpen bool + operationIdle chan struct{} connWork int connWorkDone *sync.Cond handleCounter atomic.Uint64 @@ -234,6 +235,7 @@ func (s *Session) beginOperation() (func(), bool) { return nil, false } s.operationOpen = true + s.operationIdle = make(chan struct{}) s.mu.Unlock() var once sync.Once @@ -241,6 +243,10 @@ func (s *Session) beginOperation() (func(), bool) { once.Do(func() { s.mu.Lock() s.operationOpen = false + if s.operationIdle != nil { + close(s.operationIdle) + s.operationIdle = nil + } s.mu.Unlock() }) }, true @@ -257,6 +263,7 @@ func (s *Session) beginOperationForTransaction(txnKey string) (func(), bool, boo return nil, true, false } s.operationOpen = true + s.operationIdle = make(chan struct{}) s.mu.Unlock() var once sync.Once @@ -264,11 +271,37 @@ func (s *Session) beginOperationForTransaction(txnKey string) (func(), bool, boo once.Do(func() { s.mu.Lock() s.operationOpen = false + if s.operationIdle != nil { + close(s.operationIdle) + s.operationIdle = nil + } s.mu.Unlock() }) }, true, true } +func (s *Session) waitOperationIdle(ctx context.Context) error { + for { + s.mu.RLock() + if !s.operationOpen { + s.mu.RUnlock() + return nil + } + idle := s.operationIdle + s.mu.RUnlock() + + if idle == nil { + return errors.New("session operation idle signal missing") + } + select { + case <-idle: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } +} + // beginConnWork fences any operation that uses the session connection while a // raw SQL transaction may be open. It intentionally does not mutate queryActive: // conn work includes COPY receive and metadata/planning work, while queryActive @@ -1651,6 +1684,10 @@ func (s *customActionServer) DoAction(cmd *flight.Action, stream flight.FlightSe return s.handler.doDestroySession(cmd.Body, stream) case "HealthCheck": return s.handler.doHealthCheck(cmd.Body, stream) + case "WaitSessionIdle": + return s.handler.doWaitSessionIdle(cmd.Body, stream) + case "ReleaseQueryHandle": + return s.handler.doReleaseQueryHandle(cmd.Body, stream) default: // Fall through to standard flightsql action router (BeginTransaction, etc.) return s.FlightServer.DoAction(cmd, stream) diff --git a/server/flightclient/flight_executor.go b/server/flightclient/flight_executor.go index fbf6283a..f1625942 100644 --- a/server/flightclient/flight_executor.go +++ b/server/flightclient/flight_executor.go @@ -3,15 +3,17 @@ package flightclient import ( "context" "encoding/hex" + "encoding/json" "errors" "fmt" + "io" "math/big" "runtime" + "strconv" "strings" "sync" "sync/atomic" "time" - "strconv" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" @@ -21,15 +23,24 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/posthog/duckgres/duckdbservice/arrowmap" "github.com/posthog/duckgres/server/sqlcore" + "github.com/posthog/duckgres/server/wire" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) // MaxGRPCMessageSize is the max gRPC message size for Flight SQL communication. // DuckDB query results can easily exceed the default 4MB limit. const MaxGRPCMessageSize = 1 << 30 // 1GB +const ( + waitSessionIdleAction = "WaitSessionIdle" + releaseQueryHandleAction = "ReleaseQueryHandle" + queryCloseWaitTimeout = 30 * time.Second +) + // ErrWorkerDead is returned when the backing worker process has crashed. var ErrWorkerDead = errors.New("flight worker is dead") @@ -202,22 +213,38 @@ func (e *FlightExecutor) QueryContext(ctx context.Context, query string, args .. return &emptyRowSet{}, nil } - reader, err := e.client.DoGet(reqCtx, info.Endpoint[0].Ticket) + ticket := info.Endpoint[0].Ticket + if err := reqCtx.Err(); err != nil { + _ = e.releaseQueryHandle(ticket) + _ = e.waitForSessionIdle() + return nil, err + } + + reader, err := e.client.DoGet(reqCtx, ticket) if err != nil { + // If cancellation lands after Execute has registered a worker-side + // handle but before DoGet returns a RowSet, there is no Close call to + // acknowledge the abandoned split-phase operation. Release the handle + // explicitly, then wait in case DoGet consumed it and is still unwinding. + _ = e.releaseQueryHandle(ticket) + _ = e.waitForSessionIdle() return nil, fmt.Errorf("flight doget: %w", err) } schema, err := flight.DeserializeSchema(info.Schema, e.alloc) if err != nil { + cancel() reader.Release() + _ = e.waitForSessionIdle() return nil, fmt.Errorf("flight deserialize schema: %w", err) } success = true return &FlightRowSet{ - reader: reader, - schema: schema, - cancel: cancel, + reader: reader, + schema: schema, + cancel: cancel, + waitForClosed: e.waitForSessionIdle, }, nil } @@ -307,6 +334,141 @@ func (e *FlightExecutor) Close() error { return nil } +func isTerminalSessionIdleWaitError(err error) bool { + if strings.Contains(err.Error(), "flight client panic") { + return true + } + switch status.Code(err) { + case codes.Canceled, codes.Unavailable, codes.FailedPrecondition, codes.NotFound: + return true + default: + return false + } +} + +func (e *FlightExecutor) waitForSessionIdle() (err error) { + if e.dead.Load() { + return nil + } + if e.client == nil || e.client.Client == nil { + return nil + } + defer func() { + recoverClientPanic(&err) + if err != nil && isTerminalSessionIdleWaitError(err) { + err = nil + } + }() + + payload, err := json.Marshal(wire.WorkerWaitSessionIdlePayload{ + WorkerControlMetadata: wire.WorkerControlMetadata{ + WorkerID: e.workerID, + OwnerEpoch: e.ownerEpoch, + CPInstanceID: e.cpInstanceID, + }, + }) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), queryCloseWaitTimeout) + defer cancel() + if e.ctx != nil { + go func() { + select { + case <-e.ctx.Done(): + cancel() + case <-ctx.Done(): + } + }() + } + + stream, err := e.client.Client.DoAction( + e.withSession(ctx), + &flight.Action{Type: waitSessionIdleAction, Body: payload}, + ) + if err != nil { + if isTerminalSessionIdleWaitError(err) { + return nil + } + return err + } + for { + _, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + if isTerminalSessionIdleWaitError(err) { + return nil + } + return err + } + } +} + +func (e *FlightExecutor) releaseQueryHandle(ticket *flight.Ticket) (err error) { + if e.dead.Load() { + return nil + } + if e.client == nil || e.client.Client == nil || ticket == nil || len(ticket.Ticket) == 0 { + return nil + } + defer func() { + recoverClientPanic(&err) + if err != nil && isTerminalSessionIdleWaitError(err) { + err = nil + } + }() + + payload, err := json.Marshal(wire.WorkerReleaseQueryHandlePayload{ + WorkerControlMetadata: wire.WorkerControlMetadata{ + WorkerID: e.workerID, + OwnerEpoch: e.ownerEpoch, + CPInstanceID: e.cpInstanceID, + }, + Ticket: ticket.Ticket, + }) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), queryCloseWaitTimeout) + defer cancel() + if e.ctx != nil { + go func() { + select { + case <-e.ctx.Done(): + cancel() + case <-ctx.Done(): + } + }() + } + + stream, err := e.client.Client.DoAction( + e.withSession(ctx), + &flight.Action{Type: releaseQueryHandleAction, Body: payload}, + ) + if err != nil { + if isTerminalSessionIdleWaitError(err) { + return nil + } + return err + } + for { + _, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + if isTerminalSessionIdleWaitError(err) { + return nil + } + return err + } + } +} + func (e *FlightExecutor) LastProfilingOutput() string { v := e.lastProfiling.Load() if v == nil { @@ -331,12 +493,14 @@ type FlightRowSet struct { schema *arrow.Schema // Current batch state - currentBatch arrow.RecordBatch - batchRow int // current row index within currentBatch - done bool - err error - closeOnce sync.Once - cancel context.CancelFunc + currentBatch arrow.RecordBatch + batchRow int // current row index within currentBatch + done bool + err error + closeOnce sync.Once + closeErr error + cancel context.CancelFunc + waitForClosed func() error } func (r *FlightRowSet) Columns() ([]string, error) { @@ -422,8 +586,11 @@ func (r *FlightRowSet) Close() error { r.currentBatch = nil } r.reader.Release() + if r.waitForClosed != nil && (!r.done || r.err != nil) { + r.closeErr = r.waitForClosed() + } }) - return nil + return r.closeErr } func (r *FlightRowSet) Err() error { @@ -433,12 +600,12 @@ func (r *FlightRowSet) Err() error { // emptyRowSet is returned when a query produces no endpoints and no schema. type emptyRowSet struct{} -func (e *emptyRowSet) Columns() ([]string, error) { return nil, nil } +func (e *emptyRowSet) Columns() ([]string, error) { return nil, nil } func (e *emptyRowSet) ColumnTypes() ([]sqlcore.ColumnTyper, error) { return nil, nil } -func (e *emptyRowSet) Next() bool { return false } -func (e *emptyRowSet) Scan(dest ...any) error { return fmt.Errorf("no rows") } -func (e *emptyRowSet) Close() error { return nil } -func (e *emptyRowSet) Err() error { return nil } +func (e *emptyRowSet) Next() bool { return false } +func (e *emptyRowSet) Scan(dest ...any) error { return fmt.Errorf("no rows") } +func (e *emptyRowSet) Close() error { return nil } +func (e *emptyRowSet) Err() error { return nil } // emptySchemaRowSet is returned when a query produces no data rows but does // have schema information (e.g., SELECT ... LIMIT 0). This preserves column @@ -464,10 +631,10 @@ func (e *emptySchemaRowSet) ColumnTypes() ([]sqlcore.ColumnTyper, error) { return types, nil } -func (e *emptySchemaRowSet) Next() bool { return false } +func (e *emptySchemaRowSet) Next() bool { return false } func (e *emptySchemaRowSet) Scan(...any) error { return fmt.Errorf("no rows") } -func (e *emptySchemaRowSet) Close() error { return nil } -func (e *emptySchemaRowSet) Err() error { return nil } +func (e *emptySchemaRowSet) Close() error { return nil } +func (e *emptySchemaRowSet) Err() error { return nil } // flightExecResult implements ExecResult for Flight SQL updates. type flightExecResult struct { diff --git a/server/flightclient/flight_executor_test.go b/server/flightclient/flight_executor_test.go index fc2e1705..e90dd6c4 100644 --- a/server/flightclient/flight_executor_test.go +++ b/server/flightclient/flight_executor_test.go @@ -2,9 +2,25 @@ package flightclient import ( "context" + "encoding/json" + "net" + "sync" "testing" + "time" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/flight" + "github.com/apache/arrow-go/v18/arrow/flight/flightsql" + pb "github.com/apache/arrow-go/v18/arrow/flight/gen/flight" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/posthog/duckgres/server/wire" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) func TestFlightExecutorWithSessionAddsOwnerEpochHeader(t *testing.T) { @@ -30,3 +46,359 @@ func TestFlightExecutorWithSessionAddsOwnerEpochHeader(t *testing.T) { t.Fatalf("unexpected cp instance metadata: %#v", got) } } + +type closeWaitFlightServer struct { + pb.UnimplementedFlightServiceServer + + schema *arrow.Schema + + doGetStarted chan struct{} + doGetContextCanceled chan struct{} + allowDoGetReturn chan struct{} + doGetDone chan struct{} + doActionCalled chan struct{} + doActionOnce sync.Once + releaseActionCalled chan struct{} + releaseActionOnce sync.Once + releaseTicket []byte + waitBeforeDoGetCancel chan struct{} + waitBeforeOnce sync.Once + closeAfterFirstBatch bool + invalidInfoSchema bool + doGetErr error + doActionErr error +} + +func newCloseWaitFlightServer() *closeWaitFlightServer { + return &closeWaitFlightServer{ + schema: arrow.NewSchema([]arrow.Field{{Name: "x", Type: arrow.PrimitiveTypes.Int64}}, nil), + doGetStarted: make(chan struct{}), + doGetContextCanceled: make(chan struct{}), + allowDoGetReturn: make(chan struct{}), + doGetDone: make(chan struct{}), + doActionCalled: make(chan struct{}), + releaseActionCalled: make(chan struct{}), + } +} + +func (s *closeWaitFlightServer) GetFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.FlightInfo, error) { + ticket, err := flightsql.CreateStatementQueryTicket([]byte("query-1")) + if err != nil { + return nil, err + } + schema := flight.SerializeSchema(s.schema, memory.DefaultAllocator) + if s.invalidInfoSchema { + schema = []byte("not an arrow schema") + } + return &flight.FlightInfo{ + Schema: schema, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: &flight.Ticket{Ticket: ticket}, + }}, + }, nil +} + +func (s *closeWaitFlightServer) DoGet(_ *flight.Ticket, stream pb.FlightService_DoGetServer) error { + close(s.doGetStarted) + if s.doGetErr != nil { + close(s.doGetDone) + return s.doGetErr + } + + builder := array.NewInt64Builder(memory.DefaultAllocator) + builder.Append(1) + arr := builder.NewArray() + builder.Release() + defer arr.Release() + + record := array.NewRecordBatch(s.schema, []arrow.Array{arr}, 1) + defer record.Release() + + writer := flight.NewRecordWriter(stream, ipc.WithSchema(s.schema), ipc.WithAllocator(memory.DefaultAllocator)) + if err := writer.Write(record); err != nil { + return err + } + if s.closeAfterFirstBatch { + close(s.doGetDone) + return nil + } + + <-stream.Context().Done() + close(s.doGetContextCanceled) + <-s.allowDoGetReturn + close(s.doGetDone) + return stream.Context().Err() +} + +func (s *closeWaitFlightServer) DoAction(action *flight.Action, stream pb.FlightService_DoActionServer) error { + switch action.Type { + case releaseQueryHandleAction: + s.releaseActionOnce.Do(func() { + close(s.releaseActionCalled) + }) + var payload wire.WorkerReleaseQueryHandlePayload + if err := json.Unmarshal(action.Body, &payload); err != nil { + return err + } + s.releaseTicket = payload.Ticket + return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) + case waitSessionIdleAction: + default: + return s.UnimplementedFlightServiceServer.DoAction(action, stream) + } + s.doActionOnce.Do(func() { + close(s.doActionCalled) + }) + var payload wire.WorkerWaitSessionIdlePayload + if err := json.Unmarshal(action.Body, &payload); err != nil { + return err + } + if s.doActionErr != nil { + return s.doActionErr + } + if s.waitBeforeDoGetCancel != nil { + select { + case <-s.doGetContextCanceled: + case <-time.After(200 * time.Millisecond): + s.waitBeforeOnce.Do(func() { + close(s.waitBeforeDoGetCancel) + }) + return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) + } + } + <-s.doGetDone + return stream.Send(&flight.Result{Body: []byte(`{"ok":true}`)}) +} + +func newCloseWaitExecutor(t *testing.T, srv *closeWaitFlightServer) (*FlightExecutor, context.Context) { + t.Helper() + grpcSrv := grpc.NewServer() + pb.RegisterFlightServiceServer(grpcSrv, srv) + lis, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + go func() { _ = grpcSrv.Serve(lis) }() + t.Cleanup(grpcSrv.Stop) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + flightCli, err := flight.NewClientWithMiddlewareCtx( + ctx, lis.Addr().String(), nil, nil, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("flight client: %v", err) + } + t.Cleanup(func() { _ = flightCli.Close() }) + + exec := NewFlightExecutorFromClient(&flightsql.Client{Client: flightCli}, "session-1") + t.Cleanup(func() { _ = exec.Close() }) + return exec, ctx +} + +func TestFlightRowSetCloseWaitsForWorkerDoGetCleanup(t *testing.T) { + srv := newCloseWaitFlightServer() + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("QueryContext: %v", err) + } + select { + case <-srv.doGetStarted: + case <-time.After(time.Second): + t.Fatal("DoGet did not start") + } + + closeReturned := make(chan error, 1) + go func() { + closeReturned <- rows.Close() + }() + + select { + case err := <-closeReturned: + t.Fatalf("Close returned before worker DoGet cleanup completed: %v", err) + case <-srv.doGetContextCanceled: + case <-time.After(time.Second): + t.Fatal("worker DoGet did not observe cancellation") + } + + select { + case err := <-closeReturned: + t.Fatalf("Close returned while worker DoGet cleanup was still blocked: %v", err) + case <-time.After(50 * time.Millisecond): + } + + close(srv.allowDoGetReturn) + select { + case err := <-closeReturned: + if err != nil { + t.Fatalf("Close returned error: %v", err) + } + case <-time.After(time.Second): + t.Fatal("Close did not return after worker DoGet cleanup completed") + } +} + +func TestFlightRowSetCloseSkipsWaitAfterCleanEOF(t *testing.T) { + srv := newCloseWaitFlightServer() + srv.closeAfterFirstBatch = true + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("QueryContext: %v", err) + } + if !rows.Next() { + t.Fatalf("expected first row, err=%v", rows.Err()) + } + if rows.Next() { + t.Fatal("expected EOF after first row") + } + if err := rows.Err(); err != nil { + t.Fatalf("rowset error: %v", err) + } + if err := rows.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + select { + case <-srv.doActionCalled: + t.Fatal("Close called WaitSessionIdle after a clean EOF") + case <-time.After(50 * time.Millisecond): + } +} + +func TestFlightRowSetCloseSkipsWaitAfterExecutorMarkedDead(t *testing.T) { + srv := newCloseWaitFlightServer() + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("QueryContext: %v", err) + } + select { + case <-srv.doGetStarted: + case <-time.After(time.Second): + t.Fatal("DoGet did not start") + } + + exec.MarkDead() + closeReturned := make(chan error, 1) + go func() { + closeReturned <- rows.Close() + }() + + select { + case err := <-closeReturned: + if err != nil { + t.Fatalf("Close returned error: %v", err) + } + case <-time.After(100 * time.Millisecond): + close(srv.allowDoGetReturn) + t.Fatal("Close waited for worker idle after executor was marked dead") + } + select { + case <-srv.doActionCalled: + t.Fatal("Close called WaitSessionIdle after executor was marked dead") + default: + } + close(srv.allowDoGetReturn) +} + +func TestFlightRowSetCloseTreatsTerminalWaitFailureAsBestEffort(t *testing.T) { + srv := newCloseWaitFlightServer() + srv.doActionErr = status.Error(codes.Unavailable, "worker is gone") + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err != nil { + t.Fatalf("QueryContext: %v", err) + } + select { + case <-srv.doGetStarted: + case <-time.After(time.Second): + t.Fatal("DoGet did not start") + } + + closeReturned := make(chan error, 1) + go func() { + closeReturned <- rows.Close() + }() + select { + case err := <-closeReturned: + if err != nil { + t.Fatalf("Close returned terminal wait failure: %v", err) + } + case <-time.After(time.Second): + close(srv.allowDoGetReturn) + t.Fatal("Close did not return after terminal WaitSessionIdle failure") + } + close(srv.allowDoGetReturn) +} + +func TestQueryContextReleasesHandleWhenDoGetFails(t *testing.T) { + srv := newCloseWaitFlightServer() + srv.doGetErr = status.Error(codes.Canceled, "context canceled") + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err == nil { + if rows != nil { + _ = rows.Close() + } + t.Fatal("expected QueryContext to return DoGet error") + } + + select { + case <-srv.releaseActionCalled: + case <-time.After(time.Second): + t.Fatal("QueryContext did not release the abandoned query handle") + } + ticket, err := flightsql.GetStatementQueryTicket(&flight.Ticket{Ticket: srv.releaseTicket}) + if err != nil { + t.Fatalf("release ticket did not decode: %v", err) + } + if got := string(ticket.GetStatementHandle()); got != "query-1" { + t.Fatalf("released handle %q, want query-1", got) + } + + select { + case <-srv.doActionCalled: + case <-time.After(time.Second): + t.Fatal("QueryContext did not wait for the session to become idle after handle release") + } +} + +func TestQueryContextCancelsDoGetBeforeWaitingAfterSchemaError(t *testing.T) { + srv := newCloseWaitFlightServer() + srv.invalidInfoSchema = true + srv.waitBeforeDoGetCancel = make(chan struct{}) + close(srv.allowDoGetReturn) + exec, ctx := newCloseWaitExecutor(t, srv) + + rows, err := exec.QueryContext(ctx, "SELECT 1") + if err == nil { + if rows != nil { + _ = rows.Close() + } + t.Fatal("expected QueryContext to return schema error") + } + + select { + case <-srv.doGetContextCanceled: + case <-time.After(time.Second): + t.Fatal("DoGet did not observe cancellation") + } + select { + case <-srv.doActionCalled: + case <-time.After(time.Second): + t.Fatal("QueryContext did not wait for the session to become idle") + } + select { + case <-srv.waitBeforeDoGetCancel: + t.Fatal("QueryContext waited for session idle before canceling DoGet") + default: + } +} diff --git a/server/wire/worker_proto.go b/server/wire/worker_proto.go index 9ba1918a..829334b3 100644 --- a/server/wire/worker_proto.go +++ b/server/wire/worker_proto.go @@ -51,3 +51,17 @@ type WorkerDestroySessionPayload struct { type WorkerHealthCheckPayload struct { WorkerControlMetadata } + +// WorkerWaitSessionIdlePayload asks a worker to acknowledge that the session's +// current Flight SQL operation has released its worker-side lifecycle state. +type WorkerWaitSessionIdlePayload struct { + WorkerControlMetadata +} + +// WorkerReleaseQueryHandlePayload asks a worker to release a statement query +// handle that was created by GetFlightInfo but abandoned before DoGet could +// consume it. +type WorkerReleaseQueryHandlePayload struct { + WorkerControlMetadata + Ticket []byte `json:"ticket"` +} diff --git a/server/worker_control.go b/server/worker_control.go index 38543837..449c16f1 100644 --- a/server/worker_control.go +++ b/server/worker_control.go @@ -2,15 +2,15 @@ package server import "github.com/posthog/duckgres/server/wire" -// Aliases retained so existing references to server.WorkerControlMetadata, -// server.WorkerCreateSessionPayload, server.WorkerDestroySessionPayload and -// server.WorkerHealthCheckPayload continue to compile after the types -// moved to server/wire. New code should import server/wire and use -// wire.X directly. +// Aliases retained so existing references to server.Worker* payloads continue +// to compile after the types moved to server/wire. New code should import +// server/wire and use wire.X directly. type ( - WorkerControlMetadata = wire.WorkerControlMetadata - WorkerCreateSessionPayload = wire.WorkerCreateSessionPayload - WorkerDestroySessionPayload = wire.WorkerDestroySessionPayload - WorkerHealthCheckPayload = wire.WorkerHealthCheckPayload + WorkerControlMetadata = wire.WorkerControlMetadata + WorkerCreateSessionPayload = wire.WorkerCreateSessionPayload + WorkerDestroySessionPayload = wire.WorkerDestroySessionPayload + WorkerHealthCheckPayload = wire.WorkerHealthCheckPayload + WorkerWaitSessionIdlePayload = wire.WorkerWaitSessionIdlePayload + WorkerReleaseQueryHandlePayload = wire.WorkerReleaseQueryHandlePayload ) diff --git a/tests/e2e-mw-dev/README.md b/tests/e2e-mw-dev/README.md index d59b6e8e..f29ff8aa 100644 --- a/tests/e2e-mw-dev/README.md +++ b/tests/e2e-mw-dev/README.md @@ -49,7 +49,8 @@ client-go: statements **discarded until Sync** (the queued INSERT must not execute); the statement after `\syncpipeline` must execute normally. This is why the harness Job image is `postgres:18-alpine` (pipeline meta-commands are - psql 18+). + psql 18+). The same wire lane also asserts that a pgwire CancelRequest leaves + the same session immediately reusable. - **cold-burst absorption** — there is no warm pool, so a burst of cold sessions spawns workers on demand; if it outruns the org/global cap the surplus gets a graceful client-visible hint (`no Duckgres worker … retry in about 45 seconds` @@ -148,6 +149,12 @@ normal `go test ./...` lane. GetFlightInfo-to-DoGet handoffs, and abandoned-continuation cleanup are covered by `duckdbservice` unit tests. The harness still asserts the cluster invariant this protects: one active session owns one worker. +- **Worker DoGet close acknowledgement internals** — the harness covers the + black-box pgwire behavior (CancelRequest then immediate same-session reuse), + but not the exact internal wait point. Pausing the worker exactly after it + observes gRPC cancellation but before it releases the session operation token + would need a bespoke worker/Flight fault-injection client. Covered by + `server/flightclient` and `duckdbservice` unit tests instead. - **Malformed Bind message validation (#720)** — negative count/length fields in a Bind message must return a clean `08P01` instead of panicking. Every real client (psql, lib/pq, ...) only emits well-formed Bind messages, so diff --git a/tests/e2e-mw-dev/harness.sh b/tests/e2e-mw-dev/harness.sh index d38dea5f..972b9428 100755 --- a/tests/e2e-mw-dev/harness.sh +++ b/tests/e2e-mw-dev/harness.sh @@ -24,7 +24,8 @@ # demand; any surplus gets a graceful retry hint) then served. # jsonb || keeps Postgres concat semantics through # transpilation (#716), and a pipelined extended-query error -# discards queued messages until Sync (#718). +# discards queued messages until Sync (#718). A same pgwire +# session remains usable immediately after CancelRequest. # activation : DuckLake + Iceberg catalogs attach and read/write. The FIRST # session on a cold Iceberg worker must not fail the pg_catalog # compat-view bind (the CP primes the REST catalog's schema list @@ -104,7 +105,7 @@ HEAVY_EXPECT=1500000000 # stdout would land in that capture and make jq choke ("Invalid numeric literal"). log() { echo ">>> $*" >&2; } -apk add --no-cache curl jq postgresql-client openssl >/dev/null 2>&1 || true +apk add --no-cache curl jq postgresql-client openssl python3 py3-psycopg2 >/dev/null 2>&1 || true # kubectl, for the pod-level assertions the Go suite used to make via client-go. # The Job runs as the `duckgres` SA (pods get/list/delete/patch + pods/exec + @@ -403,6 +404,81 @@ EOF pg "$1" "$2" ducklake "DROP TABLE $t;" } +# Black-box regression for async cancel cleanup: libpq's cancel path sends a +# PostgreSQL CancelRequest while keeping the same connection open. The next +# query runs immediately on that same session; if the control plane reuses the +# session before the worker has released its cancelled DoGet operation, this can +# fail with "session already has an active operation". +cancel_then_reuse_same_session() { # org password + log "cancel then immediate same-session reuse on $1" + out="$(mktemp)" + if ! python3 - "$1" "$2" "$CP_IP" "$SNI_SUFFIX" "$HEAVY_Q" >"$out" 2>&1 <<'PY' +import sys +import threading +import time + +import psycopg2 + +org, password, hostaddr, sni_suffix, heavy_q = sys.argv[1:] +conn = psycopg2.connect( + dbname="ducklake", + user="root", + password=password, + host=org + sni_suffix, + hostaddr=hostaddr, + port=5432, + sslmode="require", + connect_timeout=30, +) +conn.autocommit = True +try: + cur = conn.cursor() + timer = threading.Timer(3.0, conn.cancel) + timer.start() + cancelled = False + try: + cur.execute(heavy_q) + print("heavy query completed before cancel") + except Exception as exc: + cancelled = True + msg = str(exc) + if "cancel" not in msg.lower(): + raise SystemExit(f"heavy query failed without cancellation: {msg}") + print(f"cancel_error={msg}") + finally: + timer.cancel() + if not cancelled: + raise SystemExit("heavy query did not get cancelled") + + cur.execute( + "CREATE TEMP TABLE cancel_reuse_marker(v INT); " + "INSERT INTO cancel_reuse_marker VALUES (42); " + "SELECT v FROM cancel_reuse_marker;" + ) + marker = cur.fetchone()[0] + if marker != 42: + raise SystemExit(f"marker query returned {marker}, want 42") + print("marker=42") +finally: + conn.close() +PY + then + text="$(tr '\n' ' ' <"$out" | tail -c 400)" + rm -f "$out" + case "$text" in + *"session already has an active operation"*) + fail "cancel_then_reuse_same_session: session reused before worker cancel cleanup completed: $text" ;; + *) fail "cancel_then_reuse_same_session: cancel/reuse client failed: $text" ;; + esac + fi + text="$(tr '\n' ' ' <"$out" | tail -c 400)" + rm -f "$out" + case "$text" in + *"cancel_error="*"marker=42"*) ;; + *) fail "cancel_then_reuse_same_session: expected cancel and same-session marker success, got: $text" ;; + esac +} + # ---- bundled extension forks ---------------------------------------------- # Ported from TestK8sDucklakeExtensionIsBundledFork / TestK8sHttpfsExtensionIsBundledFork. assert_fork_extensions() { # org password @@ -1125,6 +1201,7 @@ lane_cnpg() { # full wire/catalog/concurrency/sizing coverage on the cnpg org persistent_user_secret "$CNPG" "$cnpg_pw" # after rw_ducklake (org worker hot) persistent_user_secret_isolation "$CNPG" "$cnpg_pw" pipeline_error_recovery "$CNPG" "$cnpg_pw" # after rw_ducklake (table writes proven) + cancel_then_reuse_same_session "$CNPG" "$cnpg_pw" assert_fork_extensions "$CNPG" "$cnpg_pw" # after a DuckLake R/W (httpfs loaded) iceberg_cold_first_connect "$CNPG" "$cnpg_pw" # cold first session must not fail the metadata-init bind rw_iceberg "$CNPG" "$cnpg_pw" @@ -1232,7 +1309,7 @@ main() { # mid-run image bump); it stays covered by the controlplane/ unit tests. log "SKIP version-reaper (needs an in-run image bump; see README)" - log "PASS: admin-no-query-token + wire + malformed-startup-resilience + jsonb-concat + cold-burst-absorption + pipeline-error-recovery + activation(DuckLake/Iceberg) + iceberg-cold-first-connect + ext-forks + worker-pod + concurrency + durability + crash-recovery + graceful-drain + one-session-per-worker + parallel-cold-burst-ramp + worker-sizing(cnpg DuckLake+Iceberg) + org-default-profile(ext) + persistent-user-secrets(cnpg+ext, cross-user isolation) + isolation + lifecycle-teardown, on cnpg & ext (4 parallel lanes)" + log "PASS: admin-no-query-token + wire + malformed-startup-resilience + jsonb-concat + cold-burst-absorption + pipeline-error-recovery + cancel-reuse + activation(DuckLake/Iceberg) + iceberg-cold-first-connect + ext-forks + worker-pod + concurrency + durability + crash-recovery + graceful-drain + one-session-per-worker + parallel-cold-burst-ramp + worker-sizing(cnpg DuckLake+Iceberg) + org-default-profile(ext) + persistent-user-secrets(cnpg+ext, cross-user isolation) + isolation + lifecycle-teardown, on cnpg & ext (4 parallel lanes)" } main "$@"