Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/BenchmarkDotNet.Analyzers/AsyncTypeShapes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ public static bool IsAsyncEnumerable(ITypeSymbol type, INamedTypeSymbol? asyncEn
return true;
}

if (TryFindPatternGetAsyncEnumerator(type) is { } enumeratorType
&& HasPatternMoveNextAsync(enumeratorType)
&& HasPublicInstanceProperty(enumeratorType, "Current"))
if (TryFindPatternGetAsyncEnumerator(type) is { } enumeratorType)
{
return true;
// Roslyn commits to a found pattern `GetAsyncEnumerator` — if its return type doesn't
// satisfy the await-foreach enumerator shape it reports an error instead of falling back
// to `IAsyncEnumerable<T>`, even when the source also implements the interface. We mirror
// that here so the analyzer's view of binding matches what `await foreach` would actually
// accept.
return HasPatternMoveNextAsync(enumeratorType)
&& HasPublicInstanceProperty(enumeratorType, "Current");
}

if (asyncEnumerableInterfaceSymbol != null)
Expand Down
8 changes: 4 additions & 4 deletions src/BenchmarkDotNet/Code/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ private static DeclarationsProvider GetDeclarationsProvider(BenchmarkCase benchm
{
var method = benchmark.Descriptor.WorkloadMethod;

if (method.ReturnType.IsAwaitable())
if (method.ReturnType.IsAwaitable(out var awaitableInfo))
{
return new AsyncDeclarationsProvider(benchmark);
return new AsyncDeclarationsProvider(benchmark, awaitableInfo.ResultType);
}

if (method.ReturnType.IsAsyncEnumerable(out var itemType, out var enumeratorType, out var moveNextAwaitableType))
if (method.ReturnType.IsAsyncEnumerable(out var asyncEnumerableInfo))
{
return new AsyncEnumerableDeclarationsProvider(benchmark, itemType, enumeratorType, moveNextAwaitableType);
return new AsyncEnumerableDeclarationsProvider(benchmark, asyncEnumerableInfo.ItemType, asyncEnumerableInfo.MoveNextAsyncMethod.ReturnType);
}

if (method.ReturnType == typeof(void) && method.HasAttribute<AsyncStateMachineAttribute>())
Expand Down
191 changes: 57 additions & 134 deletions src/BenchmarkDotNet/Code/DeclarationsProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private void Replace(SmartStringBuilder smartStringBuilder, MethodInfo? method,
userImpl = string.Empty;
needsExplicitReturn = true;
}
else if (method.ReturnType.IsAwaitable())
else if (method.ReturnType.IsAwaitable(out _))
{
modifier = "async";
userImpl = $"await {GetMethodPrefix(method)}.{method.Name}();";
Expand Down Expand Up @@ -97,7 +97,7 @@ protected string GetPassArgumentsDirect()
);
}

internal class SyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark)
internal sealed class SyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark)
{
public override string[] GetExtraFields() => [];

Expand Down Expand Up @@ -175,8 +175,13 @@ private string GetPassArguments()
);
}

