Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions duckdbservice/flight_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,58 @@ func (h *FlightSQLHandler) doHealthCheck(body []byte, stream flight.FlightServic
return sendActionResult(stream, &flight.Result{Body: resp})
}

func (h *FlightSQLHandler) doWaitSessionIdle(body []byte, stream flight.FlightService_DoActionServer) error {
var req server.WorkerWaitSessionIdlePayload
if err := json.Unmarshal(body, &req); err != nil {
return status.Errorf(codes.InvalidArgument, "invalid WaitSessionIdle request: %v", err)
}
if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil {
return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err)
}
session, err := h.sessionFromContext(stream.Context())
if err != nil {
return err
}
if err := session.waitOperationIdle(stream.Context()); err != nil {
switch {
case errors.Is(err, context.Canceled):
return status.Errorf(codes.Canceled, "wait for session idle: %v", err)
case errors.Is(err, context.DeadlineExceeded):
return status.Errorf(codes.DeadlineExceeded, "wait for session idle: %v", err)
default:
return status.Errorf(codes.Internal, "wait for session idle: %v", err)
}
}

resp, _ := json.Marshal(map[string]bool{"ok": true})
return sendActionResult(stream, &flight.Result{Body: resp})
}

func (h *FlightSQLHandler) doReleaseQueryHandle(body []byte, stream flight.FlightService_DoActionServer) error {
var req server.WorkerReleaseQueryHandlePayload
if err := json.Unmarshal(body, &req); err != nil {
return status.Errorf(codes.InvalidArgument, "invalid ReleaseQueryHandle request: %v", err)
}
if err := h.pool.validateControlMetadata(req.WorkerControlMetadata); err != nil {
return status.Errorf(codes.FailedPrecondition, "stale worker owner: %v", err)
}
session, err := h.sessionFromContext(stream.Context())
if err != nil {
return err
}
ticket, err := flightsql.GetStatementQueryTicket(&flight.Ticket{Ticket: req.Ticket})
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid statement ticket: %v", err)
}
handleID := string(ticket.GetStatementHandle())
if handle, ok := popQueryHandle(session, handleID); ok {
releaseQueryHandleValue(handle)
}

resp, _ := json.Marshal(map[string]bool{"ok": true})
return sendActionResult(stream, &flight.Result{Body: resp})
}

// Flight SQL method implementations

func (h *FlightSQLHandler) GetFlightInfoStatement(ctx context.Context, cmd flightsql.StatementQuery,
Expand Down
118 changes: 118 additions & 0 deletions duckdbservice/flight_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,124 @@ func TestCreateSessionRejectsWhileDraining(t *testing.T) {
}
}

func TestWaitSessionIdleBlocksUntilOperationReleases(t *testing.T) {
session := &Session{
ID: "session-1",
Username: "alice",
CreatedAt: time.Now(),
queries: make(map[string]*QueryHandle),
txns: make(map[string]*trackedTx),
txnOwner: make(map[string]string),
}
pool := &SessionPool{
sessions: map[string]*Session{session.ID: session},
stopRefresh: make(map[string]func()),
warmupDone: make(chan struct{}),
startTime: time.Now(),
}
close(pool.warmupDone)
handler := &FlightSQLHandler{pool: pool, alloc: memory.DefaultAllocator}

finishOperation, ok := session.beginOperation()
if !ok {
t.Fatal("beginOperation rejected test session")
}

body, err := json.Marshal(server.WorkerWaitSessionIdlePayload{})
if err != nil {
t.Fatalf("marshal request: %v", err)
}
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", session.ID))
stream := &mockDoActionStream{ctx: ctx}
done := make(chan error, 1)
go func() {
done <- handler.doWaitSessionIdle(body, stream)
}()

select {
case err := <-done:
t.Fatalf("WaitSessionIdle returned before operation released: %v", err)
case <-time.After(50 * time.Millisecond):
}

finishOperation()
select {
case err := <-done:
if err != nil {
t.Fatalf("WaitSessionIdle returned error: %v", err)
}
case <-time.After(time.Second):
t.Fatal("WaitSessionIdle did not return after operation released")
}
if len(stream.results) != 1 {
t.Fatalf("expected one action result, got %d", len(stream.results))
}
}

