1- using System . Collections . Generic ;
1+ using System . Collections ;
2+ using System . Collections . Generic ;
23using System . Diagnostics . CodeAnalysis ;
34using System . Linq ;
45using System . Linq . Expressions ;
@@ -17,9 +18,26 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
1718 private bool _disableRootRewrite ;
1819 private IEntityType ? _entityType ;
1920
21+ private readonly MethodInfo _select ;
22+ private readonly MethodInfo _where ;
23+
2024 public ProjectableExpressionReplacer ( IProjectionExpressionResolver projectionExpressionResolver )
2125 {
2226 _resolver = projectionExpressionResolver ;
27+ _select = typeof ( Queryable ) . GetMethods ( BindingFlags . Static | BindingFlags . Public )
28+ . Where ( x => x . Name == nameof ( Queryable . Select ) )
29+ . First ( x =>
30+ x . GetParameters ( ) . Last ( ) . ParameterType // Expression<Func<T, Ret>>
31+ . GetGenericArguments ( ) . First ( ) // Func<T, Ret>
32+ . GetGenericArguments ( ) . Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
33+ ) ;
34+ _where = typeof ( Queryable ) . GetMethods ( BindingFlags . Static | BindingFlags . Public )
35+ . Where ( x => x . Name == nameof ( Queryable . Where ) )
36+ . First ( x =>
37+ x . GetParameters ( ) . Last ( ) . ParameterType // Expression<Func<T, Ret>>
38+ . GetGenericArguments ( ) . First ( ) // Func<T, Ret>
39+ . GetGenericArguments ( ) . Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
40+ ) ;
2341 }
2442
2543 bool TryGetReflectedExpression ( MemberInfo memberInfo , [ NotNullWhen ( true ) ] out LambdaExpression ? reflectedExpression )
@@ -45,6 +63,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
4563
4664 if ( _disableRootRewrite )
4765 {
66+ // This boolean is enabled when a "Select" is encountered
4867 return ret ;
4968 }
5069
@@ -53,10 +72,62 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
5372 // Probably a First() or ToList()
5473 case MethodCallExpression { Arguments . Count : > 0 , Object : null } call when _entityType != null :
5574 {
75+ // if return type != IQueryable {
76+ // if return type is IEnuberable {
77+ // // case of a ToList()
78+ // return (ret.arg[0]).Select(...).ToList() or the other method
79+ // } else {
80+ // // case of a Max()
81+ // return ret;
82+ // }
83+ // } else if retrun type == entitytype {
84+ // // case of a first()
85+ // return obj.MyMap(x => new Obj {});
86+ // }
87+
88+
89+ if ( call . Method . ReturnType . IsAssignableTo ( typeof ( IQueryable ) ) )
90+ {
91+ // Generic case where the return type is still a IQueryable<T>
92+ return _AddProjectableSelect ( call , _entityType ) ;
93+ }
94+
95+ if ( call . Method . ReturnType == _entityType . ClrType )
96+ {
97+ // case of a .First(), .SingleAsync()
98+ if ( call . Arguments . Count != 1 && true /* Add && arg.count == 1 exist */ )
99+ {
100+ // .First(x => whereCondition), since we need to add a select after the last condition but
101+ // before the query become executed by EF (before the .First()), we rewrite the .First(where)
102+ // as .Where(where).Select(x => ...).First()
103+
104+ var where = Expression . Call ( null , _where . MakeGenericMethod ( _entityType . ClrType ) , call . Arguments ) ;
105+ // The call instance is based on the wrong polymorphied method.
106+ var first = call . Method . DeclaringType ? . GetMethods ( )
107+ . FirstOrDefault ( x => x . Name == call . Method . Name && x . GetParameters ( ) . Length == 1 ) ;
108+ if ( first == null )
109+ {
110+ // Unknown case that should not happen.
111+ return call ;
112+ }
113+
114+ return Expression . Call ( null , first . MakeGenericMethod ( _entityType . ClrType ) , _AddProjectableSelect ( where , _entityType ) ) ;
115+ }
116+
117+ // .First() without arguments is the same case as bellow so we let it fallthrough
118+ }
119+ else if ( ! call . Method . ReturnType . IsAssignableTo ( typeof ( IEnumerable ) ) )
120+ {
121+ // case of something like a .Max(), .Sum()
122+ return call ;
123+ }
124+
125+ // return type is IEnumerable<EntityType> or EntityType (in case of fallthrough from a .First())
126+
127+ // case of something like .ToList(), .ToArrayAsync()
56128 var self = _AddProjectableSelect ( call . Arguments . First ( ) , _entityType ) ;
57129 return call . Update ( null , call . Arguments . Skip ( 1 ) . Prepend ( self ) ) ;
58130 }
59- // Probably a foreach call
60131 case QueryRootExpression root :
61132 return _AddProjectableSelect ( root , root . EntityType ) ;
62133 default :
@@ -170,14 +241,7 @@ private Expression _AddProjectableSelect(Expression node, IEntityType entityType
170241 . Where ( x => projectableProperties . All ( y => x . Name != y . Name && x . Name != $ "<{ y . Name } >k__BackingField") ) ;
171242
172243 // Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty })
173- var select = typeof ( Queryable ) . GetMethods ( BindingFlags . Static | BindingFlags . Public )
174- . Where ( x => x . Name == nameof ( Queryable . Select ) )
175- . First ( x =>
176- x . GetParameters ( ) . Last ( ) . ParameterType // Expression<Func<T, Ret>>
177- . GetGenericArguments ( ) . First ( ) // Func<T, Ret>
178- . GetGenericArguments ( ) . Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
179- )
180- . MakeGenericMethod ( entityType . ClrType , entityType . ClrType ) ;
244+ var select = _select . MakeGenericMethod ( entityType . ClrType , entityType . ClrType ) ;
181245 var xParam = Expression . Parameter ( entityType . ClrType ) ;
182246 return Expression . Call (
183247 null ,
0 commit comments