internal class AsyncDeclarationsProvider(BenchmarkCase benchmark) : DeclarationsProvider(benchmark)
internal abstract class AsyncDeclarationsProviderBase(BenchmarkCase benchmark) : DeclarationsProvider(benchmark)
{
// Type used to drive the WorkloadCore builder selection. For ordinary awaitables it's the workload
// method's own return type, but `IAsyncEnumerable<T>` has no GetAwaiter, so AsyncEnumerableDeclarationsProvider
// overrides this to expose the MoveNextAsync awaitable as a proxy.
protected virtual Type WorkloadAwaitableReturnType => Descriptor.WorkloadMethod.ReturnType;

public override string[] GetExtraFields() =>
[
$"public {typeof(WorkloadValueTaskSource).GetCorrectCSharpTypeName()} workloadContinuerAndValueTaskSource;",
Expand All @@ -193,87 +198,6 @@ protected override string GetExtraGlobalSetupImpl()
protected override string GetExtraGlobalCleanupImpl()
=> "this.__fieldsContainer.workloadContinuerAndValueTaskSource.Complete();";

protected override SmartStringBuilder ReplaceCore(SmartStringBuilder smartStringBuilder)
{
// Unlike sync calls, async calls suffer from unrolling, so we multiply the invokeCount by the unroll factor and delegate the implementation to *NoUnroll methods.
int unrollFactor = Benchmark.Job.ResolveValue(RunMode.UnrollFactorCharacteristic, EnvironmentResolver.Instance);
string passArguments = GetPassArgumentsDirect();
string workloadMethodCall = GetWorkloadMethodCall(passArguments);
bool hasAsyncMethodBuilderAttribute = TryGetAsyncMethodBuilderAttribute(out var asyncMethodBuilderAttribute);
Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, Descriptor.WorkloadMethod.ReturnType);
string finalReturn = GetFinalReturn(workloadCoreReturnType);
string coreImpl = $$"""
private {{CoreReturnType}} OverheadActionUnroll({{CoreParameters}})
{
return this.OverheadActionNoUnroll(invokeCount * {{unrollFactor}}, clock);
}

private {{CoreReturnType}} OverheadActionNoUnroll({{CoreParameters}})
{
{{StartClockSyncCode}}
while (--invokeCount >= 0)
{
this.__Overhead({{passArguments}});
}
{{ReturnSyncCode}}
}

private {{CoreReturnType}} WorkloadActionUnroll({{CoreParameters}})
{
return this.WorkloadActionNoUnroll(invokeCount * {{unrollFactor}}, clock);
}

private {{CoreReturnType}} WorkloadActionNoUnroll({{CoreParameters}})
{
this.__fieldsContainer.invokeCount = invokeCount;
this.__fieldsContainer.clock = clock;
// The source is allocated and the workload loop started in __GlobalSetup,
// so this hot path is branchless and allocation-free.
return this.__fieldsContainer.workloadContinuerAndValueTaskSource.Continue();
}

private async void __StartWorkload()
{
await __WorkloadCore();
}

{{asyncMethodBuilderAttribute}}
private async {{workloadCoreReturnType.GetCorrectCSharpTypeName()}} __WorkloadCore()
{
try
{
if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.GetIsComplete())
{
{{finalReturn}}
}
while (true)
{
{{typeof(StartedClock).GetCorrectCSharpTypeName()}} startedClock = {{typeof(ClockExtensions).GetCorrectCSharpTypeName()}}.Start(this.__fieldsContainer.clock);
while (--this.__fieldsContainer.invokeCount >= 0)
{
// Necessary because of error CS4004: Cannot await in an unsafe context
{{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} awaitable;
unsafe { awaitable = {{workloadMethodCall}} }
await awaitable;
}
if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed()))
{
{{finalReturn}}
}
}
}
catch (global::System.Exception e)
{
__fieldsContainer.workloadContinuerAndValueTaskSource.SetException(e);
{{finalReturn}}
}
}
""";

return smartStringBuilder
.Replace("$CoreImpl$", coreImpl);
}

protected bool TryGetAsyncMethodBuilderAttribute(out string asyncMethodBuilderAttribute)
{
asyncMethodBuilderAttribute = string.Empty;
Expand Down Expand Up @@ -322,32 +246,15 @@ protected static string GetFinalReturn(Type workloadCoreReturnType)
? "return;"
: $"return default({finalReturnType.GetCorrectCSharpTypeName()});";
}
}

