Skip to content

Commit 4fb6013

Browse files
committed
Handle cases like .First(where) and .Sum
1 parent aa7cf4f commit 4fb6013

2 files changed

Lines changed: 94 additions & 10 deletions

File tree

samples/BasicSample/Program.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,31 @@ public static void Main(string[] args)
134134
{
135135
Console.WriteLine($"User name: {u.FullName}");
136136
}
137+
138+
foreach (var u in dbContext.Users.ToList())
139+
{
140+
Console.WriteLine($"User name: {u.FullName}");
141+
}
142+
143+
foreach (var u in dbContext.Users.OrderBy(x => x.FullName))
144+
{
145+
Console.WriteLine($"User name: {u.FullName}");
146+
}
147+
}
148+
149+
{
150+
foreach (var u in dbContext.Users.Where(x => x.TotalSpent >= 1))
151+
{
152+
Console.WriteLine($"User name: {u.FullName}");
153+
}
137154
}
138155

139156
{
140157
var result = dbContext.Users.FirstOrDefault();
141158
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
159+
160+
result = dbContext.Users.FirstOrDefault(x => x.TotalSpent > 1);
161+
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
142162
}
143163

144164
{

src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Collections.Generic;
1+
using System.Collections;
2+
using System.Collections.Generic;
23
using System.Diagnostics.CodeAnalysis;
34
using System.Linq;
45
using System.Linq.Expressions;
@@ -17,9 +18,26 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
1718
private bool _disableRootRewrite;
1819
private IEntityType? _entityType;
1920

21+
private readonly MethodInfo _select;
22+
private readonly MethodInfo _where;
23+
2024
public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver)
2125
{
2226
_resolver = projectionExpressionResolver;
27+
_select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
28+
.Where(x => x.Name == nameof(Queryable.Select))
29+
.First(x =>
30+
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
31+
.GetGenericArguments().First() // Func<T, Ret>
32+
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
33+
);
34+
_where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
35+
.Where(x => x.Name == nameof(Queryable.Where))
36+
.First(x =>
37+
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
38+
.GetGenericArguments().First() // Func<T, Ret>
39+
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
40+
);
2341
}
2442

2543
bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression)
@@ -45,6 +63,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
4563

4664
if (_disableRootRewrite)
4765
{
66+
// This boolean is enabled when a "Select" is encountered
4867
return ret;
4968
}
5069

@@ -53,10 +72,62 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
5372
// Probably a First() or ToList()
5473
case MethodCallExpression { Arguments.Count: > 0, Object: null } call when _entityType != null:
5574
{
75+
// if return type != IQueryable {
76+
// if return type is IEnuberable {
77+
// // case of a ToList()
78+
// return (ret.arg[0]).Select(...).ToList() or the other method
79+
// } else {
80+
// // case of a Max()
81+
// return ret;
82+
// }
83+
// } else if retrun type == entitytype {
84+
// // case of a first()
85+
// return obj.MyMap(x => new Obj {});
86+
// }
87+
88+
89+
if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable)))
90+
{
91+
// Generic case where the return type is still a IQueryable<T>
92+
return _AddProjectableSelect(call, _entityType);
93+
}
94+
95+
if (call.Method.ReturnType == _entityType.ClrType)
96+
{
97+
// case of a .First(), .SingleAsync()
98+
if (call.Arguments.Count != 1 && true /* Add && arg.count == 1 exist */)
99+
{
100+
// .First(x => whereCondition), since we need to add a select after the last condition but
101+
// before the query become executed by EF (before the .First()), we rewrite the .First(where)
102+
// as .Where(where).Select(x => ...).First()
103+
104+
var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments);
105+
// The call instance is based on the wrong polymorphied method.
106+
var first = call.Method.DeclaringType?.GetMethods()
107+
.FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1);
108+
if (first == null)
109+
{
110+
// Unknown case that should not happen.
111+
return call;
112+
}
113+
114+
return Expression.Call(null, first.MakeGenericMethod(_entityType.ClrType), _AddProjectableSelect(where, _entityType));
115+
}
116+
117+
// .First() without arguments is the same case as bellow so we let it fallthrough
118+
}
119+
else if (!call.Method.ReturnType.IsAssignableTo(typeof(IEnumerable)))
120+
{
121+
// case of something like a .Max(), .Sum()
122+
return call;
123+
}
124+
125+
// return type is IEnumerable<EntityType> or EntityType (in case of fallthrough from a .First())
126+
127+
// case of something like .ToList(), .ToArrayAsync()
56128
var self = _AddProjectableSelect(call.Arguments.First(), _entityType);
57129
return call.Update(null, call.Arguments.Skip(1).Prepend(self));
58130
}
59-
// Probably a foreach call
60131
case QueryRootExpression root:
61132
return _AddProjectableSelect(root, root.EntityType);
62133
default:
@@ -170,14 +241,7 @@ private Expression _AddProjectableSelect(Expression node, IEntityType entityType
170241
.Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField"));
171242

172243
// Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty })
173-
var select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
174-
.Where(x => x.Name == nameof(Queryable.Select))
175-
.First(x =>
176-
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
177-
.GetGenericArguments().First() // Func<T, Ret>
178-
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
179-
)
180-
.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
244+
var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
181245
var xParam = Expression.Parameter(entityType.ClrType);
182246
return Expression.Call(
183247
null,

0 commit comments

Comments
 (0)