func TestReleaseQueryHandleReleasesAbandonedOperation(t *testing.T) {
pool := &SessionPool{
sessions: make(map[string]*Session),
stopRefresh: make(map[string]func()),
warmupDone: make(chan struct{}),
startTime: time.Now(),
}
close(pool.warmupDone)
session := &Session{
ID: "session-1",
Username: "alice",
CreatedAt: time.Now(),
queries: make(map[string]*QueryHandle),
txns: make(map[string]*trackedTx),
txnOwner: make(map[string]string),
}
pool.sessions[session.ID] = session
handler := &FlightSQLHandler{pool: pool, alloc: memory.DefaultAllocator}

finishOperation, ok := session.beginOperation()
if !ok {
t.Fatal("beginOperation rejected test session")
}
finishDrain, err := pool.beginDrainWork(false)
if err != nil {
t.Fatalf("beginDrainWork: %v", err)
}
session.queries["query-1"] = &QueryHandle{
Query: "SELECT 1",
createdAt: time.Now(),
finishDrain: finishDrain,
finishOperation: finishOperation,
}

ticket, err := flightsql.CreateStatementQueryTicket([]byte("query-1"))
if err != nil {
t.Fatalf("create ticket: %v", err)
}
body, err := json.Marshal(server.WorkerReleaseQueryHandlePayload{Ticket: ticket})
if err != nil {
t.Fatalf("marshal request: %v", err)
}
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("x-duckgres-session", session.ID))
stream := &mockDoActionStream{ctx: ctx}

if err := handler.doReleaseQueryHandle(body, stream); err != nil {
t.Fatalf("ReleaseQueryHandle: %v", err)
}
if _, ok := session.queries["query-1"]; ok {
t.Fatal("query handle was not removed")
}
if got := pool.ActiveDrainWork(); got != 0 {
t.Fatalf("active drain work=%d, want 0", got)
}
finishOperation2, ok := session.beginOperation()
if !ok {
t.Fatal("operation gate was not released")
}
finishOperation2()
if len(stream.results) != 1 {
t.Fatalf("expected one action result, got %d", len(stream.results))
}
}

func TestCreateSessionSendFailureDestroysSession(t *testing.T) {
pool := &SessionPool{
sessions: make(map[string]*Session),
Expand Down
37 changes: 37 additions & 0 deletions duckdbservice/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ type Session struct {
txnOwner map[string]string
closed bool
operationOpen bool
operationIdle chan struct{}
connWork int
connWorkDone *sync.Cond
handleCounter atomic.Uint64
Expand Down Expand Up @@ -234,13 +235,18 @@ func (s *Session) beginOperation() (func(), bool) {
return nil, false
}
s.operationOpen = true
s.operationIdle = make(chan struct{})
s.mu.Unlock()

var once sync.Once
return func() {
once.Do(func() {
s.mu.Lock()
s.operationOpen = false
if s.operationIdle != nil {
close(s.operationIdle)
s.operationIdle = nil
}
s.mu.Unlock()
})
}, true
Expand All @@ -257,18 +263,45 @@ func (s *Session) beginOperationForTransaction(txnKey string) (func(), bool, boo
return nil, true, false
}
s.operationOpen = true
s.operationIdle = make(chan struct{})
s.mu.Unlock()

var once sync.Once
return func() {
once.Do(func() {
s.mu.Lock()
s.operationOpen = false
if s.operationIdle != nil {
close(s.operationIdle)
s.operationIdle = nil
}
s.mu.Unlock()
})
}, true, true
}

func (s *Session) waitOperationIdle(ctx context.Context) error {
for {
s.mu.RLock()
if !s.operationOpen {
s.mu.RUnlock()
return nil
}
idle := s.operationIdle
s.mu.RUnlock()

if idle == nil {
return errors.New("session operation idle signal missing")
}
select {
case <-idle:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
}

// beginConnWork fences any operation that uses the session connection while a
// raw SQL transaction may be open. It intentionally does not mutate queryActive:
// conn work includes COPY receive and metadata/planning work, while queryActive
Expand Down Expand Up @@ -1651,6 +1684,10 @@ func (s *customActionServer) DoAction(cmd *flight.Action, stream flight.FlightSe
return s.handler.doDestroySession(cmd.Body, stream)
case "HealthCheck":
return s.handler.doHealthCheck(cmd.Body, stream)
case "WaitSessionIdle":
return s.handler.doWaitSessionIdle(cmd.Body, stream)
case "ReleaseQueryHandle":
return s.handler.doReleaseQueryHandle(cmd.Body, stream)
default:
// Fall through to standard flightsql action router (BeginTransaction, etc.)
return s.FlightServer.DoAction(cmd, stream)
Expand Down
Loading
Loading