Skip to content

Commit 1771c41

Browse files
authored
Merge pull request #57 from rhodon-jargon/interface-support
Add support for interfaces in generic methods
2 parents eb02f0a + 3b98b1e commit 1771c41

7 files changed

Lines changed: 155 additions & 66 deletions

src/EntityFrameworkCore.Projectables/Extensions/TypeExtensions.cs

Lines changed: 93 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,89 +25,128 @@ public static IEnumerable<Type> GetNestedTypePath(this Type type)
2525
yield return type;
2626
}
2727

28-
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
28+
private static bool CanHaveOverridingMethod(this Type derivedType, MethodInfo methodInfo)
2929
{
3030
// We only need to search for virtual instance methods who are not declared on the derivedType
3131
if (derivedType == methodInfo.DeclaringType || methodInfo.IsStatic || !methodInfo.IsVirtual)
3232
{
33-
return methodInfo;
33+
return false;
3434
}
3535

3636
if (!derivedType.IsAssignableTo(methodInfo.DeclaringType))
3737
{
3838
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
3939
}
4040

41-
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
41+
return true;
42+
}
4243

43-
foreach (var derivedMethodInfo in derivedMethods)
44+
private static int? GetOverridingMethodIndex(this MethodInfo methodInfo, MethodInfo[]? allDerivedMethods)
45+
{
46+
if (allDerivedMethods is { Length: > 0 })
4447
{
45-
if (HasCompatibleSignature(methodInfo, derivedMethodInfo))
48+
var baseDefinition = methodInfo.GetBaseDefinition();
49+
for (var i = 0; i < allDerivedMethods.Length; i++)
4650
{
47-
return derivedMethodInfo;
51+
var derivedMethodInfo = allDerivedMethods[i];
52+
if (derivedMethodInfo.GetBaseDefinition() == baseDefinition)
53+
{
54+
return i;
55+
}
4856
}
4957
}
5058

51-
// No derived methods were found. Return the original methodInfo
52-
return methodInfo;
59+
return null;
60+
}
5361

54-
static bool HasCompatibleSignature(MethodInfo methodInfo, MethodInfo derivedMethodInfo)
62+
public static MethodInfo GetOverridingMethod(this Type derivedType, MethodInfo methodInfo)
63+
{
64+
if (!derivedType.CanHaveOverridingMethod(methodInfo))
5565
{
56-
if (methodInfo.Name != derivedMethodInfo.Name)
57-
{
58-
return false;
59-
}
66+
return methodInfo;
67+
}
68+
69+
var derivedMethods = derivedType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
6070

61-
var methodParameters = methodInfo.GetParameters();
71+
return methodInfo.GetOverridingMethodIndex(derivedMethods) is { } i
72+
? derivedMethods[i]
73+
// No derived methods were found. Return the original methodInfo
74+
: methodInfo;
75+
}
6276

