diff --git a/cli/beamable.common/Runtime/Promise.cs b/cli/beamable.common/Runtime/Promise.cs index ba51993f5b..08f8621a53 100644 --- a/cli/beamable.common/Runtime/Promise.cs +++ b/cli/beamable.common/Runtime/Promise.cs @@ -450,7 +450,19 @@ void ICriticalNotifyCompletion.UnsafeOnCompleted(Action continuation) void INotifyCompletion.OnCompleted(Action continuation) { - ((ICriticalNotifyCompletion)this).UnsafeOnCompleted(continuation); + // Mirror TaskAwaiter.OnCompleted: flow ExecutionContext so AsyncLocal + // values set by the caller survive the await and are restored on the + // thread that resolves the promise. + var capturedContext = System.Threading.ExecutionContext.Capture(); + if (capturedContext == null) + { + ((ICriticalNotifyCompletion)this).UnsafeOnCompleted(continuation); + return; + } + ((ICriticalNotifyCompletion)this).UnsafeOnCompleted(() => + { + System.Threading.ExecutionContext.Run(capturedContext, s => ((Action)s)(), continuation); + }); } /// @@ -1253,6 +1265,18 @@ public void SetStateMachine(IAsyncStateMachine machine) _stateMachine = machine; } + private void MoveNextWithCapturedContext(System.Threading.ExecutionContext capturedContext) + { + if (capturedContext == null) + { + _stateMachine.MoveNext(); + } + else + { + System.Threading.ExecutionContext.Run(capturedContext, s => ((IAsyncStateMachine)s).MoveNext(), _stateMachine); + } + } + public void AwaitOnCompleted( ref TAwaiter awaiter, ref TStateMachine stateMachine) where TAwaiter : INotifyCompletion @@ -1264,10 +1288,11 @@ public void AwaitOnCompleted( _stateMachine.SetStateMachine(stateMachine); } - awaiter.OnCompleted(() => - { - _stateMachine.MoveNext(); - }); + // Mirror AsyncTaskMethodBuilder: capture ExecutionContext at the await + // point and restore it before resuming the state machine so AsyncLocal + // values flow across the await. + var capturedContext = System.Threading.ExecutionContext.Capture(); + awaiter.OnCompleted(() => MoveNextWithCapturedContext(capturedContext)); } public void AwaitUnsafeOnCompleted( @@ -1275,7 +1300,14 @@ public void AwaitUnsafeOnCompleted( where TAwaiter : ICriticalNotifyCompletion where TStateMachine : IAsyncStateMachine { - AwaitOnCompleted(ref awaiter, ref stateMachine); + if (_stateMachine == null) + { + _stateMachine = stateMachine; + _stateMachine.SetStateMachine(stateMachine); + } + + var capturedContext = System.Threading.ExecutionContext.Capture(); + awaiter.UnsafeOnCompleted(() => MoveNextWithCapturedContext(capturedContext)); } public void Start(ref TStateMachine stateMachine) @@ -1313,6 +1345,18 @@ public void SetStateMachine(IAsyncStateMachine machine) _stateMachine = machine; } + private void MoveNextWithCapturedContext(System.Threading.ExecutionContext capturedContext) + { + if (capturedContext == null) + { + _stateMachine.MoveNext(); + } + else + { + System.Threading.ExecutionContext.Run(capturedContext, s => ((IAsyncStateMachine)s).MoveNext(), _stateMachine); + } + } + public void AwaitOnCompleted( ref TAwaiter awaiter, ref TStateMachine stateMachine) where TAwaiter : INotifyCompletion @@ -1324,10 +1368,8 @@ public void AwaitOnCompleted( _stateMachine.SetStateMachine(stateMachine); } - awaiter.OnCompleted(() => - { - _stateMachine.MoveNext(); - }); + var capturedContext = System.Threading.ExecutionContext.Capture(); + awaiter.OnCompleted(() => MoveNextWithCapturedContext(capturedContext)); } public void AwaitUnsafeOnCompleted( @@ -1335,7 +1377,14 @@ public void AwaitUnsafeOnCompleted( where TAwaiter : ICriticalNotifyCompletion where TStateMachine : IAsyncStateMachine { - AwaitOnCompleted(ref awaiter, ref stateMachine); + if (_stateMachine == null) + { + _stateMachine = stateMachine; + _stateMachine.SetStateMachine(stateMachine); + } + + var capturedContext = System.Threading.ExecutionContext.Capture(); + awaiter.UnsafeOnCompleted(() => MoveNextWithCapturedContext(capturedContext)); } public void Start(ref TStateMachine stateMachine)