diff --git a/src/BenchmarkDotNet.Analyzers/AsyncTypeShapes.cs b/src/BenchmarkDotNet.Analyzers/AsyncTypeShapes.cs index df239b0570..bce65bfab4 100644 --- a/src/BenchmarkDotNet.Analyzers/AsyncTypeShapes.cs +++ b/src/BenchmarkDotNet.Analyzers/AsyncTypeShapes.cs @@ -25,11 +25,15 @@ public static bool IsAsyncEnumerable(ITypeSymbol type, INamedTypeSymbol? asyncEn return true; } - if (TryFindPatternGetAsyncEnumerator(type) is { } enumeratorType - && HasPatternMoveNextAsync(enumeratorType) - && HasPublicInstanceProperty(enumeratorType, "Current")) + if (TryFindPatternGetAsyncEnumerator(type) is { } enumeratorType) { - return true; + // Roslyn commits to a found pattern `GetAsyncEnumerator` — if its return type doesn't + // satisfy the await-foreach enumerator shape it reports an error instead of falling back + // to `IAsyncEnumerable`, even when the source also implements the interface. We mirror + // that here so the analyzer's view of binding matches what `await foreach` would actually + // accept. + return HasPatternMoveNextAsync(enumeratorType) + && HasPublicInstanceProperty(enumeratorType, "Current"); } if (asyncEnumerableInterfaceSymbol != null) diff --git a/src/BenchmarkDotNet/Code/CodeGenerator.cs b/src/BenchmarkDotNet/Code/CodeGenerator.cs index 2b87797af6..675b1f94cf 100644 --- a/src/BenchmarkDotNet/Code/CodeGenerator.cs +++ b/src/BenchmarkDotNet/Code/CodeGenerator.cs @@ -113,14 +113,14 @@ private static DeclarationsProvider GetDeclarationsProvider(BenchmarkCase benchm { var method = benchmark.Descriptor.WorkloadMethod; - if (method.ReturnType.IsAwaitable()) + if (method.ReturnType.IsAwaitable(out var awaitableInfo)) { - return new AsyncDeclarationsProvider(benchmark); + return new AsyncDeclarationsProvider(benchmark, awaitableInfo.ResultType); } - if (method.ReturnType.IsAsyncEnumerable(out var itemType, out var enumeratorType, out var moveNextAwaitableType)) + if (method.ReturnType.IsAsyncEnumerable(out var asyncEnumerableInfo)) { - return new AsyncEnumerableDeclarationsProvider(benchmark, itemType, enumeratorType, moveNextAwaitableType); + return new AsyncEnumerableDeclarationsProvider(benchmark, asyncEnumerableInfo.ItemType, asyncEnumerableInfo.MoveNextAsyncMethod.ReturnType); } if (method.ReturnType == typeof(void) && method.HasAttribute()) diff --git a/src/BenchmarkDotNet/Code/DeclarationsProvider.cs b/src/BenchmarkDotNet/Code/DeclarationsProvider.cs index 88d9e4e610..040e1ad945 100644 --- a/src/BenchmarkDotNet/Code/DeclarationsProvider.cs +++ b/src/BenchmarkDotNet/Code/DeclarationsProvider.cs @@ -47,7 +47,7 @@ private void Replace(SmartStringBuilder smartStringBuilder, MethodInfo? method, userImpl = string.Empty; needsExplicitReturn = true; } - else if (method.ReturnType.IsAwaitable()) + else if (method.ReturnType.IsAwaitable(out _)) { modifier = "async"; userImpl = $"await {GetMethodPrefix(method)}.{method.Name}();"; @@ -97,7 +97,7 @@ protected string GetPassArgumentsDirect() ); } - internal class SyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark) + internal sealed class SyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark) { public override string[] GetExtraFields() => []; @@ -175,8 +175,13 @@ private string GetPassArguments() ); } - internal class AsyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark) + internal abstract class AsyncDeclarationsProviderBase(BenchmarkCase benchmark) : DeclarationsProvider(benchmark) { + // Type used to drive the WorkloadCore builder selection. For ordinary awaitables it's the workload + // method's own return type, but `IAsyncEnumerable` has no GetAwaiter, so AsyncEnumerableDeclarationsProvider + // overrides this to expose the MoveNextAsync awaitable as a proxy. + protected virtual Type WorkloadAwaitableReturnType => Descriptor.WorkloadMethod.ReturnType; + public override string[] GetExtraFields() => [ $"public {typeof(WorkloadValueTaskSource).GetCorrectCSharpTypeName()} workloadContinuerAndValueTaskSource;", @@ -193,87 +198,6 @@ protected override string GetExtraGlobalSetupImpl() protected override string GetExtraGlobalCleanupImpl() => "this.__fieldsContainer.workloadContinuerAndValueTaskSource.Complete();"; - protected override SmartStringBuilder ReplaceCore(SmartStringBuilder smartStringBuilder) - { - // Unlike sync calls, async calls suffer from unrolling, so we multiply the invokeCount by the unroll factor and delegate the implementation to *NoUnroll methods. - int unrollFactor = Benchmark.Job.ResolveValue(RunMode.UnrollFactorCharacteristic, EnvironmentResolver.Instance); - string passArguments = GetPassArgumentsDirect(); - string workloadMethodCall = GetWorkloadMethodCall(passArguments); - bool hasAsyncMethodBuilderAttribute = TryGetAsyncMethodBuilderAttribute(out var asyncMethodBuilderAttribute); - Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, Descriptor.WorkloadMethod.ReturnType); - string finalReturn = GetFinalReturn(workloadCoreReturnType); - string coreImpl = $$""" - private {{CoreReturnType}} OverheadActionUnroll({{CoreParameters}}) - { - return this.OverheadActionNoUnroll(invokeCount * {{unrollFactor}}, clock); - } - - private {{CoreReturnType}} OverheadActionNoUnroll({{CoreParameters}}) - { - {{StartClockSyncCode}} - while (--invokeCount >= 0) - { - this.__Overhead({{passArguments}}); - } - {{ReturnSyncCode}} - } - - private {{CoreReturnType}} WorkloadActionUnroll({{CoreParameters}}) - { - return this.WorkloadActionNoUnroll(invokeCount * {{unrollFactor}}, clock); - } - - private {{CoreReturnType}} WorkloadActionNoUnroll({{CoreParameters}}) - { - this.__fieldsContainer.invokeCount = invokeCount; - this.__fieldsContainer.clock = clock; - // The source is allocated and the workload loop started in __GlobalSetup, - // so this hot path is branchless and allocation-free. - return this.__fieldsContainer.workloadContinuerAndValueTaskSource.Continue(); - } - - private async void __StartWorkload() - { - await __WorkloadCore(); - } - - {{asyncMethodBuilderAttribute}} - private async {{workloadCoreReturnType.GetCorrectCSharpTypeName()}} __WorkloadCore() - { - try - { - if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.GetIsComplete()) - { - {{finalReturn}} - } - while (true) - { - {{typeof(StartedClock).GetCorrectCSharpTypeName()}} startedClock = {{typeof(ClockExtensions).GetCorrectCSharpTypeName()}}.Start(this.__fieldsContainer.clock); - while (--this.__fieldsContainer.invokeCount >= 0) - { - // Necessary because of error CS4004: Cannot await in an unsafe context - {{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} awaitable; - unsafe { awaitable = {{workloadMethodCall}} } - await awaitable; - } - if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed())) - { - {{finalReturn}} - } - } - } - catch (global::System.Exception e) - { - __fieldsContainer.workloadContinuerAndValueTaskSource.SetException(e); - {{finalReturn}} - } - } - """; - - return smartStringBuilder - .Replace("$CoreImpl$", coreImpl); - } - protected bool TryGetAsyncMethodBuilderAttribute(out string asyncMethodBuilderAttribute) { asyncMethodBuilderAttribute = string.Empty; @@ -322,32 +246,15 @@ protected static string GetFinalReturn(Type workloadCoreReturnType) ? "return;" : $"return default({finalReturnType.GetCorrectCSharpTypeName()});"; } - } - internal class AsyncEnumerableDeclarationsProvider(BenchmarkCase benchmark, Type itemType, Type enumeratorType, Type moveNextAwaitableType) : AsyncDeclarationsProvider(benchmark) - { protected override SmartStringBuilder ReplaceCore(SmartStringBuilder smartStringBuilder) { + // Unlike sync calls, async calls suffer from unrolling, so we multiply the invokeCount by the unroll factor and delegate the implementation to *NoUnroll methods. int unrollFactor = Benchmark.Job.ResolveValue(RunMode.UnrollFactorCharacteristic, EnvironmentResolver.Instance); string passArguments = GetPassArgumentsDirect(); string workloadMethodCall = GetWorkloadMethodCall(passArguments); - string itemTypeName = itemType.GetCorrectCSharpTypeName(); - string enumerableTypeName = Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName(); - string enumeratorTypeName = enumeratorType.GetCorrectCSharpTypeName(); - // We hand-roll the `await foreach` desugaring (explicit GetAsyncEnumerator + while-loop) instead - // of using the C# `await foreach` keyword: that keeps the IL byte-for-byte aligned with - // AsyncEnumerableCoreEmitter, which doesn't wrap the iteration in the try/catch + wrap field - // pattern Roslyn emits for the keyword form. - string disposeAsyncCall = ResolveDisposeAsync() is { } disposeAsyncMethod - ? $"await enumerator.{disposeAsyncMethod.Name}();" - : string.Empty; - // IAsyncEnumerable has no GetAwaiter, so its own return type can't drive the WorkloadCore - // builder. Use MoveNextAsync's return type as the proxy and feed it through the same resolver - // as the awaitable path: any result the awaitable produces (typically `bool`) is discarded by - // `__StartWorkload`'s `await`, and `[AsyncCallerType]` / `[AsyncMethodBuilder]` overrides on - // the workload method still apply. bool hasAsyncMethodBuilderAttribute = TryGetAsyncMethodBuilderAttribute(out var asyncMethodBuilderAttribute); - Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, moveNextAwaitableType); + Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, WorkloadAwaitableReturnType); string finalReturn = GetFinalReturn(workloadCoreReturnType); string coreImpl = $$""" private {{CoreReturnType}} OverheadActionUnroll({{CoreParameters}}) @@ -383,7 +290,7 @@ private async void __StartWorkload() { await __WorkloadCore(); } - + {{asyncMethodBuilderAttribute}} private async {{workloadCoreReturnType.GetCorrectCSharpTypeName()}} __WorkloadCore() { @@ -395,23 +302,12 @@ private async void __StartWorkload() } while (true) { - {{itemTypeName}} lastItem = default({{itemTypeName}}); {{typeof(StartedClock).GetCorrectCSharpTypeName()}} startedClock = {{typeof(ClockExtensions).GetCorrectCSharpTypeName()}}.Start(this.__fieldsContainer.clock); while (--this.__fieldsContainer.invokeCount >= 0) { - // Necessary because of error CS4004: Cannot await in an unsafe context - {{enumerableTypeName}} enumerable; - unsafe { enumerable = {{workloadMethodCall}} } - {{enumeratorTypeName}} enumerator = enumerable.GetAsyncEnumerator(); - while (await enumerator.MoveNextAsync()) - { - lastItem = enumerator.Current; - } - {{disposeAsyncCall}} + {{GetCallAndConsumeImpl(workloadMethodCall)}} } - {{typeof(ClockSpan).GetCorrectCSharpTypeName()}} elapsed = startedClock.GetElapsed(); - {{typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}}.KeepAliveWithoutBoxing(lastItem); - if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(elapsed)) + if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed())) { {{finalReturn}} } @@ -429,24 +325,51 @@ private async void __StartWorkload() .Replace("$CoreImpl$", coreImpl); } - // Roslyn's `await foreach` resolution: prefer a public instance DisposeAsync with all-optional - // params whose awaiter's GetResult returns void; otherwise fall back to IAsyncDisposable. - // Returns null if neither shape matches, in which case the template skips the dispose call. - private MethodInfo? ResolveDisposeAsync() + protected abstract string GetCallAndConsumeImpl(string workloadMethodCall); + } + + internal class AsyncDeclarationsProvider(BenchmarkCase benchmark, Type resultType) : AsyncDeclarationsProviderBase(benchmark) + { + protected override string GetCallAndConsumeImpl(string workloadMethodCall) { - var disposeAsyncMethod = enumeratorType - .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(m => m.Name == nameof(IAsyncDisposable.DisposeAsync) - && m.GetParameters().All(p => p.IsOptional) - && m.ReturnType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance) - ?.ReturnType - .GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance) - ?.ReturnType == typeof(void)); - if (disposeAsyncMethod is not null) - return disposeAsyncMethod; - if (typeof(IAsyncDisposable).IsAssignableFrom(enumeratorType)) - return typeof(IAsyncDisposable).GetMethod(nameof(IAsyncDisposable.DisposeAsync)); - return null; + string awaitStatement; + if (resultType == typeof(void)) + { + awaitStatement = "await awaitable;"; + } + else + { + var resultTypeName = resultType.GetCorrectCSharpTypeName(); + awaitStatement = $""" + {resultTypeName} result = await awaitable; + {typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}.KeepAliveWithoutBoxing<{resultTypeName}>(in result); + """; + } + return $$""" + // Necessary because of error CS4004: Cannot await in an unsafe context + {{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} awaitable; + unsafe { awaitable = {{workloadMethodCall}} } + {{awaitStatement}} + """; + } + } + + internal class AsyncEnumerableDeclarationsProvider(BenchmarkCase benchmark, Type itemType, Type moveNextAwaitableType) : AsyncDeclarationsProviderBase(benchmark) + { + protected override Type WorkloadAwaitableReturnType => moveNextAwaitableType; + + protected override string GetCallAndConsumeImpl(string workloadMethodCall) + { + string itemTypeName = itemType.GetCorrectCSharpTypeName(); + return $$""" + // Necessary because of error CS4004: Cannot await in an unsafe context + {{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} enumerable; + unsafe { enumerable = {{workloadMethodCall}} } + await foreach ({{itemTypeName}} item in enumerable) + { + {{typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}}.KeepAliveWithoutBoxing<{{itemTypeName}}>(in item); + } + """; } } } \ No newline at end of file diff --git a/src/BenchmarkDotNet/Engines/Consumer.cs b/src/BenchmarkDotNet/Engines/Consumer.cs index 03aca7af83..bf3b3bbfb5 100644 --- a/src/BenchmarkDotNet/Engines/Consumer.cs +++ b/src/BenchmarkDotNet/Engines/Consumer.cs @@ -146,7 +146,7 @@ public void Consume(in T value) else if (default(T) == null && !typeof(T).IsValueType) Consume((object?)value); else - DeadCodeEliminationHelper.KeepAliveWithoutBoxingReadonly(value); // non-primitive and nullable value types + DeadCodeEliminationHelper.KeepAliveWithoutBoxing(in value); // non-primitive and nullable value types } } } \ No newline at end of file diff --git a/src/BenchmarkDotNet/Engines/DeadCodeEliminationHelper.cs b/src/BenchmarkDotNet/Engines/DeadCodeEliminationHelper.cs index 28bdee25d2..0288894851 100644 --- a/src/BenchmarkDotNet/Engines/DeadCodeEliminationHelper.cs +++ b/src/BenchmarkDotNet/Engines/DeadCodeEliminationHelper.cs @@ -1,33 +1,30 @@ using JetBrains.Annotations; using System.Runtime.CompilerServices; -namespace BenchmarkDotNet.Engines -{ - public static class DeadCodeEliminationHelper - { - /// - /// This method can't get inlined, so any value send to it - /// will not get eliminated by the dead code elimination - /// - [MethodImpl(MethodImplOptions.NoInlining)] - [UsedImplicitly] // Used in generated benchmarks - public static void KeepAliveWithoutBoxing(T value) { } +namespace BenchmarkDotNet.Engines; - /// - /// This method can't get inlined, so any value send to it - /// will not get eliminated by the dead code elimination - /// - [MethodImpl(MethodImplOptions.NoInlining)] - [UsedImplicitly] // Used in generated benchmarks - public static void KeepAliveWithoutBoxing(ref T value) { } +[UsedImplicitly] +public static class DeadCodeEliminationHelper +{ + /// + /// This method can't get inlined, so any value send to it + /// will not get eliminated by the dead code elimination + /// + [MethodImpl(MethodImplOptions.NoInlining)] + public static void KeepAliveWithoutBoxing(T value) +#if NET9_0_OR_GREATER + where T : allows ref struct +#endif + { } - /// - /// This method can't get inlined, so any value send to it - /// will not get eliminated by the dead code elimination - /// it's not called KeepAliveWithoutBoxing because compiler would not be able to diff `ref` and `in` - /// - [MethodImpl(MethodImplOptions.NoInlining)] - [UsedImplicitly] // Used in generated benchmarks - public static void KeepAliveWithoutBoxingReadonly(in T value) { } - } + /// + /// This method can't get inlined, so any value send to it + /// will not get eliminated by the dead code elimination + /// + [MethodImpl(MethodImplOptions.NoInlining)] + public static void KeepAliveWithoutBoxing(in T value) +#if NET9_0_OR_GREATER + where T : allows ref struct +#endif + { } } \ No newline at end of file diff --git a/src/BenchmarkDotNet/Extensions/ReflectionExtensions.cs b/src/BenchmarkDotNet/Extensions/ReflectionExtensions.cs index 7386cef78b..4f2eaf0dc4 100644 --- a/src/BenchmarkDotNet/Extensions/ReflectionExtensions.cs +++ b/src/BenchmarkDotNet/Extensions/ReflectionExtensions.cs @@ -252,86 +252,102 @@ internal static bool IsByRefLike(this Type type) => type.IsByRefLike; #endif - internal static bool IsAwaitable(this Type type) + internal static bool IsAwaitable(this Type type, [NotNullWhen(true)] out AwaitableInfo? info) { // This does not handle await extension. - var awaiterType = type + var getAwaiterMethod = type .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(m => m.Name == nameof(Task.GetAwaiter) && m.GetParameters().Length == 0) - ?.ReturnType; - if (awaiterType is null) + .FirstOrDefault(m => m.Name == nameof(Task.GetAwaiter) && m.GetParameters().Length == 0); + if (getAwaiterMethod is null) { + info = null; return false; } - if (!awaiterType + var awaiterType = getAwaiterMethod.ReturnType; + var getResultMethod = awaiterType .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .Any(m => m.Name == nameof(TaskAwaiter.GetResult) && m.GetParameters().Length == 0)) + .FirstOrDefault(m => m.Name == nameof(TaskAwaiter.GetResult) && m.GetParameters().Length == 0); + var isCompletedProperty = awaiterType.GetProperty(nameof(TaskAwaiter.IsCompleted), BindingFlags.Public | BindingFlags.Instance); + if (getResultMethod is null + || isCompletedProperty?.PropertyType != typeof(bool) + || !awaiterType.GetInterfaces().Any(t => typeof(INotifyCompletion).IsAssignableFrom(t))) { + info = null; return false; } - if (awaiterType.GetProperty(nameof(TaskAwaiter.IsCompleted), BindingFlags.Public | BindingFlags.Instance)?.PropertyType != typeof(bool)) - { - return false; - } - return awaiterType.GetInterfaces().Any(type => typeof(INotifyCompletion).IsAssignableFrom(type)); + info = new AwaitableInfo(awaiterType, getAwaiterMethod, getResultMethod, isCompletedProperty, getResultMethod.ReturnType); + return true; } - internal static bool IsAsyncEnumerable(this Type type, [NotNullWhen(true)] out Type? itemType, [NotNullWhen(true)] out Type? enumeratorType, [NotNullWhen(true)] out Type? moveNextAwaitableType) + internal static bool IsAsyncEnumerable(this Type type, [NotNullWhen(true)] out AsyncEnumerableInfo? info) { - // 1. If the type IS exactly IAsyncEnumerable, the interface's T is what `await foreach` will see. - // Also short-circuits the case where there is no public pattern method to find anyway - // (the interface declares it but on the type itself it's only accessible via interface dispatch). - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)) - { - itemType = type.GetGenericArguments()[0]; - enumeratorType = typeof(IAsyncEnumerator<>).MakeGenericType(itemType); - moveNextAwaitableType = typeof(ValueTask); - return true; - } - // 2. Otherwise mirror the C# `await foreach` resolution: pattern-based lookup wins over the - // interface, so try a public instance GetAsyncEnumerator (with all parameters optional, e.g. - // `CancellationToken` defaulting to default) whose return type has a public instance - // MoveNextAsync awaitable-to-bool (also accepting all-optional params) and a public instance - // Current property. The element type comes from Current so it tracks what the compiler binds - // to, even if the type also explicitly implements IAsyncEnumerable for a different U. - // This does not handle extension GetAsyncEnumerator. - var getAsyncEnumeratorMethod = type + // 1. Pattern first: a public instance GetAsyncEnumerator with all-optional parameters whose + // return type has a public instance MoveNextAsync awaitable-to-bool (also accepting + // all-optional params) and a public instance Current property. Roslyn's `await foreach` + // binds to this in preference to the interface, so we mirror that order. The element type + // comes from Current so it tracks what the compiler binds to, even if the type also + // implements IAsyncEnumerable for a different U. (Extension GetAsyncEnumerator is not + // handled.) + // + // Note: when the type IS exactly IAsyncEnumerable, `GetMethods(Public|Instance)` returns + // the interface's own GetAsyncEnumerator, so this branch also handles that case naturally — + // we just flag it as interface dispatch via the conditional below. + var patternGetAsyncEnumerator = type .GetMethods(BindingFlags.Public | BindingFlags.Instance) .FirstOrDefault(m => m.Name == nameof(IAsyncEnumerable<>.GetAsyncEnumerator) && m.GetParameters().All(p => p.IsOptional)); - if (getAsyncEnumeratorMethod is not null) + if (patternGetAsyncEnumerator is not null) { - var patternEnumeratorType = getAsyncEnumeratorMethod.ReturnType; + var patternEnumeratorType = patternGetAsyncEnumerator.ReturnType; var moveNextAsyncMethod = patternEnumeratorType .GetMethods(BindingFlags.Public | BindingFlags.Instance) .FirstOrDefault(m => m.Name == nameof(IAsyncEnumerator<>.MoveNextAsync) && m.GetParameters().All(p => p.IsOptional)); - if (moveNextAsyncMethod?.ReturnType.IsAwaitable() == true - && moveNextAsyncMethod.ReturnType - .GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!.ReturnType - .GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)?.ReturnType == typeof(bool) - && patternEnumeratorType.GetProperty(nameof(IAsyncEnumerator<>.Current), BindingFlags.Public | BindingFlags.Instance)?.GetMethod is { } currentGetter) + if (moveNextAsyncMethod?.ReturnType.IsAwaitable(out var moveNextAwaitable) == true + && moveNextAwaitable.ResultType == typeof(bool) + && patternEnumeratorType.GetProperty(nameof(IAsyncEnumerator<>.Current), BindingFlags.Public | BindingFlags.Instance) is { } currentProperty) { - itemType = currentGetter.ReturnType; - enumeratorType = patternEnumeratorType; - moveNextAwaitableType = moveNextAsyncMethod.ReturnType; + info = new AsyncEnumerableInfo( + currentProperty.PropertyType, + patternEnumeratorType, + patternGetAsyncEnumerator, + moveNextAsyncMethod, + moveNextAwaitable, + currentProperty, + IsInterfaceDispatch: type.IsInterface); return true; } + // A public pattern `GetAsyncEnumerator` was found but its return type doesn't satisfy + // the await-foreach enumerator pattern. Roslyn commits to the pattern method once it's + // found and reports an error rather than silently falling back to `IAsyncEnumerable`, + // so we reject here as well — even if the source also implements the interface. + info = null; + return false; } - // 3. Last fallback: the type only implements IAsyncEnumerable via the interface (typically - // an explicit interface implementation), with no matching public pattern method. + // 2. Fallback: no pattern method on the source — bind via the `IAsyncEnumerable` interface + // if the source implements it (typically an explicit interface implementation). foreach (var iface in type.GetInterfaces()) { if (iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)) { - itemType = iface.GetGenericArguments()[0]; - enumeratorType = typeof(IAsyncEnumerator<>).MakeGenericType(itemType); - moveNextAwaitableType = typeof(ValueTask); + var ifaceItemType = iface.GetGenericArguments()[0]; + var ifaceEnumeratorType = typeof(IAsyncEnumerator<>).MakeGenericType(ifaceItemType); + var ifaceMoveNextAsync = ifaceEnumeratorType.GetMethod(nameof(IAsyncEnumerator<>.MoveNextAsync))!; + // `MoveNextAsync` on `IAsyncEnumerator` returns `ValueTask` which always + // satisfies the awaitable shape; pull the resolved `AwaitableInfo` from IsAwaitable + // rather than constructing it by hand. + ifaceMoveNextAsync.ReturnType.IsAwaitable(out var ifaceMoveNextAwaitable); + info = new AsyncEnumerableInfo( + ifaceItemType, + ifaceEnumeratorType, + iface.GetMethod(nameof(IAsyncEnumerable<>.GetAsyncEnumerator))!, + ifaceMoveNextAsync, + ifaceMoveNextAwaitable!, + ifaceEnumeratorType.GetProperty(nameof(IAsyncEnumerator<>.Current))!, + IsInterfaceDispatch: true); return true; } } - itemType = null; - enumeratorType = null; - moveNextAwaitableType = null; + info = null; return false; } @@ -342,4 +358,32 @@ internal static bool IsAsyncEnumerable(this Type type, [NotNullWhen(true)] out T internal static bool HasAsyncMethodBuilderAttribute(this MemberInfo memberInfo) => memberInfo.GetAsyncMethodBuilderAttribute() != null; } + + /// + /// Everything resolves while binding the + /// awaitable pattern — bundled so callers (emitter, codegen, validators) reuse the same lookups + /// instead of repeating the GetAwaiter/GetResult/IsCompleted reflection. + /// + internal sealed record AwaitableInfo( + Type AwaiterType, + MethodInfo GetAwaiterMethod, + MethodInfo GetResultMethod, + PropertyInfo IsCompletedProperty, + Type ResultType); + + /// + /// Everything resolves while binding the await-foreach + /// pattern — bundled so callers (emitter, codegen, validators) reuse the same lookups instead of + /// repeating the pattern-vs-interface discrimination and the Current-property search. DisposeAsync + /// is only needed by the emitter, so its resolution lives there to keep validator/codegen paths + /// from paying for a lookup they don't use. + /// + internal sealed record AsyncEnumerableInfo( + Type ItemType, + Type EnumeratorType, + MethodInfo GetAsyncEnumeratorMethod, + MethodInfo MoveNextAsyncMethod, + AwaitableInfo MoveNextAwaitable, + PropertyInfo CurrentProperty, + bool IsInterfaceDispatch); } \ No newline at end of file diff --git a/src/BenchmarkDotNet/Helpers/DynamicAwaitHelper.cs b/src/BenchmarkDotNet/Helpers/DynamicAwaitHelper.cs index 5162b8f17e..0523f3ab6e 100644 --- a/src/BenchmarkDotNet/Helpers/DynamicAwaitHelper.cs +++ b/src/BenchmarkDotNet/Helpers/DynamicAwaitHelper.cs @@ -1,3 +1,4 @@ +using BenchmarkDotNet.Extensions; using System.Reflection; using System.Runtime.CompilerServices; @@ -5,66 +6,62 @@ namespace BenchmarkDotNet.Helpers; internal static class DynamicAwaitHelper { - internal static async ValueTask<(bool hasResult, object? result)> AwaitResult(object value, Type declaredType) + internal static async ValueTask<(bool hasResult, object? result)> AwaitResult(object value, AwaitableInfo awaitableInfo) { - var getAwaiterMethod = declaredType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!; - var awaiterType = getAwaiterMethod.ReturnType; - var getResultMethod = awaiterType.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)!; - var result = await new DynamicAwaitable(getAwaiterMethod, awaiterType, getResultMethod, value); - return (getResultMethod.ReturnType != typeof(void), result); + var result = await new DynamicAwaitable(awaitableInfo, value); + return (awaitableInfo.ResultType != typeof(void), result); } - internal static ValueTask DrainAsyncEnumerableAsync(object asyncEnumerable, Type declaredType) - => EnumerateCoreAsync(asyncEnumerable, declaredType, items: null); + internal static ValueTask DrainAsyncEnumerableAsync(object asyncEnumerable, AsyncEnumerableInfo enumerableInfo) + => EnumerateCoreAsync(asyncEnumerable, enumerableInfo, items: null); - internal static async ValueTask> ToListAsync(object asyncEnumerable, Type declaredType) + internal static async ValueTask> ToListAsync(object asyncEnumerable, AsyncEnumerableInfo enumerableInfo) { List items = []; - await EnumerateCoreAsync(asyncEnumerable, declaredType, items).ConfigureAwait(false); + await EnumerateCoreAsync(asyncEnumerable, enumerableInfo, items).ConfigureAwait(false); return items; } - private static async ValueTask EnumerateCoreAsync(object asyncEnumerable, Type declaredType, List? items) + private static async ValueTask EnumerateCoreAsync(object asyncEnumerable, AsyncEnumerableInfo enumerableInfo, List? items) { - var (getAsyncEnumeratorMethod, getAsyncEnumeratorArgs) = ResolveGetAsyncEnumerator(declaredType); - var enumerator = getAsyncEnumeratorMethod.Invoke(asyncEnumerable, getAsyncEnumeratorArgs)!; - - // Look up enumerator members via GetAsyncEnumerator's declared return type rather than the runtime - // type. For the interface path that's IAsyncEnumerator, whose interface methods dispatch virtually - // — important for compiler-generated async iterator state machines that implement MoveNextAsync / - // Current as explicit interface members and so don't surface them as public instance members on the - // runtime type. For the pattern path it's the concrete enumerator type with public members. - var enumeratorMemberType = getAsyncEnumeratorMethod.ReturnType; - - var moveNextAsyncMethod = enumeratorMemberType - .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(m => m.Name == nameof(IAsyncEnumerator<>.MoveNextAsync) && m.GetParameters().All(p => p.IsOptional)) - ?? throw new InvalidOperationException($"Type {enumeratorMemberType} does not expose a MoveNextAsync method."); - var moveNextAsyncArgs = GetDefaultArgs(moveNextAsyncMethod); - var currentProperty = enumeratorMemberType.GetProperty(nameof(IAsyncEnumerator<>.Current), BindingFlags.Public | BindingFlags.Instance) - ?? throw new InvalidOperationException($"Type {enumeratorMemberType} does not expose a Current property."); - // DisposeAsync is optional for the pattern. Prefer a public instance method on the declared enumerator - // type with all-optional parameters whose awaiter's GetResult returns void; otherwise fall back to the - // IAsyncDisposable interface implementation. - var disposeAsyncMethod = enumeratorMemberType - .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(m => m.Name == nameof(IAsyncDisposable.DisposeAsync) - && m.GetParameters().All(p => p.IsOptional) - && m.ReturnType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance) - ?.ReturnType - .GetMethod(nameof(System.Runtime.CompilerServices.TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance) - ?.ReturnType == typeof(void)) - ?? (typeof(IAsyncDisposable).IsAssignableFrom(enumeratorMemberType) - ? typeof(IAsyncDisposable).GetMethod(nameof(IAsyncDisposable.DisposeAsync)) - : null); + var getAsyncEnumeratorArgs = GetDefaultArgs(enumerableInfo.GetAsyncEnumeratorMethod); + var enumerator = enumerableInfo.GetAsyncEnumeratorMethod.Invoke(asyncEnumerable, getAsyncEnumeratorArgs)!; + + var moveNextAsyncArgs = GetDefaultArgs(enumerableInfo.MoveNextAsyncMethod); + var currentProperty = enumerableInfo.CurrentProperty; + var moveNextAwaitable = enumerableInfo.MoveNextAwaitable; + + // DisposeAsync is optional for the await-foreach pattern. Roslyn matches a public instance + // method named DisposeAsync whose parameters are all optional and whose return type satisfies + // the awaitable pattern with a void GetResult; otherwise it falls back to the IAsyncDisposable + // interface dispatch. + MethodInfo? disposeAsyncMethod = null; + AwaitableInfo? disposeAwaitableInfo = null; + foreach (var candidate in enumerableInfo.EnumeratorType.GetMethods(BindingFlags.Public | BindingFlags.Instance)) + { + if (candidate.Name == nameof(IAsyncDisposable.DisposeAsync) + && candidate.GetParameters().All(p => p.IsOptional) + && candidate.ReturnType.IsAwaitable(out var awaitable) + && awaitable.ResultType == typeof(void)) + { + disposeAsyncMethod = candidate; + disposeAwaitableInfo = awaitable; + break; + } + } + if (disposeAsyncMethod is null && typeof(IAsyncDisposable).IsAssignableFrom(enumerableInfo.EnumeratorType)) + { + disposeAsyncMethod = typeof(IAsyncDisposable).GetMethod(nameof(IAsyncDisposable.DisposeAsync))!; + disposeAsyncMethod.ReturnType.IsAwaitable(out disposeAwaitableInfo); + } var disposeAsyncArgs = disposeAsyncMethod is null ? null : GetDefaultArgs(disposeAsyncMethod); try { while (true) { - var moveNextResult = moveNextAsyncMethod.Invoke(enumerator, moveNextAsyncArgs); - bool hasMore = (bool)(await AwaitDynamicAsync(moveNextAsyncMethod.ReturnType, moveNextResult!).ConfigureAwait(false))!; + var moveNextResult = enumerableInfo.MoveNextAsyncMethod.Invoke(enumerator, moveNextAsyncArgs); + bool hasMore = (bool)(await new DynamicAwaitable(moveNextAwaitable, moveNextResult!))!; if (!hasMore) { break; @@ -79,38 +76,12 @@ private static async ValueTask EnumerateCoreAsync(object asyncEnumerable, Type d var disposeResult = disposeAsyncMethod.Invoke(enumerator, disposeAsyncArgs); if (disposeResult != null) { - await AwaitDynamicAsync(disposeAsyncMethod.ReturnType, disposeResult).ConfigureAwait(false); + await new DynamicAwaitable(disposeAwaitableInfo!, disposeResult); } } } } - private static (MethodInfo method, object?[] args) ResolveGetAsyncEnumerator(Type enumerableType) - { - // Mirror IsAsyncEnumerable's precedence: exact IAsyncEnumerable, then pattern, then interface fallback. - if (enumerableType.IsGenericType && enumerableType.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)) - { - var method = enumerableType.GetMethod(nameof(IAsyncEnumerable<>.GetAsyncEnumerator))!; - return (method, [CancellationToken.None]); - } - var pattern = enumerableType - .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(m => m.Name == nameof(IAsyncEnumerable<>.GetAsyncEnumerator) && m.GetParameters().All(p => p.IsOptional)); - if (pattern != null) - { - return (pattern, GetDefaultArgs(pattern)); - } - foreach (var iface in enumerableType.GetInterfaces()) - { - if (iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)) - { - var method = iface.GetMethod(nameof(IAsyncEnumerable<>.GetAsyncEnumerator))!; - return (method, [CancellationToken.None]); - } - } - throw new InvalidOperationException($"Type {enumerableType} is not an async enumerable."); - } - private static object?[] GetDefaultArgs(MethodInfo method) { var parameters = method.GetParameters(); @@ -126,27 +97,19 @@ private static (MethodInfo method, object?[] args) ResolveGetAsyncEnumerator(Typ return args; } - private static async ValueTask AwaitDynamicAsync(Type awaitableType, object awaitable) - { - var getAwaiterMethod = awaitableType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!; - var awaiterType = getAwaiterMethod.ReturnType; - var getResultMethod = awaiterType.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)!; - return await new DynamicAwaitable(getAwaiterMethod, awaiterType, getResultMethod, awaitable); - } - - private readonly struct DynamicAwaitable(MethodInfo getAwaiterMethod, Type awaiterType, MethodInfo getResultMethod, object awaitable) + private readonly struct DynamicAwaitable(AwaitableInfo awaitableInfo, object awaitable) { public DynamicAwaiter GetAwaiter() - => new(awaiterType, getResultMethod, getAwaiterMethod.Invoke(awaitable, null)); + => new(awaitableInfo, awaitableInfo.GetAwaiterMethod.Invoke(awaitable, null)); } - private readonly struct DynamicAwaiter(Type awaiterType, MethodInfo getResultMethod, object? awaiter) : ICriticalNotifyCompletion + private readonly struct DynamicAwaiter(AwaitableInfo awaitableInfo, object? awaiter) : ICriticalNotifyCompletion { public bool IsCompleted - => awaiterType.GetProperty(nameof(TaskAwaiter.IsCompleted), BindingFlags.Public | BindingFlags.Instance)!.GetMethod!.Invoke(awaiter, null) is true; + => awaitableInfo.IsCompletedProperty.GetMethod!.Invoke(awaiter, null) is true; public object? GetResult() - => getResultMethod.Invoke(awaiter, null); + => awaitableInfo.GetResultMethod.Invoke(awaiter, null); public void OnCompleted(Action continuation) => OnCompletedCore(typeof(INotifyCompletion), nameof(INotifyCompletion.OnCompleted), continuation); @@ -157,7 +120,7 @@ public void UnsafeOnCompleted(Action continuation) private void OnCompletedCore(Type interfaceType, string methodName, Action continuation) { var onCompletedMethod = interfaceType.GetMethod(methodName); - var map = awaiterType.GetInterfaceMap(interfaceType); + var map = awaitableInfo.AwaiterType.GetInterfaceMap(interfaceType); for (int i = 0; i < map.InterfaceMethods.Length; i++) { diff --git a/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/AsyncCoreEmitter.cs b/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/AsyncCoreEmitter.cs index cb3cf9a28a..55f05d5308 100644 --- a/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/AsyncCoreEmitter.cs +++ b/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/AsyncCoreEmitter.cs @@ -1,4 +1,5 @@ using BenchmarkDotNet.Engines; +using BenchmarkDotNet.Extensions; using BenchmarkDotNet.Helpers.Reflection.Emit; using BenchmarkDotNet.Running; using Perfolizer.Horology; @@ -12,7 +13,7 @@ namespace BenchmarkDotNet.Toolchains.InProcess.Emit.Implementation; partial class RunnableEmitter { // TODO: update this to support runtime-async. - private sealed class AsyncCoreEmitter(BuildPartition buildPartition, ModuleBuilder moduleBuilder, BenchmarkBuildInfo benchmark) : AsyncCoreEmitterBase(buildPartition, moduleBuilder, benchmark) + private sealed class AsyncCoreEmitter(BuildPartition buildPartition, ModuleBuilder moduleBuilder, BenchmarkBuildInfo benchmark, AwaitableInfo awaitableInfo) : AsyncCoreEmitterBase(buildPartition, moduleBuilder, benchmark) { protected override void EmitWorkloadCore() { @@ -32,9 +33,7 @@ protected override void EmitWorkloadCore() ); var benchmarkAwaiterField = asyncStateMachineTypeBuilder.DefineField( "<>u__2", - Descriptor.WorkloadMethod.ReturnType - .GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)! - .ReturnType, + awaitableInfo.AwaiterType, FieldAttributes.Private ); EmitMoveNextImpl(); @@ -46,11 +45,21 @@ protected override void EmitWorkloadCore() void EmitMoveNextImpl() { + var resultType = awaitableInfo.ResultType; var isCompleteAwaiterLocal = ilBuilder.DeclareLocal(typeof(ValueTaskAwaiter)); var isCompleteAwaitableLocal = ilBuilder.DeclareLocal(typeof(ValueTask)); + // The value-type awaitable spill local (Roslyn declares one for ValueTask et al. so it + // can take its address for GetAwaiter — reference-type awaitables stay on the stack). var benchmarkAwaitableLocal = Descriptor.WorkloadMethod.ReturnType.IsValueType ? ilBuilder.DeclareLocal(Descriptor.WorkloadMethod.ReturnType) : null; + // Source local for `T result = await awaitable;` — declared only when the awaiter's + // GetResult returns non-void, in which case the template captures the value and pipes + // it through DeadCodeEliminationHelper so the JIT can't elide the producer's work. + // Roslyn places this AFTER the (optional) awaitable spill and BEFORE the awaiter temp. + var resultLocal = resultType == typeof(void) + ? null + : ilBuilder.DeclareLocal(resultType); var benchmarkAwaiterLocal = ilBuilder.DeclareLocal(benchmarkAwaiterField.FieldType); var invokeCountLocal = ilBuilder.DeclareLocal(typeof(long)); var exceptionLocal = ilBuilder.DeclareLocal(typeof(Exception)); @@ -162,18 +171,18 @@ void EmitMoveNextImpl() ilBuilder.Emit(OpCodes.Call, Descriptor.WorkloadMethod); if (benchmarkAwaitableLocal == null) { - ilBuilder.Emit(OpCodes.Callvirt, Descriptor.WorkloadMethod.ReturnType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!); + ilBuilder.Emit(OpCodes.Callvirt, awaitableInfo.GetAwaiterMethod); } else { ilBuilder.EmitStloc(benchmarkAwaitableLocal); ilBuilder.EmitLdloca(benchmarkAwaitableLocal); - ilBuilder.Emit(OpCodes.Call, Descriptor.WorkloadMethod.ReturnType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!); + ilBuilder.Emit(OpCodes.Call, awaitableInfo.GetAwaiterMethod); } // if (awaiter.IsCompleted) goto benchmarkContinuationGetResultLabel; ilBuilder.EmitStloc(benchmarkAwaiterLocal); ilBuilder.EmitLdloca(benchmarkAwaiterLocal); - ilBuilder.Emit(OpCodes.Call, benchmarkAwaiterField.FieldType.GetProperty(nameof(TaskAwaiter.IsCompleted), BindingFlags.Public | BindingFlags.Instance)!.GetMethod!); + ilBuilder.Emit(OpCodes.Call, awaitableInfo.IsCompletedProperty.GetMethod!); ilBuilder.Emit(OpCodes.Brtrue, benchmarkContinuationGetResultLabel); // state = 1; <>u__2 = awaiter; ilBuilder.Emit(OpCodes.Ldarg_0); @@ -212,11 +221,21 @@ void EmitMoveNextImpl() // --- Benchmark GetResult --- ilBuilder.MarkLabel(benchmarkContinuationGetResultLabel); ilBuilder.EmitLdloca(benchmarkAwaiterLocal); - var benchmarkAwaiterGetResultMethod = benchmarkAwaiterField.FieldType.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)!; - ilBuilder.Emit(OpCodes.Call, benchmarkAwaiterGetResultMethod); - if (benchmarkAwaiterGetResultMethod.ReturnType != typeof(void)) + ilBuilder.Emit(OpCodes.Call, awaitableInfo.GetResultMethod); + if (resultLocal is not null) { - ilBuilder.Emit(OpCodes.Pop); + // Mirror the template's `T result = await awaitable; KeepAliveWithoutBoxing(in result);`: + // store the GetResult value into a local and pass it to the non-inlined sink so the + // JIT can't elide whatever produced it. + ilBuilder.EmitStloc(resultLocal); + ilBuilder.EmitLdloca(resultLocal); + var keepAliveInMethod = typeof(DeadCodeEliminationHelper) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .First(m => m.Name == nameof(DeadCodeEliminationHelper.KeepAliveWithoutBoxing) + && m.GetParameters().Length == 1 + && m.GetParameters()[0].ParameterType.IsByRef) + .MakeGenericMethod(resultType); + ilBuilder.Emit(OpCodes.Call, keepAliveInMethod); } // --- Benchmark loop: if (--invokeCount >= 0) goto callBenchmarkLabel --- diff --git a/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/AsyncEnumerableCoreEmitter.cs b/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/AsyncEnumerableCoreEmitter.cs index 7819310904..b11c5cc198 100644 --- a/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/AsyncEnumerableCoreEmitter.cs +++ b/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/AsyncEnumerableCoreEmitter.cs @@ -14,15 +14,36 @@ partial class RunnableEmitter { private sealed class AsyncEnumerableCoreEmitter : AsyncCoreEmitterBase { - private readonly Type itemType; - private readonly EnumerableInfo enumerableInfo; + private readonly AsyncEnumerableInfo enumerableInfo; + private readonly MethodInfo? disposeAsyncMethod; + private readonly AwaitableInfo? disposeAwaitableInfo; - public AsyncEnumerableCoreEmitter(BuildPartition buildPartition, ModuleBuilder moduleBuilder, BenchmarkBuildInfo benchmark) + public AsyncEnumerableCoreEmitter(BuildPartition buildPartition, ModuleBuilder moduleBuilder, BenchmarkBuildInfo benchmark, AsyncEnumerableInfo enumerableInfo) : base(buildPartition, moduleBuilder, benchmark) { - var workloadReturnType = benchmark.BenchmarkCase.Descriptor.WorkloadMethod.ReturnType; - enumerableInfo = ResolveEnumerableInfo(workloadReturnType); - itemType = enumerableInfo.ItemType; + this.enumerableInfo = enumerableInfo; + + // DisposeAsync is optional for the await-foreach pattern. Roslyn matches a public instance + // method named DisposeAsync whose parameters are all optional and whose return type satisfies + // the awaitable pattern with a void GetResult; otherwise it falls back to the IAsyncDisposable + // interface dispatch. + foreach (var candidate in enumerableInfo.EnumeratorType.GetMethods(BindingFlags.Public | BindingFlags.Instance)) + { + if (candidate.Name == nameof(IAsyncDisposable.DisposeAsync) + && candidate.GetParameters().All(p => p.IsOptional) + && candidate.ReturnType.IsAwaitable(out var awaitable) + && awaitable.ResultType == typeof(void)) + { + disposeAsyncMethod = candidate; + disposeAwaitableInfo = awaitable; + break; + } + } + if (disposeAsyncMethod is null && typeof(IAsyncDisposable).IsAssignableFrom(enumerableInfo.EnumeratorType)) + { + disposeAsyncMethod = typeof(IAsyncDisposable).GetMethod(nameof(IAsyncDisposable.DisposeAsync))!; + disposeAsyncMethod.ReturnType.IsAwaitable(out disposeAwaitableInfo); + } } protected override void EmitWorkloadCore() @@ -35,26 +56,33 @@ protected override void EmitWorkloadCore() // Field declaration order matches Roslyn (and `<>u__N` numbering follows declaration order): // 1) <>u__1 — first awaiter type to appear in the IL (always ValueTaskAwaiter here, used // by GetIsComplete and SetResultAndGetIsComplete); - // 2) hoisted user locals in source declaration order (lastItem, startedClock); - // 3) the synthetic enumerator local (named in the template so it's 5__N, not - // <>7__wrap{N}); - // 4) <>u__N for any subsequent awaiter type that doesn't already match an existing field — + // 2) hoisted user locals in source declaration order (startedClock); + // 3) <>7__wrap2 — synthesized hoisted enumerator (Roslyn names it `<>7__wrap2` because the + // `await foreach` lowering captures the GetAsyncEnumerator result as an unnamed wrap); + // 4) when DisposeAsync exists, <>7__wrap3 (captured catch object) and <>7__wrap4 (unused state + // discriminator that Roslyn emits as part of its try/finally lowering); + // 5) <>u__N for any subsequent awaiter type that doesn't already match an existing field — // Roslyn dedupes by type, so we reuse <>u__1 when MoveNextAsync also returns ValueTaskAwaiter. var workloadContinuerAwaiterField = asyncStateMachineTypeBuilder.DefineField("<>u__1", typeof(ValueTaskAwaiter), FieldAttributes.Private); - var lastItemField = asyncStateMachineTypeBuilder.DefineField("5__2", itemType, FieldAttributes.Private); - var startedClockField = asyncStateMachineTypeBuilder.DefineField("5__3", typeof(StartedClock), FieldAttributes.Private); - var enumeratorField = asyncStateMachineTypeBuilder.DefineField("5__4", enumerableInfo.EnumeratorType, FieldAttributes.Private); + var startedClockField = asyncStateMachineTypeBuilder.DefineField("5__2", typeof(StartedClock), FieldAttributes.Private); + var enumeratorField = asyncStateMachineTypeBuilder.DefineField("<>7__wrap2", enumerableInfo.EnumeratorType, FieldAttributes.Private); + var capturedExceptionField = disposeAsyncMethod is null + ? null + : asyncStateMachineTypeBuilder.DefineField("<>7__wrap3", typeof(object), FieldAttributes.Private); + var disposeDiscriminatorField = disposeAsyncMethod is null + ? null + : asyncStateMachineTypeBuilder.DefineField("<>7__wrap4", typeof(int), FieldAttributes.Private); int nextAwaiterOrdinal = 2; - var moveNextAwaiterField = enumerableInfo.MoveNextAwaiterType == workloadContinuerAwaiterField.FieldType + var moveNextAwaiterField = enumerableInfo.MoveNextAwaitable.AwaiterType == workloadContinuerAwaiterField.FieldType ? workloadContinuerAwaiterField - : asyncStateMachineTypeBuilder.DefineField($"<>u__{nextAwaiterOrdinal++}", enumerableInfo.MoveNextAwaiterType, FieldAttributes.Private); - var disposeAwaiterField = enumerableInfo.DisposeAsyncMethod is null + : asyncStateMachineTypeBuilder.DefineField($"<>u__{nextAwaiterOrdinal++}", enumerableInfo.MoveNextAwaitable.AwaiterType, FieldAttributes.Private); + var disposeAwaiterField = disposeAsyncMethod is null ? null - : enumerableInfo.DisposeAwaiterType == workloadContinuerAwaiterField.FieldType + : disposeAwaitableInfo!.AwaiterType == workloadContinuerAwaiterField.FieldType ? workloadContinuerAwaiterField - : enumerableInfo.DisposeAwaiterType == moveNextAwaiterField.FieldType + : disposeAwaitableInfo!.AwaiterType == moveNextAwaiterField.FieldType ? moveNextAwaiterField - : asyncStateMachineTypeBuilder.DefineField($"<>u__{nextAwaiterOrdinal++}", enumerableInfo.DisposeAwaiterType!, FieldAttributes.Private); + : asyncStateMachineTypeBuilder.DefineField($"<>u__{nextAwaiterOrdinal++}", disposeAwaitableInfo!.AwaiterType!, FieldAttributes.Private); EmitMoveNextImpl(); var asyncStateMachineType = CompleteAsyncStateMachineType(asyncMethodBuilderType, builderInfo); @@ -73,31 +101,36 @@ void EmitMoveNextImpl() const int StateGetIsComplete = 0; const int StateMoveNextAsync = 1; int StateDisposeAsync = 2; - int StateSetResult = enumerableInfo.DisposeAsyncMethod is null ? 2 : 3; + int StateSetResult = disposeAsyncMethod is null ? 2 : 3; // Local declaration order matches Roslyn so the EmitsSameIL var-by-var diff lines up: // 1) awaiter then awaitable temps for the first await (Roslyn declares awaiter before // awaitable for synthetic `await expr` with no named result); - // 2) ClockSpan elapsed temp for the SetResult arg — Roslyn forces a local because the - // template puts a DCE.KeepAliveWithoutBoxing call between the elapsed read and the - // await suspension; - // 3) the named source local `enumerable` (always — Roslyn declares one even for reference + // 2) the named source local `enumerable` (always — Roslyn declares one even for reference // types because `unsafe { enumerable = ... }` is a separate assignment statement); - // 4) optional default-value materializers for GetAsyncEnumerator's optional params + // 3) optional default-value materializers for GetAsyncEnumerator's optional params // (typically a CancellationToken local pre-declared here so EmitDefaultArgsForOptionalParameters // can reuse it rather than declaring a new local mid-emit and shifting indexes); - // 5) MoveNext awaiter/awaitable temps — REUSED from V_3/V_4 when types match (Roslyn dedupes + // 4) the named source local `item` for `var item = enumerator.Current;` — Roslyn declares + // it after `enumerable` (and after defaultArg materializers) but BEFORE the MoveNext + // awaiter/awaitable temps; + // 5) caught-exception `object` local — only when DisposeAsync exists (the inner catch(object) + // handler that captures any exception thrown by the iteration uses this slot, and it is + // reused after the dispose await for the rethrow logic that reads <>7__wrap3); + // 6) MoveNext awaiter/awaitable temps — REUSED from V_3/V_4 when types match (Roslyn dedupes // by type), otherwise new locals; - // 6) DisposeAsync awaiter/awaitable temps (if applicable); - // 7) the loop-decrement long, then the catch-block Exception local. + // 7) DisposeAsync awaiter/awaitable temps (if applicable); + // 8) the loop-decrement long, then the catch-block Exception local. + // The template inlines `startedClock.GetElapsed()` directly into the SetResult call site, + // so no ClockSpan local is declared. var isCompleteAwaiterLocal = ilBuilder.DeclareLocal(typeof(ValueTaskAwaiter)); var isCompleteAwaitableLocal = ilBuilder.DeclareLocal(typeof(ValueTask)); - var elapsedLocal = ilBuilder.DeclareLocal(typeof(ClockSpan)); - var enumerableLocal = ilBuilder.DeclareLocal(enumerableInfo.WorkloadReturnType); + var enumerableLocal = ilBuilder.DeclareLocal(Descriptor.WorkloadMethod.ReturnType); var defaultArgLocals = enumerableInfo.GetAsyncEnumeratorMethod .GetParameters() .DistinctBy(p => p.ParameterType) .ToDictionary(p => p.ParameterType, p => ilBuilder.DeclareLocal(p.ParameterType)); + var itemLocal = ilBuilder.DeclareLocal(enumerableInfo.ItemType); var moveNextAwaiterLocal = moveNextAwaiterField.FieldType == isCompleteAwaiterLocal.LocalType ? isCompleteAwaiterLocal : ilBuilder.DeclareLocal(moveNextAwaiterField.FieldType); @@ -106,15 +139,18 @@ void EmitMoveNextImpl() ? isCompleteAwaitableLocal : ilBuilder.DeclareLocal(enumerableInfo.MoveNextAsyncMethod.ReturnType)) : null; + var caughtObjectLocal = disposeAsyncMethod is null + ? null + : ilBuilder.DeclareLocal(typeof(object)); LocalBuilder? disposeAwaiterLocal = null; LocalBuilder? disposeAwaitableLocal = null; - if (enumerableInfo.DisposeAsyncMethod is not null) + if (disposeAsyncMethod is not null) { // Same awaiter-then-awaitable order as the GetIsComplete pattern (Roslyn always emits // the awaiter local first for synthetic awaits). disposeAwaiterLocal = ilBuilder.DeclareLocal(disposeAwaiterField!.FieldType); - disposeAwaitableLocal = enumerableInfo.DisposeAsyncMethod.ReturnType.IsValueType - ? ilBuilder.DeclareLocal(enumerableInfo.DisposeAsyncMethod.ReturnType) + disposeAwaitableLocal = disposeAsyncMethod.ReturnType.IsValueType + ? ilBuilder.DeclareLocal(disposeAsyncMethod.ReturnType) : null; } var invokeCountLocal = ilBuilder.DeclareLocal(typeof(long)); @@ -125,9 +161,18 @@ void EmitMoveNextImpl() var startClockLabel = ilBuilder.DefineLabel(); var callBenchmarkLabel = ilBuilder.DefineLabel(); var callBenchmarkLoopLabel = ilBuilder.DefineLabel(); + var loopBodyLabel = ilBuilder.DefineLabel(); var moveNextLoopLabel = ilBuilder.DefineLabel(); var moveNextContinuationLabel = ilBuilder.DefineLabel(); var moveNextGetResultLabel = ilBuilder.DefineLabel(); + // The following labels only matter when DisposeAsync exists — they're the await-foreach + // try/finally lowering's skeleton (state-1 lands at `state1TargetLabel` which is the nop + // right before the inner try opens; after the inner try-catch, control reaches + // `afterInnerTryLabel` and runs the dispose-then-rethrow sequence). + var state1TargetLabel = ilBuilder.DefineLabel(); + var afterInnerTryLabel = ilBuilder.DefineLabel(); + var skipRethrowLabel = ilBuilder.DefineLabel(); + var rethrowDispatchLabel = ilBuilder.DefineLabel(); var startDisposeLabel = ilBuilder.DefineLabel(); var disposeContinuationLabel = ilBuilder.DefineLabel(); var disposeGetResultLabel = ilBuilder.DefineLabel(); @@ -138,7 +183,7 @@ void EmitMoveNextImpl() // largest dispatch state (3 when DisposeAsync exists, 2 otherwise) — Roslyn's switch // emit pre-pushes it as part of bounds-check elimination scaffolding. ilBuilder.EmitLdloc(stateLocal); - ilBuilder.EmitLdc_I4(enumerableInfo.DisposeAsyncMethod is null ? 2 : 3); + ilBuilder.EmitLdc_I4(disposeAsyncMethod is null ? 2 : 3); ilBuilder.Emit(OpCodes.Pop); ilBuilder.Emit(OpCodes.Pop); ilBuilder.Emit(OpCodes.Nop); @@ -150,9 +195,15 @@ void EmitMoveNextImpl() // array is [state-0 → ..., state-1 → ..., state-2 → ..., state-3 → ...]; everything // outside the range falls through to the first-time path. ilBuilder.EmitLdloc(stateLocal); - ilBuilder.Emit(OpCodes.Switch, enumerableInfo.DisposeAsyncMethod is null + // When DisposeAsync exists Roslyn wraps the iteration in a nested try whose body + // contains the MoveNext resume point — but CIL forbids branching into a protected + // region from outside it, so state-1 targets a `nop` just *before* the inner try and + // falls through into it; a redispatch inside the try then jumps to the actual + // moveNextContinuation. Without DisposeAsync there is no nested try and state-1 + // targets the moveNext continuation directly. + ilBuilder.Emit(OpCodes.Switch, disposeAsyncMethod is null ? [getIsCompleteContinuationLabel, moveNextContinuationLabel, setResultContinuationLabel] - : [getIsCompleteContinuationLabel, moveNextContinuationLabel, disposeContinuationLabel, setResultContinuationLabel]); + : [getIsCompleteContinuationLabel, state1TargetLabel, disposeContinuationLabel, setResultContinuationLabel]); EmitGetIsCompleteAwait(StateGetIsComplete); @@ -166,11 +217,8 @@ void EmitMoveNextImpl() ilBuilder.MaybeEmitSetLocalToDefault(returnDefaultLocal); ilBuilder.Emit(OpCodes.Leave, endTryLabel); - // ===== while(true) { startClock → benchmark loop → DCE → SetResultAndGetIsComplete } ===== + // ===== while(true) { startClock → benchmark loop → SetResultAndGetIsComplete } ===== ilBuilder.MarkLabel(startClockLabel); - // lastItem = default; — reset accumulator per measurement - ilBuilder.Emit(OpCodes.Ldarg_0); - ilBuilder.EmitSetFieldToDefault(lastItemField); // startedClock = ClockExtensions.Start(clock); ilBuilder.Emit(OpCodes.Ldarg_0); ilBuilder.EmitLdloc(thisLocal!); @@ -180,12 +228,12 @@ void EmitMoveNextImpl() ilBuilder.Emit(OpCodes.Stfld, startedClockField); ilBuilder.Emit(OpCodes.Br, callBenchmarkLoopLabel); - // ===== Benchmark call: get enumerable, get enumerator, foreach loop, dispose ===== - // Note: there is intentionally no nested CIL try/catch around the foreach. In CIL you can't - // branch into a protected region from outside, and the state-1 (MoveNextAsync) resume needs - // to land in the foreach loop body which would be inside such a region. If the foreach - // throws, the outer SetException catch surfaces it as a benchmark failure (DisposeAsync - // doesn't run in that case — acceptable trade-off given the benchmark process exits anyway). + // ===== Benchmark call: get enumerable, get enumerator, await foreach, dispose ===== + // When DisposeAsync exists Roslyn lowers the `await foreach` into a nested try/catch + // (the catch captures into <>7__wrap3) followed by a dispose-then-rethrow sequence; + // we mirror that exactly so the state-machine layout matches the template's IL. + // Without DisposeAsync there is no nested protected region and the iteration lives + // directly in the outer try. ilBuilder.MarkLabel(callBenchmarkLabel); // enumerable = workload(args); enumerator = enumerable.GetAsyncEnumerator(...); @@ -198,28 +246,62 @@ void EmitMoveNextImpl() ilBuilder.Emit(OpCodes.Call, Descriptor.WorkloadMethod); ilBuilder.EmitStloc(enumerableLocal); ilBuilder.Emit(OpCodes.Ldarg_0); - if (enumerableInfo.WorkloadReturnType.IsValueType) + if (Descriptor.WorkloadMethod.ReturnType.IsValueType) ilBuilder.EmitLdloca(enumerableLocal); else ilBuilder.EmitLdloc(enumerableLocal); EmitGetAsyncEnumeratorCall(); ilBuilder.Emit(OpCodes.Stfld, enumeratorField); - // "Goto check first" pattern Roslyn uses for `while (await MoveNextAsync()) { body }`: - // skip the body on first entry, run the MoveNextAsync check, and only enter the body - // when GetResult returns true. Conventional do-while-style structure (check at the - // bottom) would emit a different sequence of instructions and throw off the - // EmitsSameIL diff. - var loopBodyLabel = ilBuilder.DefineLabel(); - ilBuilder.Emit(OpCodes.Br, moveNextLoopLabel); + if (disposeAsyncMethod is not null) + { + // <>7__wrap3 = null; <>7__wrap4 = 0; — Roslyn initializes both before opening the + // inner try. <>7__wrap3 holds any exception caught by the try; <>7__wrap4 is the + // try/finally state discriminator that Roslyn emits but never actually reads back + // for this lowering shape (kept for IL-equivalence). + ilBuilder.Emit(OpCodes.Ldarg_0); + ilBuilder.Emit(OpCodes.Ldnull); + ilBuilder.Emit(OpCodes.Stfld, capturedExceptionField!); + ilBuilder.Emit(OpCodes.Ldarg_0); + ilBuilder.Emit(OpCodes.Ldc_I4_0); + ilBuilder.Emit(OpCodes.Stfld, disposeDiscriminatorField!); + + // state-1 jumps to this label (just before the inner try opens), then control + // falls through into the try at its first instruction. + ilBuilder.MarkLabel(state1TargetLabel); + ilBuilder.Emit(OpCodes.Nop); + + ilBuilder.BeginExceptionBlock(); + // Re-dispatch state-1 inside the inner protected region — CIL forbids branching + // into the try from outside, so the outer switch lands at the nop above and we + // jump to moveNextContinuationLabel here once execution is safely inside the try. + ilBuilder.EmitLdloc(stateLocal); + ilBuilder.EmitLdc_I4(StateMoveNextAsync); + ilBuilder.Emit(OpCodes.Beq_S, moveNextContinuationLabel); + ilBuilder.Emit(OpCodes.Br_S, moveNextLoopLabel); + } + else + { + // "Goto check first" pattern Roslyn uses for `while (await MoveNextAsync()) { body }`: + // skip the body on first entry, run the MoveNextAsync check, and only enter the + // body when GetResult returns true. + ilBuilder.Emit(OpCodes.Br, moveNextLoopLabel); + } - // --- Loop body: lastItem = enumerator.Current; --- + // --- Loop body: var item = enumerator.Current; DeadCodeEliminationHelper.KeepAliveWithoutBoxing(in item); --- ilBuilder.MarkLabel(loopBodyLabel); ilBuilder.Emit(OpCodes.Ldarg_0); - ilBuilder.Emit(OpCodes.Ldarg_0); EmitLoadEnumeratorForCall(enumeratorField); EmitInvokeEnumeratorMethod(enumerableInfo.CurrentProperty.GetMethod!); - ilBuilder.Emit(OpCodes.Stfld, lastItemField); + ilBuilder.EmitStloc(itemLocal); + ilBuilder.EmitLdloca(itemLocal); + var keepAliveInMethod = typeof(DeadCodeEliminationHelper) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .First(m => m.Name == nameof(DeadCodeEliminationHelper.KeepAliveWithoutBoxing) + && m.GetParameters().Length == 1 + && m.GetParameters()[0].ParameterType.IsByRef) + .MakeGenericMethod(enumerableInfo.ItemType); + ilBuilder.Emit(OpCodes.Call, keepAliveInMethod); ilBuilder.MarkLabel(moveNextLoopLabel); // moveNextAwaitable = enumerator.MoveNextAsync(); @@ -228,17 +310,17 @@ void EmitMoveNextImpl() EmitInvokeEnumeratorMethod(enumerableInfo.MoveNextAsyncMethod); if (moveNextAwaitableLocal is null) { - ilBuilder.Emit(OpCodes.Callvirt, enumerableInfo.MoveNextAwaitableType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!); + ilBuilder.Emit(OpCodes.Callvirt, enumerableInfo.MoveNextAwaitable.GetAwaiterMethod); } else { ilBuilder.EmitStloc(moveNextAwaitableLocal); ilBuilder.EmitLdloca(moveNextAwaitableLocal); - ilBuilder.Emit(OpCodes.Call, enumerableInfo.MoveNextAwaitableType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!); + ilBuilder.Emit(OpCodes.Call, enumerableInfo.MoveNextAwaitable.GetAwaiterMethod); } ilBuilder.EmitStloc(moveNextAwaiterLocal); EmitLoadAwaiterAddressOrValue(moveNextAwaiterLocal); - ilBuilder.Emit(OpCodes.Call, moveNextAwaiterField.FieldType.GetProperty(nameof(TaskAwaiter.IsCompleted), BindingFlags.Public | BindingFlags.Instance)!.GetMethod!); + ilBuilder.Emit(OpCodes.Call, enumerableInfo.MoveNextAwaitable.IsCompletedProperty.GetMethod!); ilBuilder.Emit(OpCodes.Brtrue, moveNextGetResultLabel); // state = 1; <>u__moveNext = awaiter; AwaitUnsafeOnCompleted(...); leave; ilBuilder.Emit(OpCodes.Ldarg_0); @@ -269,35 +351,62 @@ void EmitMoveNextImpl() ilBuilder.EmitStloc(stateLocal); ilBuilder.Emit(OpCodes.Stfld, stateField); - // --- GetResult --- if true, loop body again; else fall through to dispose. + // --- GetResult --- if true, loop body again; else fall through. ilBuilder.MarkLabel(moveNextGetResultLabel); EmitLoadAwaiterAddressOrValue(moveNextAwaiterLocal); - ilBuilder.Emit(OpCodes.Call, moveNextAwaiterField.FieldType.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance, null, Type.EmptyTypes, null)!); + ilBuilder.Emit(OpCodes.Call, enumerableInfo.MoveNextAwaitable.GetResultMethod); ilBuilder.Emit(OpCodes.Brtrue, loopBodyLabel); - // ====== Inline finally: DisposeAsync (if applicable) ====== - ilBuilder.MarkLabel(startDisposeLabel); - if (enumerableInfo.DisposeAsyncMethod is not null) + if (disposeAsyncMethod is not null) { + // catch (object) { <>7__wrap3 = caught; } — Roslyn's await-foreach lowering uses + // `catch (object)` (not `catch (Exception)`) so non-Exception throwables are also + // captured and rethrown after dispose. ILGenerator synthesizes `leave` instructions + // at the end of the try body (when transitioning to BeginCatchBlock) and at the end + // of the catch body (when EndExceptionBlock runs); both target the end-of-block + // marker which we re-mark as `afterInnerTryLabel` below. + ilBuilder.BeginCatchBlock(typeof(object)); + ilBuilder.EmitStloc(caughtObjectLocal!); + ilBuilder.Emit(OpCodes.Ldarg_0); + ilBuilder.EmitLdloc(caughtObjectLocal!); + ilBuilder.Emit(OpCodes.Stfld, capturedExceptionField!); + ilBuilder.EndExceptionBlock(); + + // ====== After-foreach: optional await DisposeAsync, then rethrow if captured ====== + ilBuilder.MarkLabel(afterInnerTryLabel); + + // For REFERENCE-type enumerators (typically IAsyncEnumerator): if the enumerator + // field is null we skip the dispose await entirely. Roslyn emits this `ldfld; + // brfalse` guard only for reference enumerators — a value-type Enumerator (struct) + // can never be null, and `ldfld; brfalse` on a multi-word struct is meaningless, + // so the value-type path goes straight to DisposeAsync. + ilBuilder.MarkLabel(startDisposeLabel); + if (!enumerableInfo.EnumeratorType.IsValueType) + { + ilBuilder.Emit(OpCodes.Ldarg_0); + ilBuilder.Emit(OpCodes.Ldfld, enumeratorField); + ilBuilder.Emit(OpCodes.Brfalse_S, skipRethrowLabel); + } + // disposeAwaitable = enumerator.DisposeAsync(); ilBuilder.Emit(OpCodes.Ldarg_0); EmitLoadEnumeratorForCall(enumeratorField); - EmitInvokeEnumeratorMethod(enumerableInfo.DisposeAsyncMethod); + EmitInvokeEnumeratorMethod(disposeAsyncMethod); if (disposeAwaitableLocal is null) { - ilBuilder.Emit(OpCodes.Callvirt, enumerableInfo.DisposeAwaitableType!.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!); + ilBuilder.Emit(OpCodes.Callvirt, disposeAwaitableInfo!.GetAwaiterMethod); } else { ilBuilder.EmitStloc(disposeAwaitableLocal); ilBuilder.EmitLdloca(disposeAwaitableLocal); - ilBuilder.Emit(OpCodes.Call, enumerableInfo.DisposeAwaitableType!.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!); + ilBuilder.Emit(OpCodes.Call, disposeAwaitableInfo!.GetAwaiterMethod); } ilBuilder.EmitStloc(disposeAwaiterLocal!); EmitLoadAwaiterAddressOrValue(disposeAwaiterLocal!); - ilBuilder.Emit(OpCodes.Call, disposeAwaiterField!.FieldType.GetProperty(nameof(TaskAwaiter.IsCompleted), BindingFlags.Public | BindingFlags.Instance)!.GetMethod!); + ilBuilder.Emit(OpCodes.Call, disposeAwaitableInfo!.IsCompletedProperty.GetMethod!); ilBuilder.Emit(OpCodes.Brtrue, disposeGetResultLabel); - // state = 3; <>u__dispose = awaiter; AwaitUnsafeOnCompleted; leave + // state = 2; <>u__dispose = awaiter; AwaitUnsafeOnCompleted; leave ilBuilder.Emit(OpCodes.Ldarg_0); ilBuilder.EmitLdc_I4(StateDisposeAsync); ilBuilder.Emit(OpCodes.Dup); @@ -305,15 +414,15 @@ void EmitMoveNextImpl() ilBuilder.Emit(OpCodes.Stfld, stateField); ilBuilder.Emit(OpCodes.Ldarg_0); ilBuilder.EmitLdloc(disposeAwaiterLocal!); - ilBuilder.Emit(OpCodes.Stfld, disposeAwaiterField); + ilBuilder.Emit(OpCodes.Stfld, disposeAwaiterField!); ilBuilder.Emit(OpCodes.Ldarg_0); ilBuilder.Emit(OpCodes.Ldflda, builderField); ilBuilder.EmitLdloca(disposeAwaiterLocal!); ilBuilder.Emit(OpCodes.Ldarg_0); - ilBuilder.Emit(OpCodes.Call, GetAwaitOnCompletedMethod(asyncMethodBuilderType, disposeAwaiterField.FieldType, asyncStateMachineTypeBuilder)); + ilBuilder.Emit(OpCodes.Call, GetAwaitOnCompletedMethod(asyncMethodBuilderType, disposeAwaiterField!.FieldType, asyncStateMachineTypeBuilder)); ilBuilder.Emit(OpCodes.Leave, returnLabel); - // --- Resume from state 3 --- + // --- Resume from state 2 --- ilBuilder.MarkLabel(disposeContinuationLabel); ilBuilder.Emit(OpCodes.Ldarg_0); ilBuilder.Emit(OpCodes.Ldfld, disposeAwaiterField); @@ -328,10 +437,36 @@ void EmitMoveNextImpl() ilBuilder.MarkLabel(disposeGetResultLabel); EmitLoadAwaiterAddressOrValue(disposeAwaiterLocal!); - ilBuilder.Emit(OpCodes.Call, disposeAwaiterField.FieldType.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance, null, Type.EmptyTypes, null)!); + ilBuilder.Emit(OpCodes.Call, disposeAwaitableInfo!.GetResultMethod); + + // Rethrow path: if <>7__wrap3 holds a captured exception, rethrow it (via + // ExceptionDispatchInfo when it's a real Exception, raw `throw` otherwise so + // non-Exception payloads keep their original semantics). + ilBuilder.MarkLabel(skipRethrowLabel); + ilBuilder.Emit(OpCodes.Ldarg_0); + ilBuilder.Emit(OpCodes.Ldfld, capturedExceptionField!); + ilBuilder.EmitStloc(caughtObjectLocal!); + ilBuilder.EmitLdloc(caughtObjectLocal!); + var afterRethrowLabel = ilBuilder.DefineLabel(); + ilBuilder.Emit(OpCodes.Brfalse_S, afterRethrowLabel); + ilBuilder.EmitLdloc(caughtObjectLocal!); + ilBuilder.Emit(OpCodes.Isinst, typeof(Exception)); + ilBuilder.Emit(OpCodes.Dup); + ilBuilder.Emit(OpCodes.Brtrue_S, rethrowDispatchLabel); + ilBuilder.EmitLdloc(caughtObjectLocal!); + ilBuilder.Emit(OpCodes.Throw); + ilBuilder.MarkLabel(rethrowDispatchLabel); + ilBuilder.Emit(OpCodes.Call, typeof(System.Runtime.ExceptionServices.ExceptionDispatchInfo).GetMethod(nameof(System.Runtime.ExceptionServices.ExceptionDispatchInfo.Capture))!); + ilBuilder.Emit(OpCodes.Callvirt, typeof(System.Runtime.ExceptionServices.ExceptionDispatchInfo).GetMethod(nameof(System.Runtime.ExceptionServices.ExceptionDispatchInfo.Throw), Type.EmptyTypes)!); + ilBuilder.MarkLabel(afterRethrowLabel); + + // <>7__wrap3 = null; — clear the captured-exception field for the next iteration. + ilBuilder.Emit(OpCodes.Ldarg_0); + ilBuilder.Emit(OpCodes.Ldnull); + ilBuilder.Emit(OpCodes.Stfld, capturedExceptionField!); } - // Reset enumerator field so next iteration starts fresh + // Reset enumerator field so next iteration starts fresh. ilBuilder.Emit(OpCodes.Ldarg_0); ilBuilder.EmitSetFieldToDefault(enumeratorField); @@ -353,26 +488,9 @@ void EmitMoveNextImpl() ilBuilder.Emit(OpCodes.Conv_I8); ilBuilder.Emit(OpCodes.Bge, callBenchmarkLabel); - // Match the template's source order: compute `elapsed` first, then DCE keep-alive on - // `lastItem`, then await SetResult(elapsed). Forcing `elapsed` through a local is what - // Roslyn does too — the DCE call between the elapsed read and the await prevents the - // C# compiler from keeping the ClockSpan on the stack. - ilBuilder.Emit(OpCodes.Ldarg_0); - ilBuilder.Emit(OpCodes.Ldflda, startedClockField); - ilBuilder.Emit(OpCodes.Call, typeof(StartedClock).GetMethod(nameof(StartedClock.GetElapsed), BindingFlags.Public | BindingFlags.Instance)!); - ilBuilder.EmitStloc(elapsedLocal); - - // DCE keep-alive on the accumulated lastItem after the inner loop ends. - ilBuilder.Emit(OpCodes.Ldarg_0); - ilBuilder.Emit(OpCodes.Ldfld, lastItemField); - var keepAliveMethod = typeof(DeadCodeEliminationHelper) - .GetMethods(BindingFlags.Public | BindingFlags.Static) - .First(m => m.Name == nameof(DeadCodeEliminationHelper.KeepAliveWithoutBoxing) - && m.GetParameters().Length == 1 - && !m.GetParameters()[0].ParameterType.IsByRef) - .MakeGenericMethod(itemType); - ilBuilder.Emit(OpCodes.Call, keepAliveMethod); - + // The template inlines `startedClock.GetElapsed()` directly into SetResult — no + // intermediate ClockSpan local. EmitSetResultAndGetIsCompleteAwait emits the GetElapsed + // call as part of the SetResult argument sequence. EmitSetResultAndGetIsCompleteAwait(StateSetResult); ilBuilder.MarkLabel(setResultContinuationLabel); @@ -440,11 +558,14 @@ void EmitGetIsCompleteAwait(int state) void EmitSetResultAndGetIsCompleteAwait(int state) { - // elapsed is now in the elapsed local (computed before the DCE call). + // Inline `startedClock.GetElapsed()` as the SetResult argument — the template no longer + // declares a ClockSpan elapsed local, so neither do we. ilBuilder.EmitLdloc(thisLocal!); ilBuilder.Emit(OpCodes.Ldflda, fieldsContainerField); ilBuilder.Emit(OpCodes.Ldfld, workloadContinuerAndValueTaskSourceField); - ilBuilder.EmitLdloc(elapsedLocal); + ilBuilder.Emit(OpCodes.Ldarg_0); + ilBuilder.Emit(OpCodes.Ldflda, startedClockField); + ilBuilder.Emit(OpCodes.Call, typeof(StartedClock).GetMethod(nameof(StartedClock.GetElapsed), BindingFlags.Public | BindingFlags.Instance)!); ilBuilder.Emit(OpCodes.Callvirt, typeof(WorkloadValueTaskSource).GetMethod(nameof(WorkloadValueTaskSource.SetResultAndGetIsComplete), BindingFlags.Public | BindingFlags.Instance)!); ilBuilder.EmitStloc(isCompleteAwaitableLocal); ilBuilder.EmitLdloca(isCompleteAwaitableLocal); @@ -486,7 +607,7 @@ void EmitResumeFromValueTaskBoolAwait(FieldInfo awaiterField, LocalBuilder await void EmitGetAsyncEnumeratorCall() { EmitDefaultArgsForOptionalParameters(enumerableInfo.GetAsyncEnumeratorMethod); - var opCode = enumerableInfo.IsInterfaceDispatch || !enumerableInfo.WorkloadReturnType.IsValueType + var opCode = enumerableInfo.IsInterfaceDispatch || !Descriptor.WorkloadMethod.ReturnType.IsValueType ? OpCodes.Callvirt : OpCodes.Call; ilBuilder.Emit(opCode, enumerableInfo.GetAsyncEnumeratorMethod); @@ -540,110 +661,5 @@ void EmitLoadAwaiterAddressOrValue(LocalBuilder awaiterLocal) } } } - - // ----------------------------------------------------------------------------------------- - // Resolution: figure out which methods/types the await foreach pattern binds to for the - // workload's return type. Mirrors the C# compiler's resolution rules — pattern first, then - // interface fallback. - // ----------------------------------------------------------------------------------------- - - private sealed record EnumerableInfo( - Type WorkloadReturnType, - Type EnumeratorType, - Type ItemType, - MethodInfo GetAsyncEnumeratorMethod, - MethodInfo MoveNextAsyncMethod, - Type MoveNextAwaitableType, - Type MoveNextAwaiterType, - PropertyInfo CurrentProperty, - MethodInfo? DisposeAsyncMethod, - Type? DisposeAwaitableType, - Type? DisposeAwaiterType, - bool IsInterfaceDispatch); - - private static EnumerableInfo ResolveEnumerableInfo(Type workloadReturnType) - { - // Pattern first: a public instance GetAsyncEnumerator with all-optional parameters. - var patternGetAsyncEnumerator = workloadReturnType - .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(m => m.Name == nameof(IAsyncEnumerable<>.GetAsyncEnumerator) - && m.GetParameters().All(p => p.IsOptional)); - MethodInfo getAsyncEnumeratorMethod; - bool isInterfaceDispatch; - if (patternGetAsyncEnumerator != null) - { - getAsyncEnumeratorMethod = patternGetAsyncEnumerator; - isInterfaceDispatch = false; - } - else - { - Type? iface = null; - if (workloadReturnType.IsGenericType && workloadReturnType.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)) - { - iface = workloadReturnType; - } - else - { - foreach (var i in workloadReturnType.GetInterfaces()) - { - if (i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)) - { - iface = i; - break; - } - } - } - if (iface is null) - { - throw new NotSupportedException($"Type {workloadReturnType.GetDisplayName()} is not an async enumerable."); - } - getAsyncEnumeratorMethod = iface.GetMethod(nameof(IAsyncEnumerable<>.GetAsyncEnumerator))!; - isInterfaceDispatch = true; - } - - var enumeratorType = getAsyncEnumeratorMethod.ReturnType; - var moveNextAsyncMethod = enumeratorType - .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(m => m.Name == nameof(IAsyncEnumerator<>.MoveNextAsync) && m.GetParameters().All(p => p.IsOptional)) - ?? throw new NotSupportedException($"Enumerator type {enumeratorType.GetDisplayName()} does not expose MoveNextAsync."); - var moveNextAwaitableType = moveNextAsyncMethod.ReturnType; - var moveNextAwaiterType = moveNextAwaitableType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!.ReturnType; - - var currentProperty = enumeratorType.GetProperty(nameof(IAsyncEnumerator<>.Current), BindingFlags.Public | BindingFlags.Instance) - ?? throw new NotSupportedException($"Enumerator type {enumeratorType.GetDisplayName()} does not expose Current."); - var itemType = currentProperty.PropertyType; - - // DisposeAsync is optional for the await-foreach pattern. Roslyn matches a public instance - // method named DisposeAsync whose parameters are all optional and whose awaiter's GetResult - // returns void; otherwise it falls back to the IAsyncDisposable interface dispatch. - MethodInfo? disposeAsyncMethod = enumeratorType - .GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(m => m.Name == nameof(IAsyncDisposable.DisposeAsync) - && m.GetParameters().All(p => p.IsOptional) - && m.ReturnType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance) - ?.ReturnType - .GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance) - ?.ReturnType == typeof(void)); - if (disposeAsyncMethod is null && typeof(IAsyncDisposable).IsAssignableFrom(enumeratorType)) - { - disposeAsyncMethod = typeof(IAsyncDisposable).GetMethod(nameof(IAsyncDisposable.DisposeAsync)); - } - Type? disposeAwaitableType = disposeAsyncMethod?.ReturnType; - Type? disposeAwaiterType = disposeAwaitableType?.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)!.ReturnType; - - return new EnumerableInfo( - workloadReturnType, - enumeratorType, - itemType, - getAsyncEnumeratorMethod, - moveNextAsyncMethod, - moveNextAwaitableType, - moveNextAwaiterType, - currentProperty, - disposeAsyncMethod, - disposeAwaitableType, - disposeAwaiterType, - isInterfaceDispatch); - } } } diff --git a/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/RunnableEmitter.cs b/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/RunnableEmitter.cs index b0b4ac4903..9fcfd95487 100644 --- a/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/RunnableEmitter.cs +++ b/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/RunnableEmitter.cs @@ -72,10 +72,10 @@ public static Assembly EmitPartitionAssembly(GenerateResult generateResult, Buil { var returnType = benchmark.BenchmarkCase.Descriptor.WorkloadMethod.ReturnType; RunnableEmitter runnableEmitter; - if (returnType.IsAwaitable()) - runnableEmitter = new AsyncCoreEmitter(buildPartition, moduleBuilder, benchmark); - else if (returnType.IsAsyncEnumerable(out _, out _, out _)) - runnableEmitter = new AsyncEnumerableCoreEmitter(buildPartition, moduleBuilder, benchmark); + if (returnType.IsAwaitable(out var awaitableInfo)) + runnableEmitter = new AsyncCoreEmitter(buildPartition, moduleBuilder, benchmark, awaitableInfo); + else if (returnType.IsAsyncEnumerable(out var asyncEnumerableInfo)) + runnableEmitter = new AsyncEnumerableCoreEmitter(buildPartition, moduleBuilder, benchmark, asyncEnumerableInfo); else runnableEmitter = new SyncCoreEmitter(buildPartition, moduleBuilder, benchmark); runnableEmitter.EmitRunnableCore(); diff --git a/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/SetupCleanupEmitter.cs b/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/SetupCleanupEmitter.cs index 8d99c924fe..32a549972f 100644 --- a/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/SetupCleanupEmitter.cs +++ b/src/BenchmarkDotNet/Toolchains/InProcess/Emit/Implementation/Emitters/SetupCleanupEmitter.cs @@ -12,7 +12,7 @@ private enum SetupCleanupKind { Other, GlobalSetup, GlobalCleanup } private void EmitSetupCleanup(string methodName, MethodInfo? methodToCall, SetupCleanupKind kind) { - if (methodToCall?.ReturnType.IsAwaitable() == true) + if (methodToCall?.ReturnType.IsAwaitable(out _) == true) { EmitAsyncSetupCleanup(methodName, methodToCall, kind); } diff --git a/src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionFactory.cs b/src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionFactory.cs index a8235d561b..b152fc4c6f 100644 --- a/src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionFactory.cs +++ b/src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionFactory.cs @@ -65,18 +65,18 @@ private static IBenchmarkAction CreateCore(IBenchmarkActionFactory? factory, obj unrollFactor); } - if (resultType.IsAwaitable()) + if (resultType.IsAwaitable(out _)) { throw new NotSupportedException($"Default {nameof(BenchmarkActionFactory)} does not support returning awaitable types except (Value)Task()."); } - if (resultType.IsAsyncEnumerable(out var asyncEnumerableItemType, out _, out _)) + if (resultType.IsAsyncEnumerable(out var asyncEnumerableInfo)) { - var asyncEnumerableInterface = typeof(IAsyncEnumerable<>).MakeGenericType(asyncEnumerableItemType); + var asyncEnumerableInterface = typeof(IAsyncEnumerable<>).MakeGenericType(asyncEnumerableInfo.ItemType); if (asyncEnumerableInterface.IsAssignableFrom(resultType)) { return Create( - typeof(BenchmarkActionAsyncEnumerable<>).MakeGenericType(asyncEnumerableItemType), + typeof(BenchmarkActionAsyncEnumerable<>).MakeGenericType(asyncEnumerableInfo.ItemType), resultInstance, targetMethod, unrollFactor); @@ -84,7 +84,7 @@ private static IBenchmarkAction CreateCore(IBenchmarkActionFactory? factory, obj if (resultType.IsGenericType && resultType.GetGenericTypeDefinition() == typeof(ConfiguredCancelableAsyncEnumerable<>)) { return Create( - typeof(BenchmarkActionConfiguredCancelableAsyncEnumerable<>).MakeGenericType(resultType.GetGenericArguments()[0]), + typeof(BenchmarkActionConfiguredCancelableAsyncEnumerable<>).MakeGenericType(asyncEnumerableInfo.ItemType), resultInstance, targetMethod, unrollFactor); diff --git a/src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionImpl.cs b/src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionImpl.cs index ee67c909cb..51148cb31e 100644 --- a/src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionImpl.cs +++ b/src/BenchmarkDotNet/Toolchains/InProcess/NoEmit/BenchmarkActionImpl.cs @@ -375,7 +375,8 @@ private async Task WorkloadCore() var startedClock = clock!.Start(); while (--invokeCount >= 0) { - await callback(); + var result = await callback(); + DeadCodeEliminationHelper.KeepAliveWithoutBoxing(in result); } if (await workloadValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed())) { @@ -524,7 +525,8 @@ private async ValueTask WorkloadCore() var startedClock = clock!.Start(); while (--invokeCount >= 0) { - await callback(); + var result = await callback(); + DeadCodeEliminationHelper.KeepAliveWithoutBoxing(in result); } if (await workloadValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed())) { @@ -544,18 +546,21 @@ public override void Cleanup() } [AggressivelyOptimizeMethods] -public class BenchmarkActionConfiguredCancelableAsyncEnumerable : BenchmarkActionBase +public class BenchmarkActionAsyncEnumerable : BenchmarkActionBase +#if NET9_0_OR_GREATER + where T : allows ref struct +#endif { - private readonly Func> callback; + private readonly Func> callback; private readonly int unrollFactor; private WorkloadValueTaskSource workloadValueTaskSource = null!; private IClock? clock; private long invokeCount; [SetsRequiredMembers] - public BenchmarkActionConfiguredCancelableAsyncEnumerable(object? instance, MethodInfo method, int unrollFactor) + public BenchmarkActionAsyncEnumerable(object? instance, MethodInfo method, int unrollFactor) { - callback = CreateWorkload>>(instance, method); + callback = CreateWorkload>>(instance, method); this.unrollFactor = unrollFactor; InvokeSingle = InvokeOnce; InvokeUnroll = WorkloadActionUnroll; @@ -595,18 +600,15 @@ private async Task WorkloadCore() } while (true) { - T? lastItem = default; var startedClock = clock!.Start(); while (--invokeCount >= 0) { await foreach (var item in callback()) { - lastItem = item; + DeadCodeEliminationHelper.KeepAliveWithoutBoxing(in item); } } - var elapsed = startedClock.GetElapsed(); - DeadCodeEliminationHelper.KeepAliveWithoutBoxing(lastItem); - if (await workloadValueTaskSource.SetResultAndGetIsComplete(elapsed)) + if (await workloadValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed())) { return; } @@ -623,18 +625,21 @@ public override void Cleanup() } [AggressivelyOptimizeMethods] -public class BenchmarkActionAsyncEnumerable : BenchmarkActionBase +public class BenchmarkActionConfiguredCancelableAsyncEnumerable : BenchmarkActionBase +#if NET10_0_OR_GREATER + where T : allows ref struct +#endif { - private readonly Func> callback; + private readonly Func> callback; private readonly int unrollFactor; private WorkloadValueTaskSource workloadValueTaskSource = null!; private IClock? clock; private long invokeCount; [SetsRequiredMembers] - public BenchmarkActionAsyncEnumerable(object? instance, MethodInfo method, int unrollFactor) + public BenchmarkActionConfiguredCancelableAsyncEnumerable(object? instance, MethodInfo method, int unrollFactor) { - callback = CreateWorkload>>(instance, method); + callback = CreateWorkload>>(instance, method); this.unrollFactor = unrollFactor; InvokeSingle = InvokeOnce; InvokeUnroll = WorkloadActionUnroll; @@ -674,18 +679,15 @@ private async Task WorkloadCore() } while (true) { - T? lastItem = default; var startedClock = clock!.Start(); while (--invokeCount >= 0) { await foreach (var item in callback()) { - lastItem = item; + DeadCodeEliminationHelper.KeepAliveWithoutBoxing(in item); } } - var elapsed = startedClock.GetElapsed(); - DeadCodeEliminationHelper.KeepAliveWithoutBoxing(lastItem); - if (await workloadValueTaskSource.SetResultAndGetIsComplete(elapsed)) + if (await workloadValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed())) { return; } diff --git a/src/BenchmarkDotNet/Validators/AwaitableAsyncEnumerableAmbiguityValidator.cs b/src/BenchmarkDotNet/Validators/AwaitableAsyncEnumerableAmbiguityValidator.cs index c7ed2bc723..c24da3a069 100644 --- a/src/BenchmarkDotNet/Validators/AwaitableAsyncEnumerableAmbiguityValidator.cs +++ b/src/BenchmarkDotNet/Validators/AwaitableAsyncEnumerableAmbiguityValidator.cs @@ -37,10 +37,10 @@ private void CollectErrors(string benchmarkClassName, IEnumerable if (!method.GetCustomAttributes(false).OfType().Any()) continue; - if (!method.ReturnType.IsAsyncEnumerable(out _, out _, out _)) + if (!method.ReturnType.IsAsyncEnumerable(out _)) continue; - if (!method.ReturnType.IsAwaitable()) + if (!method.ReturnType.IsAwaitable(out _)) continue; validationErrors.Add(new ValidationError( diff --git a/src/BenchmarkDotNet/Validators/ExecutionValidator.cs b/src/BenchmarkDotNet/Validators/ExecutionValidator.cs index a6845585bd..8a8658db72 100644 --- a/src/BenchmarkDotNet/Validators/ExecutionValidator.cs +++ b/src/BenchmarkDotNet/Validators/ExecutionValidator.cs @@ -20,23 +20,23 @@ protected override async ValueTask ExecuteBenchmarksAsync(object benchmarkTypeIn { var workloadMethod = benchmark.Descriptor.WorkloadMethod; var result = workloadMethod.Invoke(benchmarkTypeInstance, null); - if (workloadMethod.ReturnType.IsAwaitable()) + if (workloadMethod.ReturnType.IsAwaitable(out var awaitableInfo)) { if (result is null) { errors.Add(new ValidationError(TreatsWarningsAsErrors, $"Awaitable benchmark '{benchmark.DisplayInfo}' returned null", benchmark)); continue; } - await DynamicAwaitHelper.AwaitResult(result, workloadMethod.ReturnType).ConfigureAwait(true); + await DynamicAwaitHelper.AwaitResult(result, awaitableInfo).ConfigureAwait(true); } - else if (workloadMethod.ReturnType.IsAsyncEnumerable(out _, out _, out _)) + else if (workloadMethod.ReturnType.IsAsyncEnumerable(out var asyncEnumerableInfo)) { if (result is null) { errors.Add(new ValidationError(TreatsWarningsAsErrors, $"Async enumerable benchmark '{benchmark.DisplayInfo}' returned null", benchmark)); continue; } - await DynamicAwaitHelper.DrainAsyncEnumerableAsync(result, workloadMethod.ReturnType).ConfigureAwait(true); + await DynamicAwaitHelper.DrainAsyncEnumerableAsync(result, asyncEnumerableInfo).ConfigureAwait(true); } } catch (Exception ex) when (!ExceptionHelper.IsProperCancelation(ex, cancellationToken)) diff --git a/src/BenchmarkDotNet/Validators/ExecutionValidatorBase.cs b/src/BenchmarkDotNet/Validators/ExecutionValidatorBase.cs index 004b2e1940..53781e8a77 100644 --- a/src/BenchmarkDotNet/Validators/ExecutionValidatorBase.cs +++ b/src/BenchmarkDotNet/Validators/ExecutionValidatorBase.cs @@ -116,14 +116,14 @@ private async ValueTask TryToCallGlobalMethod(object benchmarkTypeInsta try { var result = methods[0].Invoke(benchmarkTypeInstance, null); - if (methods[0].ReturnType.IsAwaitable()) + if (methods[0].ReturnType.IsAwaitable(out var awaitableInfo)) { if (result is null) { errors.Add(new ValidationError(TreatsWarningsAsErrors, $"[{GetAttributeName(typeof(T))}] for {benchmarkTypeInstance.GetType().Name} returned null")); return false; } - await DynamicAwaitHelper.AwaitResult(result, methods[0].ReturnType).ConfigureAwait(false); + await DynamicAwaitHelper.AwaitResult(result, awaitableInfo).ConfigureAwait(false); } } catch (Exception ex) when (!ExceptionHelper.IsProperCancelation(ex, cancellationToken)) diff --git a/src/BenchmarkDotNet/Validators/ReturnValueValidator.cs b/src/BenchmarkDotNet/Validators/ReturnValueValidator.cs index 968b674711..f0b52838d6 100644 --- a/src/BenchmarkDotNet/Validators/ReturnValueValidator.cs +++ b/src/BenchmarkDotNet/Validators/ReturnValueValidator.cs @@ -31,27 +31,27 @@ protected override async ValueTask ExecuteBenchmarksAsync(object benchmarkTypeIn InProcessNoEmitRunner.FillMembers(benchmarkTypeInstance, benchmark, cancellationToken); var workloadMethod = benchmark.Descriptor.WorkloadMethod; var result = workloadMethod.Invoke(benchmarkTypeInstance, null); - if (workloadMethod.ReturnType.IsAwaitable()) + if (workloadMethod.ReturnType.IsAwaitable(out var awaitableInfo)) { if (result is null) { errors.Add(new ValidationError(TreatsWarningsAsErrors, $"Awaitable benchmark '{benchmark.DisplayInfo}' returned null", benchmark)); continue; } - (var hasResult, result) = await DynamicAwaitHelper.AwaitResult(result, workloadMethod.ReturnType).ConfigureAwait(true); + (var hasResult, result) = await DynamicAwaitHelper.AwaitResult(result, awaitableInfo).ConfigureAwait(true); if (hasResult) { results.Add((benchmark, result!)); } } - else if (workloadMethod.ReturnType.IsAsyncEnumerable(out _, out _, out _)) + else if (workloadMethod.ReturnType.IsAsyncEnumerable(out var asyncEnumerableInfo)) { if (result is null) { errors.Add(new ValidationError(TreatsWarningsAsErrors, $"Async enumerable benchmark '{benchmark.DisplayInfo}' returned null", benchmark)); continue; } - result = await DynamicAwaitHelper.ToListAsync(result, workloadMethod.ReturnType).ConfigureAwait(true); + result = await DynamicAwaitHelper.ToListAsync(result, asyncEnumerableInfo).ConfigureAwait(true); results.Add((benchmark, result)); } else if (workloadMethod.ReturnType != typeof(void)) diff --git a/src/BenchmarkDotNet/Validators/SetupCleanupValidator.cs b/src/BenchmarkDotNet/Validators/SetupCleanupValidator.cs index 4b00b75819..db2660d57d 100644 --- a/src/BenchmarkDotNet/Validators/SetupCleanupValidator.cs +++ b/src/BenchmarkDotNet/Validators/SetupCleanupValidator.cs @@ -43,8 +43,8 @@ private IEnumerable ValidateReturnType(string benchmarkClass // produce two errors at once. The runtime awaits dual-shaped returns instead of rejecting // them, which matches that validator's warning-not-error severity. if (method.GetCustomAttributes(false).OfType().Any() - && method.ReturnType.IsAsyncEnumerable(out _, out _, out _) - && !method.ReturnType.IsAwaitable()) + && method.ReturnType.IsAsyncEnumerable(out _) + && !method.ReturnType.IsAwaitable(out _)) { yield return new ValidationError( TreatsWarningsAsErrors,