63-
var derivedMethodParameters = derivedMethodInfo.GetParameters();
64-
if (methodParameters.Length != derivedMethodParameters.Length)
65-
{
66-
return false;
67-
}
77+
public static PropertyInfo GetOverridingProperty(this Type derivedType, PropertyInfo propertyInfo)
78+
{
79+
var accessor = propertyInfo.GetAccessors()[0];
6880

69-
// Match all parameters
70-
for (var parameterIndex = 0; parameterIndex < methodParameters.Length; parameterIndex++)
71-
{
72-
var parameter = methodParameters[parameterIndex];
73-
var derivedParameter = derivedMethodParameters[parameterIndex];
81+
if (!derivedType.CanHaveOverridingMethod(accessor))
82+
{
83+
return propertyInfo;
84+
}
85+
86+
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
87+
var derivedPropertyMethods = derivedProperties
88+
.Select((Func<PropertyInfo, MethodInfo?>)
89+
(propertyInfo.GetMethod == accessor ? p => p.GetMethod : p => p.SetMethod))
90+
.OfType<MethodInfo>().ToArray();
91+
92+
return accessor.GetOverridingMethodIndex(derivedPropertyMethods) is { } i
93+
? derivedProperties[i]
94+
// No derived methods were found. Return the original methodInfo
95+
: propertyInfo;
96+
}
7497

75-
if (parameter.ParameterType.IsGenericParameter)
76-
{
77-
if (!derivedParameter.ParameterType.IsGenericParameter)
78-
{
79-
return false;
80-
}
81-
}
82-
else
83-
{
84-
if (parameter.ParameterType != derivedParameter.ParameterType)
85-
{
86-
return false;
87-
}
88-
}
89-
}
98+
public static MethodInfo GetImplementingMethod(this Type derivedType, MethodInfo methodInfo)
99+
{
100+
var interfaceType = methodInfo.DeclaringType;
101+
// We only need to search for interface methods
102+
if (interfaceType?.IsInterface != true || derivedType.IsInterface || methodInfo.IsStatic || !methodInfo.IsVirtual)
103+
{
104+
return methodInfo;
105+
}
90106

91-
// Match the number of generic type arguments
92-
if (methodInfo.IsGenericMethodDefinition)
93-
{
94-
var methodGenericParameters = methodInfo.GetGenericArguments();
107+
if (!derivedType.IsAssignableTo(interfaceType))
108+
{
109+
throw new ArgumentException("MethodInfo needs to be declared on the type hierarchy", nameof(methodInfo));
110+
}
95111

96-
if (!derivedMethodInfo.IsGenericMethodDefinition)
97-
{
98-
return false;
99-
}
112+
var interfaceMap = derivedType.GetInterfaceMap(interfaceType);
113+
for (var i = 0; i < interfaceMap.InterfaceMethods.Length; i++)
114+
{
115+
if (interfaceMap.InterfaceMethods[i] == methodInfo)
116+
{
117+
return interfaceMap.TargetMethods[i];
118+
}
119+
}
100120

101-
var derivedGenericArguments = derivedMethodInfo.GetGenericArguments();
121+
throw new ApplicationException(
122+
$"The interface map for {derivedType} doesn't contain the implemented method for {methodInfo}!");
123+
}
102124

103-
if (methodGenericParameters.Length != derivedGenericArguments.Length)
104-
{
105-
return false;
106-
}
107-
}
125+
public static PropertyInfo GetImplementingProperty(this Type derivedType, PropertyInfo propertyInfo)
126+
{
127+
var accessor = propertyInfo.GetAccessors()[0];
108128

109-
return true;
129+
var implementingAccessor = derivedType.GetImplementingMethod(accessor);
130+
if (implementingAccessor == accessor)
131+
{
132+
return propertyInfo;
110133
}
134+
135+
var derivedProperties = derivedType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);
136+
137+
return derivedProperties.First(propertyInfo.GetMethod == accessor
138+
? p => p.GetMethod == implementingAccessor
139+
: p => p.SetMethod == implementingAccessor);
111140
}
141+
142+
public static MethodInfo GetConcreteMethod(this Type derivedType, MethodInfo methodInfo)
143+
=> methodInfo.DeclaringType?.IsInterface == true
144+
? derivedType.GetImplementingMethod(methodInfo)
145+
: derivedType.GetOverridingMethod(methodInfo);
146+
147+
public static PropertyInfo GetConcreteProperty(this Type derivedType, PropertyInfo propertyInfo)
148+
=> propertyInfo.DeclaringType?.IsInterface == true
149+
? derivedType.GetImplementingProperty(propertyInfo)
150+
: derivedType.GetOverridingProperty(propertyInfo);
112151
}
113152
}

