|
| 1 | +package utils |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "iter" |
| 6 | + "sync" |
| 7 | + "sync/atomic" |
| 8 | + |
| 9 | + "golang.org/x/sync/errgroup" |
| 10 | +) |
| 11 | + |
| 12 | +// Mutex guards access to object of type T. |
| 13 | +type Mutex[T any] struct { |
| 14 | + mu sync.Mutex |
| 15 | + value T |
| 16 | +} |
| 17 | + |
| 18 | +// NewMutex creates a new Mutex with given object. |
| 19 | +func NewMutex[T any](value T) (m Mutex[T]) { |
| 20 | + m.value = value |
| 21 | + // nolint:nakedret |
| 22 | + return |
| 23 | +} |
| 24 | + |
| 25 | +// Lock returns an iterator which locks the mutex and yields the guarded object. |
| 26 | +// The mutex is unlocked when the iterator is done. |
| 27 | +// If the mutex is nil, the iterator is a no-op. |
| 28 | +func (m *Mutex[T]) Lock() iter.Seq[T] { |
| 29 | + return func(yield func(val T) bool) { |
| 30 | + m.mu.Lock() |
| 31 | + defer m.mu.Unlock() |
| 32 | + _ = yield(m.value) |
| 33 | + } |
| 34 | +} |
| 35 | + |
| 36 | +// version of the value stored in an atomic watch. |
| 37 | +type version[T any] struct { |
| 38 | + updated chan struct{} |
| 39 | + value T |
| 40 | +} |
| 41 | + |
| 42 | +// newVersion constructs a new active version. |
| 43 | +func newVersion[T any](value T) *version[T] { |
| 44 | + return &version[T]{make(chan struct{}), value} |
| 45 | +} |
| 46 | + |
| 47 | +type atomicWatch[T any] struct { |
| 48 | + ptr atomic.Pointer[version[T]] |
| 49 | +} |
| 50 | + |
| 51 | +type AtomicSend[T any] struct { |
| 52 | + atomicWatch[T] |
| 53 | +} |
| 54 | + |
| 55 | +// Store updates the value of the atomic watch. |
| 56 | +func (w *AtomicSend[T]) Send(value T) { |
| 57 | + close(w.ptr.Swap(newVersion(value)).updated) |
| 58 | +} |
| 59 | + |
| 60 | +// Update conditionally updates the value of the atomic watch. |
| 61 | +func (w *AtomicSend[T]) Update(f func(T) (T, bool)) { |
| 62 | + old := w.ptr.Load() |
| 63 | + if value, ok := f(old.value); ok { |
| 64 | + w.ptr.Store(newVersion(value)) |
| 65 | + close(old.updated) |
| 66 | + } |
| 67 | +} |
| 68 | + |
| 69 | +func NewAtomicSend[T any](value T) (w AtomicSend[T]) { |
| 70 | + w.atomicWatch.ptr.Store(newVersion(value)) |
| 71 | + // nolint:nakedret |
| 72 | + return |
| 73 | +} |
| 74 | + |
| 75 | +func (w *AtomicSend[T]) Subscribe() AtomicRecv[T] { |
| 76 | + return AtomicRecv[T]{&w.atomicWatch} |
| 77 | +} |
| 78 | + |
| 79 | +// AtomicWatch stores a pointer to an IMMUTABLE value. |
| 80 | +// Loading and waiting for updates do NOT require locking. |
| 81 | +// TODO(gprusak): remove mutex and rename to AtomicSend, |
| 82 | +// this will allow for sharing a mutex across multiple AtomicSenders. |
| 83 | +type AtomicWatch[T any] struct { |
| 84 | + atomicWatch[T] |
| 85 | + mu sync.Mutex |
| 86 | +} |
| 87 | + |
| 88 | +// AtomicRecv is a read-only reference to AtomicWatch. |
| 89 | +type AtomicRecv[T any] struct{ *atomicWatch[T] } |
| 90 | + |
| 91 | +// NewAtomicWatch creates a new AtomicWatch with the given initial value. |
| 92 | +func NewAtomicWatch[T any](value T) (w AtomicWatch[T]) { |
| 93 | + w.ptr.Store(newVersion(value)) |
| 94 | + // nolint:nakedret |
| 95 | + return |
| 96 | +} |
| 97 | + |
| 98 | +// Subscribe returns a view-only API of the atomic watch. |
| 99 | +func (w *AtomicWatch[T]) Subscribe() AtomicRecv[T] { |
| 100 | + return AtomicRecv[T]{&w.atomicWatch} |
| 101 | +} |
| 102 | + |
| 103 | +// Load returns the current value of the atomic watch. |
| 104 | +// Does not do any locking. |
| 105 | +func (w *atomicWatch[T]) Load() T { return w.ptr.Load().value } |
| 106 | + |
| 107 | +// Store updates the value of the atomic watch. |
| 108 | +func (w *AtomicWatch[T]) Store(value T) { |
| 109 | + w.mu.Lock() |
| 110 | + defer w.mu.Unlock() |
| 111 | + close(w.ptr.Swap(newVersion(value)).updated) |
| 112 | +} |
| 113 | + |
| 114 | +// Update conditionally updates the value of the atomic watch. |
| 115 | +func (w *AtomicWatch[T]) Update(f func(T) (T, bool)) { |
| 116 | + w.mu.Lock() |
| 117 | + defer w.mu.Unlock() |
| 118 | + old := w.ptr.Load() |
| 119 | + if value, ok := f(old.value); ok { |
| 120 | + w.ptr.Store(newVersion(value)) |
| 121 | + close(old.updated) |
| 122 | + } |
| 123 | +} |
| 124 | + |
| 125 | +// Wait waits for the value of the atomic watch to satisfy the predicate. |
| 126 | +// Does not do any locking. |
| 127 | +func (w *atomicWatch[T]) Wait(ctx context.Context, pred func(T) bool) (T, error) { |
| 128 | + for { |
| 129 | + v := w.ptr.Load() |
| 130 | + if pred(v.value) { |
| 131 | + return v.value, nil |
| 132 | + } |
| 133 | + select { |
| 134 | + case <-ctx.Done(): |
| 135 | + return Zero[T](), ctx.Err() |
| 136 | + case <-v.updated: |
| 137 | + } |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +// Iter executes sequentially the function f on each value of the atomic watch. |
| 142 | +// Context passed to f is canceled when the next value is available. |
| 143 | +// Exits when the returned error is different from nil and context.Canceled, |
| 144 | +// or when the context passed to Iter is canceled (after f exits). |
| 145 | +func (w *atomicWatch[T]) Iter(ctx context.Context, f func(ctx context.Context, v T) error) error { |
| 146 | + for ctx.Err() == nil { |
| 147 | + v := w.ptr.Load() |
| 148 | + g, ctx := errgroup.WithContext(ctx) |
| 149 | + g.Go(func() error { return f(ctx, v.value) }) |
| 150 | + g.Go(func() error { |
| 151 | + select { |
| 152 | + case <-ctx.Done(): |
| 153 | + case <-v.updated: |
| 154 | + } |
| 155 | + return context.Canceled |
| 156 | + }) |
| 157 | + if err := IgnoreCancel(g.Wait()); err != nil { |
| 158 | + return err |
| 159 | + } |
| 160 | + } |
| 161 | + return ctx.Err() |
| 162 | +} |
| 163 | + |
| 164 | +// WatchCtrl controls the locked object in a Watch. |
| 165 | +// It is provided only in the iterator returned by Lock(). |
| 166 | +// Should NOT be stored anywhere. |
| 167 | +type WatchCtrl struct { |
| 168 | + mu sync.Mutex |
| 169 | + updated chan struct{} |
| 170 | +} |
| 171 | + |
| 172 | +// Watch stores a value of type T. |
| 173 | +// Essentially a mutex, that can be awaited for updates. |
| 174 | +type Watch[T any] struct { |
| 175 | + ctrl WatchCtrl |
| 176 | + val T |
| 177 | +} |
| 178 | + |
| 179 | +// NewWatch constructs a new watch with the given value. |
| 180 | +// Note that value in the watch cannot be changed, so T |
| 181 | +// should be a pointer type if updates are required. |
| 182 | +func NewWatch[T any](val T) Watch[T] { |
| 183 | + return Watch[T]{ |
| 184 | + WatchCtrl{updated: make(chan struct{})}, |
| 185 | + val, |
| 186 | + } |
| 187 | +} |
| 188 | + |
| 189 | +// Wait waits for the value in the watch to be updated. |
| 190 | +// Should be called only after locking the watch, i.e. within Lock() iterator. |
| 191 | +// It unlocks -> waits for the update -> locks again. |
| 192 | +func (c *WatchCtrl) Wait(ctx context.Context) error { |
| 193 | + updated := c.updated |
| 194 | + c.mu.Unlock() |
| 195 | + defer c.mu.Lock() |
| 196 | + select { |
| 197 | + case <-ctx.Done(): |
| 198 | + return ctx.Err() |
| 199 | + case <-updated: |
| 200 | + return nil |
| 201 | + } |
| 202 | +} |
| 203 | + |
| 204 | +// WaitUntil waits for the value in the watch to satisfy the predicate. |
| 205 | +// Should be called only after locking the watch, i.e. within Lock() iterator. |
| 206 | +// The predicate is evaluated under the lock, so it can access the guarded object. |
| 207 | +func (c *WatchCtrl) WaitUntil(ctx context.Context, pred func() bool) error { |
| 208 | + for !pred() { |
| 209 | + if err := c.Wait(ctx); err != nil { |
| 210 | + return err |
| 211 | + } |
| 212 | + } |
| 213 | + return nil |
| 214 | +} |
| 215 | + |
| 216 | +// Updated signals waiters that the value in the watch has been updated. |
| 217 | +func (c *WatchCtrl) Updated() { |
| 218 | + close(c.updated) |
| 219 | + c.updated = make(chan struct{}) |
| 220 | +} |
| 221 | + |
| 222 | +// Lock returns an iterator which locks the watch and yields the guarded object. |
| 223 | +// The watch is unlocked when the iterator is done. |
| 224 | +// If the watch is nil, the iterator is a no-op. |
| 225 | +// Additionally the WatchCtrl object is provided to the yield function: |
| 226 | +// * to unlock -> wait for the update -> lock again, call ctrl.Wait(ctx) |
| 227 | +// * to signal an update, call ctrl.Updated(). |
| 228 | +func (w *Watch[T]) Lock() iter.Seq2[T, *WatchCtrl] { |
| 229 | + return func(yield func(val T, ctrl *WatchCtrl) bool) { |
| 230 | + w.ctrl.mu.Lock() |
| 231 | + defer w.ctrl.mu.Unlock() |
| 232 | + _ = yield(w.val, &w.ctrl) |
| 233 | + } |
| 234 | +} |
0 commit comments