diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs index 06f9e9b9201..4f4d40306c4 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs @@ -93,7 +93,9 @@ public class SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) nameof(Math.Sign) when arguments is [var arg] && (arg.Type == typeof(decimal) || arg.Type == typeof(double) || arg.Type == typeof(float) || arg.Type == typeof(int) || arg.Type == typeof(long) || arg.Type == typeof(sbyte) || arg.Type == typeof(short)) - => TranslateFunction("SIGN", arg, nullTypeMapping: true), + // T-SQL SIGN returns the same type as its input, but Math.Sign always returns int; + // wrap with a CAST to avoid InvalidCastException at materialization time. + => TranslateSign(arg), nameof(double.DegreesToRadians) when arguments is [var arg] && (arg.Type == typeof(double) || arg.Type == typeof(float)) => TranslateFunction("RADIANS", arg), @@ -115,7 +117,27 @@ public class SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) _ => null }; - SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg, bool nullTypeMapping = false) + SqlExpression TranslateSign(SqlExpression arg) + { + var typeMapping = ExpressionExtensions.InferTypeMapping(arg); + + // Use arg.Type (not method.ReturnType) so that the function's type mapping reflects the actual + // T-SQL SIGN return type (same as input). This prevents SqlExpressionSimplifyingExpressionVisitor + // from stripping the CAST as a no-op when both store types would otherwise be "int". + var signFunction = sqlExpressionFactory.Function( + "SIGN", + [sqlExpressionFactory.ApplyTypeMapping(arg, typeMapping)], + nullable: true, + argumentsPropagateNullability: Statics.TrueArrays[1], + arg.Type, + typeMapping); + + return arg.Type == typeof(int) + ? signFunction + : sqlExpressionFactory.Convert(signFunction, typeof(int)); + } + + SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg) { var typeMapping = ExpressionExtensions.InferTypeMapping(arg); return sqlExpressionFactory.Function( @@ -124,7 +146,7 @@ SqlExpression TranslateFunction(string sqlFunctionName, SqlExpression arg, bool nullable: true, argumentsPropagateNullability: Statics.TrueArrays[1], method.ReturnType, - nullTypeMapping ? null : typeMapping); + typeMapping); } SqlExpression TranslateBinaryFunction(string sqlFunctionName, SqlExpression arg1, SqlExpression arg2) diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MathTranslationsCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MathTranslationsCosmosTest.cs index f65838697ca..0da472decac 100644 --- a/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MathTranslationsCosmosTest.cs +++ b/test/EFCore.Cosmos.FunctionalTests/Query/Translations/MathTranslationsCosmosTest.cs @@ -444,6 +444,45 @@ public override async Task Sign() SELECT VALUE c FROM root c WHERE (SIGN(c["Double"]) > 0) +""", + // + """ +SELECT VALUE SIGN(c["Double"]) +FROM root c +"""); + } + + public override async Task Sign_decimal() + { + await base.Sign_decimal(); + + AssertSql( + """ +SELECT VALUE c +FROM root c +WHERE (SIGN(c["Decimal"]) > 0) +""", + // + """ +SELECT VALUE SIGN(c["Decimal"]) +FROM root c +"""); + } + + public override async Task Sign_int() + { + await base.Sign_int(); + + AssertSql( + """ +SELECT VALUE c +FROM root c +WHERE (SIGN(c["Int"]) > 0) +""", + // + """ +SELECT VALUE SIGN(c["Int"]) +FROM root c """); } @@ -456,6 +495,11 @@ public override async Task Sign_float() SELECT VALUE c FROM root c WHERE (SIGN(c["Float"]) > 0) +""", + // + """ +SELECT VALUE SIGN(c["Float"]) +FROM root c """); } diff --git a/test/EFCore.Specification.Tests/Query/Translations/MathTranslationsTestBase.cs b/test/EFCore.Specification.Tests/Query/Translations/MathTranslationsTestBase.cs index f6b9e993b27..f9734e32594 100644 --- a/test/EFCore.Specification.Tests/Query/Translations/MathTranslationsTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/Translations/MathTranslationsTestBase.cs @@ -186,12 +186,36 @@ public virtual Task Sqrt_float() => AssertQuery(ss => ss.Set().Where(b => b.Float > 0 && MathF.Sqrt(b.Float) > 0)); [ConditionalFact] - public virtual Task Sign() - => AssertQuery(ss => ss.Set().Where(b => Math.Sign(b.Double) > 0)); + public virtual async Task Sign() + { + await AssertQuery(ss => ss.Set().Where(b => Math.Sign(b.Double) > 0)); + + await AssertQueryScalar(ss => ss.Set().Select(b => Math.Sign(b.Double))); + } + + [ConditionalFact] + public virtual async Task Sign_decimal() + { + await AssertQuery(ss => ss.Set().Where(b => Math.Sign(b.Decimal) > 0)); + + await AssertQueryScalar(ss => ss.Set().Select(b => Math.Sign(b.Decimal))); + } [ConditionalFact] - public virtual Task Sign_float() - => AssertQuery(ss => ss.Set().Where(b => MathF.Sign(b.Float) > 0)); + public virtual async Task Sign_int() + { + await AssertQuery(ss => ss.Set().Where(b => Math.Sign(b.Int) > 0)); + + await AssertQueryScalar(ss => ss.Set().Select(b => Math.Sign(b.Int))); + } + + [ConditionalFact] + public virtual async Task Sign_float() + { + await AssertQuery(ss => ss.Set().Where(b => MathF.Sign(b.Float) > 0)); + + await AssertQueryScalar(ss => ss.Set().Select(b => MathF.Sign(b.Float))); + } [ConditionalFact] public virtual Task Max() diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MathTranslationsSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MathTranslationsSqlServerTest.cs index 3a848d36b84..150ba2c1d5d 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MathTranslationsSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/Translations/MathTranslationsSqlServerTest.cs @@ -461,7 +461,46 @@ public override async Task Sign() """ SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] FROM [BasicTypesEntities] AS [b] -WHERE SIGN([b].[Double]) > 0 +WHERE CAST(SIGN([b].[Double]) AS int) > 0 +""", + // + """ +SELECT CAST(SIGN([b].[Double]) AS int) +FROM [BasicTypesEntities] AS [b] +"""); + } + + public override async Task Sign_decimal() + { + await base.Sign_decimal(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] +FROM [BasicTypesEntities] AS [b] +WHERE CAST(SIGN([b].[Decimal]) AS int) > 0 +""", + // + """ +SELECT CAST(SIGN([b].[Decimal]) AS int) +FROM [BasicTypesEntities] AS [b] +"""); + } + + public override async Task Sign_int() + { + await base.Sign_int(); + + AssertSql( + """ +SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] +FROM [BasicTypesEntities] AS [b] +WHERE SIGN([b].[Int]) > 0 +""", + // + """ +SELECT SIGN([b].[Int]) +FROM [BasicTypesEntities] AS [b] """); } @@ -473,7 +512,12 @@ public override async Task Sign_float() """ SELECT [b].[Id], [b].[Bool], [b].[Byte], [b].[ByteArray], [b].[DateOnly], [b].[DateTime], [b].[DateTimeOffset], [b].[Decimal], [b].[Double], [b].[Enum], [b].[FlagsEnum], [b].[Float], [b].[Guid], [b].[Int], [b].[Long], [b].[Short], [b].[String], [b].[TimeOnly], [b].[TimeSpan] FROM [BasicTypesEntities] AS [b] -WHERE SIGN([b].[Float]) > 0 +WHERE CAST(SIGN([b].[Float]) AS int) > 0 +""", + // + """ +SELECT CAST(SIGN([b].[Float]) AS int) +FROM [BasicTypesEntities] AS [b] """); } diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MathTranslationsSqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MathTranslationsSqliteTest.cs index 1c96938b8de..0fde43e7ad0 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MathTranslationsSqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/Translations/MathTranslationsSqliteTest.cs @@ -416,9 +416,20 @@ public override async Task Sign() SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan" FROM "BasicTypesEntities" AS "b" WHERE sign("b"."Double") > 0.0 +""", + // + """ +SELECT sign("b"."Double") +FROM "BasicTypesEntities" AS "b" """); } + public override async Task Sign_decimal() + => await AssertTranslationFailed(() => base.Sign_decimal()); // SQLite decimal support + + public override async Task Sign_int() + => await AssertTranslationFailed(() => base.Sign_int()); // SQLite int support + public override async Task Sign_float() { await base.Sign_float(); @@ -428,6 +439,11 @@ public override async Task Sign_float() SELECT "b"."Id", "b"."Bool", "b"."Byte", "b"."ByteArray", "b"."DateOnly", "b"."DateTime", "b"."DateTimeOffset", "b"."Decimal", "b"."Double", "b"."Enum", "b"."FlagsEnum", "b"."Float", "b"."Guid", "b"."Int", "b"."Long", "b"."Short", "b"."String", "b"."TimeOnly", "b"."TimeSpan" FROM "BasicTypesEntities" AS "b" WHERE sign("b"."Float") > 0 +""", + // + """ +SELECT sign("b"."Float") +FROM "BasicTypesEntities" AS "b" """); }