src/EntityFrameworkCore.Projectables/Services/ProjectableExpressionReplacer.cs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
4242
protected override Expression VisitMethodCall(MethodCallExpression node)
4343
{
4444
// Get the overriding methodInfo based on te type of the received of this expression
45-
var methodInfo = node.Object?.Type.GetOverridingMethod(node.Method) ?? node.Method;
45+
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
4646

4747
if (TryGetReflectedExpression(methodInfo, out var reflectedExpression))
4848
{
@@ -74,16 +74,25 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
7474

7575
protected override Expression VisitMember(MemberExpression node)
7676
{
77-
var nodeMember = node.Expression switch {
78-
{ Type: { } } => node.Expression.Type.GetMember(node.Member.Name, node.Member.MemberType, BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)[0],
77+
var nodeExpression = node.Expression switch {
78+
UnaryExpression { NodeType: ExpressionType.Convert, Type: { IsInterface: true } type, Operand: { } operand }
79+
when type.IsAssignableFrom(operand.Type)
80+
// This is an interface member. Operand contains the concrete (or at least more concrete) expression,
81+
// from which we can try to find the concrete member.
82+
=> operand,
83+
_ => node.Expression
84+
};
85+
var nodeMember = node.Member switch {
86+
PropertyInfo property when nodeExpression is not null
87+
=> nodeExpression.Type.GetConcreteProperty(property),
7988
_ => node.Member
8089
};
8190

8291
if (TryGetReflectedExpression(nodeMember, out var reflectedExpression))
8392
{
84-
if (node.Expression is not null)
93+
if (nodeExpression is not null)
8594
{
86-
_expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], node.Expression);
95+
_expressionArgumentReplacer.ParameterArgumentMapping.Add(reflectedExpression.Parameters[0], nodeExpression);
8796
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
8897
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
8998

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
SELECT 2
2+
FROM [Concrete] AS [c]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
SELECT 2
2+
FROM [Concrete] AS [c]
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
SELECT 2
2-
FROM [Concrete] AS [c]
2+
FROM [MoreConcrete] AS [m]
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
SELECT 2
2-
FROM [Concrete] AS [c]
2+
FROM [MoreConcrete] AS [m]

tests/EntityFrameworkCore.Projectables.FunctionalTests/InheritedModelTests.cs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ namespace EntityFrameworkCore.Projectables.FunctionalTests
2020
[UsesVerify]
2121
public class InheritedModelTests
2222
{
23-
public abstract class Base
23+
public interface IBase
24+
{
25+
int ComputedProperty { get; }
26+
int ComputedMethod();
27+
}
28+
29+
public abstract class Base : IBase
2430
{
2531
public int Id { get; set; }
2632

@@ -62,9 +68,9 @@ public Task ProjectOverOverriddenPropertyImplementation()
6268
[Fact]
6369
public Task ProjectOverInheritedPropertyImplementation()
6470
{
65-
using var dbContext = new SampleDbContext<Concrete>();
71+
using var dbContext = new SampleDbContext<MoreConcrete>();
6672

67-
var query = dbContext.Set<Concrete>()
73+
var query = dbContext.Set<MoreConcrete>()
6874
.Select(x => x.ComputedProperty);
6975

7076
return Verifier.Verify(query.ToQueryString());
@@ -84,12 +90,43 @@ public Task ProjectOverOverriddenMethodImplementation()
8490
[Fact]
8591
public Task ProjectOverInheritedMethodImplementation()
8692
{
87-
using var dbContext = new SampleDbContext<Concrete>();
93+
using var dbContext = new SampleDbContext<MoreConcrete>();
8894

89-
var query = dbContext.Set<Concrete>()
95+
var query = dbContext.Set<MoreConcrete>()
9096
.Select(x => x.ComputedMethod());
9197

9298
return Verifier.Verify(query.ToQueryString());
9399
}
100+
101+
[Fact]
102+
public Task ProjectOverImplementedProperty()
103+
{
104+
using var dbContext = new SampleDbContext<Concrete>();
105+
106+
var query = dbContext.Set<Concrete>().SelectComputedProperty();
107+
108+
return Verifier.Verify(query.ToQueryString());
109+
}
110+
111+
[Fact]
112+
public Task ProjectOverImplementedMethod()
113+
{
114+
using var dbContext = new SampleDbContext<Concrete>();
115+
116+
var query = dbContext.Set<Concrete>().SelectComputedMethod();
117+
118+
return Verifier.Verify(query.ToQueryString());
119+
}
120+
}
121+
122+
public static class ModelExtensions
123+
{
124+
public static IQueryable<int> SelectComputedProperty<TConcrete>(this IQueryable<TConcrete> concretes)
125+
where TConcrete : InheritedModelTests.IBase
126+
=> concretes.Select(x => x.ComputedProperty);
127+
128+
public static IQueryable<int> SelectComputedMethod<TConcrete>(this IQueryable<TConcrete> concretes)
129+
where TConcrete : InheritedModelTests.IBase
130+
=> concretes.Select(x => x.ComputedMethod());
94131
}
95132
}

0 commit comments

Comments
 (0)