internal class AsyncEnumerableDeclarationsProvider(BenchmarkCase benchmark, Type itemType, Type enumeratorType, Type moveNextAwaitableType) : AsyncDeclarationsProvider(benchmark)
{
protected override SmartStringBuilder ReplaceCore(SmartStringBuilder smartStringBuilder)
{
// Unlike sync calls, async calls suffer from unrolling, so we multiply the invokeCount by the unroll factor and delegate the implementation to *NoUnroll methods.
int unrollFactor = Benchmark.Job.ResolveValue(RunMode.UnrollFactorCharacteristic, EnvironmentResolver.Instance);
string passArguments = GetPassArgumentsDirect();
string workloadMethodCall = GetWorkloadMethodCall(passArguments);
string itemTypeName = itemType.GetCorrectCSharpTypeName();
string enumerableTypeName = Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName();
string enumeratorTypeName = enumeratorType.GetCorrectCSharpTypeName();
// We hand-roll the `await foreach` desugaring (explicit GetAsyncEnumerator + while-loop) instead
// of using the C# `await foreach` keyword: that keeps the IL byte-for-byte aligned with
// AsyncEnumerableCoreEmitter, which doesn't wrap the iteration in the try/catch + wrap field
// pattern Roslyn emits for the keyword form.
string disposeAsyncCall = ResolveDisposeAsync() is { } disposeAsyncMethod
? $"await enumerator.{disposeAsyncMethod.Name}();"
: string.Empty;
// IAsyncEnumerable<T> has no GetAwaiter, so its own return type can't drive the WorkloadCore
// builder. Use MoveNextAsync's return type as the proxy and feed it through the same resolver
// as the awaitable path: any result the awaitable produces (typically `bool`) is discarded by
// `__StartWorkload`'s `await`, and `[AsyncCallerType]` / `[AsyncMethodBuilder]` overrides on
// the workload method still apply.
bool hasAsyncMethodBuilderAttribute = TryGetAsyncMethodBuilderAttribute(out var asyncMethodBuilderAttribute);
Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, moveNextAwaitableType);
Type workloadCoreReturnType = GetWorkloadCoreReturnType(hasAsyncMethodBuilderAttribute, WorkloadAwaitableReturnType);
string finalReturn = GetFinalReturn(workloadCoreReturnType);
string coreImpl = $$"""
private {{CoreReturnType}} OverheadActionUnroll({{CoreParameters}})
Expand Down Expand Up @@ -383,7 +290,7 @@ private async void __StartWorkload()
{
await __WorkloadCore();
}

{{asyncMethodBuilderAttribute}}
private async {{workloadCoreReturnType.GetCorrectCSharpTypeName()}} __WorkloadCore()
{
Expand All @@ -395,23 +302,12 @@ private async void __StartWorkload()
}
while (true)
{
{{itemTypeName}} lastItem = default({{itemTypeName}});
{{typeof(StartedClock).GetCorrectCSharpTypeName()}} startedClock = {{typeof(ClockExtensions).GetCorrectCSharpTypeName()}}.Start(this.__fieldsContainer.clock);
while (--this.__fieldsContainer.invokeCount >= 0)
{
// Necessary because of error CS4004: Cannot await in an unsafe context
{{enumerableTypeName}} enumerable;
unsafe { enumerable = {{workloadMethodCall}} }
{{enumeratorTypeName}} enumerator = enumerable.GetAsyncEnumerator();
while (await enumerator.MoveNextAsync())
{
lastItem = enumerator.Current;
}
{{disposeAsyncCall}}
{{GetCallAndConsumeImpl(workloadMethodCall)}}
}
{{typeof(ClockSpan).GetCorrectCSharpTypeName()}} elapsed = startedClock.GetElapsed();
{{typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}}.KeepAliveWithoutBoxing(lastItem);
if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(elapsed))
if (await this.__fieldsContainer.workloadContinuerAndValueTaskSource.SetResultAndGetIsComplete(startedClock.GetElapsed()))
{
{{finalReturn}}
}
Expand All @@ -429,24 +325,51 @@ private async void __StartWorkload()
.Replace("$CoreImpl$", coreImpl);
}

// Roslyn's `await foreach` resolution: prefer a public instance DisposeAsync with all-optional
// params whose awaiter's GetResult returns void; otherwise fall back to IAsyncDisposable.
// Returns null if neither shape matches, in which case the template skips the dispose call.
private MethodInfo? ResolveDisposeAsync()
protected abstract string GetCallAndConsumeImpl(string workloadMethodCall);
}

internal class AsyncDeclarationsProvider(BenchmarkCase benchmark, Type resultType) : AsyncDeclarationsProviderBase(benchmark)
{
protected override string GetCallAndConsumeImpl(string workloadMethodCall)
{
var disposeAsyncMethod = enumeratorType
.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.FirstOrDefault(m => m.Name == nameof(IAsyncDisposable.DisposeAsync)
&& m.GetParameters().All(p => p.IsOptional)
&& m.ReturnType.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)
?.ReturnType
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)
?.ReturnType == typeof(void));
if (disposeAsyncMethod is not null)
return disposeAsyncMethod;
if (typeof(IAsyncDisposable).IsAssignableFrom(enumeratorType))
return typeof(IAsyncDisposable).GetMethod(nameof(IAsyncDisposable.DisposeAsync));
return null;
string awaitStatement;
if (resultType == typeof(void))
{
awaitStatement = "await awaitable;";
}
else
{
var resultTypeName = resultType.GetCorrectCSharpTypeName();
awaitStatement = $"""
{resultTypeName} result = await awaitable;
{typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}.KeepAliveWithoutBoxing<{resultTypeName}>(in result);
""";
}
return $$"""
// Necessary because of error CS4004: Cannot await in an unsafe context
{{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} awaitable;
unsafe { awaitable = {{workloadMethodCall}} }
{{awaitStatement}}
""";
}
}

internal class AsyncEnumerableDeclarationsProvider(BenchmarkCase benchmark, Type itemType, Type moveNextAwaitableType) : AsyncDeclarationsProviderBase(benchmark)
{
protected override Type WorkloadAwaitableReturnType => moveNextAwaitableType;

protected override string GetCallAndConsumeImpl(string workloadMethodCall)
{
string itemTypeName = itemType.GetCorrectCSharpTypeName();
return $$"""
// Necessary because of error CS4004: Cannot await in an unsafe context
{{Descriptor.WorkloadMethod.ReturnType.GetCorrectCSharpTypeName()}} enumerable;
unsafe { enumerable = {{workloadMethodCall}} }
await foreach ({{itemTypeName}} item in enumerable)
{
{{typeof(DeadCodeEliminationHelper).GetCorrectCSharpTypeName()}}.KeepAliveWithoutBoxing<{{itemTypeName}}>(in item);
}
""";
}
}
}
2 changes: 1 addition & 1 deletion src/BenchmarkDotNet/Engines/Consumer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public void Consume<T>(in T value)
else if (default(T) == null && !typeof(T).IsValueType)
Consume((object?)value);
else
DeadCodeEliminationHelper.KeepAliveWithoutBoxingReadonly(value); // non-primitive and nullable value types
DeadCodeEliminationHelper.KeepAliveWithoutBoxing(in value); // non-primitive and nullable value types
}
}
}
Loading
Loading