diff --git a/pipelines.go b/pipelines.go index ed7ef0b..5c55953 100644 --- a/pipelines.go +++ b/pipelines.go @@ -379,11 +379,19 @@ func WithWaitGroup(wg *sync.WaitGroup) Option { } type config struct { + // bufferSize is the size of the buffer for every output channel created for this stage. bufferSize int - workers int - outputs int + // workers is the number of goroutines on which to run this stage. + workers int + // outputs is the number of output channels for this stage. + outputs int + // doNotClose is true if pipelines should not close output channels for this stage. Should be false in nearly every + // case. Stage should provide a close function if this functionality is unused. + doNotClose bool + // doneCancel is a context.CancelFunc to call when this stage is halted. doneCancel context.CancelFunc - wg *sync.WaitGroup + // wg is a sync.WaitGroup to decrement when this stage is halted. + wg *sync.WaitGroup } func makeOutputChannels[T any](c config) []chan T { @@ -435,8 +443,10 @@ func doWithConf[T any](ctx context.Context, doIt func(context.Context, ...chan T // run without a worker pool to avoid overhead from the WaitGroup go func() { defer func() { - for _, ch := range outs { - close(ch) + if !conf.doNotClose { + for _, ch := range outs { + close(ch) + } } conf.cancel() conf.done() @@ -455,8 +465,10 @@ func doWithConf[T any](ctx context.Context, doIt func(context.Context, ...chan T if id == 0 { // first thread closes the output channel. poolStopped.Wait() defer func() { - for _, ch := range outs { - close(ch) + if !conf.doNotClose { + for _, ch := range outs { + close(ch) + } } }() conf.cancel() @@ -507,27 +519,26 @@ func Reduce[S, T any](ctx context.Context, in <-chan S, f func(T, S) T) (T, erro } } -// ErrorSink provides an error-handling solution for pipelines created by this package. It manages a -// pipeline stage which can receive fatal and non-fatal errors that may occur during the course of a pipeline. +// ErrorSink provides an error-handling solution for pipelines created by this package. It manages a pipeline stage +// which can receive fatal and non-fatal errors that may occur during the course of a pipeline. +// +// ErrorSinks are safe to use from multiple pipeline stages concurrently. They must be closed to avoid leaking +// resources. ErrorSinks are eventually consistent. Calls to All() return as many errors as have been processed to date. +// At any point in time, there is no guarantee that all errors sent to ErrorSink have been received. type ErrorSink struct { - errors chan errWrapper - cancel context.CancelFunc - errs []error - needLock bool - lock *sync.Mutex - wg *sync.WaitGroup + errors chan errWrapper + cancel context.CancelFunc + errs []error + lock *sync.Mutex } // NewErrorSink returns a new ErrorSink, along with a context which is cancelled when a fatal error is sent to the // ErrorSink. Starts a new, configurable pipeline stage which catches any errors reported. func NewErrorSink(ctx context.Context, opts ...Option) (context.Context, *ErrorSink) { ctx, cancel := context.WithCancel(ctx) - result := &ErrorSink{cancel: cancel, wg: &sync.WaitGroup{}} + result := &ErrorSink{cancel: cancel, lock: &sync.Mutex{}} config := configure(opts) - if config.workers > 1 { - result.needLock = true - result.lock = &sync.Mutex{} - } + config.doNotClose = true outs := doWithConf(ctx, func(ctx context.Context, in ...chan errWrapper) { result.doErrSink(ctx, in[0]) @@ -544,7 +555,6 @@ func (s *ErrorSink) doErrSink(ctx context.Context, errors chan errWrapper) { return case werr := <-errors: s.appendErr(werr.err) - s.wg.Done() if werr.isFatal { s.cancel() } @@ -553,32 +563,34 @@ func (s *ErrorSink) doErrSink(ctx context.Context, errors chan errWrapper) { } func (s *ErrorSink) appendErr(err error) { - if s.needLock { // a lock is only needed in case the ErrorSink was started with - s.lock.Lock() - defer s.lock.Unlock() - } + s.lock.Lock() + defer s.lock.Unlock() s.errs = append(s.errs, err) } // Fatal sends a fatal error to this ErrorSink, cancelling the child context which was created by NewErrorSink, // as well as reporting this error. func (s *ErrorSink) Fatal(err error) { - s.wg.Add(1) s.errors <- errWrapper{isFatal: true, err: err} } // Error sends a non-fatal error to this ErrorSink, which is reported and included along with All() func (s *ErrorSink) Error(err error) { - s.wg.Add(1) s.errors <- errWrapper{isFatal: false, err: err} } +// Close closes the channel used to send errors for this ErrorSink. Each ErrorSink must be closed to avoid resources +// leaks +func (s *ErrorSink) Close() { + close(s.errors) +} + // All returns all errors which have been received by this ErrorSink so far. Subsequent calls to All can return strictly -// more errors, but will never return fewer errors. The only way to be certain that all errors from a pipeline have been -// reported is to pass WithWaitGroup to every pipeline stage which sends an error to this ErrorSink and wait for all -// stages to terminate before calling All(). +// more errors, but will never return fewer errors. While all errors sent to an ErrorSink eventually end up being +// reported, there is no timeframe within which they are all guaranteed to be available. func (s *ErrorSink) All() []error { - s.wg.Wait() + s.lock.Lock() + defer s.lock.Unlock() return s.errs } diff --git a/pipelines_test.go b/pipelines_test.go index db4a3a9..38d0bf6 100644 --- a/pipelines_test.go +++ b/pipelines_test.go @@ -6,6 +6,7 @@ import ( "github.com/matryer/is" "github.com/splunk/pipelines" "net/http" + "regexp" "sort" "strconv" "sync" @@ -541,11 +542,11 @@ func TestErrorSink(t *testing.T) { }) all := errs.All() - sort.Slice(all, func(i, j int) bool { - return all[i].Error() < all[j].Error() - }) is.Equal(err, nil) - is.Equal(toStr(all), []string{"1!", "2!", "3!", "4!", "err1", "err2", "err3", "err4"}) + rgx := regexp.MustCompile(`\d!|err\d`) + for _, err := range all { // not all errors will be reported; every error that does should match + is.True(rgx.MatchString(err.Error())) + } }) t.Run("fatal errors cancel returned context", func(t *testing.T) { @@ -563,13 +564,6 @@ func TestErrorSink(t *testing.T) { }) } -func toStr(errs []error) (out []string) { - for _, err := range errs { - out = append(out, err.Error()) - } - return out -} - func testWithWaitGroup[S, T any](t *testing.T, stage func(context.Context, pipelines.Option, <-chan S) <-chan T) { t.Helper() t.Run("WithWaitGroup", func(t *testing.T) { @@ -728,6 +722,7 @@ func ExampleErrorSink() { defer cancel() ctx, errs := pipelines.NewErrorSink(ctx) + defer errs.Close() urls := pipelines.Chan([]string{ "https://httpstat.us/200",