1- using System ;
2- using System . Collections . Generic ;
1+ using System . Collections . Generic ;
32using System . Diagnostics . CodeAnalysis ;
43using System . Linq ;
54using System . Linq . Expressions ;
65using System . Reflection ;
76using EntityFrameworkCore . Projectables . Extensions ;
7+ using Microsoft . EntityFrameworkCore . Metadata ;
88using Microsoft . EntityFrameworkCore . Query ;
99
1010namespace EntityFrameworkCore . Projectables . Services
@@ -13,11 +13,15 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
1313 {
1414 readonly IProjectionExpressionResolver _resolver ;
1515 readonly ExpressionArgumentReplacer _expressionArgumentReplacer = new ( ) ;
16+ readonly QueryRootReplacer _queryRootReplacer ;
1617 readonly Dictionary < MemberInfo , LambdaExpression ? > _projectableMemberCache = new ( ) ;
18+ private bool _disableRootRewrite = false ;
19+ private IEntityType ? _entityType ;
1720
1821 public ProjectableExpressionReplacer ( IProjectionExpressionResolver projectionExpressionResolver )
1922 {
2023 _resolver = projectionExpressionResolver ;
24+ _queryRootReplacer = new ( _resolver ) ;
2125 }
2226
2327 bool TryGetReflectedExpression ( MemberInfo memberInfo , [ NotNullWhen ( true ) ] out LambdaExpression ? reflectedExpression )
@@ -36,11 +40,42 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
3640 return reflectedExpression is not null ;
3741 }
3842
43+ [ return : NotNullIfNotNull ( nameof ( node ) ) ]
44+ public override Expression ? Visit ( Expression ? node )
45+ {
46+ var ret = base . Visit ( node ) ;
47+
48+ if ( _disableRootRewrite )
49+ {
50+ return ret ;
51+ }
52+
53+ switch ( node )
54+ {
55+ // Probably a First() or ToList()
56+ case MethodCallExpression { Arguments . Count : > 0 } call when _entityType != null :
57+ {
58+ var self = _AddProjectableSelect ( call . Arguments . First ( ) , _entityType ) ;
59+ return call . Update ( null , call . Arguments . Skip ( 1 ) . Prepend ( self ) ) ;
60+ }
61+ // Probably a foreach call
62+ case QueryRootExpression root :
63+ return _AddProjectableSelect ( root , root . EntityType ) ;
64+ default :
65+ return ret ;
66+ }
67+ }
68+
3969 protected override Expression VisitMethodCall ( MethodCallExpression node )
4070 {
4171 // Get the overriding methodInfo based on te type of the received of this expression
4272 var methodInfo = node . Object ? . Type . GetConcreteMethod ( node . Method ) ?? node . Method ;
4373
74+ if ( methodInfo . Name == nameof ( Queryable . Select ) )
75+ {
76+ _disableRootRewrite = true ;
77+ }
78+
4479 if ( TryGetReflectedExpression ( methodInfo , out var reflectedExpression ) )
4580 {
4681 for ( var parameterIndex = 0 ; parameterIndex < reflectedExpression . Parameters . Count ; parameterIndex ++ )
@@ -110,12 +145,16 @@ PropertyInfo property when nodeExpression is not null
110145
111146 protected override Expression VisitExtension ( Expression node )
112147 {
113- if ( node is not QueryRootExpression root )
148+ if ( node is QueryRootExpression root )
114149 {
115- return node ;
150+ _entityType = root . EntityType ;
116151 }
152+ return base . VisitExtension ( node ) ;
153+ }
117154
118- var projectableProperties = root . EntityType . ClrType . GetProperties ( )
155+ private Expression _AddProjectableSelect ( Expression node , IEntityType entityType )
156+ {
157+ var projectableProperties = entityType . ClrType . GetProperties ( )
119158 . Where ( x => x . IsDefined ( typeof ( ProjectableAttribute ) , false ) )
120159 . Where ( x => x . CanWrite )
121160 . ToList ( ) ;
@@ -125,7 +164,7 @@ protected override Expression VisitExtension(Expression node)
125164 return node ;
126165 }
127166
128- var properties = root . EntityType . GetProperties ( )
167+ var properties = entityType . GetProperties ( )
129168 . Where ( x => ! x . IsShadowProperty ( ) )
130169 . Select ( x => x . GetMemberInfo ( false , false ) )
131170 // Remove projectable properties from the ef properties. Since properties returned here for auto
@@ -140,15 +179,15 @@ protected override Expression VisitExtension(Expression node)
140179 . GetGenericArguments ( ) . First ( ) // Func<T, Ret>
141180 . GetGenericArguments ( ) . Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
142181 )
143- . MakeGenericMethod ( root . EntityType . ClrType , root . EntityType . ClrType ) ;
144- var xParam = Expression . Parameter ( root . EntityType . ClrType ) ;
182+ . MakeGenericMethod ( entityType . ClrType , entityType . ClrType ) ;
183+ var xParam = Expression . Parameter ( entityType . ClrType ) ;
145184 return Expression . Call (
146185 null ,
147186 select ,
148187 node ,
149188 Expression . Lambda (
150189 Expression . MemberInit (
151- Expression . New ( root . EntityType . ClrType ) ,
190+ Expression . New ( entityType . ClrType ) ,
152191 properties . Select ( x => Expression . Bind ( x , Expression . MakeMemberAccess ( xParam , x ) ) )
153192 . Concat ( projectableProperties
154193 . Select ( x => Expression . Bind ( x , _ReplaceParam ( _resolver . FindGeneratedExpression ( x ) , xParam ) ) )
0 commit comments