Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions mod/modregistry/anduin_parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,42 @@ func (w *parallelWorker[T]) run(taskId int, task func(context.Context) (T, error
defer w.wg.Done()
resp, err := task(w.ctx)
if err != nil {
w.errCh <- err
// Non-blocking send: if wait() has already cancelled on a
// prior error and stopped reading errCh, we must not block
// here or the worker goroutine leaks forever. wg.Done runs
// via defer either way, so wait() still terminates.
select {
case w.errCh <- err:
case <-w.ctx.Done():
}
return
}
w.respCh <- &taskResponse[T]{
select {
case w.respCh <- &taskResponse[T]{
taskId: taskId,
response: resp,
}:
case <-w.ctx.Done():
}
}()
}

func (w *parallelWorker[T]) close() {
close(w.errCh)
close(w.respCh)
}

func (w *parallelWorker[T]) wait() ([]T, error) {
defer w.close()
// Signal every worker to stop on return and block until they have
// all left their channel sends. Only then is it safe to close the
// channels, otherwise a late worker would panic sending on a closed
// channel.
defer func() {
w.cancel()
w.wg.Wait()
close(w.errCh)
close(w.respCh)
}()

done := make(chan struct{}, 1)
done := make(chan struct{})
go func() {
w.wg.Wait()
done <- struct{}{}
close(done)
}()

var err error
Expand All @@ -82,9 +96,15 @@ func (w *parallelWorker[T]) wait() ([]T, error) {
resp[idx] = r.response
}
return resp, err
case err = <-w.errCh:
w.cancel()
done <- struct{}{}
case e := <-w.errCh:
// First error wins. Cancel the context so any still-running
// workers abort, then keep draining respCh/errCh until the
// wg watcher signals `done`. Do not touch `done` here — the
// watcher owns it.
if err == nil {
err = e
w.cancel()
}
case data := <-w.respCh:
taskResps = append(taskResps, data)
}
Expand Down
109 changes: 109 additions & 0 deletions mod/modregistry/anduin_parallel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (C) 2014-2026 Anduin Transactions Inc.

package modregistry

import (
"context"
"errors"
"testing"
"time"
)

func waitWithTimeout[T any](t *testing.T, w *parallelWorker[T], timeout time.Duration) ([]T, error) {
t.Helper()
type result struct {
resp []T
err error
}
resCh := make(chan result, 1)
go func() {
resp, err := w.wait()
resCh <- result{resp: resp, err: err}
}()
select {
case r := <-resCh:
return r.resp, r.err
case <-time.After(timeout):
t.Fatalf("parallelWorker.wait() deadlocked (waited %s)", timeout)
return nil, nil
}
}

// When a task returns an error, parallelWorker.wait() must not deadlock.
//
// Previously, on error the main goroutine would re-send on the 1-capacity
// done channel that was already filled by the wg watcher goroutine once
// all workers finished. The second send had no receiver and hung forever,
// producing "fatal error: all goroutines are asleep - deadlock!".
func TestParallelWorkerErrorNoDeadlock(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

w := newParallelWorker[int](ctx)
taskErr := errors.New("boom")
w.run(0, func(ctx context.Context) (int, error) {
return 0, taskErr
})

_, err := waitWithTimeout(t, w, 2*time.Second)
if !errors.Is(err, taskErr) {
t.Fatalf("expected %v, got %v", taskErr, err)
}
}

// Multiple successful tasks arriving in any order must not deadlock.
func TestParallelWorkerMultipleSuccess(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

w := newParallelWorker[int](ctx)
const n = 16
for i := 0; i < n; i++ {
i := i
w.run(i, func(ctx context.Context) (int, error) {
return i * i, nil
})
}

resp, err := waitWithTimeout(t, w, 2*time.Second)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resp) != n {
t.Fatalf("expected %d responses, got %d", n, len(resp))
}
for i, v := range resp {
if v != i*i {
t.Fatalf("unexpected response at %d: got %d, want %d", i, v, i*i)
}
}
}

// Error mixed with slow success tasks must still return the error and not
// deadlock waiting on the in-flight success goroutines.
func TestParallelWorkerErrorWithSlowSiblings(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

w := newParallelWorker[int](ctx)
taskErr := errors.New("boom")

w.run(0, func(ctx context.Context) (int, error) {
return 0, taskErr
})
for i := 1; i < 4; i++ {
w.run(i, func(ctx context.Context) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
case <-time.After(200 * time.Millisecond):
return 42, nil
}
})
}

_, err := waitWithTimeout(t, w, 3*time.Second)
if err == nil {
t.Fatalf("expected an error, got nil")
}
}