From 3ccdef962e0fa17b6b4df49359edf6dff6d9fb59 Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Fri, 5 Jun 2026 09:59:51 -0700 Subject: [PATCH 01/12] DAB Phase 4 Implementation --- schemas/dab.draft.schema.json | 61 ++ src/Cli/Commands/AddOptions.cs | 32 +- src/Cli/Commands/ConfigureOptions.cs | 10 + src/Cli/Commands/EntityOptions.cs | 35 ++ src/Cli/Commands/UpdateOptions.cs | 32 +- src/Cli/ConfigGenerator.cs | 59 +- src/Cli/Utils.cs | 88 +++ src/Config/ObjectModel/Entity.cs | 6 +- .../EntitySemanticSearchOptions.cs | 39 ++ .../Configurations/RuntimeConfigValidator.cs | 147 +++++ .../RestRequestContexts/RestRequestContext.cs | 25 + src/Core/Models/SemanticSearchCandidate.cs | 13 + src/Core/Models/SemanticSearchConstants.cs | 15 + src/Core/Parsers/RequestParser.cs | 33 ++ .../Resolvers/Factories/QueryEngineFactory.cs | 7 +- .../Sql Query Structures/SqlQueryStructure.cs | 108 ++++ src/Core/Resolvers/SqlQueryEngine.cs | 257 ++++++++- .../Services/OpenAPI/OpenApiDocumentor.cs | 38 +- src/Core/Services/RequestValidator.cs | 21 +- src/Core/Services/RestService.cs | 72 +++ .../SemanticSearch/ISemanticSearchService.cs | 18 + .../NoOpSemanticSearchService.cs | 30 + .../Queries/InputTypeBuilder.cs | 11 + .../Queries/QueryBuilder.cs | 16 +- .../Sql/SchemaConverter.cs | 14 + .../MultipleMutationBuilderTests.cs | 2 + .../GraphQLBuilder/QueryBuilderTests.cs | 48 ++ .../UnitTests/ConfigValidationUnitTests.cs | 112 ++++ .../UnitTests/RequestParserUnitTests.cs | 8 +- .../UnitTests/RequestValidatorUnitTests.cs | 53 ++ .../UnitTests/SqlQueryExecutorUnitTests.cs | 8 +- .../RedisSemanticSearchService.cs | 538 ++++++++++++++++++ src/Service/Startup.cs | 4 + 33 files changed, 1920 insertions(+), 40 deletions(-) create mode 100644 src/Config/ObjectModel/EntitySemanticSearchOptions.cs create mode 100644 src/Core/Models/SemanticSearchCandidate.cs create mode 100644 src/Core/Models/SemanticSearchConstants.cs create mode 100644 src/Core/Services/SemanticSearch/ISemanticSearchService.cs create mode 100644 src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs create mode 100644 src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 84f8e5cfbd..185d1b3903 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -1245,6 +1245,67 @@ } } }, + "semantic-search": { + "type": "object", + "description": "Semantic search configuration for this entity.", + "additionalProperties": false, + "properties": { + "enabled": { + "$ref": "#/$defs/boolean-or-string", + "description": "Enables semantic search for this entity.", + "default": false + }, + "redis-index-name": { + "type": "string", + "description": "Name of the Redis vector index used for semantic search for this entity." + }, + "redis-index-type": { + "type": "string", + "description": "Redis index storage type used by the semantic search index.", + "enum": ["hash", "json"], + "default": "hash" + }, + "redis-index-multiplier": { + "type": "integer", + "description": "Multiplier applied to requested result count when querying Redis.", + "default": 2, + "minimum": 1, + "maximum": 10 + }, + "similarity-threshold": { + "type": "number", + "description": "Minimum Redis similarity value required for a semantic match.", + "default": 0.8, + "minimum": 0.0, + "maximum": 1.0 + }, + "input-description": { + "type": "string", + "description": "Description surfaced in API metadata for semantic search input.", + "default": "Natural language value used for semantic search." + }, + "output-description": { + "type": "string", + "description": "Description surfaced in API metadata for semantic distance output.", + "default": "Semantic distance score returned by semantic search." + } + }, + "allOf": [ + { + "if": { + "properties": { + "enabled": { + "const": true + } + }, + "required": ["enabled"] + }, + "then": { + "required": ["redis-index-name"] + } + } + ] + }, "mcp": { "oneOf": [ { diff --git a/src/Cli/Commands/AddOptions.cs b/src/Cli/Commands/AddOptions.cs index e7e378d94b..6fee8a834a 100644 --- a/src/Cli/Commands/AddOptions.cs +++ b/src/Cli/Commands/AddOptions.cs @@ -34,15 +34,22 @@ public AddOptions( string? policyDatabase, string? cacheEnabled, string? cacheTtl, - string? description, - IEnumerable? parametersNameCollection, - IEnumerable? parametersDescriptionCollection, - IEnumerable? parametersRequiredCollection, - IEnumerable? parametersDefaultCollection, - IEnumerable? fieldsNameCollection, - IEnumerable? fieldsAliasCollection, - IEnumerable? fieldsDescriptionCollection, - IEnumerable? fieldsPrimaryKeyCollection, + string? semanticSearchEnabled = null, + string? semanticSearchRedisIndexName = null, + string? semanticSearchRedisIndexType = null, + string? semanticSearchRedisIndexMultiplier = null, + string? semanticSearchSimilarityThreshold = null, + string? semanticSearchInputDescription = null, + string? semanticSearchOutputDescription = null, + string? description = null, + IEnumerable? parametersNameCollection = null, + IEnumerable? parametersDescriptionCollection = null, + IEnumerable? parametersRequiredCollection = null, + IEnumerable? parametersDefaultCollection = null, + IEnumerable? fieldsNameCollection = null, + IEnumerable? fieldsAliasCollection = null, + IEnumerable? fieldsDescriptionCollection = null, + IEnumerable? fieldsPrimaryKeyCollection = null, string? mcpDmlTools = null, string? mcpCustomTool = null, string? config = null @@ -62,6 +69,13 @@ public AddOptions( policyDatabase, cacheEnabled, cacheTtl, + semanticSearchEnabled, + semanticSearchRedisIndexName, + semanticSearchRedisIndexType, + semanticSearchRedisIndexMultiplier, + semanticSearchSimilarityThreshold, + semanticSearchInputDescription, + semanticSearchOutputDescription, description, parametersNameCollection, parametersDescriptionCollection, diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index be076d7983..8fd4d5bc42 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -53,6 +53,8 @@ public ConfigureOptions( int? runtimeMcpDmlToolsAggregateRecordsQueryTimeout = null, bool? runtimeCacheEnabled = null, int? runtimeCacheTtl = null, + string? runtimeCacheLevel2Provider = null, + string? runtimeCacheLevel2ConnectionString = null, CompressionLevel? runtimeCompressionLevel = null, HostMode? runtimeHostMode = null, IEnumerable? runtimeHostCorsOrigins = null, @@ -117,6 +119,8 @@ public ConfigureOptions( // Cache RuntimeCacheEnabled = runtimeCacheEnabled; RuntimeCacheTTL = runtimeCacheTtl; + RuntimeCacheLevel2Provider = runtimeCacheLevel2Provider; + RuntimeCacheLevel2ConnectionString = runtimeCacheLevel2ConnectionString; // Compression RuntimeCompressionLevel = runtimeCompressionLevel; // Host @@ -242,6 +246,12 @@ public ConfigureOptions( [Option("runtime.cache.ttl-seconds", Required = false, HelpText = "Customize the DAB cache's global default time to live in seconds. Default: 5 seconds (Integer).")] public int? RuntimeCacheTTL { get; } + [Option("runtime.cache.level-2.provider", Required = false, HelpText = "Set level-2 cache provider. Allowed values: redis.")] + public string? RuntimeCacheLevel2Provider { get; } + + [Option("runtime.cache.level-2.connection-string", Required = false, HelpText = "Set level-2 cache connection string.")] + public string? RuntimeCacheLevel2ConnectionString { get; } + [Option("runtime.compression.level", Required = false, HelpText = "Set the response compression level. Allowed values: optimal (default), fastest, none.")] public CompressionLevel? RuntimeCompressionLevel { get; } diff --git a/src/Cli/Commands/EntityOptions.cs b/src/Cli/Commands/EntityOptions.cs index 3b2b77d9b2..66b534bd6e 100644 --- a/src/Cli/Commands/EntityOptions.cs +++ b/src/Cli/Commands/EntityOptions.cs @@ -25,6 +25,13 @@ public EntityOptions( string? policyDatabase, string? cacheEnabled, string? cacheTtl, + string? semanticSearchEnabled, + string? semanticSearchRedisIndexName, + string? semanticSearchRedisIndexType, + string? semanticSearchRedisIndexMultiplier, + string? semanticSearchSimilarityThreshold, + string? semanticSearchInputDescription, + string? semanticSearchOutputDescription, string? description, IEnumerable? parametersNameCollection, IEnumerable? parametersDescriptionCollection, @@ -54,6 +61,13 @@ public EntityOptions( PolicyDatabase = policyDatabase; CacheEnabled = cacheEnabled; CacheTtl = cacheTtl; + SemanticSearchEnabled = semanticSearchEnabled; + SemanticSearchRedisIndexName = semanticSearchRedisIndexName; + SemanticSearchRedisIndexType = semanticSearchRedisIndexType; + SemanticSearchRedisIndexMultiplier = semanticSearchRedisIndexMultiplier; + SemanticSearchSimilarityThreshold = semanticSearchSimilarityThreshold; + SemanticSearchInputDescription = semanticSearchInputDescription; + SemanticSearchOutputDescription = semanticSearchOutputDescription; Description = description; ParametersNameCollection = parametersNameCollection; ParametersDescriptionCollection = parametersDescriptionCollection; @@ -110,6 +124,27 @@ public EntityOptions( [Option("cache.ttl", Required = false, HelpText = "Specify time to live in seconds for cache entries for Entity.")] public string? CacheTtl { get; } + [Option("semantic-search.enabled", Required = false, HelpText = "Enable semantic search for this entity. Accepted values are true/false.")] + public string? SemanticSearchEnabled { get; } + + [Option("semantic-search.redis-index-name", Required = false, HelpText = "Name of the Redis vector index to use for semantic search.")] + public string? SemanticSearchRedisIndexName { get; } + + [Option("semantic-search.redis-index-type", Required = false, HelpText = "Redis index type for semantic search. Allowed values: hash, json.")] + public string? SemanticSearchRedisIndexType { get; } + + [Option("semantic-search.redis-index-multiplier", Required = false, HelpText = "Multiplier for Redis candidate retrieval. Allowed values: 1-10.")] + public string? SemanticSearchRedisIndexMultiplier { get; } + + [Option("semantic-search.similarity-threshold", Required = false, HelpText = "Default semantic similarity threshold. Allowed values: 0.0-1.0.")] + public string? SemanticSearchSimilarityThreshold { get; } + + [Option("semantic-search.input-description", Required = false, HelpText = "Description for semantic search input metadata.")] + public string? SemanticSearchInputDescription { get; } + + [Option("semantic-search.output-description", Required = false, HelpText = "Description for semantic distance output metadata.")] + public string? SemanticSearchOutputDescription { get; } + [Option("description", Required = false, HelpText = "Description of the entity.")] public string? Description { get; } diff --git a/src/Cli/Commands/UpdateOptions.cs b/src/Cli/Commands/UpdateOptions.cs index 050afa2ddb..573c40baa6 100644 --- a/src/Cli/Commands/UpdateOptions.cs +++ b/src/Cli/Commands/UpdateOptions.cs @@ -42,15 +42,22 @@ public UpdateOptions( string? policyDatabase, string? cacheEnabled, string? cacheTtl, - string? description, - IEnumerable? parametersNameCollection, - IEnumerable? parametersDescriptionCollection, - IEnumerable? parametersRequiredCollection, - IEnumerable? parametersDefaultCollection, - IEnumerable? fieldsNameCollection, - IEnumerable? fieldsAliasCollection, - IEnumerable? fieldsDescriptionCollection, - IEnumerable? fieldsPrimaryKeyCollection, + string? semanticSearchEnabled = null, + string? semanticSearchRedisIndexName = null, + string? semanticSearchRedisIndexType = null, + string? semanticSearchRedisIndexMultiplier = null, + string? semanticSearchSimilarityThreshold = null, + string? semanticSearchInputDescription = null, + string? semanticSearchOutputDescription = null, + string? description = null, + IEnumerable? parametersNameCollection = null, + IEnumerable? parametersDescriptionCollection = null, + IEnumerable? parametersRequiredCollection = null, + IEnumerable? parametersDefaultCollection = null, + IEnumerable? fieldsNameCollection = null, + IEnumerable? fieldsAliasCollection = null, + IEnumerable? fieldsDescriptionCollection = null, + IEnumerable? fieldsPrimaryKeyCollection = null, string? mcpDmlTools = null, string? mcpCustomTool = null, string? config = null) @@ -68,6 +75,13 @@ public UpdateOptions( policyDatabase, cacheEnabled, cacheTtl, + semanticSearchEnabled, + semanticSearchRedisIndexName, + semanticSearchRedisIndexType, + semanticSearchRedisIndexMultiplier, + semanticSearchSimilarityThreshold, + semanticSearchInputDescription, + semanticSearchOutputDescription, description, parametersNameCollection, parametersDescriptionCollection, diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index 8cec9dd239..e450cc9ad2 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -465,6 +465,14 @@ public static bool TryAddNewEntity(AddOptions options, RuntimeConfig initialRunt EntityRestOptions restOptions = ConstructRestOptions(options.RestRoute, SupportedRestMethods, initialRuntimeConfig.DataSource.DatabaseType == DatabaseType.CosmosDB_NoSQL); EntityGraphQLOptions graphqlOptions = ConstructGraphQLTypeDetails(options.GraphQLType, graphQLOperationsForStoredProcedures); EntityCacheOptions? cacheOptions = ConstructCacheOptions(options.CacheEnabled, options.CacheTtl); + EntitySemanticSearchOptions? semanticSearchOptions = ConstructSemanticSearchOptions( + options.SemanticSearchEnabled, + options.SemanticSearchRedisIndexName, + options.SemanticSearchRedisIndexType, + options.SemanticSearchRedisIndexMultiplier, + options.SemanticSearchSimilarityThreshold, + options.SemanticSearchInputDescription, + options.SemanticSearchOutputDescription); EntityMcpOptions? mcpOptions = null; if (options.McpDmlTools is not null || options.McpCustomTool is not null) @@ -488,6 +496,7 @@ public static bool TryAddNewEntity(AddOptions options, RuntimeConfig initialRunt Relationships: null, Mappings: null, Cache: cacheOptions, + SemanticSearch: semanticSearchOptions, Description: string.IsNullOrWhiteSpace(options.Description) ? null : options.Description, Mcp: mcpOptions); @@ -974,7 +983,9 @@ private static bool TryUpdateConfiguredRuntimeOptions( // Cache: Enabled and TTL if (options.RuntimeCacheEnabled != null || - options.RuntimeCacheTTL != null) + options.RuntimeCacheTTL != null || + options.RuntimeCacheLevel2Provider != null || + options.RuntimeCacheLevel2ConnectionString != null) { RuntimeCacheOptions? updatedCacheOptions = runtimeConfig?.Runtime?.Cache ?? new(); bool status = TryUpdateConfiguredCacheValues(options, ref updatedCacheOptions); @@ -1388,6 +1399,43 @@ private static bool TryUpdateConfiguredCacheValues( } } + // Runtime.Cache.level-2.provider + updatedValue = options?.RuntimeCacheLevel2Provider; + if (updatedValue != null) + { + string provider = ((string)updatedValue).Trim(); + if (!string.Equals(provider, "redis", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogError("Failed to update Runtime.Cache.level-2.provider as '{updatedValue}'. Supported value is 'redis'.", updatedValue); + return false; + } + + RuntimeCacheLevel2Options currentLevel2 = updatedCacheOptions?.Level2 ?? new(); + updatedCacheOptions = updatedCacheOptions! with + { + Level2 = currentLevel2 with + { + Provider = provider.ToLowerInvariant() + } + }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Cache.level-2.provider as '{updatedValue}'", updatedValue); + } + + // Runtime.Cache.level-2.connection-string + updatedValue = options?.RuntimeCacheLevel2ConnectionString; + if (updatedValue != null) + { + RuntimeCacheLevel2Options currentLevel2 = updatedCacheOptions?.Level2 ?? new(); + updatedCacheOptions = updatedCacheOptions! with + { + Level2 = currentLevel2 with + { + ConnectionString = (string)updatedValue + } + }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Cache.level-2.connection-string."); + } + return true; } catch (Exception ex) @@ -1851,6 +1899,14 @@ public static bool TryUpdateExistingEntity(UpdateOptions options, RuntimeConfig EntityActionPolicy? updatedPolicy = GetPolicyForOperation(options.PolicyRequest, options.PolicyDatabase); EntityActionFields? updatedFields = GetFieldsForOperation(options.FieldsToInclude, options.FieldsToExclude); EntityCacheOptions? updatedCacheOptions = ConstructCacheOptions(options.CacheEnabled, options.CacheTtl); + EntitySemanticSearchOptions? updatedSemanticSearchOptions = ConstructSemanticSearchOptions( + options.SemanticSearchEnabled, + options.SemanticSearchRedisIndexName, + options.SemanticSearchRedisIndexType, + options.SemanticSearchRedisIndexMultiplier, + options.SemanticSearchSimilarityThreshold, + options.SemanticSearchInputDescription, + options.SemanticSearchOutputDescription); // Determine if the entity is or will be a stored procedure bool isStoredProcedureAfterUpdate = doOptionsRepresentStoredProcedure || (isCurrentEntityStoredProcedure && options.SourceType is null); @@ -2112,6 +2168,7 @@ public static bool TryUpdateExistingEntity(UpdateOptions options, RuntimeConfig Relationships: updatedRelationships, Mappings: updatedMappings, Cache: updatedCacheOptions, + SemanticSearch: updatedSemanticSearchOptions ?? entity.SemanticSearch, Description: string.IsNullOrWhiteSpace(options.Description) ? entity.Description : options.Description, Mcp: updatedMcpOptions ); diff --git a/src/Cli/Utils.cs b/src/Cli/Utils.cs index c1ff7f2a99..40ff2a8c75 100644 --- a/src/Cli/Utils.cs +++ b/src/Cli/Utils.cs @@ -893,6 +893,94 @@ public static EntityGraphQLOptions ConstructGraphQLTypeDetails(string? graphQL, return cacheOptions with { Enabled = isEnabled, TtlSeconds = ttl, UserProvidedTtlOptions = isCacheTtlUserProvided }; } + /// + /// Constructs the EntitySemanticSearchOptions for Add/Update. + /// + /// EntitySemanticSearchOptions when at least one semantic option is provided, null otherwise. + public static EntitySemanticSearchOptions? ConstructSemanticSearchOptions( + string? semanticSearchEnabled, + string? semanticSearchRedisIndexName, + string? semanticSearchRedisIndexType, + string? semanticSearchRedisIndexMultiplier, + string? semanticSearchSimilarityThreshold, + string? semanticSearchInputDescription, + string? semanticSearchOutputDescription) + { + if (semanticSearchEnabled is null + && semanticSearchRedisIndexName is null + && semanticSearchRedisIndexType is null + && semanticSearchRedisIndexMultiplier is null + && semanticSearchSimilarityThreshold is null + && semanticSearchInputDescription is null + && semanticSearchOutputDescription is null) + { + return null; + } + + bool enabled = false; + if (semanticSearchEnabled is not null && !bool.TryParse(semanticSearchEnabled, out enabled)) + { + _logger.LogError("Invalid format for --semantic-search.enabled. Accepted values are true/false."); + return null; + } + + string redisIndexType = EntitySemanticSearchOptions.DEFAULT_REDIS_INDEX_TYPE; + if (!string.IsNullOrWhiteSpace(semanticSearchRedisIndexType)) + { + if (!string.Equals(semanticSearchRedisIndexType, "hash", StringComparison.OrdinalIgnoreCase) + && !string.Equals(semanticSearchRedisIndexType, "json", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogError("Invalid format for --semantic-search.redis-index-type. Accepted values are hash/json."); + return null; + } + + redisIndexType = semanticSearchRedisIndexType.ToLowerInvariant(); + } + + int redisIndexMultiplier = EntitySemanticSearchOptions.DEFAULT_REDIS_INDEX_MULTIPLIER; + if (!string.IsNullOrWhiteSpace(semanticSearchRedisIndexMultiplier) + && !int.TryParse(semanticSearchRedisIndexMultiplier, out redisIndexMultiplier)) + { + _logger.LogError("Invalid format for --semantic-search.redis-index-multiplier. Accepted values are integer values in range [1,10]."); + return null; + } + + if (redisIndexMultiplier < 1 || redisIndexMultiplier > 10) + { + _logger.LogError("Invalid value for --semantic-search.redis-index-multiplier. Accepted values are integer values in range [1,10]."); + return null; + } + + double similarityThreshold = EntitySemanticSearchOptions.DEFAULT_SIMILARITY_THRESHOLD; + if (!string.IsNullOrWhiteSpace(semanticSearchSimilarityThreshold) + && !double.TryParse(semanticSearchSimilarityThreshold, out similarityThreshold)) + { + _logger.LogError("Invalid format for --semantic-search.similarity-threshold. Accepted values are decimal values in range [0.0,1.0]."); + return null; + } + + if (similarityThreshold < 0.0 || similarityThreshold > 1.0) + { + _logger.LogError("Invalid value for --semantic-search.similarity-threshold. Accepted values are decimal values in range [0.0,1.0]."); + return null; + } + + return new EntitySemanticSearchOptions + { + Enabled = enabled, + RedisIndexName = semanticSearchRedisIndexName, + RedisIndexType = redisIndexType, + RedisIndexMultiplier = redisIndexMultiplier, + SimilarityThreshold = similarityThreshold, + InputDescription = string.IsNullOrWhiteSpace(semanticSearchInputDescription) + ? EntitySemanticSearchOptions.DEFAULT_INPUT_DESCRIPTION + : semanticSearchInputDescription, + OutputDescription = string.IsNullOrWhiteSpace(semanticSearchOutputDescription) + ? EntitySemanticSearchOptions.DEFAULT_OUTPUT_DESCRIPTION + : semanticSearchOutputDescription + }; + } + /// /// Constructs the EntityMcpOptions for Add/Update. /// diff --git a/src/Config/ObjectModel/Entity.cs b/src/Config/ObjectModel/Entity.cs index 9c9ba17f2c..637c219744 100644 --- a/src/Config/ObjectModel/Entity.cs +++ b/src/Config/ObjectModel/Entity.cs @@ -37,6 +37,8 @@ public record Entity public Dictionary? Mappings { get; init; } public Dictionary? Relationships { get; init; } public EntityCacheOptions? Cache { get; init; } + [JsonPropertyName("semantic-search")] + public EntitySemanticSearchOptions? SemanticSearch { get; init; } public EntityHealthCheckConfig? Health { get; init; } [JsonConverter(typeof(EntityMcpOptionsConverterFactory))] @@ -62,7 +64,8 @@ public Entity( EntityHealthCheckConfig? Health = null, string? Description = null, EntityMcpOptions? Mcp = null, - bool IsAutoentity = false) + bool IsAutoentity = false, + EntitySemanticSearchOptions? SemanticSearch = null) { this.Health = Health; this.Source = Source; @@ -73,6 +76,7 @@ public Entity( this.Mappings = Mappings; this.Relationships = Relationships; this.Cache = Cache; + this.SemanticSearch = SemanticSearch; this.IsLinkingEntity = IsLinkingEntity; this.Description = Description; this.Mcp = Mcp; diff --git a/src/Config/ObjectModel/EntitySemanticSearchOptions.cs b/src/Config/ObjectModel/EntitySemanticSearchOptions.cs new file mode 100644 index 0000000000..1175330501 --- /dev/null +++ b/src/Config/ObjectModel/EntitySemanticSearchOptions.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Entity specific semantic search configuration. +/// +public record EntitySemanticSearchOptions +{ + public const int DEFAULT_REDIS_INDEX_MULTIPLIER = 2; + public const double DEFAULT_SIMILARITY_THRESHOLD = 0.8; + public const string DEFAULT_REDIS_INDEX_TYPE = "hash"; + public const string DEFAULT_INPUT_DESCRIPTION = "Natural language value used for semantic search."; + public const string DEFAULT_OUTPUT_DESCRIPTION = "Semantic distance score returned by semantic search."; + + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } = false; + + [JsonPropertyName("redis-index-name")] + public string? RedisIndexName { get; init; } + + [JsonPropertyName("redis-index-type")] + public string RedisIndexType { get; init; } = DEFAULT_REDIS_INDEX_TYPE; + + [JsonPropertyName("redis-index-multiplier")] + public int RedisIndexMultiplier { get; init; } = DEFAULT_REDIS_INDEX_MULTIPLIER; + + [JsonPropertyName("similarity-threshold")] + public double SimilarityThreshold { get; init; } = DEFAULT_SIMILARITY_THRESHOLD; + + [JsonPropertyName("input-description")] + public string InputDescription { get; init; } = DEFAULT_INPUT_DESCRIPTION; + + [JsonPropertyName("output-description")] + public string OutputDescription { get; init; } = DEFAULT_OUTPUT_DESCRIPTION; +} diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index c8d86e8e11..7283fe09e4 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -65,6 +65,23 @@ public class RuntimeConfigValidator : IConfigValidator // Error messages. public const string INVALID_CLAIMS_IN_POLICY_ERR_MSG = "One or more claim types supplied in the database policy are not supported."; + private const string SEMANTIC_SEARCH_REDIS_REQUIREMENT_ERR_MSG = + "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured."; + + private static readonly HashSet _reservedSemanticRestNames = + [ + "semantic_search", + "semantic_threshold", + "semantic_distance" + ]; + + private static readonly HashSet _reservedSemanticGraphQLNames = + [ + "semanticSearch", + "semanticThreshold", + "semanticDistance" + ]; + public RuntimeConfigValidator( RuntimeConfigProvider runtimeConfigProvider, IFileSystem fileSystem, @@ -706,6 +723,136 @@ public void ValidateEntityConfiguration(RuntimeConfig runtimeConfig) ValidateNameRequirements(entity.GraphQL.Singular); ValidateNameRequirements(entity.GraphQL.Plural); } + + ValidateSemanticSearchConfiguration(runtimeConfig, entityName, entity); + } + } + + /// + /// Validates semantic-search settings configured for an entity. + /// + private void ValidateSemanticSearchConfiguration(RuntimeConfig runtimeConfig, string entityName, Entity entity) + { + EntitySemanticSearchOptions? semantic = entity.SemanticSearch; + + if (semantic is null || !semantic.Enabled) + { + return; + } + + if (entity.Source.Type is not EntitySourceType.Table and not EntitySourceType.View) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Semantic search is only supported for entities with source type 'table' or 'view'.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (string.IsNullOrWhiteSpace(semantic.RedisIndexName)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"semantic-search.redis-index-name is required when semantic-search.enabled is true for entity '{entityName}'.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (!HasValidSemanticRedisConfiguration(runtimeConfig)) + { + HandleOrRecordException(new DataApiBuilderException( + message: SEMANTIC_SEARCH_REDIS_REQUIREMENT_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (semantic.RedisIndexMultiplier < 1 || semantic.RedisIndexMultiplier > 10) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"semantic-search.redis-index-multiplier for entity '{entityName}' must be an integer between 1 and 10.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (semantic.SimilarityThreshold < 0.0 || semantic.SimilarityThreshold > 1.0) + { + HandleOrRecordException(new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + ValidateSemanticReservedNames(entityName, entity); + } + + /// + /// Validates Redis provider/connection-string prerequisites for semantic-search. + /// + private static bool HasValidSemanticRedisConfiguration(RuntimeConfig runtimeConfig) + { + RuntimeCacheLevel2Options? level2 = runtimeConfig.Runtime?.Cache?.Level2; + return level2 is not null + && string.Equals(level2.Provider, EntityCacheOptions.L2_CACHE_PROVIDER, StringComparison.OrdinalIgnoreCase) + && !string.IsNullOrWhiteSpace(level2.ConnectionString); + } + + /// + /// Ensures configured field names do not collide with semantic reserved names. + /// + private void ValidateSemanticReservedNames(string entityName, Entity entity) + { + HashSet exposedFieldNames = new(StringComparer.OrdinalIgnoreCase); + + if (entity.Fields is not null) + { + foreach (FieldMetadata field in entity.Fields) + { + if (!string.IsNullOrWhiteSpace(field.Name)) + { + exposedFieldNames.Add(field.Name); + } + + if (!string.IsNullOrWhiteSpace(field.Alias)) + { + exposedFieldNames.Add(field.Alias); + } + } + } + + if (entity.Mappings is not null) + { + foreach (KeyValuePair mapping in entity.Mappings) + { + if (!string.IsNullOrWhiteSpace(mapping.Key)) + { + exposedFieldNames.Add(mapping.Key); + } + + if (!string.IsNullOrWhiteSpace(mapping.Value)) + { + exposedFieldNames.Add(mapping.Value); + } + } + } + + foreach (string reserved in _reservedSemanticRestNames) + { + if (exposedFieldNames.Contains(reserved)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Entity '{entityName}' cannot enable semantic search because field name '{reserved}' is reserved.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + } + + foreach (string reserved in _reservedSemanticGraphQLNames) + { + if (exposedFieldNames.Contains(reserved)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Entity '{entityName}' cannot enable semantic search because field name '{reserved}' is reserved.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } } } diff --git a/src/Core/Models/RestRequestContexts/RestRequestContext.cs b/src/Core/Models/RestRequestContexts/RestRequestContext.cs index e9987730a0..716f9cd96b 100644 --- a/src/Core/Models/RestRequestContexts/RestRequestContext.cs +++ b/src/Core/Models/RestRequestContexts/RestRequestContext.cs @@ -95,6 +95,31 @@ protected RestRequestContext(string entityName, DatabaseObject dbo) /// public int? First { get; set; } + + /// + /// Optional semantic search input provided by the caller. + /// + public string? SemanticSearch { get; set; } + + /// + /// Optional semantic threshold override provided by the caller. + /// + public double? SemanticThreshold { get; set; } + + /// + /// True when semantic distance should be included in REST output. + /// + public bool IncludeSemanticDistanceInResponse { get; set; } + + /// + /// Semantic distance keyed by a stable primary-key signature generated from result rows. + /// + public Dictionary? SemanticDistanceByPrimaryKeySignature { get; set; } + + /// + /// True when user explicitly requested semantic distance in the projection. + /// + public bool IsSemanticDistanceExplicitlySelected { get; set; } /// /// Is the result supposed to be multiple values or not. /// diff --git a/src/Core/Models/SemanticSearchCandidate.cs b/src/Core/Models/SemanticSearchCandidate.cs new file mode 100644 index 0000000000..65251ffcfd --- /dev/null +++ b/src/Core/Models/SemanticSearchCandidate.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Models; + +/// +/// Represents one semantic candidate with SQL column values extracted from +/// the semantic document and primary key values used for dedupe/output mapping. +/// +public record SemanticSearchCandidate( + IReadOnlyDictionary PrimaryKeyValues, + IReadOnlyDictionary ColumnValues, + double Distance); diff --git a/src/Core/Models/SemanticSearchConstants.cs b/src/Core/Models/SemanticSearchConstants.cs new file mode 100644 index 0000000000..dbca8d9426 --- /dev/null +++ b/src/Core/Models/SemanticSearchConstants.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Models; + +public static class SemanticSearchConstants +{ + public const string REST_SEARCH_QUERY_PARAM = "$semantic_search"; + public const string REST_THRESHOLD_QUERY_PARAM = "$semantic_threshold"; + public const string REST_DISTANCE_FIELD = "semantic_distance"; + + public const string GRAPHQL_SEARCH_ARGUMENT = "semanticSearch"; + public const string GRAPHQL_THRESHOLD_ARGUMENT = "semanticThreshold"; + public const string GRAPHQL_DISTANCE_FIELD = "semanticDistance"; +} diff --git a/src/Core/Parsers/RequestParser.cs b/src/Core/Parsers/RequestParser.cs index 081018e820..90d3eb3c58 100644 --- a/src/Core/Parsers/RequestParser.cs +++ b/src/Core/Parsers/RequestParser.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Globalization; using System.Net; using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Services; @@ -35,6 +36,14 @@ public class RequestParser /// Prefix used for specifying paging in the query string of the URL. /// public const string AFTER_URL = "$after"; + /// + /// Prefix used for specifying semantic search text in the query string of the URL. + /// + public const string SEMANTIC_SEARCH_URL = SemanticSearchConstants.REST_SEARCH_QUERY_PARAM; + /// + /// Prefix used for specifying semantic similarity threshold in the query string of the URL. + /// + public const string SEMANTIC_THRESHOLD_URL = SemanticSearchConstants.REST_THRESHOLD_QUERY_PARAM; /// /// Parses the primary key string to identify the field names composing the key @@ -138,6 +147,14 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); } + if (rawSortValue.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase)) + { + throw new DataApiBuilderException( + message: "semantic_distance cannot be used in orderBy.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + (context.OrderByClauseInUrl, context.OrderByClauseOfBackingColumns) = GenerateOrderByLists(context, sqlMetadataProvider, $"?{SORT_URL}={rawSortValue}"); break; case AFTER_URL: @@ -146,6 +163,22 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv case FIRST_URL: context.First = RequestValidator.CheckFirstValidity(context.ParsedQueryString[key]!); break; + case SEMANTIC_SEARCH_URL: + context.SemanticSearch = context.ParsedQueryString[key]; + break; + case SEMANTIC_THRESHOLD_URL: + if (!double.TryParse(context.ParsedQueryString[key], NumberStyles.Float, CultureInfo.InvariantCulture, out double semanticThreshold) + || semanticThreshold < 0.0 + || semanticThreshold > 1.0) + { + throw new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + context.SemanticThreshold = semanticThreshold; + break; default: throw new DataApiBuilderException( message: $"Invalid Query Parameter: {key}", diff --git a/src/Core/Resolvers/Factories/QueryEngineFactory.cs b/src/Core/Resolvers/Factories/QueryEngineFactory.cs index 1d2ae2935d..1260db344a 100644 --- a/src/Core/Resolvers/Factories/QueryEngineFactory.cs +++ b/src/Core/Resolvers/Factories/QueryEngineFactory.cs @@ -9,6 +9,7 @@ using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -32,6 +33,7 @@ public class QueryEngineFactory : IQueryEngineFactory private readonly IAuthorizationResolver _authorizationResolver; private readonly GQLFilterParser _gQLFilterParser; private readonly DabCacheService _cache; + private readonly ISemanticSearchService _semanticSearchService; private readonly ILogger _logger; /// @@ -42,6 +44,7 @@ public QueryEngineFactory(RuntimeConfigProvider runtimeConfigProvider, IHttpContextAccessor contextAccessor, IAuthorizationResolver authorizationResolver, GQLFilterParser gQLFilterParser, + ISemanticSearchService semanticSearchService, ILogger logger, DabCacheService cache, HotReloadEventHandler? handler) @@ -55,6 +58,7 @@ public QueryEngineFactory(RuntimeConfigProvider runtimeConfigProvider, _contextAccessor = contextAccessor; _authorizationResolver = authorizationResolver; _gQLFilterParser = gQLFilterParser; + _semanticSearchService = semanticSearchService; _cache = cache; _logger = logger; @@ -75,7 +79,8 @@ public void ConfigureQueryEngines() _gQLFilterParser, _logger, _runtimeConfigProvider, - _cache); + _cache, + _semanticSearchService); _queryEngines.Add(DatabaseType.MSSQL, queryEngine); _queryEngines.Add(DatabaseType.MySQL, queryEngine); _queryEngines.Add(DatabaseType.PostgreSQL, queryEngine); diff --git a/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs b/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs index 0370b44b83..1e03b0d695 100644 --- a/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs +++ b/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs @@ -4,6 +4,7 @@ using System.Data; using System.Net; using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; @@ -92,6 +93,80 @@ public class SqlQueryStructure : BaseSqlQueryStructure /// public GroupByMetadata GroupByMetadata { get; private set; } + /// + /// True when semantic search arguments are present on the request. + /// + public bool SemanticSearchRequested { get; private set; } + + /// + /// True when a user supplied ordering expression. + /// + public bool HasUserSpecifiedOrdering { get; private set; } + + /// + /// Primary-key signature to semantic distance map. + /// + public Dictionary SemanticDistanceByPrimaryKeySignature { get; } + + /// + /// Applies a semantic narrowing predicate of the form: + /// (col1 = ... AND col2 = ...) OR (...) + /// + public void ApplySemanticCandidates(IReadOnlyList candidates) + { + if (candidates.Count == 0) + { + Predicates.Add(Predicate.MakeFalsePredicate()); + return; + } + + Predicate? combinedOr = null; + + foreach (SemanticSearchCandidate candidate in candidates) + { + Predicate? andPredicate = null; + foreach ((string columnName, object? value) in candidate.ColumnValues) + { + if (value is null) + { + continue; + } + + string parameterName = MakeDbConnectionParam( + GetParamAsSystemType(value.ToString()!, columnName, GetColumnSystemType(columnName)), + columnName); + + Predicate equalsPredicate = new( + new PredicateOperand(new Column(DatabaseObject.SchemaName, DatabaseObject.Name, columnName, SourceAlias)), + PredicateOperation.Equal, + new PredicateOperand(parameterName), + addParenthesis: true); + + andPredicate = andPredicate is null + ? equalsPredicate + : new Predicate(new PredicateOperand(andPredicate), PredicateOperation.AND, new PredicateOperand(equalsPredicate), addParenthesis: true); + } + + if (andPredicate is null) + { + continue; + } + + combinedOr = combinedOr is null + ? andPredicate + : new Predicate(new PredicateOperand(combinedOr), PredicateOperation.OR, new PredicateOperand(andPredicate), addParenthesis: true); + } + + if (combinedOr is null) + { + Predicates.Add(Predicate.MakeFalsePredicate()); + } + else + { + Predicates.Add(combinedOr); + } + } + public virtual string? CacheControlOption { get; set; } public const string CACHE_CONTROL = "Cache-Control"; @@ -280,6 +355,7 @@ public SqlQueryStructure( IsListQuery = context.IsMany; SourceAlias = $"{DatabaseObject.SchemaName}_{DatabaseObject.Name}"; AddFields(context, sqlMetadataProvider); + SemanticSearchRequested = !string.IsNullOrWhiteSpace(context.SemanticSearch); foreach (KeyValuePair predicate in context.PrimaryKeyValuePairs) { sqlMetadataProvider.TryGetBackingColumn(EntityName, predicate.Key, out string? backingColumn); @@ -293,6 +369,7 @@ public SqlQueryStructure( // to only Find, we populate the SourceAlias in this constructor where we know we have a Find operation. OrderByColumns = context.OrderByClauseOfBackingColumns is not null ? context.OrderByClauseOfBackingColumns : PrimaryKeyAsOrderByColumns(); + HasUserSpecifiedOrdering = context.OrderByClauseOfBackingColumns is not null; foreach (OrderByColumn column in OrderByColumns) { @@ -444,6 +521,18 @@ private SqlQueryStructure( EntityName = sqlMetadataProvider.GetDatabaseType() == DatabaseType.DWSQL ? GraphQLUtils.GetEntityNameFromContext(ctx) : _underlyingFieldType.Name; bool isGroupByQuery = queryField?.Name.Value == QueryBuilder.GROUP_BY_FIELD_NAME; + SemanticSearchRequested = queryParams.TryGetValue(QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME, out object? semanticValue) + && semanticValue is string semanticSearchText + && !string.IsNullOrWhiteSpace(semanticSearchText); + + if (SemanticSearchRequested && isGroupByQuery) + { + throw new DataApiBuilderException( + message: "Semantic search is not supported for aggregate queries.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + if (GraphQLUtils.TryExtractGraphQLFieldModelName(_underlyingFieldType.Directives, out string? modelName)) { EntityName = modelName; @@ -520,6 +609,7 @@ private SqlQueryStructure( if (orderByObject is not null) { OrderByColumns = ProcessGqlOrderByArg((List)orderByObject, queryArgumentSchemas[QueryBuilder.ORDER_BY_FIELD_NAME], isGroupByQuery); + HasUserSpecifiedOrdering = true; } } @@ -575,6 +665,7 @@ private SqlQueryStructure( PaginationMetadata = new(this); GroupByMetadata = new(); ColumnLabelToParam = new(); + SemanticDistanceByPrimaryKeySignature = new(StringComparer.Ordinal); FilterPredicates = string.Empty; OrderByColumns = new(); AddCacheControlOptions(httpRequestHeaders); @@ -733,6 +824,9 @@ void ProcessPaginationFields(IReadOnlyList paginationSelections) // TODO : This is inefficient and could lead to errors. we should rewrite this to use the ISelection API. private void AddGraphQLFields(IReadOnlyList selections, RuntimeConfigProvider runtimeConfigProvider) { + bool entitySemanticEnabled = runtimeConfigProvider.GetConfig().Entities.TryGetValue(EntityName, out Entity? entity) + && entity.SemanticSearch?.Enabled is true; + foreach (ISelectionNode node in selections) { if (node.Kind == SyntaxKind.FragmentSpread) @@ -784,6 +878,12 @@ private void AddGraphQLFields(IReadOnlyList selections, RuntimeC if (field.SelectionSet is null) { + if (entitySemanticEnabled + && string.Equals(fieldName, SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD, StringComparison.Ordinal)) + { + continue; + } + if (MetadataProvider.TryGetBackingColumn(EntityName, fieldName, out string? name) && !string.IsNullOrWhiteSpace(name)) { @@ -796,6 +896,14 @@ private void AddGraphQLFields(IReadOnlyList selections, RuntimeC } else { + if (SemanticSearchRequested) + { + throw new DataApiBuilderException( + message: "Semantic search is not supported with relationship expansion.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + ObjectField? subschemaField = _underlyingFieldType.Fields[fieldName]; if (_ctx == null) diff --git a/src/Core/Resolvers/SqlQueryEngine.cs b/src/Core/Resolvers/SqlQueryEngine.cs index f567251771..12d804328f 100644 --- a/src/Core/Resolvers/SqlQueryEngine.cs +++ b/src/Core/Resolvers/SqlQueryEngine.cs @@ -4,6 +4,7 @@ using System.Text; using System.Text.Json; using System.Text.Json.Nodes; +using System.Linq; using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Configurations; @@ -12,7 +13,9 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Exceptions; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Service.GraphQLBuilder; using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; using HotChocolate.Resolvers; @@ -37,6 +40,7 @@ public class SqlQueryEngine : IQueryEngine private readonly RuntimeConfigProvider _runtimeConfigProvider; private readonly GQLFilterParser _gQLFilterParser; private readonly DabCacheService _cache; + private readonly ISemanticSearchService _semanticSearchService; // // Constructor. @@ -49,7 +53,8 @@ public SqlQueryEngine( GQLFilterParser gQLFilterParser, ILogger logger, RuntimeConfigProvider runtimeConfigProvider, - DabCacheService cache) + DabCacheService cache, + ISemanticSearchService? semanticSearchService = null) { _queryFactory = queryFactory; _sqlMetadataProviderFactory = sqlMetadataProviderFactory; @@ -59,6 +64,7 @@ public SqlQueryEngine( _logger = logger; _runtimeConfigProvider = runtimeConfigProvider; _cache = cache; + _semanticSearchService = semanticSearchService ?? NoOpSemanticSearchService.Instance; } /// @@ -79,16 +85,32 @@ public SqlQueryEngine( _runtimeConfigProvider, _gQLFilterParser); + (bool shouldReturnEmpty, JsonDocument? emptySemanticResponse) = await TryApplySemanticNarrowingAsync( + structure, + dataSourceName, + parameters, + GraphQLUtils.GetEntityNameFromContext(context)); + if (shouldReturnEmpty) + { + return new Tuple(emptySemanticResponse, structure.PaginationMetadata); + } + if (structure.PaginationMetadata.IsPaginated) { + JsonDocument? queryResult = await ExecuteAsync(structure, dataSourceName); + JsonDocument? processedResult = ApplySemanticDistanceAndOrderingIfNeeded(queryResult, structure, dataSourceName, includeRestField: false, includeGraphQlField: true); + return new Tuple( - SqlPaginationUtil.CreatePaginationConnectionFromJsonDocument(await ExecuteAsync(structure, dataSourceName), structure.PaginationMetadata, structure.GroupByMetadata), + SqlPaginationUtil.CreatePaginationConnectionFromJsonDocument(processedResult, structure.PaginationMetadata, structure.GroupByMetadata), structure.PaginationMetadata); } else { + JsonDocument? queryResult = await ExecuteAsync(structure, dataSourceName); + JsonDocument? processedResult = ApplySemanticDistanceAndOrderingIfNeeded(queryResult, structure, dataSourceName, includeRestField: false, includeGraphQlField: true); + return new Tuple( - await ExecuteAsync(structure, dataSourceName), + processedResult, structure.PaginationMetadata); } } @@ -189,7 +211,234 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), _runtimeConfigProvider, _gQLFilterParser, _httpContextAccessor.HttpContext!); - return await ExecuteAsync(structure, dataSourceName); + + (bool shouldReturnEmpty, JsonDocument? emptySemanticResponse) = await TryApplySemanticNarrowingAsync( + structure, + dataSourceName, + graphQLParameters: null, + context.EntityName, + context); + if (shouldReturnEmpty) + { + return emptySemanticResponse; + } + + JsonDocument? response = await ExecuteAsync(structure, dataSourceName); + return ApplySemanticDistanceAndOrderingIfNeeded(response, structure, dataSourceName, includeRestField: true, includeGraphQlField: false, context); + } + + private async Task<(bool shouldReturnEmpty, JsonDocument? emptyResponse)> TryApplySemanticNarrowingAsync( + SqlQueryStructure structure, + string dataSourceName, + IDictionary? graphQLParameters, + string entityName, + FindRequestContext? restContext = null) + { + JsonDocument? emptyResponse = null; + + string? semanticSearchText = restContext?.SemanticSearch; + if (semanticSearchText is null + && graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME, out object? graphQlSearchValue) + && graphQlSearchValue is string graphQlSearchText) + { + semanticSearchText = graphQlSearchText; + } + + bool hasSemanticThresholdOnly = restContext?.SemanticThreshold is not null + || (graphQLParameters is not null && graphQLParameters.ContainsKey(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME)); + + bool semanticRequested = !string.IsNullOrWhiteSpace(semanticSearchText) || hasSemanticThresholdOnly; + + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + if (!runtimeConfig.Entities.TryGetValue(entityName, out Entity? entity) + || entity.SemanticSearch is null + || !entity.SemanticSearch.Enabled) + { + if (semanticRequested) + { + throw new DataApiBuilderException( + message: $"Semantic search is not enabled for entity '{entityName}'.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + return (false, null); + } + + if (string.IsNullOrWhiteSpace(semanticSearchText)) + { + return (false, null); + } + + if (graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.PAGINATION_TOKEN_ARGUMENT_NAME, out object? afterArg) + && afterArg is string afterToken + && !string.IsNullOrWhiteSpace(afterToken)) + { + throw new DataApiBuilderException( + message: "Pagination continuation tokens are not supported when semantic search is used.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME, out object? graphQlThreshold) + && graphQlThreshold is double graphQlThresholdDouble + && (graphQlThresholdDouble < 0.0 || graphQlThresholdDouble > 1.0)) + { + throw new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + double effectiveThreshold = restContext?.SemanticThreshold + ?? (graphQLParameters is not null && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME, out object? gqlThresholdObj) && gqlThresholdObj is double gqlThresholdVal + ? gqlThresholdVal + : entity.SemanticSearch.SimilarityThreshold); + + int first = restContext?.First + ?? (graphQLParameters is not null && graphQLParameters.TryGetValue(QueryBuilder.PAGE_START_ARGUMENT_NAME, out object? firstObj) && firstObj is int gqlFirst ? gqlFirst : (int)runtimeConfig.DefaultPageSize()); + int redisTop = first * entity.SemanticSearch.RedisIndexMultiplier; + + SourceDefinition sourceDefinition = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(entityName); + IReadOnlyList primaryKeyColumns = sourceDefinition.PrimaryKey; + + IReadOnlyList candidates = await _semanticSearchService.GetCandidatesAsync( + entityName, + entity.SemanticSearch, + primaryKeyColumns, + semanticSearchText, + effectiveThreshold, + redisTop); + + if (candidates.Count == 0) + { + emptyResponse = JsonDocument.Parse("[]"); + return (true, emptyResponse); + } + + // De-duplicate candidate keys while keeping the highest distance. + Dictionary deduped = new(StringComparer.Ordinal); + foreach (SemanticSearchCandidate candidate in candidates) + { + string signature = BuildPrimaryKeySignatureFromValues(primaryKeyColumns, candidate.PrimaryKeyValues); + if (!deduped.TryGetValue(signature, out SemanticSearchCandidate? existing) || candidate.Distance > existing.Distance) + { + deduped[signature] = candidate; + } + } + + IReadOnlyList narrowedCandidates = deduped.Values.ToList(); + structure.ApplySemanticCandidates(narrowedCandidates); + foreach (KeyValuePair kvp in deduped) + { + structure.SemanticDistanceByPrimaryKeySignature[kvp.Key] = kvp.Value.Distance; + } + + return (false, null); + } + + private JsonDocument? ApplySemanticDistanceAndOrderingIfNeeded( + JsonDocument? response, + SqlQueryStructure structure, + string dataSourceName, + bool includeRestField, + bool includeGraphQlField, + FindRequestContext? restContext = null) + { + if (response is null || structure.SemanticDistanceByPrimaryKeySignature.Count == 0) + { + return response; + } + + SourceDefinition sourceDefinition = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(structure.EntityName); + if (response.RootElement.ValueKind is not JsonValueKind.Array) + { + return response; + } + + List<(double? distance, JsonObject row)> rows = new(); + foreach (JsonElement element in response.RootElement.EnumerateArray()) + { + JsonObject? row = JsonObject.Create(element); + if (row is null) + { + continue; + } + + string signature = BuildPrimaryKeySignatureFromRow(sourceDefinition.PrimaryKey, structure.EntityName, dataSourceName, element); + double? distance = null; + if (structure.SemanticDistanceByPrimaryKeySignature.TryGetValue(signature, out double mappedDistance)) + { + distance = mappedDistance; + } + + if (includeRestField) + { + row[SemanticSearchConstants.REST_DISTANCE_FIELD] = distance.HasValue ? JsonValue.Create(distance.Value) : null; + } + + if (includeGraphQlField) + { + row[SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD] = distance.HasValue ? JsonValue.Create(distance.Value) : null; + } + + rows.Add((distance, row)); + } + + if (!structure.HasUserSpecifiedOrdering) + { + rows = rows.OrderByDescending(r => r.distance ?? double.MinValue).ToList(); + } + + JsonElement updated = JsonSerializer.SerializeToElement(rows.Select(r => r.row).ToList()); + + response.Dispose(); + return JsonDocument.Parse(updated.GetRawText()); + } + + private string BuildPrimaryKeySignatureFromRow( + IReadOnlyList primaryKeyColumns, + string entityName, + string dataSourceName, + JsonElement row) + { + ISqlMetadataProvider metadataProvider = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName); + Dictionary values = new(StringComparer.Ordinal); + foreach (string primaryKeyColumn in primaryKeyColumns) + { + if (!metadataProvider.TryGetExposedColumnName(entityName, primaryKeyColumn, out string? exposedName) + || string.IsNullOrWhiteSpace(exposedName) + || !row.TryGetProperty(exposedName, out JsonElement valueElement)) + { + values[primaryKeyColumn] = null; + continue; + } + + values[primaryKeyColumn] = valueElement.ValueKind switch + { + JsonValueKind.Null => null, + JsonValueKind.String => valueElement.GetString(), + _ => valueElement.ToString() + }; + } + + return BuildPrimaryKeySignatureFromValues(primaryKeyColumns, values); + } + + private static string BuildPrimaryKeySignatureFromValues( + IReadOnlyList primaryKeyColumns, + IReadOnlyDictionary values) + { + return string.Join("|", primaryKeyColumns.Select(pk => + { + string serializedValue = values.TryGetValue(pk, out object? value) + ? value?.ToString() ?? "null" + : "null"; + return $"{pk}={serializedValue}"; + })); } /// diff --git a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs index 2aaba3eed3..43113449b0 100644 --- a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs @@ -12,6 +12,7 @@ using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Parsers; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; @@ -374,6 +375,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour // Operations including primary key Dictionary pkOperations = CreateOperations( entityName: entityName, + entity: entity, sourceDefinition: sourceDefinition, includePrimaryKeyPathComponent: true, configuredRestOperations: configuredRestOperations, @@ -397,6 +399,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour // Operations excluding primary key Dictionary operations = CreateOperations( entityName: entityName, + entity: entity, sourceDefinition: sourceDefinition, includePrimaryKeyPathComponent: false, configuredRestOperations: configuredRestOperations, @@ -432,6 +435,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour /// Collection of operation types and associated definitions. private Dictionary CreateOperations( string entityName, + Entity entity, SourceDefinition sourceDefinition, bool includePrimaryKeyPathComponent, Dictionary configuredRestOperations, @@ -444,7 +448,7 @@ private Dictionary CreateOperations( if (configuredRestOperations[OperationType.Get]) { OpenApiOperation getOperation = CreateBaseOperation(description: GETONE_DESCRIPTION, tags: tags); - AddQueryParameters(getOperation.Parameters); + AddQueryParameters(getOperation.Parameters, entity); getOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); openApiPathItemOperations.Add(OperationType.Get, getOperation); } @@ -487,7 +491,7 @@ private Dictionary CreateOperations( if (configuredRestOperations[OperationType.Get]) { OpenApiOperation getAllOperation = CreateBaseOperation(description: GETALL_DESCRIPTION, tags: tags); - AddQueryParameters(getAllOperation.Parameters); + AddQueryParameters(getAllOperation.Parameters, entity); getAllOperation.Responses.Add( HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName, includeNextLink: true)); @@ -537,12 +541,27 @@ private Dictionary CreateOperations( /// Helper method to add query parameters like $select, $first, $orderby etc. to get and getAll operations for tables/views. /// /// List of parameters for the operation. - private static void AddQueryParameters(IList parameters) + private static void AddQueryParameters(IList parameters, Entity entity) { foreach (OpenApiParameter openApiParameter in _tableAndViewQueryParameters) { parameters.Add(openApiParameter); } + + if (entity.SemanticSearch?.Enabled is true) + { + parameters.Add(GetOpenApiQueryParameter( + name: RequestParser.SEMANTIC_SEARCH_URL, + description: entity.SemanticSearch.InputDescription, + required: false, + type: "string")); + + parameters.Add(GetOpenApiQueryParameter( + name: RequestParser.SEMANTIC_THRESHOLD_URL, + description: "Minimum semantic similarity threshold between 0.0 and 1.0.", + required: false, + type: "number")); + } } /// @@ -1500,6 +1519,19 @@ private static OpenApiSchema CreateComponentSchema( } } + if (!isRequestBodySchema + && entityConfig?.SemanticSearch?.Enabled is true + && !properties.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD)) + { + properties.Add(SemanticSearchConstants.REST_DISTANCE_FIELD, new OpenApiSchema() + { + Type = "number", + Description = entityConfig.SemanticSearch.OutputDescription, + ReadOnly = true, + Nullable = true + }); + } + OpenApiSchema schema = new() { Type = SCHEMA_OBJECT_TYPE, diff --git a/src/Core/Services/RequestValidator.cs b/src/Core/Services/RequestValidator.cs index aef6ca9ab3..e96e7523ba 100644 --- a/src/Core/Services/RequestValidator.cs +++ b/src/Core/Services/RequestValidator.cs @@ -272,6 +272,15 @@ public void ValidateInsertRequestContext(InsertRequestContext insertRequestCtx) { ISqlMetadataProvider sqlMetadataProvider = GetSqlMetadataProvider(insertRequestCtx.EntityName); + if (insertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD) + || insertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD)) + { + throw new DataApiBuilderException( + message: "semantic_distance is read-only.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + IEnumerable fieldsInRequestBody = insertRequestCtx.FieldValuePairsInBody.Keys; SourceDefinition sourceDefinition = TryGetSourceDefinition(insertRequestCtx.EntityName, sqlMetadataProvider); @@ -356,6 +365,16 @@ public void ValidateInsertRequestContext(InsertRequestContext insertRequestCtx) public void ValidateUpsertRequestContext(UpsertRequestContext upsertRequestCtx, bool isPrimaryKeyInUrl = true) { ISqlMetadataProvider sqlMetadataProvider = GetSqlMetadataProvider(upsertRequestCtx.EntityName); + + if (upsertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD) + || upsertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD)) + { + throw new DataApiBuilderException( + message: "semantic_distance is read-only.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + IEnumerable fieldsInRequestBody = upsertRequestCtx.FieldValuePairsInBody.Keys; bool isRequestBodyStrict = IsRequestBodyStrict(); SourceDefinition sourceDefinition = TryGetSourceDefinition(upsertRequestCtx.EntityName, sqlMetadataProvider); @@ -401,7 +420,7 @@ public void ValidateUpsertRequestContext(UpsertRequestContext upsertRequestCtx, else { // Body-based PK: non-auto-generated PK columns MUST be present. - // Auto-generated PK columns are skipped — they cannot be supplied by the caller. + // Auto-generated PK columns are skipped � they cannot be supplied by the caller. if (column.Value.IsAutoGenerated) { continue; diff --git a/src/Core/Services/RestService.cs b/src/Core/Services/RestService.cs index 2bfab7e05f..c9523f1160 100644 --- a/src/Core/Services/RestService.cs +++ b/src/Core/Services/RestService.cs @@ -194,6 +194,7 @@ RequestValidator requestValidator context.RawQueryString = queryString; context.ParsedQueryString = HttpUtility.ParseQueryString(queryString); RequestParser.ParseQueryString(context, sqlMetadataProvider); + ValidateSemanticRestRequest(context, operationType, _runtimeConfigProvider.GetConfig()); } } @@ -235,6 +236,12 @@ private async Task DispatchQuery(RestRequestContext context, Data if (context is FindRequestContext findRequestContext) { + if (findRequestContext.IncludeSemanticDistanceInResponse + && !findRequestContext.FieldsToBeReturned.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD)) + { + findRequestContext.FieldsToBeReturned.Add(SemanticSearchConstants.REST_DISTANCE_FIELD); + } + using JsonDocument? restApiResponse = await queryEngine.ExecuteAsync(findRequestContext); return restApiResponse is null ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), findRequestContext, _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName), _runtimeConfigProvider.GetConfig(), GetHttpContext()) : SqlResponseHelpers.FormatFindResult(restApiResponse.RootElement.Clone(), findRequestContext, _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName), _runtimeConfigProvider.GetConfig(), GetHttpContext()); @@ -356,6 +363,71 @@ private bool IsHttpMethodAllowedForStoredProcedure(string entityName) return false; } + private static void ValidateSemanticRestRequest(RestRequestContext context, EntityActionOperation operationType, RuntimeConfig runtimeConfig) + { + bool hasSemanticSearch = !string.IsNullOrWhiteSpace(context.SemanticSearch); + bool hasSemanticThreshold = context.SemanticThreshold.HasValue; + bool isSemanticRequest = hasSemanticSearch || hasSemanticThreshold; + + if (context.FieldsToBeReturned.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparer.OrdinalIgnoreCase)) + { + context.FieldsToBeReturned = context.FieldsToBeReturned + .Where(field => !string.Equals(field, SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase)) + .ToList(); + context.IsSemanticDistanceExplicitlySelected = true; + } + + if (!isSemanticRequest) + { + if (context.IsSemanticDistanceExplicitlySelected) + { + throw new DataApiBuilderException( + message: "semantic_distance can only be selected when semantic search is used.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + return; + } + + if (operationType is not EntityActionOperation.Read) + { + throw new DataApiBuilderException( + message: "Semantic search is only supported for read operations.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (!runtimeConfig.Entities.TryGetValue(context.EntityName, out Entity? entity) + || entity.SemanticSearch is null + || !entity.SemanticSearch.Enabled) + { + throw new DataApiBuilderException( + message: $"Semantic search is not enabled for entity '{context.EntityName}'.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (!string.IsNullOrWhiteSpace(context.After)) + { + throw new DataApiBuilderException( + message: "Pagination continuation tokens are not supported when semantic search is used.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (context.OrderByClauseInUrl is not null + && context.OrderByClauseInUrl.Any(col => string.Equals(col.ColumnName, SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase))) + { + throw new DataApiBuilderException( + message: "semantic_distance cannot be used in orderBy.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + context.IncludeSemanticDistanceInResponse = context.FieldsToBeReturned.Count == 0 || context.IsSemanticDistanceExplicitlySelected; + } + /// /// Gets the list of HTTP methods defined for entities representing stored procedures. /// When no explicit REST method configuration is present for a stored procedure entity, diff --git a/src/Core/Services/SemanticSearch/ISemanticSearchService.cs b/src/Core/Services/SemanticSearch/ISemanticSearchService.cs new file mode 100644 index 0000000000..f3c124edfd --- /dev/null +++ b/src/Core/Services/SemanticSearch/ISemanticSearchService.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Models; + +namespace Azure.DataApiBuilder.Core.Services.SemanticSearch; + +public interface ISemanticSearchService +{ + Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top); +} diff --git a/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs b/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs new file mode 100644 index 0000000000..dfabecb133 --- /dev/null +++ b/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Models; + +namespace Azure.DataApiBuilder.Core.Services.SemanticSearch; + +/// +/// Placeholder semantic search service used when no embedding/vector integration is configured. +/// +public sealed class NoOpSemanticSearchService : ISemanticSearchService +{ + public static readonly NoOpSemanticSearchService Instance = new(); + + private NoOpSemanticSearchService() + { + } + + public Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top) + { + return Task.FromResult>([]); + } +} diff --git a/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs b/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs index 58ff41c504..9fd42b5e3b 100644 --- a/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System; using Azure.DataApiBuilder.Service.GraphQLBuilder.Directives; using Azure.DataApiBuilder.Service.GraphQLBuilder.GraphQLTypes; using HotChocolate.Language; @@ -50,6 +51,11 @@ private static List GenerateOrderByInputFieldsForBuilt List inputFields = new(); foreach (FieldDefinitionNode field in node.Fields) { + if (string.Equals(field.Name.Value, QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, StringComparison.Ordinal)) + { + continue; + } + if (IsBuiltInType(field.Type)) { inputFields.Add( @@ -110,6 +116,11 @@ private static List GenerateFilterInputFieldsForBuiltI List inputFields = new(); foreach (FieldDefinitionNode field in objectTypeDefinitionNode.Fields) { + if (string.Equals(field.Name.Value, QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, StringComparison.Ordinal)) + { + continue; + } + string fieldTypeName = field.Type.NamedType().Name.Value; if (IsBuiltInType(field.Type)) { diff --git a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs index a2cc63b2c2..100759522f 100644 --- a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs @@ -30,6 +30,8 @@ public static class QueryBuilder public const string GROUP_BY_AGGREGATE_FIELD_ARG_NAME = "field"; public const string GROUP_BY_AGGREGATE_FIELD_DISTINCT_NAME = "distinct"; public const string GROUP_BY_AGGREGATE_FIELD_HAVING_NAME = "having"; + public const string SEMANTIC_SEARCH_ARGUMENT_NAME = SemanticSearchConstants.GRAPHQL_SEARCH_ARGUMENT; + public const string SEMANTIC_THRESHOLD_ARGUMENZT_NAME = SemanticSearchConstants.GRAPHQL_THRESHOLD_ARGUMENT; // Define the enabled database types for aggregation public static readonly HashSet AggregationEnabledDatabaseTypes = new() @@ -191,21 +193,29 @@ public static FieldDefinitionNode GenerateGetAllQuery( location: null, new NameNode(GenerateListQueryName(name.Value, entity)), new StringValueNode($"Get a list of all the {GetDefinedSingularName(name.Value, entity)} items from the database"), - QueryArgumentsForField(filterInputName, orderByInputName), + QueryArgumentsForField(filterInputName, orderByInputName, includeSemanticArguments: entity.SemanticSearch?.Enabled ?? false), new NonNullTypeNode(new NamedTypeNode(returnType.Name)), fieldDefinitionNodeDirectives ); } - public static List QueryArgumentsForField(string filterInputName, string orderByInputName) + public static List QueryArgumentsForField(string filterInputName, string orderByInputName, bool includeSemanticArguments = false) { - return new() + List args = new() { new(location: null, new NameNode(PAGE_START_ARGUMENT_NAME), description: new StringValueNode("The number of items to return from the page start point"), new IntType().ToTypeNode(), defaultValue: null, new List()), new(location: null, new NameNode(PAGINATION_TOKEN_ARGUMENT_NAME), new StringValueNode("A pagination token from a previous query to continue through a paginated list"), new StringType().ToTypeNode(), defaultValue: null, new List()), new(location: null, new NameNode(FILTER_FIELD_NAME), new StringValueNode("Filter options for query"), new NamedTypeNode(filterInputName), defaultValue: null, new List()), new(location: null, new NameNode(ORDER_BY_FIELD_NAME), new StringValueNode("Ordering options for query"), new NamedTypeNode(orderByInputName), defaultValue: null, new List()), }; + + if (includeSemanticArguments) + { + args.Add(new(location: null, new NameNode(SEMANTIC_SEARCH_ARGUMENT_NAME), new StringValueNode("Natural language value used for semantic search."), new StringType().ToTypeNode(), defaultValue: null, new List())); + args.Add(new(location: null, new NameNode(SEMANTIC_THRESHOLD_ARGUMENT_NAME), new StringValueNode("Minimum semantic similarity threshold between 0.0 and 1.0."), new FloatType().ToTypeNode(), defaultValue: null, new List())); + } + + return args; } public static ObjectTypeDefinitionNode AddQueryArgumentsForRelationships(ObjectTypeDefinitionNode node, Dictionary inputObjects) diff --git a/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs b/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs index 76057a76dc..ca13a09803 100644 --- a/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs +++ b/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs @@ -210,6 +210,20 @@ private static ObjectTypeDefinitionNode CreateObjectTypeDefinitionForTableOrView } } + if (configEntity.SemanticSearch?.Enabled is true) + { + List semanticDistanceDirectives = [new DirectiveNode(AutoGeneratedDirectiveType.DirectiveName)]; + FieldDefinitionNode semanticDistanceField = new( + location: null, + name: new NameNode(QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME), + description: new StringValueNode(configEntity.SemanticSearch.OutputDescription), + arguments: [], + type: new NamedTypeNode(new NameNode(FLOAT_TYPE)), + directives: semanticDistanceDirectives); + + fieldDefinitionNodes.TryAdd(QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, semanticDistanceField); + } + // A linking entity is not exposed in the runtime config file but is used by DAB to support multiple mutations on entities with M:N relationship. // Hence we don't need to process relationships for the linking entity itself. if (!configEntity.IsLinkingEntity) diff --git a/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs b/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs index ffe636f6db..89a747df15 100644 --- a/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs +++ b/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs @@ -16,6 +16,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.GraphQLBuilder; using Azure.DataApiBuilder.Service.GraphQLBuilder.Directives; using Azure.DataApiBuilder.Service.GraphQLBuilder.Mutations; @@ -439,6 +440,7 @@ private static async Task GetGQLSchemaCreator(RuntimeConfi contextAccessor: httpContextAccessor.Object, authorizationResolver: authorizationResolver, gQLFilterParser: graphQLFilterParser, + semanticSearchService: NoOpSemanticSearchService.Instance, logger: queryEngineLogger.Object, cache: cacheService, handler: null); diff --git a/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs b/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs index c257a86054..e672422770 100644 --- a/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs +++ b/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs @@ -211,6 +211,54 @@ type Foo @model(name:""Foo"") { Assert.AreEqual("Boolean", hasNextPageField.Type.NamedType().Name.Value, "hasNextPage should be Boolean type"); } + [TestMethod] + [TestCategory("Query Generation")] + [TestCategory("Collection access")] + public void CollectionQueryAddsSemanticArgumentsWhenEnabled() + { + string gql = + @" +type Foo @model(name:""Foo"") { + id: ID! +} + "; + + DocumentNode root = Utf8GraphQLParser.Parse(gql); + Dictionary entityNameToDatabaseType = new() + { + { "Foo", DatabaseType.CosmosDB_NoSQL } + }; + + RuntimeEntities entities = new(new Dictionary + { + { + "Foo", + new Entity( + Source: new("Foo", EntitySourceType.Table, null, null), + GraphQL: new("Foo", "Foos"), + Fields: null, + Rest: new(), + Permissions: [], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions { Enabled = true, RedisIndexName = "idx:foo" }) + } + }); + + DocumentNode queryRoot = QueryBuilder.Build( + root, + entityNameToDatabaseType, + entities, + inputTypes: new(), + entityPermissionsMap: _entityPermissions + ); + + ObjectTypeDefinitionNode query = GetQueryNode(queryRoot); + FieldDefinitionNode collectionField = query.Fields.First(f => f.Name.Value == "foos"); + Assert.IsTrue(collectionField.Arguments.Any(a => a.Name.Value == QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME)); + Assert.IsTrue(collectionField.Arguments.Any(a => a.Name.Value == QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME)); + } + [TestMethod] public void PrimaryKeyFieldAsQueryInput() { diff --git a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs index 05561e4cf9..f0cbc216e3 100644 --- a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs +++ b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs @@ -3040,6 +3040,118 @@ public void ValidateMaxResponseSizeInConfig( } } + [TestMethod] + public void ValidateSemanticSearchRequiresRedisLevel2Configuration() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null)), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured.", + ex.Message); + } + + [TestMethod] + public void ValidateSemanticSearchRequiresRedisIndexNameWhenEnabled() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "semantic-search.redis-index-name is required when semantic-search.enabled is true for entity 'Product'.", + ex.Message); + } + + [TestMethod] + public void ValidateSemanticSearchRejectsReservedFieldName() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: [new FieldMetadata { Name = "id" }, new FieldMetadata { Name = "semantic_distance" }], + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Entity 'Product' cannot enable semantic search because field name 'semantic_distance' is reserved.", + ex.Message); + } + private static RuntimeConfigValidator InitializeRuntimeConfigValidator() { MockFileSystem fileSystem = new(); diff --git a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs index 4da3266271..0aeb309c61 100644 --- a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs @@ -35,7 +35,7 @@ public class RequestParserUnitTests public void ExtractRawQueryParameter_PreservesEncoding(string queryString, string parameterName, string expectedValue) { // Call the internal method directly (no reflection needed) - string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + string result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); Assert.AreEqual(expectedValue, result, $"Expected '{expectedValue}' but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); @@ -49,10 +49,10 @@ public void ExtractRawQueryParameter_PreservesEncoding(string queryString, strin [DataRow("", "$filter", DisplayName = "Empty query string")] [DataRow(null, "$filter", DisplayName = "Null query string")] [DataRow("?otherParam=value", "$filter", DisplayName = "Different parameter")] - public void ExtractRawQueryParameter_ReturnsNull_WhenParameterNotFound(string? queryString, string parameterName) + public void ExtractRawQueryParameter_ReturnsNull_WhenParameterNotFound(string queryString, string parameterName) { // Call the internal method directly (no reflection needed) - string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + string result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); Assert.IsNull(result, $"Expected null but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); @@ -71,7 +71,7 @@ public void ExtractRawQueryParameter_ReturnsNull_WhenParameterNotFound(string? q public void ExtractRawQueryParameter_HandlesEdgeCases(string queryString, string parameterName, string expectedValue) { // Call the internal method directly (no reflection needed) - string? result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); + string result = RequestParser.ExtractRawQueryParameter(queryString, parameterName); Assert.AreEqual(expectedValue, result, $"Expected '{expectedValue}' but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); diff --git a/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs b/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs index 35c86abc82..77f057cd0d 100644 --- a/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Net; using System.Text; +using System.Text.Json; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; @@ -327,6 +328,58 @@ public void PrimaryKeyWithNoValueTest() primaryKeyRoute = "id//title/"; PerformRequestParserPrimaryKeyTest(findRequestContext, primaryKeyRoute, expectsException: true); } + + [TestMethod] + public void InsertRequestBodyCannotSetSemanticDistanceField() + { + RuntimeConfig mockConfig = new( + Schema: "", + DataSource: new(DatabaseType.PostgreSQL, "", new()), + Runtime: new( + Rest: new(Path: "/api"), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null) + ), + Entities: new(new Dictionary + { + { + DEFAULT_NAME, + new Entity( + Source: new(DEFAULT_NAME, EntitySourceType.Table, null, null), + GraphQL: new(DEFAULT_NAME, DEFAULT_NAME), + Fields: null, + Rest: new(), + Permissions: [], + Mappings: null, + Relationships: null) + } + }) + ); + + MockFileSystem fileSystem = new(); + fileSystem.AddFile(FileSystemRuntimeConfigLoader.DEFAULT_CONFIG_FILE_NAME, new MockFileData(mockConfig.ToJson())); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader); + + Mock metadataProviderFactory = new(); + metadataProviderFactory.Setup(x => x.GetMetadataProvider(It.IsAny())).Returns(_mockMetadataStore.Object); + RequestValidator requestValidator = new(metadataProviderFactory.Object, provider); + + using JsonDocument body = JsonDocument.Parse("{\"semantic_distance\":0.81}"); + InsertRequestContext context = new( + entityName: DEFAULT_NAME, + dbo: GetDbo(DEFAULT_SCHEMA, DEFAULT_NAME), + insertPayloadRoot: body.RootElement, + operationType: EntityActionOperation.Insert); + + DataApiBuilderException ex = Assert.ThrowsException( + () => requestValidator.ValidateInsertRequestContext(context)); + + Assert.AreEqual(HttpStatusCode.BadRequest, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.BadRequest, ex.SubStatusCode); + StringAssert.Contains(ex.Message, "semantic_distance is read-only."); + } #endregion #region Helper Methods /// diff --git a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs index 198a9baca8..024febded0 100644 --- a/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs +++ b/src/Service.Tests/UnitTests/SqlQueryExecutorUnitTests.cs @@ -950,8 +950,8 @@ public void TestOboNoUserContext_UsesBaseConnectionString() [DataRow(null, null, "iss and oid/sub", DisplayName = "Authenticated user with no claims throws OboAuthenticationFailure")] public void TestOboEnabled_AuthenticatedUserMissingClaims_ThrowsException( - string? issuer, - string? objectId, + string issuer, + string objectId, string missingClaimDescription) { // Arrange - Create an authenticated HttpContext with incomplete claims @@ -986,8 +986,8 @@ public void TestOboEnabled_AuthenticatedUserMissingClaims_ThrowsException( /// The oid claim value, or null to omit. /// A configured HttpContextAccessor mock with authenticated user. private static Mock CreateHttpContextAccessorWithAuthenticatedUserMissingClaims( - string? issuer, - string? objectId) + string issuer, + string objectId) { Mock httpContextAccessor = new(); DefaultHttpContext context = new(); diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs new file mode 100644 index 0000000000..726137a53f --- /dev/null +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -0,0 +1,538 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; +using Azure.DataApiBuilder.Service.Exceptions; +using Azure.Identity; +using Microsoft.Azure.StackExchangeRedis; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; + +namespace Azure.DataApiBuilder.Service.Services.SemanticSearch; + +/// +/// Resolves semantic candidates by: +/// 1) generating/retrieving an embedding vector for semantic_search input, +/// 2) performing a Redis FT.SEARCH KNN query, +/// 3) extracting SQL column/value pairs from Redis hash/json documents. +/// +public sealed class RedisSemanticSearchService : ISemanticSearchService +{ + private const string EMBED_ENDPOINT_ENV = "DAB_SEMANTIC_EMBED_ENDPOINT"; + private const string EMBED_API_KEY_ENV = "DAB_SEMANTIC_EMBED_API_KEY"; + private const string VECTOR_SCORE_FIELD = "__vector_score"; + private const string DEFAULT_VECTOR_FIELD = "embedding"; + + private readonly RuntimeConfigProvider _runtimeConfigProvider; + private readonly IMetadataProviderFactory _metadataProviderFactory; + private readonly IHttpClientFactory _httpClientFactory; + private readonly ILogger _logger; + + public RedisSemanticSearchService( + RuntimeConfigProvider runtimeConfigProvider, + IMetadataProviderFactory metadataProviderFactory, + IHttpClientFactory httpClientFactory, + ILogger logger) + { + _runtimeConfigProvider = runtimeConfigProvider; + _metadataProviderFactory = metadataProviderFactory; + _httpClientFactory = httpClientFactory; + _logger = logger; + } + + public async Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top) + { + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + string? connectionString = runtimeConfig.Runtime?.Cache?.Level2?.ConnectionString; + + if (string.IsNullOrWhiteSpace(options.RedisIndexName) + || string.IsNullOrWhiteSpace(connectionString) + || top <= 0) + { + return []; + } + + float[] embedding = await GetEmbeddingAsync(semanticSearchValue); + if (embedding.Length == 0) + { + return []; + } + + string dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName); + SourceDefinition sourceDefinition = _metadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(entityName); + HashSet sourceColumns = new(sourceDefinition.Columns.Keys, StringComparer.OrdinalIgnoreCase); + + try + { + using IConnectionMultiplexer multiplexer = await CreateConnectionMultiplexerAsync(connectionString); + IDatabase db = multiplexer.GetDatabase(); + + string vectorFieldName = await ResolveVectorFieldNameAsync(db, options.RedisIndexName!, options.RedisIndexType); + byte[] vectorBytes = ToRedisVectorBytes(embedding); + + RedisResult rawResult = await db.ExecuteAsync( + "FT.SEARCH", + options.RedisIndexName!, + $"*=>[KNN {top} @{vectorFieldName} $vec AS {VECTOR_SCORE_FIELD}]", + "PARAMS", + "2", + "vec", + vectorBytes, + "SORTBY", + VECTOR_SCORE_FIELD, + "DIALECT", + "2", + "LIMIT", + "0", + top.ToString(CultureInfo.InvariantCulture)); + + return ParseCandidates(rawResult, options.RedisIndexType, sourceColumns, primaryKeyColumns, similarityThreshold); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Semantic search query failed for entity {entityName} and index {indexName}.", entityName, options.RedisIndexName); + throw new DataApiBuilderException( + message: $"Semantic search index '{options.RedisIndexName}' for entity '{entityName}' was not found or could not be queried.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + } + + private static byte[] ToRedisVectorBytes(float[] vector) + { + byte[] bytes = new byte[vector.Length * sizeof(float)]; + Buffer.BlockCopy(vector, 0, bytes, 0, bytes.Length); + return bytes; + } + + private async Task GetEmbeddingAsync(string semanticSearchValue) + { + if (TryParseVectorText(semanticSearchValue, out float[]? parsedVector) && parsedVector is not null) + { + return parsedVector; + } + + string? endpoint = Environment.GetEnvironmentVariable(EMBED_ENDPOINT_ENV); + if (string.IsNullOrWhiteSpace(endpoint)) + { + return []; + } + + HttpClient client = _httpClientFactory.CreateClient(); + string? apiKey = Environment.GetEnvironmentVariable(EMBED_API_KEY_ENV); + if (!string.IsNullOrWhiteSpace(apiKey)) + { + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); + } + + string payload = JsonSerializer.Serialize(new Dictionary { { "input", semanticSearchValue } }); + using HttpRequestMessage request = new(HttpMethod.Post, endpoint) + { + Content = new StringContent(payload, Encoding.UTF8, "application/json") + }; + + using HttpResponseMessage response = await client.SendAsync(request); + if (!response.IsSuccessStatusCode) + { + return []; + } + + string body = await response.Content.ReadAsStringAsync(); + return TryExtractEmbedding(body, out float[]? embedding) && embedding is not null ? embedding : []; + } + + private static bool TryParseVectorText(string text, out float[]? vector) + { + vector = null; + + try + { + using JsonDocument json = JsonDocument.Parse(text); + if (json.RootElement.ValueKind is not JsonValueKind.Array) + { + return false; + } + + List values = []; + foreach (JsonElement element in json.RootElement.EnumerateArray()) + { + if (!element.TryGetSingle(out float current)) + { + return false; + } + + values.Add(current); + } + + vector = values.ToArray(); + return vector.Length > 0; + } + catch + { + return false; + } + } + + private static bool TryExtractEmbedding(string responseBody, out float[]? embedding) + { + embedding = null; + + try + { + using JsonDocument json = JsonDocument.Parse(responseBody); + + if (json.RootElement.ValueKind is JsonValueKind.Array) + { + return TryReadEmbeddingArray(json.RootElement, out embedding); + } + + if (json.RootElement.ValueKind is JsonValueKind.Object) + { + if (json.RootElement.TryGetProperty("embedding", out JsonElement directEmbedding) + && TryReadEmbeddingArray(directEmbedding, out embedding)) + { + return true; + } + + if (json.RootElement.TryGetProperty("data", out JsonElement data) + && data.ValueKind is JsonValueKind.Array + && data.GetArrayLength() > 0) + { + JsonElement first = data[0]; + if (first.ValueKind is JsonValueKind.Object + && first.TryGetProperty("embedding", out JsonElement nestedEmbedding) + && TryReadEmbeddingArray(nestedEmbedding, out embedding)) + { + return true; + } + } + } + + return false; + } + catch + { + return false; + } + } + + private static bool TryReadEmbeddingArray(JsonElement array, out float[]? embedding) + { + embedding = null; + + if (array.ValueKind is not JsonValueKind.Array) + { + return false; + } + + List values = []; + foreach (JsonElement item in array.EnumerateArray()) + { + if (!item.TryGetSingle(out float current)) + { + return false; + } + + values.Add(current); + } + + embedding = values.ToArray(); + return embedding.Length > 0; + } + + private static async Task CreateConnectionMultiplexerAsync(string connectionString) + { + ConfigurationOptions options = ConfigurationOptions.Parse(connectionString); + + if (Startup.ShouldUseEntraAuthForRedis(options)) + { + options = await options.ConfigureForAzureWithTokenCredentialAsync(new DefaultAzureCredential()); + } + + return await ConnectionMultiplexer.ConnectAsync(options); + } + + private static async Task ResolveVectorFieldNameAsync(IDatabase db, string indexName, string redisIndexType) + { + try + { + RedisResult infoResult = await db.ExecuteAsync("FT.INFO", indexName); + if (infoResult.IsNull) + { + return DEFAULT_VECTOR_FIELD; + } + + RedisResult[]? infoItems = AsRedisArray(infoResult); + if (infoItems is not null) + { + for (int i = 0; i + 1 < infoItems.Length; i += 2) + { + string? key = infoItems[i].ToString(); + if (!string.Equals(key, "attributes", StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + RedisResult[]? attributes = AsRedisArray(infoItems[i + 1]); + if (attributes is null) + { + continue; + } + + foreach (RedisResult attribute in attributes) + { + RedisResult[]? tokens = AsRedisArray(attribute); + if (tokens is null) + { + continue; + } + + bool isVector = false; + string? identifier = null; + string? alias = null; + + for (int t = 0; t + 1 < tokens.Length; t += 2) + { + string? tokenKey = tokens[t].ToString(); + string? tokenValue = tokens[t + 1].ToString(); + + if (string.Equals(tokenKey, "type", StringComparison.OrdinalIgnoreCase) + && string.Equals(tokenValue, "VECTOR", StringComparison.OrdinalIgnoreCase)) + { + isVector = true; + } + else if (string.Equals(tokenKey, "identifier", StringComparison.OrdinalIgnoreCase)) + { + identifier = tokenValue; + } + else if (string.Equals(tokenKey, "attribute", StringComparison.OrdinalIgnoreCase)) + { + alias = tokenValue; + } + } + + if (isVector) + { + string? selected = string.Equals(redisIndexType, "json", StringComparison.OrdinalIgnoreCase) + ? alias ?? identifier + : identifier ?? alias; + + if (!string.IsNullOrWhiteSpace(selected)) + { + return selected.TrimStart('$', '.'); + } + } + } + } + } + } + catch + { + // Fall back to default vector field name; caller will throw a semantic index error if query fails. + } + + return DEFAULT_VECTOR_FIELD; + } + + private static IReadOnlyList ParseCandidates( + RedisResult rawResult, + string redisIndexType, + HashSet sourceColumns, + IReadOnlyList primaryKeyColumns, + double similarityThreshold) + { + RedisResult[]? rows = AsRedisArray(rawResult); + if (rawResult.IsNull || rows is null || rows.Length < 3) + { + return []; + } + + List results = []; + for (int i = 1; i + 1 < rows.Length; i += 2) + { + RedisResult[]? payload = AsRedisArray(rows[i + 1]); + if (payload is null) + { + continue; + } + + Dictionary extracted = ExtractDocumentFields(payload, redisIndexType); + double similarity = TryReadSimilarity(extracted); + + if (similarity < similarityThreshold) + { + continue; + } + + Dictionary sqlColumns = new(StringComparer.OrdinalIgnoreCase); + foreach ((string key, object? value) in extracted) + { + string normalized = NormalizeFieldName(key); + if (sourceColumns.Contains(normalized)) + { + sqlColumns[normalized] = value; + } + } + + if (sqlColumns.Count == 0) + { + continue; + } + + Dictionary primaryKeys = new(StringComparer.OrdinalIgnoreCase); + bool hasAllPrimaryKeys = true; + foreach (string pk in primaryKeyColumns) + { + if (!sqlColumns.TryGetValue(pk, out object? pkValue) || pkValue is null) + { + hasAllPrimaryKeys = false; + break; + } + + primaryKeys[pk] = pkValue; + } + + if (!hasAllPrimaryKeys) + { + continue; + } + + results.Add(new SemanticSearchCandidate(primaryKeys, sqlColumns, similarity)); + } + + return results; + } + + private static RedisResult[]? AsRedisArray(RedisResult result) + { + try + { + RedisResult[]? value = (RedisResult[]?)result; + return value; + } + catch + { + return null; + } + } + + private static Dictionary ExtractDocumentFields(RedisResult[] payload, string redisIndexType) + { + Dictionary fields = new(StringComparer.OrdinalIgnoreCase); + + for (int i = 0; i + 1 < payload.Length; i += 2) + { + string? key = payload[i].ToString(); + string? value = payload[i + 1].ToString(); + if (string.IsNullOrWhiteSpace(key)) + { + continue; + } + + if (string.Equals(redisIndexType, "json", StringComparison.OrdinalIgnoreCase) + && string.Equals(key, "$", StringComparison.Ordinal)) + { + MergeJsonDocumentFields(fields, value); + continue; + } + + fields[key] = value; + } + + return fields; + } + + private static void MergeJsonDocumentFields(Dictionary fields, string? json) + { + if (string.IsNullOrWhiteSpace(json)) + { + return; + } + + try + { + using JsonDocument doc = JsonDocument.Parse(json); + if (doc.RootElement.ValueKind is not JsonValueKind.Object) + { + return; + } + + foreach (JsonProperty property in doc.RootElement.EnumerateObject()) + { + fields[property.Name] = property.Value.ValueKind switch + { + JsonValueKind.Null => null, + JsonValueKind.String => property.Value.GetString(), + JsonValueKind.Number => property.Value.ToString(), + JsonValueKind.True => true, + JsonValueKind.False => false, + _ => property.Value.ToString() + }; + } + } + catch + { + // Ignore malformed JSON payloads and continue with other fields. + } + } + + private static double TryReadSimilarity(Dictionary fields) + { + if (!fields.TryGetValue(VECTOR_SCORE_FIELD, out object? scoreObj) + || scoreObj is null + || !double.TryParse(scoreObj.ToString(), NumberStyles.Float, CultureInfo.InvariantCulture, out double rawScore)) + { + return 0.0; + } + + // Redis KNN score is distance; normalize to similarity for semantic-threshold semantics. + double similarity = 1.0 - rawScore; + if (similarity < 0.0) + { + similarity = 0.0; + } + + if (similarity > 1.0) + { + similarity = 1.0; + } + + return similarity; + } + + private static string NormalizeFieldName(string field) + { + string normalized = field.Trim(); + + if (normalized.StartsWith("$.", StringComparison.Ordinal)) + { + normalized = normalized[2..]; + } + + if (normalized.Contains('.', StringComparison.Ordinal)) + { + normalized = normalized[(normalized.LastIndexOf('.') + 1)..]; + } + + return normalized; + } +} \ No newline at end of file diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index f084185929..6d27fc04da 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -26,11 +26,13 @@ using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Core.Telemetry; using Azure.DataApiBuilder.Mcp.Core; using Azure.DataApiBuilder.Service.Controllers; using Azure.DataApiBuilder.Service.Exceptions; using Azure.DataApiBuilder.Service.HealthCheck; +using Azure.DataApiBuilder.Service.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Telemetry; using Azure.DataApiBuilder.Service.Utilities; using Azure.Identity; @@ -312,6 +314,8 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddHttpClient(); + services.AddSingleton(); // ILogger explicit creation required for logger to use --LogLevel startup argument specified. services.AddSingleton>(implementationFactory: (serviceProvider) => From e7c997c3f655881c5ecffa6645da93ba09b8d1d7 Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Wed, 10 Jun 2026 07:45:06 -0700 Subject: [PATCH 02/12] More features and tests --- schemas/dab.draft.schema.json | 15 +++++ src/Cli.Tests/ConfigureOptionsTests.cs | 48 ++++++++++++++++ src/Cli/Commands/ConfigureOptions.cs | 10 ++++ src/Cli/ConfigGenerator.cs | 57 +++++++++++++++++++ src/Config/ObjectModel/RuntimeOptions.cs | 4 ++ .../RuntimeSemanticSearchOptions.cs | 31 ++++++++++ .../Configurations/RuntimeConfigValidator.cs | 21 +++++++ .../Queries/QueryBuilder.cs | 5 +- .../UnitTests/ConfigValidationUnitTests.cs | 42 +++++++++++++- .../RedisSemanticSearchService.cs | 10 ++-- 10 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 185d1b3903..0857901674 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -490,6 +490,21 @@ } } }, + "semantic-search": { + "type": "object", + "description": "Runtime semantic-search configuration.", + "additionalProperties": false, + "properties": { + "embedding-endpoint": { + "type": "string", + "description": "Endpoint used to generate embeddings from semantic-search input text." + }, + "embedding-api-key": { + "type": "string", + "description": "Optional API key used as bearer token for embedding endpoint calls." + } + } + }, "compression": { "type": "object", "description": "Configures HTTP response compression settings.", diff --git a/src/Cli.Tests/ConfigureOptionsTests.cs b/src/Cli.Tests/ConfigureOptionsTests.cs index 5824b2e054..c71eade35d 100644 --- a/src/Cli.Tests/ConfigureOptionsTests.cs +++ b/src/Cli.Tests/ConfigureOptionsTests.cs @@ -541,6 +541,54 @@ public void TestUpdateTTLForCacheSettings(int updatedTtlValue) Assert.AreEqual(updatedTtlValue, runtimeConfig.Runtime.Cache.TtlSeconds); } + /// + /// Tests that running "dab configure --runtime.semantic-search.embedding-endpoint {value}" updates runtime semantic-search embedding endpoint. + /// + [TestMethod] + public void TestUpdateEmbeddingEndpointForSemanticSearchRuntimeSettings() + { + // Arrange -> all the setup which includes creating options. + SetupFileSystemWithInitialConfig(INITIAL_CONFIG); + string updatedEmbeddingEndpoint = "https://example.org/embed"; + + // Act: Attempts to update embedding endpoint value + ConfigureOptions options = new( + runtimeSemanticSearchEmbeddingEndpoint: updatedEmbeddingEndpoint, + config: TEST_RUNTIME_CONFIG_FILE + ); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert: Validate the embedding endpoint value is updated + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? runtimeConfig)); + Assert.AreEqual(updatedEmbeddingEndpoint, runtimeConfig.Runtime?.SemanticSearch?.EmbeddingEndpoint); + } + + /// + /// Tests that running "dab configure --runtime.semantic-search.embedding-api-key {value}" updates runtime semantic-search embedding API key. + /// + [TestMethod] + public void TestUpdateEmbeddingApiKeyForSemanticSearchRuntimeSettings() + { + // Arrange -> all the setup which includes creating options. + SetupFileSystemWithInitialConfig(INITIAL_CONFIG); + string updatedEmbeddingApiKey = "test-api-key"; + + // Act: Attempts to update embedding API key value + ConfigureOptions options = new( + runtimeSemanticSearchEmbeddingApiKey: updatedEmbeddingApiKey, + config: TEST_RUNTIME_CONFIG_FILE + ); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert: Validate the embedding API key value is updated + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? runtimeConfig)); + Assert.AreEqual(updatedEmbeddingApiKey, runtimeConfig.Runtime?.SemanticSearch?.EmbeddingApiKey); + } + /// /// Tests that running "dab configure --runtime.compression.level {value}" on a config with various values results /// in runtime config update. Takes in updated value for compression.level and diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 8fd4d5bc42..70f4d586b4 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -55,6 +55,8 @@ public ConfigureOptions( int? runtimeCacheTtl = null, string? runtimeCacheLevel2Provider = null, string? runtimeCacheLevel2ConnectionString = null, + string? runtimeSemanticSearchEmbeddingEndpoint = null, + string? runtimeSemanticSearchEmbeddingApiKey = null, CompressionLevel? runtimeCompressionLevel = null, HostMode? runtimeHostMode = null, IEnumerable? runtimeHostCorsOrigins = null, @@ -121,6 +123,8 @@ public ConfigureOptions( RuntimeCacheTTL = runtimeCacheTtl; RuntimeCacheLevel2Provider = runtimeCacheLevel2Provider; RuntimeCacheLevel2ConnectionString = runtimeCacheLevel2ConnectionString; + RuntimeSemanticSearchEmbeddingEndpoint = runtimeSemanticSearchEmbeddingEndpoint; + RuntimeSemanticSearchEmbeddingApiKey = runtimeSemanticSearchEmbeddingApiKey; // Compression RuntimeCompressionLevel = runtimeCompressionLevel; // Host @@ -252,6 +256,12 @@ public ConfigureOptions( [Option("runtime.cache.level-2.connection-string", Required = false, HelpText = "Set level-2 cache connection string.")] public string? RuntimeCacheLevel2ConnectionString { get; } + [Option("runtime.semantic-search.embedding-endpoint", Required = false, HelpText = "Set semantic-search embedding endpoint.")] + public string? RuntimeSemanticSearchEmbeddingEndpoint { get; } + + [Option("runtime.semantic-search.embedding-api-key", Required = false, HelpText = "Set semantic-search embedding API key.")] + public string? RuntimeSemanticSearchEmbeddingApiKey { get; } + [Option("runtime.compression.level", Required = false, HelpText = "Set the response compression level. Allowed values: optimal (default), fastest, none.")] public CompressionLevel? RuntimeCompressionLevel { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index e450cc9ad2..f4c4b9edab 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -999,6 +999,22 @@ private static bool TryUpdateConfiguredRuntimeOptions( } } + // Semantic Search: Embedding endpoint and API key + if (options.RuntimeSemanticSearchEmbeddingEndpoint != null || + options.RuntimeSemanticSearchEmbeddingApiKey != null) + { + RuntimeSemanticSearchOptions? updatedSemanticSearchOptions = runtimeConfig?.Runtime?.SemanticSearch ?? new(); + bool status = TryUpdateConfiguredSemanticSearchValues(options, ref updatedSemanticSearchOptions); + if (status) + { + runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { SemanticSearch = updatedSemanticSearchOptions } }; + } + else + { + return false; + } + } + // Compression: Level if (options.RuntimeCompressionLevel != null) { @@ -1445,6 +1461,47 @@ private static bool TryUpdateConfiguredCacheValues( } } + /// + /// Attempts to update the Config parameters in the semantic-search runtime settings based on the provided value. + /// + /// options. + /// updatedSemanticSearchOptions. + /// True if updates succeed, else false. + private static bool TryUpdateConfiguredSemanticSearchValues( + ConfigureOptions options, + ref RuntimeSemanticSearchOptions? updatedSemanticSearchOptions) + { + try + { + if (options?.RuntimeSemanticSearchEmbeddingEndpoint is not null) + { + RuntimeSemanticSearchOptions current = updatedSemanticSearchOptions ?? new(); + updatedSemanticSearchOptions = current with + { + EmbeddingEndpoint = options.RuntimeSemanticSearchEmbeddingEndpoint + }; + _logger.LogInformation("Updated RuntimeConfig with runtime.semantic-search.embedding-endpoint."); + } + + if (options?.RuntimeSemanticSearchEmbeddingApiKey is not null) + { + RuntimeSemanticSearchOptions current = updatedSemanticSearchOptions ?? new(); + updatedSemanticSearchOptions = current with + { + EmbeddingApiKey = options.RuntimeSemanticSearchEmbeddingApiKey + }; + _logger.LogInformation("Updated RuntimeConfig with runtime.semantic-search.embedding-api-key."); + } + + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to update RuntimeConfig.SemanticSearch with exception message: {exceptionMessage}.", ex.Message); + return false; + } + } + /// /// Attempts to update the Config parameters in the Compression runtime settings based on the provided value. /// Validates user-provided parameters and then returns true if the updated Compression options diff --git a/src/Config/ObjectModel/RuntimeOptions.cs b/src/Config/ObjectModel/RuntimeOptions.cs index 525ea8d089..fa76db17da 100644 --- a/src/Config/ObjectModel/RuntimeOptions.cs +++ b/src/Config/ObjectModel/RuntimeOptions.cs @@ -15,6 +15,8 @@ public record RuntimeOptions public string? BaseRoute { get; init; } public TelemetryOptions? Telemetry { get; init; } public RuntimeCacheOptions? Cache { get; init; } + [JsonPropertyName("semantic-search")] + public RuntimeSemanticSearchOptions? SemanticSearch { get; init; } public PaginationOptions? Pagination { get; init; } public RuntimeHealthCheckConfig? Health { get; init; } public CompressionOptions? Compression { get; init; } @@ -28,6 +30,7 @@ public RuntimeOptions( string? BaseRoute = null, TelemetryOptions? Telemetry = null, RuntimeCacheOptions? Cache = null, + RuntimeSemanticSearchOptions? SemanticSearch = null, PaginationOptions? Pagination = null, RuntimeHealthCheckConfig? Health = null, CompressionOptions? Compression = null) @@ -39,6 +42,7 @@ public RuntimeOptions( this.BaseRoute = BaseRoute; this.Telemetry = Telemetry; this.Cache = Cache; + this.SemanticSearch = SemanticSearch; this.Pagination = Pagination; this.Health = Health; this.Compression = Compression; diff --git a/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs b/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs new file mode 100644 index 0000000000..1c72667de2 --- /dev/null +++ b/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Runtime semantic search configuration. +/// +public record RuntimeSemanticSearchOptions +{ + /// + /// Endpoint used to generate embeddings from semantic-search input text. + /// + [JsonPropertyName("embedding-endpoint")] + public string? EmbeddingEndpoint { get; init; } = null; + + /// + /// Optional API key used as a bearer token for embedding endpoint calls. + /// + [JsonPropertyName("embedding-api-key")] + public string? EmbeddingApiKey { get; init; } = null; + + [JsonConstructor] + public RuntimeSemanticSearchOptions(string? EmbeddingEndpoint = null, string? EmbeddingApiKey = null) + { + this.EmbeddingEndpoint = EmbeddingEndpoint; + this.EmbeddingApiKey = EmbeddingApiKey; + } +} \ No newline at end of file diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index 7283fe09e4..a9219964c4 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -68,6 +68,9 @@ public class RuntimeConfigValidator : IConfigValidator private const string SEMANTIC_SEARCH_REDIS_REQUIREMENT_ERR_MSG = "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured."; + private const string SEMANTIC_SEARCH_EMBEDDING_REQUIREMENT_ERR_MSG = + "Semantic search requires runtime.semantic-search.embedding-endpoint to be configured."; + private static readonly HashSet _reservedSemanticRestNames = [ "semantic_search", @@ -764,6 +767,14 @@ private void ValidateSemanticSearchConfiguration(RuntimeConfig runtimeConfig, st subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); } + if (!HasValidSemanticEmbeddingConfiguration(runtimeConfig)) + { + HandleOrRecordException(new DataApiBuilderException( + message: SEMANTIC_SEARCH_EMBEDDING_REQUIREMENT_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + if (semantic.RedisIndexMultiplier < 1 || semantic.RedisIndexMultiplier > 10) { HandleOrRecordException(new DataApiBuilderException( @@ -794,6 +805,16 @@ private static bool HasValidSemanticRedisConfiguration(RuntimeConfig runtimeConf && !string.IsNullOrWhiteSpace(level2.ConnectionString); } + /// + /// Validates embedding endpoint prerequisites for semantic-search. + /// + private static bool HasValidSemanticEmbeddingConfiguration(RuntimeConfig runtimeConfig) + { + RuntimeSemanticSearchOptions? semanticSearch = runtimeConfig.Runtime?.SemanticSearch; + return semanticSearch is not null + && !string.IsNullOrWhiteSpace(semanticSearch.EmbeddingEndpoint); + } + /// /// Ensures configured field names do not collide with semantic reserved names. /// diff --git a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs index 100759522f..69d57c86cd 100644 --- a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs @@ -30,8 +30,9 @@ public static class QueryBuilder public const string GROUP_BY_AGGREGATE_FIELD_ARG_NAME = "field"; public const string GROUP_BY_AGGREGATE_FIELD_DISTINCT_NAME = "distinct"; public const string GROUP_BY_AGGREGATE_FIELD_HAVING_NAME = "having"; - public const string SEMANTIC_SEARCH_ARGUMENT_NAME = SemanticSearchConstants.GRAPHQL_SEARCH_ARGUMENT; - public const string SEMANTIC_THRESHOLD_ARGUMENZT_NAME = SemanticSearchConstants.GRAPHQL_THRESHOLD_ARGUMENT; + public const string SEMANTIC_SEARCH_ARGUMENT_NAME = "semanticSearch"; + public const string SEMANTIC_THRESHOLD_ARGUMENT_NAME = "semanticThreshold"; + public const string SEMANTIC_DISTANCE_FIELD_NAME = "semanticDistance"; // Define the enabled database types for aggregation public static readonly HashSet AggregationEnabledDatabaseTypes = new() diff --git a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs index f0cbc216e3..e7d6d457eb 100644 --- a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs +++ b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs @@ -3113,6 +3113,45 @@ public void ValidateSemanticSearchRequiresRedisIndexNameWhenEnabled() ex.Message); } + [TestMethod] + public void ValidateSemanticSearchRequiresEmbeddingEndpointConfiguration() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Semantic search requires runtime.semantic-search.embedding-endpoint to be configured.", + ex.Message); + } + [TestMethod] public void ValidateSemanticSearchRejectsReservedFieldName() { @@ -3143,7 +3182,8 @@ public void ValidateSemanticSearchRejectsReservedFieldName() Cache: new RuntimeCacheOptions { Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") - }), + }, + SemanticSearch: new RuntimeSemanticSearchOptions(EmbeddingEndpoint: "https://example.org/embed")), Entities: new(new Dictionary { ["Product"] = entity })); DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs index 726137a53f..0b67657187 100644 --- a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -31,8 +31,6 @@ namespace Azure.DataApiBuilder.Service.Services.SemanticSearch; /// public sealed class RedisSemanticSearchService : ISemanticSearchService { - private const string EMBED_ENDPOINT_ENV = "DAB_SEMANTIC_EMBED_ENDPOINT"; - private const string EMBED_API_KEY_ENV = "DAB_SEMANTIC_EMBED_API_KEY"; private const string VECTOR_SCORE_FIELD = "__vector_score"; private const string DEFAULT_VECTOR_FIELD = "embedding"; @@ -71,7 +69,7 @@ public async Task> GetCandidatesAsync( return []; } - float[] embedding = await GetEmbeddingAsync(semanticSearchValue); + float[] embedding = await GetEmbeddingAsync(runtimeConfig, semanticSearchValue); if (embedding.Length == 0) { return []; @@ -124,21 +122,21 @@ private static byte[] ToRedisVectorBytes(float[] vector) return bytes; } - private async Task GetEmbeddingAsync(string semanticSearchValue) + private async Task GetEmbeddingAsync(RuntimeConfig runtimeConfig, string semanticSearchValue) { if (TryParseVectorText(semanticSearchValue, out float[]? parsedVector) && parsedVector is not null) { return parsedVector; } - string? endpoint = Environment.GetEnvironmentVariable(EMBED_ENDPOINT_ENV); + string? endpoint = runtimeConfig.Runtime?.SemanticSearch?.EmbeddingEndpoint; if (string.IsNullOrWhiteSpace(endpoint)) { return []; } HttpClient client = _httpClientFactory.CreateClient(); - string? apiKey = Environment.GetEnvironmentVariable(EMBED_API_KEY_ENV); + string? apiKey = runtimeConfig.Runtime?.SemanticSearch?.EmbeddingApiKey; if (!string.IsNullOrWhiteSpace(apiKey)) { client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); From 990087d466ed8cd7c91061d77736d1358cc296b4 Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Fri, 5 Jun 2026 09:59:51 -0700 Subject: [PATCH 03/12] DAB Phase 4 Implementation --- schemas/dab.draft.schema.json | 61 ++ src/Cli/Commands/AddOptions.cs | 14 + src/Cli/Commands/ConfigureOptions.cs | 10 + src/Cli/Commands/EntityOptions.cs | 35 ++ src/Cli/Commands/UpdateOptions.cs | 14 + src/Cli/ConfigGenerator.cs | 59 +- src/Cli/Utils.cs | 88 +++ src/Config/ObjectModel/Entity.cs | 6 +- .../EntitySemanticSearchOptions.cs | 39 ++ .../Configurations/RuntimeConfigValidator.cs | 147 +++++ .../RestRequestContexts/RestRequestContext.cs | 25 + src/Core/Models/SemanticSearchCandidate.cs | 13 + src/Core/Models/SemanticSearchConstants.cs | 15 + src/Core/Parsers/RequestParser.cs | 33 ++ .../Resolvers/Factories/QueryEngineFactory.cs | 7 +- .../Sql Query Structures/SqlQueryStructure.cs | 108 ++++ src/Core/Resolvers/SqlQueryEngine.cs | 257 ++++++++- .../Services/OpenAPI/OpenApiDocumentor.cs | 38 +- src/Core/Services/RequestValidator.cs | 19 + src/Core/Services/RestService.cs | 72 +++ .../SemanticSearch/ISemanticSearchService.cs | 18 + .../NoOpSemanticSearchService.cs | 30 + .../Queries/InputTypeBuilder.cs | 11 +- .../Queries/QueryBuilder.cs | 16 +- .../Sql/SchemaConverter.cs | 14 + .../MultipleMutationBuilderTests.cs | 2 + .../GraphQLBuilder/QueryBuilderTests.cs | 48 ++ .../UnitTests/ConfigValidationUnitTests.cs | 112 ++++ .../UnitTests/RequestValidatorUnitTests.cs | 53 ++ .../RedisSemanticSearchService.cs | 538 ++++++++++++++++++ src/Service/Startup.cs | 4 + 31 files changed, 1889 insertions(+), 17 deletions(-) create mode 100644 src/Config/ObjectModel/EntitySemanticSearchOptions.cs create mode 100644 src/Core/Models/SemanticSearchCandidate.cs create mode 100644 src/Core/Models/SemanticSearchConstants.cs create mode 100644 src/Core/Services/SemanticSearch/ISemanticSearchService.cs create mode 100644 src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs create mode 100644 src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 6aa8f02641..94b731bd56 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -1487,6 +1487,67 @@ } } }, + "semantic-search": { + "type": "object", + "description": "Semantic search configuration for this entity.", + "additionalProperties": false, + "properties": { + "enabled": { + "$ref": "#/$defs/boolean-or-string", + "description": "Enables semantic search for this entity.", + "default": false + }, + "redis-index-name": { + "type": "string", + "description": "Name of the Redis vector index used for semantic search for this entity." + }, + "redis-index-type": { + "type": "string", + "description": "Redis index storage type used by the semantic search index.", + "enum": ["hash", "json"], + "default": "hash" + }, + "redis-index-multiplier": { + "type": "integer", + "description": "Multiplier applied to requested result count when querying Redis.", + "default": 2, + "minimum": 1, + "maximum": 10 + }, + "similarity-threshold": { + "type": "number", + "description": "Minimum Redis similarity value required for a semantic match.", + "default": 0.8, + "minimum": 0.0, + "maximum": 1.0 + }, + "input-description": { + "type": "string", + "description": "Description surfaced in API metadata for semantic search input.", + "default": "Natural language value used for semantic search." + }, + "output-description": { + "type": "string", + "description": "Description surfaced in API metadata for semantic distance output.", + "default": "Semantic distance score returned by semantic search." + } + }, + "allOf": [ + { + "if": { + "properties": { + "enabled": { + "const": true + } + }, + "required": ["enabled"] + }, + "then": { + "required": ["redis-index-name"] + } + } + ] + }, "mcp": { "oneOf": [ { diff --git a/src/Cli/Commands/AddOptions.cs b/src/Cli/Commands/AddOptions.cs index 8fab1937d7..0bf66d539f 100644 --- a/src/Cli/Commands/AddOptions.cs +++ b/src/Cli/Commands/AddOptions.cs @@ -45,6 +45,13 @@ public AddOptions( IEnumerable? fieldsAliasCollection, IEnumerable? fieldsDescriptionCollection, IEnumerable? fieldsPrimaryKeyCollection, + string? semanticSearchEnabled = null, + string? semanticSearchRedisIndexName = null, + string? semanticSearchRedisIndexType = null, + string? semanticSearchRedisIndexMultiplier = null, + string? semanticSearchSimilarityThreshold = null, + string? semanticSearchInputDescription = null, + string? semanticSearchOutputDescription = null, string? mcpDmlTools = null, string? mcpCustomTool = null, string? config = null @@ -66,6 +73,13 @@ public AddOptions( cacheTtlSeconds, cacheLevel, healthEnabled, + semanticSearchEnabled, + semanticSearchRedisIndexName, + semanticSearchRedisIndexType, + semanticSearchRedisIndexMultiplier, + semanticSearchSimilarityThreshold, + semanticSearchInputDescription, + semanticSearchOutputDescription, description, parametersNameCollection, parametersDescriptionCollection, diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 12f5f55455..b64de19981 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -61,6 +61,8 @@ public ConfigureOptions( int? runtimePaginationMaxPageSize = null, int? runtimePaginationDefaultPageSize = null, bool? runtimePaginationNextLinkRelative = null, + string? runtimeCacheLevel2Provider = null, + string? runtimeCacheLevel2ConnectionString = null, CompressionLevel? runtimeCompressionLevel = null, bool? runtimeHealthEnabled = null, int? runtimeHealthCacheTtlSeconds = null, @@ -161,6 +163,8 @@ public ConfigureOptions( RuntimePaginationMaxPageSize = runtimePaginationMaxPageSize; RuntimePaginationDefaultPageSize = runtimePaginationDefaultPageSize; RuntimePaginationNextLinkRelative = runtimePaginationNextLinkRelative; + RuntimeCacheLevel2Provider = runtimeCacheLevel2Provider; + RuntimeCacheLevel2ConnectionString = runtimeCacheLevel2ConnectionString; // Compression RuntimeCompressionLevel = runtimeCompressionLevel; // Health @@ -342,6 +346,12 @@ public ConfigureOptions( [Option("runtime.pagination.next-link-relative", Required = false, HelpText = "Use relative URLs for pagination next links. Default: false (boolean).")] public bool? RuntimePaginationNextLinkRelative { get; } + [Option("runtime.cache.level-2.provider", Required = false, HelpText = "Set level-2 cache provider. Allowed values: redis.")] + public string? RuntimeCacheLevel2Provider { get; } + + [Option("runtime.cache.level-2.connection-string", Required = false, HelpText = "Set level-2 cache connection string.")] + public string? RuntimeCacheLevel2ConnectionString { get; } + [Option("runtime.compression.level", Required = false, HelpText = "Set the response compression level. Allowed values: optimal (default), fastest, none.")] public CompressionLevel? RuntimeCompressionLevel { get; } diff --git a/src/Cli/Commands/EntityOptions.cs b/src/Cli/Commands/EntityOptions.cs index 76988d4c75..a0da509448 100644 --- a/src/Cli/Commands/EntityOptions.cs +++ b/src/Cli/Commands/EntityOptions.cs @@ -27,6 +27,13 @@ public EntityOptions( string? cacheTtlSeconds, string? cacheLevel, string? healthEnabled, + string? semanticSearchEnabled, + string? semanticSearchRedisIndexName, + string? semanticSearchRedisIndexType, + string? semanticSearchRedisIndexMultiplier, + string? semanticSearchSimilarityThreshold, + string? semanticSearchInputDescription, + string? semanticSearchOutputDescription, string? description, IEnumerable? parametersNameCollection, IEnumerable? parametersDescriptionCollection, @@ -58,6 +65,13 @@ public EntityOptions( CacheTtlSeconds = cacheTtlSeconds; CacheLevel = cacheLevel; HealthEnabled = healthEnabled; + SemanticSearchEnabled = semanticSearchEnabled; + SemanticSearchRedisIndexName = semanticSearchRedisIndexName; + SemanticSearchRedisIndexType = semanticSearchRedisIndexType; + SemanticSearchRedisIndexMultiplier = semanticSearchRedisIndexMultiplier; + SemanticSearchSimilarityThreshold = semanticSearchSimilarityThreshold; + SemanticSearchInputDescription = semanticSearchInputDescription; + SemanticSearchOutputDescription = semanticSearchOutputDescription; Description = description; ParametersNameCollection = parametersNameCollection; ParametersDescriptionCollection = parametersDescriptionCollection; @@ -120,6 +134,27 @@ public EntityOptions( [Option("health.enabled", Required = false, HelpText = "Enable health checks for this entity. Default: true (boolean).")] public string? HealthEnabled { get; } + [Option("semantic-search.enabled", Required = false, HelpText = "Enable semantic search for this entity. Accepted values are true/false.")] + public string? SemanticSearchEnabled { get; } + + [Option("semantic-search.redis-index-name", Required = false, HelpText = "Name of the Redis vector index to use for semantic search.")] + public string? SemanticSearchRedisIndexName { get; } + + [Option("semantic-search.redis-index-type", Required = false, HelpText = "Redis index type for semantic search. Allowed values: hash, json.")] + public string? SemanticSearchRedisIndexType { get; } + + [Option("semantic-search.redis-index-multiplier", Required = false, HelpText = "Multiplier for Redis candidate retrieval. Allowed values: 1-10.")] + public string? SemanticSearchRedisIndexMultiplier { get; } + + [Option("semantic-search.similarity-threshold", Required = false, HelpText = "Default semantic similarity threshold. Allowed values: 0.0-1.0.")] + public string? SemanticSearchSimilarityThreshold { get; } + + [Option("semantic-search.input-description", Required = false, HelpText = "Description for semantic search input metadata.")] + public string? SemanticSearchInputDescription { get; } + + [Option("semantic-search.output-description", Required = false, HelpText = "Description for semantic distance output metadata.")] + public string? SemanticSearchOutputDescription { get; } + [Option("description", Required = false, HelpText = "Description of the entity.")] public string? Description { get; } diff --git a/src/Cli/Commands/UpdateOptions.cs b/src/Cli/Commands/UpdateOptions.cs index 206bcd704d..d5c63e8875 100644 --- a/src/Cli/Commands/UpdateOptions.cs +++ b/src/Cli/Commands/UpdateOptions.cs @@ -53,6 +53,13 @@ public UpdateOptions( IEnumerable? fieldsAliasCollection, IEnumerable? fieldsDescriptionCollection, IEnumerable? fieldsPrimaryKeyCollection, + string? semanticSearchEnabled = null, + string? semanticSearchRedisIndexName = null, + string? semanticSearchRedisIndexType = null, + string? semanticSearchRedisIndexMultiplier = null, + string? semanticSearchSimilarityThreshold = null, + string? semanticSearchInputDescription = null, + string? semanticSearchOutputDescription = null, string? mcpDmlTools = null, string? mcpCustomTool = null, string? config = null) @@ -72,6 +79,13 @@ public UpdateOptions( cacheTtlSeconds, cacheLevel, healthEnabled, + semanticSearchEnabled, + semanticSearchRedisIndexName, + semanticSearchRedisIndexType, + semanticSearchRedisIndexMultiplier, + semanticSearchSimilarityThreshold, + semanticSearchInputDescription, + semanticSearchOutputDescription, description, parametersNameCollection, parametersDescriptionCollection, diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index ae3087debb..c513f1ce30 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -472,6 +472,14 @@ public static bool TryAddNewEntity(AddOptions options, RuntimeConfig initialRunt EntityGraphQLOptions graphqlOptions = ConstructGraphQLTypeDetails(options.GraphQLType, graphQLOperationsForStoredProcedures); EntityCacheOptions? cacheOptions = ConstructCacheOptions(options.CacheEnabled, options.CacheTtlSeconds, options.CacheLevel); EntityHealthCheckConfig? entityHealthOptions = ConstructEntityHealthOptions(options.HealthEnabled); + EntitySemanticSearchOptions? semanticSearchOptions = ConstructSemanticSearchOptions( + options.SemanticSearchEnabled, + options.SemanticSearchRedisIndexName, + options.SemanticSearchRedisIndexType, + options.SemanticSearchRedisIndexMultiplier, + options.SemanticSearchSimilarityThreshold, + options.SemanticSearchInputDescription, + options.SemanticSearchOutputDescription); EntityMcpOptions? mcpOptions = null; if (options.McpDmlTools is not null || options.McpCustomTool is not null) @@ -495,6 +503,7 @@ public static bool TryAddNewEntity(AddOptions options, RuntimeConfig initialRunt Relationships: null, Mappings: null, Cache: cacheOptions, + SemanticSearch: semanticSearchOptions, Description: string.IsNullOrWhiteSpace(options.Description) ? null : options.Description, Mcp: mcpOptions, Health: entityHealthOptions); @@ -1002,7 +1011,9 @@ private static bool TryUpdateConfiguredRuntimeOptions( // Cache: Enabled and TTL if (options.RuntimeCacheEnabled != null || - options.RuntimeCacheTTL != null) + options.RuntimeCacheTTL != null || + options.RuntimeCacheLevel2Provider != null || + options.RuntimeCacheLevel2ConnectionString != null) { RuntimeCacheOptions? updatedCacheOptions = runtimeConfig?.Runtime?.Cache ?? new(); bool status = TryUpdateConfiguredCacheValues(options, ref updatedCacheOptions); @@ -1559,6 +1570,43 @@ private static bool TryUpdateConfiguredCacheValues( } } + // Runtime.Cache.level-2.provider + updatedValue = options?.RuntimeCacheLevel2Provider; + if (updatedValue != null) + { + string provider = ((string)updatedValue).Trim(); + if (!string.Equals(provider, "redis", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogError("Failed to update Runtime.Cache.level-2.provider as '{updatedValue}'. Supported value is 'redis'.", updatedValue); + return false; + } + + RuntimeCacheLevel2Options currentLevel2 = updatedCacheOptions?.Level2 ?? new(); + updatedCacheOptions = updatedCacheOptions! with + { + Level2 = currentLevel2 with + { + Provider = provider.ToLowerInvariant() + } + }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Cache.level-2.provider as '{updatedValue}'", updatedValue); + } + + // Runtime.Cache.level-2.connection-string + updatedValue = options?.RuntimeCacheLevel2ConnectionString; + if (updatedValue != null) + { + RuntimeCacheLevel2Options currentLevel2 = updatedCacheOptions?.Level2 ?? new(); + updatedCacheOptions = updatedCacheOptions! with + { + Level2 = currentLevel2 with + { + ConnectionString = (string)updatedValue + } + }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Cache.level-2.connection-string."); + } + return true; } catch (Exception ex) @@ -2327,6 +2375,14 @@ public static bool TryUpdateExistingEntity(UpdateOptions options, RuntimeConfig EntityActionFields? updatedFields = GetFieldsForOperation(options.FieldsToInclude, options.FieldsToExclude); EntityCacheOptions? updatedCacheOptions = ConstructCacheOptions(options.CacheEnabled, options.CacheTtlSeconds, options.CacheLevel) ?? entity.Cache; EntityHealthCheckConfig? updatedEntityHealthOptions = ConstructEntityHealthOptions(options.HealthEnabled) ?? entity.Health; + EntitySemanticSearchOptions? updatedSemanticSearchOptions = ConstructSemanticSearchOptions( + options.SemanticSearchEnabled, + options.SemanticSearchRedisIndexName, + options.SemanticSearchRedisIndexType, + options.SemanticSearchRedisIndexMultiplier, + options.SemanticSearchSimilarityThreshold, + options.SemanticSearchInputDescription, + options.SemanticSearchOutputDescription); // Determine if the entity is or will be a stored procedure bool isStoredProcedureAfterUpdate = doOptionsRepresentStoredProcedure || (isCurrentEntityStoredProcedure && options.SourceType is null); @@ -2588,6 +2644,7 @@ public static bool TryUpdateExistingEntity(UpdateOptions options, RuntimeConfig Relationships: updatedRelationships, Mappings: updatedMappings, Cache: updatedCacheOptions, + SemanticSearch: updatedSemanticSearchOptions ?? entity.SemanticSearch, Description: string.IsNullOrWhiteSpace(options.Description) ? entity.Description : options.Description, Mcp: updatedMcpOptions, Health: updatedEntityHealthOptions diff --git a/src/Cli/Utils.cs b/src/Cli/Utils.cs index 8f309407a8..f15bb8d55a 100644 --- a/src/Cli/Utils.cs +++ b/src/Cli/Utils.cs @@ -956,6 +956,94 @@ public static EntityGraphQLOptions ConstructGraphQLTypeDetails(string? graphQL, return new EntityHealthCheckConfig(enabled: isEnabled); } + /// + /// Constructs the EntitySemanticSearchOptions for Add/Update. + /// + /// EntitySemanticSearchOptions when at least one semantic option is provided, null otherwise. + public static EntitySemanticSearchOptions? ConstructSemanticSearchOptions( + string? semanticSearchEnabled, + string? semanticSearchRedisIndexName, + string? semanticSearchRedisIndexType, + string? semanticSearchRedisIndexMultiplier, + string? semanticSearchSimilarityThreshold, + string? semanticSearchInputDescription, + string? semanticSearchOutputDescription) + { + if (semanticSearchEnabled is null + && semanticSearchRedisIndexName is null + && semanticSearchRedisIndexType is null + && semanticSearchRedisIndexMultiplier is null + && semanticSearchSimilarityThreshold is null + && semanticSearchInputDescription is null + && semanticSearchOutputDescription is null) + { + return null; + } + + bool enabled = false; + if (semanticSearchEnabled is not null && !bool.TryParse(semanticSearchEnabled, out enabled)) + { + _logger.LogError("Invalid format for --semantic-search.enabled. Accepted values are true/false."); + return null; + } + + string redisIndexType = EntitySemanticSearchOptions.DEFAULT_REDIS_INDEX_TYPE; + if (!string.IsNullOrWhiteSpace(semanticSearchRedisIndexType)) + { + if (!string.Equals(semanticSearchRedisIndexType, "hash", StringComparison.OrdinalIgnoreCase) + && !string.Equals(semanticSearchRedisIndexType, "json", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogError("Invalid format for --semantic-search.redis-index-type. Accepted values are hash/json."); + return null; + } + + redisIndexType = semanticSearchRedisIndexType.ToLowerInvariant(); + } + + int redisIndexMultiplier = EntitySemanticSearchOptions.DEFAULT_REDIS_INDEX_MULTIPLIER; + if (!string.IsNullOrWhiteSpace(semanticSearchRedisIndexMultiplier) + && !int.TryParse(semanticSearchRedisIndexMultiplier, out redisIndexMultiplier)) + { + _logger.LogError("Invalid format for --semantic-search.redis-index-multiplier. Accepted values are integer values in range [1,10]."); + return null; + } + + if (redisIndexMultiplier < 1 || redisIndexMultiplier > 10) + { + _logger.LogError("Invalid value for --semantic-search.redis-index-multiplier. Accepted values are integer values in range [1,10]."); + return null; + } + + double similarityThreshold = EntitySemanticSearchOptions.DEFAULT_SIMILARITY_THRESHOLD; + if (!string.IsNullOrWhiteSpace(semanticSearchSimilarityThreshold) + && !double.TryParse(semanticSearchSimilarityThreshold, out similarityThreshold)) + { + _logger.LogError("Invalid format for --semantic-search.similarity-threshold. Accepted values are decimal values in range [0.0,1.0]."); + return null; + } + + if (similarityThreshold < 0.0 || similarityThreshold > 1.0) + { + _logger.LogError("Invalid value for --semantic-search.similarity-threshold. Accepted values are decimal values in range [0.0,1.0]."); + return null; + } + + return new EntitySemanticSearchOptions + { + Enabled = enabled, + RedisIndexName = semanticSearchRedisIndexName, + RedisIndexType = redisIndexType, + RedisIndexMultiplier = redisIndexMultiplier, + SimilarityThreshold = similarityThreshold, + InputDescription = string.IsNullOrWhiteSpace(semanticSearchInputDescription) + ? EntitySemanticSearchOptions.DEFAULT_INPUT_DESCRIPTION + : semanticSearchInputDescription, + OutputDescription = string.IsNullOrWhiteSpace(semanticSearchOutputDescription) + ? EntitySemanticSearchOptions.DEFAULT_OUTPUT_DESCRIPTION + : semanticSearchOutputDescription + }; + } + /// /// Constructs the EntityMcpOptions for Add/Update. /// diff --git a/src/Config/ObjectModel/Entity.cs b/src/Config/ObjectModel/Entity.cs index 9c9ba17f2c..637c219744 100644 --- a/src/Config/ObjectModel/Entity.cs +++ b/src/Config/ObjectModel/Entity.cs @@ -37,6 +37,8 @@ public record Entity public Dictionary? Mappings { get; init; } public Dictionary? Relationships { get; init; } public EntityCacheOptions? Cache { get; init; } + [JsonPropertyName("semantic-search")] + public EntitySemanticSearchOptions? SemanticSearch { get; init; } public EntityHealthCheckConfig? Health { get; init; } [JsonConverter(typeof(EntityMcpOptionsConverterFactory))] @@ -62,7 +64,8 @@ public Entity( EntityHealthCheckConfig? Health = null, string? Description = null, EntityMcpOptions? Mcp = null, - bool IsAutoentity = false) + bool IsAutoentity = false, + EntitySemanticSearchOptions? SemanticSearch = null) { this.Health = Health; this.Source = Source; @@ -73,6 +76,7 @@ public Entity( this.Mappings = Mappings; this.Relationships = Relationships; this.Cache = Cache; + this.SemanticSearch = SemanticSearch; this.IsLinkingEntity = IsLinkingEntity; this.Description = Description; this.Mcp = Mcp; diff --git a/src/Config/ObjectModel/EntitySemanticSearchOptions.cs b/src/Config/ObjectModel/EntitySemanticSearchOptions.cs new file mode 100644 index 0000000000..1175330501 --- /dev/null +++ b/src/Config/ObjectModel/EntitySemanticSearchOptions.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Entity specific semantic search configuration. +/// +public record EntitySemanticSearchOptions +{ + public const int DEFAULT_REDIS_INDEX_MULTIPLIER = 2; + public const double DEFAULT_SIMILARITY_THRESHOLD = 0.8; + public const string DEFAULT_REDIS_INDEX_TYPE = "hash"; + public const string DEFAULT_INPUT_DESCRIPTION = "Natural language value used for semantic search."; + public const string DEFAULT_OUTPUT_DESCRIPTION = "Semantic distance score returned by semantic search."; + + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } = false; + + [JsonPropertyName("redis-index-name")] + public string? RedisIndexName { get; init; } + + [JsonPropertyName("redis-index-type")] + public string RedisIndexType { get; init; } = DEFAULT_REDIS_INDEX_TYPE; + + [JsonPropertyName("redis-index-multiplier")] + public int RedisIndexMultiplier { get; init; } = DEFAULT_REDIS_INDEX_MULTIPLIER; + + [JsonPropertyName("similarity-threshold")] + public double SimilarityThreshold { get; init; } = DEFAULT_SIMILARITY_THRESHOLD; + + [JsonPropertyName("input-description")] + public string InputDescription { get; init; } = DEFAULT_INPUT_DESCRIPTION; + + [JsonPropertyName("output-description")] + public string OutputDescription { get; init; } = DEFAULT_OUTPUT_DESCRIPTION; +} diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index 33f60d7e48..70b559ceca 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -66,6 +66,23 @@ public class RuntimeConfigValidator : IConfigValidator // Error messages. public const string INVALID_CLAIMS_IN_POLICY_ERR_MSG = "One or more claim types supplied in the database policy are not supported."; + private const string SEMANTIC_SEARCH_REDIS_REQUIREMENT_ERR_MSG = + "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured."; + + private static readonly HashSet _reservedSemanticRestNames = + [ + "semantic_search", + "semantic_threshold", + "semantic_distance" + ]; + + private static readonly HashSet _reservedSemanticGraphQLNames = + [ + "semanticSearch", + "semanticThreshold", + "semanticDistance" + ]; + public RuntimeConfigValidator( RuntimeConfigProvider runtimeConfigProvider, IFileSystem fileSystem, @@ -1011,6 +1028,136 @@ public void ValidateEntityConfiguration(RuntimeConfig runtimeConfig) ValidateNameRequirements(entity.GraphQL.Singular); ValidateNameRequirements(entity.GraphQL.Plural); } + + ValidateSemanticSearchConfiguration(runtimeConfig, entityName, entity); + } + } + + /// + /// Validates semantic-search settings configured for an entity. + /// + private void ValidateSemanticSearchConfiguration(RuntimeConfig runtimeConfig, string entityName, Entity entity) + { + EntitySemanticSearchOptions? semantic = entity.SemanticSearch; + + if (semantic is null || !semantic.Enabled) + { + return; + } + + if (entity.Source.Type is not EntitySourceType.Table and not EntitySourceType.View) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Semantic search is only supported for entities with source type 'table' or 'view'.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (string.IsNullOrWhiteSpace(semantic.RedisIndexName)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"semantic-search.redis-index-name is required when semantic-search.enabled is true for entity '{entityName}'.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (!HasValidSemanticRedisConfiguration(runtimeConfig)) + { + HandleOrRecordException(new DataApiBuilderException( + message: SEMANTIC_SEARCH_REDIS_REQUIREMENT_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (semantic.RedisIndexMultiplier < 1 || semantic.RedisIndexMultiplier > 10) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"semantic-search.redis-index-multiplier for entity '{entityName}' must be an integer between 1 and 10.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (semantic.SimilarityThreshold < 0.0 || semantic.SimilarityThreshold > 1.0) + { + HandleOrRecordException(new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + ValidateSemanticReservedNames(entityName, entity); + } + + /// + /// Validates Redis provider/connection-string prerequisites for semantic-search. + /// + private static bool HasValidSemanticRedisConfiguration(RuntimeConfig runtimeConfig) + { + RuntimeCacheLevel2Options? level2 = runtimeConfig.Runtime?.Cache?.Level2; + return level2 is not null + && string.Equals(level2.Provider, EntityCacheOptions.L2_CACHE_PROVIDER, StringComparison.OrdinalIgnoreCase) + && !string.IsNullOrWhiteSpace(level2.ConnectionString); + } + + /// + /// Ensures configured field names do not collide with semantic reserved names. + /// + private void ValidateSemanticReservedNames(string entityName, Entity entity) + { + HashSet exposedFieldNames = new(StringComparer.OrdinalIgnoreCase); + + if (entity.Fields is not null) + { + foreach (FieldMetadata field in entity.Fields) + { + if (!string.IsNullOrWhiteSpace(field.Name)) + { + exposedFieldNames.Add(field.Name); + } + + if (!string.IsNullOrWhiteSpace(field.Alias)) + { + exposedFieldNames.Add(field.Alias); + } + } + } + + if (entity.Mappings is not null) + { + foreach (KeyValuePair mapping in entity.Mappings) + { + if (!string.IsNullOrWhiteSpace(mapping.Key)) + { + exposedFieldNames.Add(mapping.Key); + } + + if (!string.IsNullOrWhiteSpace(mapping.Value)) + { + exposedFieldNames.Add(mapping.Value); + } + } + } + + foreach (string reserved in _reservedSemanticRestNames) + { + if (exposedFieldNames.Contains(reserved)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Entity '{entityName}' cannot enable semantic search because field name '{reserved}' is reserved.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + } + + foreach (string reserved in _reservedSemanticGraphQLNames) + { + if (exposedFieldNames.Contains(reserved)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Entity '{entityName}' cannot enable semantic search because field name '{reserved}' is reserved.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } } } diff --git a/src/Core/Models/RestRequestContexts/RestRequestContext.cs b/src/Core/Models/RestRequestContexts/RestRequestContext.cs index e9987730a0..716f9cd96b 100644 --- a/src/Core/Models/RestRequestContexts/RestRequestContext.cs +++ b/src/Core/Models/RestRequestContexts/RestRequestContext.cs @@ -95,6 +95,31 @@ protected RestRequestContext(string entityName, DatabaseObject dbo) /// public int? First { get; set; } + + /// + /// Optional semantic search input provided by the caller. + /// + public string? SemanticSearch { get; set; } + + /// + /// Optional semantic threshold override provided by the caller. + /// + public double? SemanticThreshold { get; set; } + + /// + /// True when semantic distance should be included in REST output. + /// + public bool IncludeSemanticDistanceInResponse { get; set; } + + /// + /// Semantic distance keyed by a stable primary-key signature generated from result rows. + /// + public Dictionary? SemanticDistanceByPrimaryKeySignature { get; set; } + + /// + /// True when user explicitly requested semantic distance in the projection. + /// + public bool IsSemanticDistanceExplicitlySelected { get; set; } /// /// Is the result supposed to be multiple values or not. /// diff --git a/src/Core/Models/SemanticSearchCandidate.cs b/src/Core/Models/SemanticSearchCandidate.cs new file mode 100644 index 0000000000..65251ffcfd --- /dev/null +++ b/src/Core/Models/SemanticSearchCandidate.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Models; + +/// +/// Represents one semantic candidate with SQL column values extracted from +/// the semantic document and primary key values used for dedupe/output mapping. +/// +public record SemanticSearchCandidate( + IReadOnlyDictionary PrimaryKeyValues, + IReadOnlyDictionary ColumnValues, + double Distance); diff --git a/src/Core/Models/SemanticSearchConstants.cs b/src/Core/Models/SemanticSearchConstants.cs new file mode 100644 index 0000000000..dbca8d9426 --- /dev/null +++ b/src/Core/Models/SemanticSearchConstants.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Models; + +public static class SemanticSearchConstants +{ + public const string REST_SEARCH_QUERY_PARAM = "$semantic_search"; + public const string REST_THRESHOLD_QUERY_PARAM = "$semantic_threshold"; + public const string REST_DISTANCE_FIELD = "semantic_distance"; + + public const string GRAPHQL_SEARCH_ARGUMENT = "semanticSearch"; + public const string GRAPHQL_THRESHOLD_ARGUMENT = "semanticThreshold"; + public const string GRAPHQL_DISTANCE_FIELD = "semanticDistance"; +} diff --git a/src/Core/Parsers/RequestParser.cs b/src/Core/Parsers/RequestParser.cs index 081018e820..90d3eb3c58 100644 --- a/src/Core/Parsers/RequestParser.cs +++ b/src/Core/Parsers/RequestParser.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Globalization; using System.Net; using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Services; @@ -35,6 +36,14 @@ public class RequestParser /// Prefix used for specifying paging in the query string of the URL. /// public const string AFTER_URL = "$after"; + /// + /// Prefix used for specifying semantic search text in the query string of the URL. + /// + public const string SEMANTIC_SEARCH_URL = SemanticSearchConstants.REST_SEARCH_QUERY_PARAM; + /// + /// Prefix used for specifying semantic similarity threshold in the query string of the URL. + /// + public const string SEMANTIC_THRESHOLD_URL = SemanticSearchConstants.REST_THRESHOLD_QUERY_PARAM; /// /// Parses the primary key string to identify the field names composing the key @@ -138,6 +147,14 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); } + if (rawSortValue.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase)) + { + throw new DataApiBuilderException( + message: "semantic_distance cannot be used in orderBy.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + (context.OrderByClauseInUrl, context.OrderByClauseOfBackingColumns) = GenerateOrderByLists(context, sqlMetadataProvider, $"?{SORT_URL}={rawSortValue}"); break; case AFTER_URL: @@ -146,6 +163,22 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv case FIRST_URL: context.First = RequestValidator.CheckFirstValidity(context.ParsedQueryString[key]!); break; + case SEMANTIC_SEARCH_URL: + context.SemanticSearch = context.ParsedQueryString[key]; + break; + case SEMANTIC_THRESHOLD_URL: + if (!double.TryParse(context.ParsedQueryString[key], NumberStyles.Float, CultureInfo.InvariantCulture, out double semanticThreshold) + || semanticThreshold < 0.0 + || semanticThreshold > 1.0) + { + throw new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + context.SemanticThreshold = semanticThreshold; + break; default: throw new DataApiBuilderException( message: $"Invalid Query Parameter: {key}", diff --git a/src/Core/Resolvers/Factories/QueryEngineFactory.cs b/src/Core/Resolvers/Factories/QueryEngineFactory.cs index 1d2ae2935d..1260db344a 100644 --- a/src/Core/Resolvers/Factories/QueryEngineFactory.cs +++ b/src/Core/Resolvers/Factories/QueryEngineFactory.cs @@ -9,6 +9,7 @@ using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -32,6 +33,7 @@ public class QueryEngineFactory : IQueryEngineFactory private readonly IAuthorizationResolver _authorizationResolver; private readonly GQLFilterParser _gQLFilterParser; private readonly DabCacheService _cache; + private readonly ISemanticSearchService _semanticSearchService; private readonly ILogger _logger; /// @@ -42,6 +44,7 @@ public QueryEngineFactory(RuntimeConfigProvider runtimeConfigProvider, IHttpContextAccessor contextAccessor, IAuthorizationResolver authorizationResolver, GQLFilterParser gQLFilterParser, + ISemanticSearchService semanticSearchService, ILogger logger, DabCacheService cache, HotReloadEventHandler? handler) @@ -55,6 +58,7 @@ public QueryEngineFactory(RuntimeConfigProvider runtimeConfigProvider, _contextAccessor = contextAccessor; _authorizationResolver = authorizationResolver; _gQLFilterParser = gQLFilterParser; + _semanticSearchService = semanticSearchService; _cache = cache; _logger = logger; @@ -75,7 +79,8 @@ public void ConfigureQueryEngines() _gQLFilterParser, _logger, _runtimeConfigProvider, - _cache); + _cache, + _semanticSearchService); _queryEngines.Add(DatabaseType.MSSQL, queryEngine); _queryEngines.Add(DatabaseType.MySQL, queryEngine); _queryEngines.Add(DatabaseType.PostgreSQL, queryEngine); diff --git a/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs b/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs index 430b58c54f..af15fecf1c 100644 --- a/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs +++ b/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs @@ -4,6 +4,7 @@ using System.Data; using System.Net; using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; @@ -92,6 +93,80 @@ public class SqlQueryStructure : BaseSqlQueryStructure /// public GroupByMetadata GroupByMetadata { get; private set; } + /// + /// True when semantic search arguments are present on the request. + /// + public bool SemanticSearchRequested { get; private set; } + + /// + /// True when a user supplied ordering expression. + /// + public bool HasUserSpecifiedOrdering { get; private set; } + + /// + /// Primary-key signature to semantic distance map. + /// + public Dictionary SemanticDistanceByPrimaryKeySignature { get; } + + /// + /// Applies a semantic narrowing predicate of the form: + /// (col1 = ... AND col2 = ...) OR (...) + /// + public void ApplySemanticCandidates(IReadOnlyList candidates) + { + if (candidates.Count == 0) + { + Predicates.Add(Predicate.MakeFalsePredicate()); + return; + } + + Predicate? combinedOr = null; + + foreach (SemanticSearchCandidate candidate in candidates) + { + Predicate? andPredicate = null; + foreach ((string columnName, object? value) in candidate.ColumnValues) + { + if (value is null) + { + continue; + } + + string parameterName = MakeDbConnectionParam( + GetParamAsSystemType(value.ToString()!, columnName, GetColumnSystemType(columnName)), + columnName); + + Predicate equalsPredicate = new( + new PredicateOperand(new Column(DatabaseObject.SchemaName, DatabaseObject.Name, columnName, SourceAlias)), + PredicateOperation.Equal, + new PredicateOperand(parameterName), + addParenthesis: true); + + andPredicate = andPredicate is null + ? equalsPredicate + : new Predicate(new PredicateOperand(andPredicate), PredicateOperation.AND, new PredicateOperand(equalsPredicate), addParenthesis: true); + } + + if (andPredicate is null) + { + continue; + } + + combinedOr = combinedOr is null + ? andPredicate + : new Predicate(new PredicateOperand(combinedOr), PredicateOperation.OR, new PredicateOperand(andPredicate), addParenthesis: true); + } + + if (combinedOr is null) + { + Predicates.Add(Predicate.MakeFalsePredicate()); + } + else + { + Predicates.Add(combinedOr); + } + } + public virtual string? CacheControlOption { get; set; } public const string CACHE_CONTROL = "Cache-Control"; @@ -280,6 +355,7 @@ public SqlQueryStructure( IsListQuery = context.IsMany; SourceAlias = $"{DatabaseObject.SchemaName}_{DatabaseObject.Name}"; AddFields(context, sqlMetadataProvider); + SemanticSearchRequested = !string.IsNullOrWhiteSpace(context.SemanticSearch); foreach (KeyValuePair predicate in context.PrimaryKeyValuePairs) { sqlMetadataProvider.TryGetBackingColumn(EntityName, predicate.Key, out string? backingColumn); @@ -293,6 +369,7 @@ public SqlQueryStructure( // to only Find, we populate the SourceAlias in this constructor where we know we have a Find operation. OrderByColumns = context.OrderByClauseOfBackingColumns is not null ? context.OrderByClauseOfBackingColumns : PrimaryKeyAsOrderByColumns(); + HasUserSpecifiedOrdering = context.OrderByClauseOfBackingColumns is not null; foreach (OrderByColumn column in OrderByColumns) { @@ -444,6 +521,18 @@ private SqlQueryStructure( EntityName = sqlMetadataProvider.GetDatabaseType() == DatabaseType.DWSQL ? GraphQLUtils.GetEntityNameFromContext(ctx) : _underlyingFieldType.Name; bool isGroupByQuery = queryField?.Name.Value == QueryBuilder.GROUP_BY_FIELD_NAME; + SemanticSearchRequested = queryParams.TryGetValue(QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME, out object? semanticValue) + && semanticValue is string semanticSearchText + && !string.IsNullOrWhiteSpace(semanticSearchText); + + if (SemanticSearchRequested && isGroupByQuery) + { + throw new DataApiBuilderException( + message: "Semantic search is not supported for aggregate queries.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + if (GraphQLUtils.TryExtractGraphQLFieldModelName(_underlyingFieldType.Directives, out string? modelName)) { EntityName = modelName; @@ -520,6 +609,7 @@ private SqlQueryStructure( if (orderByObject is not null) { OrderByColumns = ProcessGqlOrderByArg((List)orderByObject, queryArgumentSchemas[QueryBuilder.ORDER_BY_FIELD_NAME], isGroupByQuery); + HasUserSpecifiedOrdering = true; } } @@ -575,6 +665,7 @@ private SqlQueryStructure( PaginationMetadata = new(this); GroupByMetadata = new(); ColumnLabelToParam = new(); + SemanticDistanceByPrimaryKeySignature = new(StringComparer.Ordinal); FilterPredicates = string.Empty; OrderByColumns = new(); AddCacheControlOptions(httpRequestHeaders); @@ -733,6 +824,9 @@ void ProcessPaginationFields(IReadOnlyList paginationSelections) // TODO : This is inefficient and could lead to errors. we should rewrite this to use the ISelection API. private void AddGraphQLFields(IReadOnlyList selections, RuntimeConfigProvider runtimeConfigProvider) { + bool entitySemanticEnabled = runtimeConfigProvider.GetConfig().Entities.TryGetValue(EntityName, out Entity? entity) + && entity.SemanticSearch?.Enabled is true; + foreach (ISelectionNode node in selections) { if (node.Kind == SyntaxKind.FragmentSpread) @@ -784,6 +878,12 @@ private void AddGraphQLFields(IReadOnlyList selections, RuntimeC if (field.SelectionSet is null) { + if (entitySemanticEnabled + && string.Equals(fieldName, SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD, StringComparison.Ordinal)) + { + continue; + } + if (MetadataProvider.TryGetBackingColumn(EntityName, fieldName, out string? name) && !string.IsNullOrWhiteSpace(name)) { @@ -796,6 +896,14 @@ private void AddGraphQLFields(IReadOnlyList selections, RuntimeC } else { + if (SemanticSearchRequested) + { + throw new DataApiBuilderException( + message: "Semantic search is not supported with relationship expansion.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + ObjectField? subschemaField = _underlyingFieldType.Fields[fieldName]; if (_ctx == null) diff --git a/src/Core/Resolvers/SqlQueryEngine.cs b/src/Core/Resolvers/SqlQueryEngine.cs index f567251771..12d804328f 100644 --- a/src/Core/Resolvers/SqlQueryEngine.cs +++ b/src/Core/Resolvers/SqlQueryEngine.cs @@ -4,6 +4,7 @@ using System.Text; using System.Text.Json; using System.Text.Json.Nodes; +using System.Linq; using Azure.DataApiBuilder.Auth; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Configurations; @@ -12,7 +13,9 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Exceptions; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Service.GraphQLBuilder; using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; using HotChocolate.Resolvers; @@ -37,6 +40,7 @@ public class SqlQueryEngine : IQueryEngine private readonly RuntimeConfigProvider _runtimeConfigProvider; private readonly GQLFilterParser _gQLFilterParser; private readonly DabCacheService _cache; + private readonly ISemanticSearchService _semanticSearchService; // // Constructor. @@ -49,7 +53,8 @@ public SqlQueryEngine( GQLFilterParser gQLFilterParser, ILogger logger, RuntimeConfigProvider runtimeConfigProvider, - DabCacheService cache) + DabCacheService cache, + ISemanticSearchService? semanticSearchService = null) { _queryFactory = queryFactory; _sqlMetadataProviderFactory = sqlMetadataProviderFactory; @@ -59,6 +64,7 @@ public SqlQueryEngine( _logger = logger; _runtimeConfigProvider = runtimeConfigProvider; _cache = cache; + _semanticSearchService = semanticSearchService ?? NoOpSemanticSearchService.Instance; } /// @@ -79,16 +85,32 @@ public SqlQueryEngine( _runtimeConfigProvider, _gQLFilterParser); + (bool shouldReturnEmpty, JsonDocument? emptySemanticResponse) = await TryApplySemanticNarrowingAsync( + structure, + dataSourceName, + parameters, + GraphQLUtils.GetEntityNameFromContext(context)); + if (shouldReturnEmpty) + { + return new Tuple(emptySemanticResponse, structure.PaginationMetadata); + } + if (structure.PaginationMetadata.IsPaginated) { + JsonDocument? queryResult = await ExecuteAsync(structure, dataSourceName); + JsonDocument? processedResult = ApplySemanticDistanceAndOrderingIfNeeded(queryResult, structure, dataSourceName, includeRestField: false, includeGraphQlField: true); + return new Tuple( - SqlPaginationUtil.CreatePaginationConnectionFromJsonDocument(await ExecuteAsync(structure, dataSourceName), structure.PaginationMetadata, structure.GroupByMetadata), + SqlPaginationUtil.CreatePaginationConnectionFromJsonDocument(processedResult, structure.PaginationMetadata, structure.GroupByMetadata), structure.PaginationMetadata); } else { + JsonDocument? queryResult = await ExecuteAsync(structure, dataSourceName); + JsonDocument? processedResult = ApplySemanticDistanceAndOrderingIfNeeded(queryResult, structure, dataSourceName, includeRestField: false, includeGraphQlField: true); + return new Tuple( - await ExecuteAsync(structure, dataSourceName), + processedResult, structure.PaginationMetadata); } } @@ -189,7 +211,234 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), _runtimeConfigProvider, _gQLFilterParser, _httpContextAccessor.HttpContext!); - return await ExecuteAsync(structure, dataSourceName); + + (bool shouldReturnEmpty, JsonDocument? emptySemanticResponse) = await TryApplySemanticNarrowingAsync( + structure, + dataSourceName, + graphQLParameters: null, + context.EntityName, + context); + if (shouldReturnEmpty) + { + return emptySemanticResponse; + } + + JsonDocument? response = await ExecuteAsync(structure, dataSourceName); + return ApplySemanticDistanceAndOrderingIfNeeded(response, structure, dataSourceName, includeRestField: true, includeGraphQlField: false, context); + } + + private async Task<(bool shouldReturnEmpty, JsonDocument? emptyResponse)> TryApplySemanticNarrowingAsync( + SqlQueryStructure structure, + string dataSourceName, + IDictionary? graphQLParameters, + string entityName, + FindRequestContext? restContext = null) + { + JsonDocument? emptyResponse = null; + + string? semanticSearchText = restContext?.SemanticSearch; + if (semanticSearchText is null + && graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME, out object? graphQlSearchValue) + && graphQlSearchValue is string graphQlSearchText) + { + semanticSearchText = graphQlSearchText; + } + + bool hasSemanticThresholdOnly = restContext?.SemanticThreshold is not null + || (graphQLParameters is not null && graphQLParameters.ContainsKey(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME)); + + bool semanticRequested = !string.IsNullOrWhiteSpace(semanticSearchText) || hasSemanticThresholdOnly; + + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + if (!runtimeConfig.Entities.TryGetValue(entityName, out Entity? entity) + || entity.SemanticSearch is null + || !entity.SemanticSearch.Enabled) + { + if (semanticRequested) + { + throw new DataApiBuilderException( + message: $"Semantic search is not enabled for entity '{entityName}'.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + return (false, null); + } + + if (string.IsNullOrWhiteSpace(semanticSearchText)) + { + return (false, null); + } + + if (graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.PAGINATION_TOKEN_ARGUMENT_NAME, out object? afterArg) + && afterArg is string afterToken + && !string.IsNullOrWhiteSpace(afterToken)) + { + throw new DataApiBuilderException( + message: "Pagination continuation tokens are not supported when semantic search is used.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME, out object? graphQlThreshold) + && graphQlThreshold is double graphQlThresholdDouble + && (graphQlThresholdDouble < 0.0 || graphQlThresholdDouble > 1.0)) + { + throw new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + double effectiveThreshold = restContext?.SemanticThreshold + ?? (graphQLParameters is not null && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME, out object? gqlThresholdObj) && gqlThresholdObj is double gqlThresholdVal + ? gqlThresholdVal + : entity.SemanticSearch.SimilarityThreshold); + + int first = restContext?.First + ?? (graphQLParameters is not null && graphQLParameters.TryGetValue(QueryBuilder.PAGE_START_ARGUMENT_NAME, out object? firstObj) && firstObj is int gqlFirst ? gqlFirst : (int)runtimeConfig.DefaultPageSize()); + int redisTop = first * entity.SemanticSearch.RedisIndexMultiplier; + + SourceDefinition sourceDefinition = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(entityName); + IReadOnlyList primaryKeyColumns = sourceDefinition.PrimaryKey; + + IReadOnlyList candidates = await _semanticSearchService.GetCandidatesAsync( + entityName, + entity.SemanticSearch, + primaryKeyColumns, + semanticSearchText, + effectiveThreshold, + redisTop); + + if (candidates.Count == 0) + { + emptyResponse = JsonDocument.Parse("[]"); + return (true, emptyResponse); + } + + // De-duplicate candidate keys while keeping the highest distance. + Dictionary deduped = new(StringComparer.Ordinal); + foreach (SemanticSearchCandidate candidate in candidates) + { + string signature = BuildPrimaryKeySignatureFromValues(primaryKeyColumns, candidate.PrimaryKeyValues); + if (!deduped.TryGetValue(signature, out SemanticSearchCandidate? existing) || candidate.Distance > existing.Distance) + { + deduped[signature] = candidate; + } + } + + IReadOnlyList narrowedCandidates = deduped.Values.ToList(); + structure.ApplySemanticCandidates(narrowedCandidates); + foreach (KeyValuePair kvp in deduped) + { + structure.SemanticDistanceByPrimaryKeySignature[kvp.Key] = kvp.Value.Distance; + } + + return (false, null); + } + + private JsonDocument? ApplySemanticDistanceAndOrderingIfNeeded( + JsonDocument? response, + SqlQueryStructure structure, + string dataSourceName, + bool includeRestField, + bool includeGraphQlField, + FindRequestContext? restContext = null) + { + if (response is null || structure.SemanticDistanceByPrimaryKeySignature.Count == 0) + { + return response; + } + + SourceDefinition sourceDefinition = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(structure.EntityName); + if (response.RootElement.ValueKind is not JsonValueKind.Array) + { + return response; + } + + List<(double? distance, JsonObject row)> rows = new(); + foreach (JsonElement element in response.RootElement.EnumerateArray()) + { + JsonObject? row = JsonObject.Create(element); + if (row is null) + { + continue; + } + + string signature = BuildPrimaryKeySignatureFromRow(sourceDefinition.PrimaryKey, structure.EntityName, dataSourceName, element); + double? distance = null; + if (structure.SemanticDistanceByPrimaryKeySignature.TryGetValue(signature, out double mappedDistance)) + { + distance = mappedDistance; + } + + if (includeRestField) + { + row[SemanticSearchConstants.REST_DISTANCE_FIELD] = distance.HasValue ? JsonValue.Create(distance.Value) : null; + } + + if (includeGraphQlField) + { + row[SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD] = distance.HasValue ? JsonValue.Create(distance.Value) : null; + } + + rows.Add((distance, row)); + } + + if (!structure.HasUserSpecifiedOrdering) + { + rows = rows.OrderByDescending(r => r.distance ?? double.MinValue).ToList(); + } + + JsonElement updated = JsonSerializer.SerializeToElement(rows.Select(r => r.row).ToList()); + + response.Dispose(); + return JsonDocument.Parse(updated.GetRawText()); + } + + private string BuildPrimaryKeySignatureFromRow( + IReadOnlyList primaryKeyColumns, + string entityName, + string dataSourceName, + JsonElement row) + { + ISqlMetadataProvider metadataProvider = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName); + Dictionary values = new(StringComparer.Ordinal); + foreach (string primaryKeyColumn in primaryKeyColumns) + { + if (!metadataProvider.TryGetExposedColumnName(entityName, primaryKeyColumn, out string? exposedName) + || string.IsNullOrWhiteSpace(exposedName) + || !row.TryGetProperty(exposedName, out JsonElement valueElement)) + { + values[primaryKeyColumn] = null; + continue; + } + + values[primaryKeyColumn] = valueElement.ValueKind switch + { + JsonValueKind.Null => null, + JsonValueKind.String => valueElement.GetString(), + _ => valueElement.ToString() + }; + } + + return BuildPrimaryKeySignatureFromValues(primaryKeyColumns, values); + } + + private static string BuildPrimaryKeySignatureFromValues( + IReadOnlyList primaryKeyColumns, + IReadOnlyDictionary values) + { + return string.Join("|", primaryKeyColumns.Select(pk => + { + string serializedValue = values.TryGetValue(pk, out object? value) + ? value?.ToString() ?? "null" + : "null"; + return $"{pk}={serializedValue}"; + })); } /// diff --git a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs index 979da52eb6..ae527195f4 100644 --- a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs @@ -12,6 +12,7 @@ using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Parsers; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; @@ -374,6 +375,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour // Operations including primary key Dictionary pkOperations = CreateOperations( entityName: entityName, + entity: entity, sourceDefinition: sourceDefinition, includePrimaryKeyPathComponent: true, configuredRestOperations: configuredRestOperations, @@ -398,6 +400,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour // Operations excluding primary key Dictionary operations = CreateOperations( entityName: entityName, + entity: entity, sourceDefinition: sourceDefinition, includePrimaryKeyPathComponent: false, configuredRestOperations: configuredRestOperations, @@ -434,6 +437,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour /// Collection of operation types and associated definitions. private Dictionary CreateOperations( string entityName, + Entity entity, SourceDefinition sourceDefinition, bool includePrimaryKeyPathComponent, Dictionary configuredRestOperations, @@ -447,7 +451,7 @@ private Dictionary CreateOperations( if (configuredRestOperations[OperationType.Get]) { OpenApiOperation getOperation = CreateBaseOperation(description: GETONE_DESCRIPTION, tags: tags); - AddQueryParameters(getOperation.Parameters); + AddQueryParameters(getOperation.Parameters, entity); getOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); openApiPathItemOperations.Add(OperationType.Get, getOperation); } @@ -492,7 +496,7 @@ private Dictionary CreateOperations( if (configuredRestOperations[OperationType.Get]) { OpenApiOperation getAllOperation = CreateBaseOperation(description: GETALL_DESCRIPTION, tags: tags); - AddQueryParameters(getAllOperation.Parameters); + AddQueryParameters(getAllOperation.Parameters, entity); getAllOperation.Responses.Add( HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName, includeNextLink: true)); @@ -542,12 +546,27 @@ private Dictionary CreateOperations( /// Helper method to add query parameters like $select, $first, $orderby etc. to get and getAll operations for tables/views. /// /// List of parameters for the operation. - private static void AddQueryParameters(IList parameters) + private static void AddQueryParameters(IList parameters, Entity entity) { foreach (OpenApiParameter openApiParameter in _tableAndViewQueryParameters) { parameters.Add(openApiParameter); } + + if (entity.SemanticSearch?.Enabled is true) + { + parameters.Add(GetOpenApiQueryParameter( + name: RequestParser.SEMANTIC_SEARCH_URL, + description: entity.SemanticSearch.InputDescription, + required: false, + type: "string")); + + parameters.Add(GetOpenApiQueryParameter( + name: RequestParser.SEMANTIC_THRESHOLD_URL, + description: "Minimum semantic similarity threshold between 0.0 and 1.0.", + required: false, + type: "number")); + } } /// @@ -1507,6 +1526,19 @@ private static OpenApiSchema CreateComponentSchema( } } + if (!isRequestBodySchema + && entityConfig?.SemanticSearch?.Enabled is true + && !properties.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD)) + { + properties.Add(SemanticSearchConstants.REST_DISTANCE_FIELD, new OpenApiSchema() + { + Type = "number", + Description = entityConfig.SemanticSearch.OutputDescription, + ReadOnly = true, + Nullable = true + }); + } + OpenApiSchema schema = new() { Type = SCHEMA_OBJECT_TYPE, diff --git a/src/Core/Services/RequestValidator.cs b/src/Core/Services/RequestValidator.cs index 0b6b0734ad..2a05bbd0c5 100644 --- a/src/Core/Services/RequestValidator.cs +++ b/src/Core/Services/RequestValidator.cs @@ -272,6 +272,15 @@ public void ValidateInsertRequestContext(InsertRequestContext insertRequestCtx) { ISqlMetadataProvider sqlMetadataProvider = GetSqlMetadataProvider(insertRequestCtx.EntityName); + if (insertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD) + || insertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD)) + { + throw new DataApiBuilderException( + message: "semantic_distance is read-only.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + IEnumerable fieldsInRequestBody = insertRequestCtx.FieldValuePairsInBody.Keys; SourceDefinition sourceDefinition = TryGetSourceDefinition(insertRequestCtx.EntityName, sqlMetadataProvider); @@ -356,6 +365,16 @@ public void ValidateInsertRequestContext(InsertRequestContext insertRequestCtx) public void ValidateUpsertRequestContext(UpsertRequestContext upsertRequestCtx, bool isPrimaryKeyInUrl = true) { ISqlMetadataProvider sqlMetadataProvider = GetSqlMetadataProvider(upsertRequestCtx.EntityName); + + if (upsertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD) + || upsertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD)) + { + throw new DataApiBuilderException( + message: "semantic_distance is read-only.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + IEnumerable fieldsInRequestBody = upsertRequestCtx.FieldValuePairsInBody.Keys; bool isRequestBodyStrict = IsRequestBodyStrict(); SourceDefinition sourceDefinition = TryGetSourceDefinition(upsertRequestCtx.EntityName, sqlMetadataProvider); diff --git a/src/Core/Services/RestService.cs b/src/Core/Services/RestService.cs index 5014d942bb..5c6705f0d5 100644 --- a/src/Core/Services/RestService.cs +++ b/src/Core/Services/RestService.cs @@ -194,6 +194,7 @@ RequestValidator requestValidator context.RawQueryString = queryString; context.ParsedQueryString = HttpUtility.ParseQueryString(queryString); RequestParser.ParseQueryString(context, sqlMetadataProvider); + ValidateSemanticRestRequest(context, operationType, _runtimeConfigProvider.GetConfig()); } } @@ -235,6 +236,12 @@ private async Task DispatchQuery(RestRequestContext context, Data if (context is FindRequestContext findRequestContext) { + if (findRequestContext.IncludeSemanticDistanceInResponse + && !findRequestContext.FieldsToBeReturned.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD)) + { + findRequestContext.FieldsToBeReturned.Add(SemanticSearchConstants.REST_DISTANCE_FIELD); + } + using JsonDocument? restApiResponse = await queryEngine.ExecuteAsync(findRequestContext); return restApiResponse is null ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), findRequestContext, _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName), _runtimeConfigProvider.GetConfig(), GetHttpContext()) : SqlResponseHelpers.FormatFindResult(restApiResponse.RootElement.Clone(), findRequestContext, _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName), _runtimeConfigProvider.GetConfig(), GetHttpContext()); @@ -356,6 +363,71 @@ private bool IsHttpMethodAllowedForStoredProcedure(string entityName) return false; } + private static void ValidateSemanticRestRequest(RestRequestContext context, EntityActionOperation operationType, RuntimeConfig runtimeConfig) + { + bool hasSemanticSearch = !string.IsNullOrWhiteSpace(context.SemanticSearch); + bool hasSemanticThreshold = context.SemanticThreshold.HasValue; + bool isSemanticRequest = hasSemanticSearch || hasSemanticThreshold; + + if (context.FieldsToBeReturned.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparer.OrdinalIgnoreCase)) + { + context.FieldsToBeReturned = context.FieldsToBeReturned + .Where(field => !string.Equals(field, SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase)) + .ToList(); + context.IsSemanticDistanceExplicitlySelected = true; + } + + if (!isSemanticRequest) + { + if (context.IsSemanticDistanceExplicitlySelected) + { + throw new DataApiBuilderException( + message: "semantic_distance can only be selected when semantic search is used.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + return; + } + + if (operationType is not EntityActionOperation.Read) + { + throw new DataApiBuilderException( + message: "Semantic search is only supported for read operations.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (!runtimeConfig.Entities.TryGetValue(context.EntityName, out Entity? entity) + || entity.SemanticSearch is null + || !entity.SemanticSearch.Enabled) + { + throw new DataApiBuilderException( + message: $"Semantic search is not enabled for entity '{context.EntityName}'.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (!string.IsNullOrWhiteSpace(context.After)) + { + throw new DataApiBuilderException( + message: "Pagination continuation tokens are not supported when semantic search is used.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (context.OrderByClauseInUrl is not null + && context.OrderByClauseInUrl.Any(col => string.Equals(col.ColumnName, SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase))) + { + throw new DataApiBuilderException( + message: "semantic_distance cannot be used in orderBy.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + context.IncludeSemanticDistanceInResponse = context.FieldsToBeReturned.Count == 0 || context.IsSemanticDistanceExplicitlySelected; + } + /// /// Gets the list of HTTP methods defined for entities representing stored procedures. /// When no explicit REST method configuration is present for a stored procedure entity, diff --git a/src/Core/Services/SemanticSearch/ISemanticSearchService.cs b/src/Core/Services/SemanticSearch/ISemanticSearchService.cs new file mode 100644 index 0000000000..f3c124edfd --- /dev/null +++ b/src/Core/Services/SemanticSearch/ISemanticSearchService.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Models; + +namespace Azure.DataApiBuilder.Core.Services.SemanticSearch; + +public interface ISemanticSearchService +{ + Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top); +} diff --git a/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs b/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs new file mode 100644 index 0000000000..dfabecb133 --- /dev/null +++ b/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Models; + +namespace Azure.DataApiBuilder.Core.Services.SemanticSearch; + +/// +/// Placeholder semantic search service used when no embedding/vector integration is configured. +/// +public sealed class NoOpSemanticSearchService : ISemanticSearchService +{ + public static readonly NoOpSemanticSearchService Instance = new(); + + private NoOpSemanticSearchService() + { + } + + public Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top) + { + return Task.FromResult>([]); + } +} diff --git a/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs b/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs index 5441dcc766..69a16d4a7d 100644 --- a/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System; using Azure.DataApiBuilder.Service.GraphQLBuilder.Directives; using Azure.DataApiBuilder.Service.GraphQLBuilder.GraphQLTypes; using HotChocolate.Language; @@ -53,7 +54,8 @@ private static List GenerateOrderByInputFieldsForBuilt // Skip scalar array fields (e.g., PostgreSQL int[], text[]) - they cannot be ordered. // Non-scalar list types (e.g., Cosmos nested object arrays) are not skipped // because they are handled as relationship navigations. - if (field.Type.IsListType() && IsBuiltInType(field.Type)) + if (string.Equals(field.Name.Value, QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, StringComparison.Ordinal) + || (field.Type.IsListType() && IsBuiltInType(field.Type))) { continue; } @@ -122,9 +124,10 @@ private static List GenerateFilterInputFieldsForBuiltI // which are read-only and cannot be filtered. Cosmos scalar arrays like // tags: [String] do NOT have @autoGenerated and remain filterable // (using ARRAY_CONTAINS). - if (field.Type.IsListType() - && IsBuiltInType(field.Type) - && field.Directives.Any(d => d.Name.Value == AutoGeneratedDirectiveType.DirectiveName)) + if (string.Equals(field.Name.Value, QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, StringComparison.Ordinal) + || (field.Type.IsListType() + && IsBuiltInType(field.Type) + && field.Directives.Any(d => d.Name.Value == AutoGeneratedDirectiveType.DirectiveName))) { continue; } diff --git a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs index a2cc63b2c2..100759522f 100644 --- a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs @@ -30,6 +30,8 @@ public static class QueryBuilder public const string GROUP_BY_AGGREGATE_FIELD_ARG_NAME = "field"; public const string GROUP_BY_AGGREGATE_FIELD_DISTINCT_NAME = "distinct"; public const string GROUP_BY_AGGREGATE_FIELD_HAVING_NAME = "having"; + public const string SEMANTIC_SEARCH_ARGUMENT_NAME = SemanticSearchConstants.GRAPHQL_SEARCH_ARGUMENT; + public const string SEMANTIC_THRESHOLD_ARGUMENZT_NAME = SemanticSearchConstants.GRAPHQL_THRESHOLD_ARGUMENT; // Define the enabled database types for aggregation public static readonly HashSet AggregationEnabledDatabaseTypes = new() @@ -191,21 +193,29 @@ public static FieldDefinitionNode GenerateGetAllQuery( location: null, new NameNode(GenerateListQueryName(name.Value, entity)), new StringValueNode($"Get a list of all the {GetDefinedSingularName(name.Value, entity)} items from the database"), - QueryArgumentsForField(filterInputName, orderByInputName), + QueryArgumentsForField(filterInputName, orderByInputName, includeSemanticArguments: entity.SemanticSearch?.Enabled ?? false), new NonNullTypeNode(new NamedTypeNode(returnType.Name)), fieldDefinitionNodeDirectives ); } - public static List QueryArgumentsForField(string filterInputName, string orderByInputName) + public static List QueryArgumentsForField(string filterInputName, string orderByInputName, bool includeSemanticArguments = false) { - return new() + List args = new() { new(location: null, new NameNode(PAGE_START_ARGUMENT_NAME), description: new StringValueNode("The number of items to return from the page start point"), new IntType().ToTypeNode(), defaultValue: null, new List()), new(location: null, new NameNode(PAGINATION_TOKEN_ARGUMENT_NAME), new StringValueNode("A pagination token from a previous query to continue through a paginated list"), new StringType().ToTypeNode(), defaultValue: null, new List()), new(location: null, new NameNode(FILTER_FIELD_NAME), new StringValueNode("Filter options for query"), new NamedTypeNode(filterInputName), defaultValue: null, new List()), new(location: null, new NameNode(ORDER_BY_FIELD_NAME), new StringValueNode("Ordering options for query"), new NamedTypeNode(orderByInputName), defaultValue: null, new List()), }; + + if (includeSemanticArguments) + { + args.Add(new(location: null, new NameNode(SEMANTIC_SEARCH_ARGUMENT_NAME), new StringValueNode("Natural language value used for semantic search."), new StringType().ToTypeNode(), defaultValue: null, new List())); + args.Add(new(location: null, new NameNode(SEMANTIC_THRESHOLD_ARGUMENT_NAME), new StringValueNode("Minimum semantic similarity threshold between 0.0 and 1.0."), new FloatType().ToTypeNode(), defaultValue: null, new List())); + } + + return args; } public static ObjectTypeDefinitionNode AddQueryArgumentsForRelationships(ObjectTypeDefinitionNode node, Dictionary inputObjects) diff --git a/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs b/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs index 622376dc13..1da4e56197 100644 --- a/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs +++ b/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs @@ -210,6 +210,20 @@ private static ObjectTypeDefinitionNode CreateObjectTypeDefinitionForTableOrView } } + if (configEntity.SemanticSearch?.Enabled is true) + { + List semanticDistanceDirectives = [new DirectiveNode(AutoGeneratedDirectiveType.DirectiveName)]; + FieldDefinitionNode semanticDistanceField = new( + location: null, + name: new NameNode(QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME), + description: new StringValueNode(configEntity.SemanticSearch.OutputDescription), + arguments: [], + type: new NamedTypeNode(new NameNode(FLOAT_TYPE)), + directives: semanticDistanceDirectives); + + fieldDefinitionNodes.TryAdd(QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, semanticDistanceField); + } + // A linking entity is not exposed in the runtime config file but is used by DAB to support multiple mutations on entities with M:N relationship. // Hence we don't need to process relationships for the linking entity itself. if (!configEntity.IsLinkingEntity) diff --git a/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs b/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs index ffe636f6db..89a747df15 100644 --- a/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs +++ b/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs @@ -16,6 +16,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.GraphQLBuilder; using Azure.DataApiBuilder.Service.GraphQLBuilder.Directives; using Azure.DataApiBuilder.Service.GraphQLBuilder.Mutations; @@ -439,6 +440,7 @@ private static async Task GetGQLSchemaCreator(RuntimeConfi contextAccessor: httpContextAccessor.Object, authorizationResolver: authorizationResolver, gQLFilterParser: graphQLFilterParser, + semanticSearchService: NoOpSemanticSearchService.Instance, logger: queryEngineLogger.Object, cache: cacheService, handler: null); diff --git a/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs b/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs index 93a3df9778..4eb727a59f 100644 --- a/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs +++ b/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs @@ -229,6 +229,54 @@ type Foo @model(name:""Foo"") { Assert.AreEqual("Boolean", hasNextPageField.Type.NamedType().Name.Value, "hasNextPage should be Boolean type"); } + [TestMethod] + [TestCategory("Query Generation")] + [TestCategory("Collection access")] + public void CollectionQueryAddsSemanticArgumentsWhenEnabled() + { + string gql = + @" +type Foo @model(name:""Foo"") { + id: ID! +} + "; + + DocumentNode root = Utf8GraphQLParser.Parse(gql); + Dictionary entityNameToDatabaseType = new() + { + { "Foo", DatabaseType.CosmosDB_NoSQL } + }; + + RuntimeEntities entities = new(new Dictionary + { + { + "Foo", + new Entity( + Source: new("Foo", EntitySourceType.Table, null, null), + GraphQL: new("Foo", "Foos"), + Fields: null, + Rest: new(), + Permissions: [], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions { Enabled = true, RedisIndexName = "idx:foo" }) + } + }); + + DocumentNode queryRoot = QueryBuilder.Build( + root, + entityNameToDatabaseType, + entities, + inputTypes: new(), + entityPermissionsMap: _entityPermissions + ); + + ObjectTypeDefinitionNode query = GetQueryNode(queryRoot); + FieldDefinitionNode collectionField = query.Fields.First(f => f.Name.Value == "foos"); + Assert.IsTrue(collectionField.Arguments.Any(a => a.Name.Value == QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME)); + Assert.IsTrue(collectionField.Arguments.Any(a => a.Name.Value == QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME)); + } + [TestMethod] public void PrimaryKeyFieldAsQueryInput() { diff --git a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs index bbb4874d1a..ef1864115d 100644 --- a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs +++ b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs @@ -3579,6 +3579,118 @@ public void ValidateEmbeddingsOptions_EndpointDisabled_SkipsValidation() configValidator.ValidateEmbeddingsOptions(runtimeConfig); } + [TestMethod] + public void ValidateSemanticSearchRequiresRedisLevel2Configuration() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null)), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured.", + ex.Message); + } + + [TestMethod] + public void ValidateSemanticSearchRequiresRedisIndexNameWhenEnabled() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "semantic-search.redis-index-name is required when semantic-search.enabled is true for entity 'Product'.", + ex.Message); + } + + [TestMethod] + public void ValidateSemanticSearchRejectsReservedFieldName() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: [new FieldMetadata { Name = "id" }, new FieldMetadata { Name = "semantic_distance" }], + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Entity 'Product' cannot enable semantic search because field name 'semantic_distance' is reserved.", + ex.Message); + } + private static RuntimeConfigValidator InitializeRuntimeConfigValidator() { MockFileSystem fileSystem = new(); diff --git a/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs b/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs index 35c86abc82..77f057cd0d 100644 --- a/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Net; using System.Text; +using System.Text.Json; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; @@ -327,6 +328,58 @@ public void PrimaryKeyWithNoValueTest() primaryKeyRoute = "id//title/"; PerformRequestParserPrimaryKeyTest(findRequestContext, primaryKeyRoute, expectsException: true); } + + [TestMethod] + public void InsertRequestBodyCannotSetSemanticDistanceField() + { + RuntimeConfig mockConfig = new( + Schema: "", + DataSource: new(DatabaseType.PostgreSQL, "", new()), + Runtime: new( + Rest: new(Path: "/api"), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null) + ), + Entities: new(new Dictionary + { + { + DEFAULT_NAME, + new Entity( + Source: new(DEFAULT_NAME, EntitySourceType.Table, null, null), + GraphQL: new(DEFAULT_NAME, DEFAULT_NAME), + Fields: null, + Rest: new(), + Permissions: [], + Mappings: null, + Relationships: null) + } + }) + ); + + MockFileSystem fileSystem = new(); + fileSystem.AddFile(FileSystemRuntimeConfigLoader.DEFAULT_CONFIG_FILE_NAME, new MockFileData(mockConfig.ToJson())); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader); + + Mock metadataProviderFactory = new(); + metadataProviderFactory.Setup(x => x.GetMetadataProvider(It.IsAny())).Returns(_mockMetadataStore.Object); + RequestValidator requestValidator = new(metadataProviderFactory.Object, provider); + + using JsonDocument body = JsonDocument.Parse("{\"semantic_distance\":0.81}"); + InsertRequestContext context = new( + entityName: DEFAULT_NAME, + dbo: GetDbo(DEFAULT_SCHEMA, DEFAULT_NAME), + insertPayloadRoot: body.RootElement, + operationType: EntityActionOperation.Insert); + + DataApiBuilderException ex = Assert.ThrowsException( + () => requestValidator.ValidateInsertRequestContext(context)); + + Assert.AreEqual(HttpStatusCode.BadRequest, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.BadRequest, ex.SubStatusCode); + StringAssert.Contains(ex.Message, "semantic_distance is read-only."); + } #endregion #region Helper Methods /// diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs new file mode 100644 index 0000000000..726137a53f --- /dev/null +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -0,0 +1,538 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; +using Azure.DataApiBuilder.Service.Exceptions; +using Azure.Identity; +using Microsoft.Azure.StackExchangeRedis; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; + +namespace Azure.DataApiBuilder.Service.Services.SemanticSearch; + +/// +/// Resolves semantic candidates by: +/// 1) generating/retrieving an embedding vector for semantic_search input, +/// 2) performing a Redis FT.SEARCH KNN query, +/// 3) extracting SQL column/value pairs from Redis hash/json documents. +/// +public sealed class RedisSemanticSearchService : ISemanticSearchService +{ + private const string EMBED_ENDPOINT_ENV = "DAB_SEMANTIC_EMBED_ENDPOINT"; + private const string EMBED_API_KEY_ENV = "DAB_SEMANTIC_EMBED_API_KEY"; + private const string VECTOR_SCORE_FIELD = "__vector_score"; + private const string DEFAULT_VECTOR_FIELD = "embedding"; + + private readonly RuntimeConfigProvider _runtimeConfigProvider; + private readonly IMetadataProviderFactory _metadataProviderFactory; + private readonly IHttpClientFactory _httpClientFactory; + private readonly ILogger _logger; + + public RedisSemanticSearchService( + RuntimeConfigProvider runtimeConfigProvider, + IMetadataProviderFactory metadataProviderFactory, + IHttpClientFactory httpClientFactory, + ILogger logger) + { + _runtimeConfigProvider = runtimeConfigProvider; + _metadataProviderFactory = metadataProviderFactory; + _httpClientFactory = httpClientFactory; + _logger = logger; + } + + public async Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top) + { + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + string? connectionString = runtimeConfig.Runtime?.Cache?.Level2?.ConnectionString; + + if (string.IsNullOrWhiteSpace(options.RedisIndexName) + || string.IsNullOrWhiteSpace(connectionString) + || top <= 0) + { + return []; + } + + float[] embedding = await GetEmbeddingAsync(semanticSearchValue); + if (embedding.Length == 0) + { + return []; + } + + string dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName); + SourceDefinition sourceDefinition = _metadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(entityName); + HashSet sourceColumns = new(sourceDefinition.Columns.Keys, StringComparer.OrdinalIgnoreCase); + + try + { + using IConnectionMultiplexer multiplexer = await CreateConnectionMultiplexerAsync(connectionString); + IDatabase db = multiplexer.GetDatabase(); + + string vectorFieldName = await ResolveVectorFieldNameAsync(db, options.RedisIndexName!, options.RedisIndexType); + byte[] vectorBytes = ToRedisVectorBytes(embedding); + + RedisResult rawResult = await db.ExecuteAsync( + "FT.SEARCH", + options.RedisIndexName!, + $"*=>[KNN {top} @{vectorFieldName} $vec AS {VECTOR_SCORE_FIELD}]", + "PARAMS", + "2", + "vec", + vectorBytes, + "SORTBY", + VECTOR_SCORE_FIELD, + "DIALECT", + "2", + "LIMIT", + "0", + top.ToString(CultureInfo.InvariantCulture)); + + return ParseCandidates(rawResult, options.RedisIndexType, sourceColumns, primaryKeyColumns, similarityThreshold); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Semantic search query failed for entity {entityName} and index {indexName}.", entityName, options.RedisIndexName); + throw new DataApiBuilderException( + message: $"Semantic search index '{options.RedisIndexName}' for entity '{entityName}' was not found or could not be queried.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + } + + private static byte[] ToRedisVectorBytes(float[] vector) + { + byte[] bytes = new byte[vector.Length * sizeof(float)]; + Buffer.BlockCopy(vector, 0, bytes, 0, bytes.Length); + return bytes; + } + + private async Task GetEmbeddingAsync(string semanticSearchValue) + { + if (TryParseVectorText(semanticSearchValue, out float[]? parsedVector) && parsedVector is not null) + { + return parsedVector; + } + + string? endpoint = Environment.GetEnvironmentVariable(EMBED_ENDPOINT_ENV); + if (string.IsNullOrWhiteSpace(endpoint)) + { + return []; + } + + HttpClient client = _httpClientFactory.CreateClient(); + string? apiKey = Environment.GetEnvironmentVariable(EMBED_API_KEY_ENV); + if (!string.IsNullOrWhiteSpace(apiKey)) + { + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); + } + + string payload = JsonSerializer.Serialize(new Dictionary { { "input", semanticSearchValue } }); + using HttpRequestMessage request = new(HttpMethod.Post, endpoint) + { + Content = new StringContent(payload, Encoding.UTF8, "application/json") + }; + + using HttpResponseMessage response = await client.SendAsync(request); + if (!response.IsSuccessStatusCode) + { + return []; + } + + string body = await response.Content.ReadAsStringAsync(); + return TryExtractEmbedding(body, out float[]? embedding) && embedding is not null ? embedding : []; + } + + private static bool TryParseVectorText(string text, out float[]? vector) + { + vector = null; + + try + { + using JsonDocument json = JsonDocument.Parse(text); + if (json.RootElement.ValueKind is not JsonValueKind.Array) + { + return false; + } + + List values = []; + foreach (JsonElement element in json.RootElement.EnumerateArray()) + { + if (!element.TryGetSingle(out float current)) + { + return false; + } + + values.Add(current); + } + + vector = values.ToArray(); + return vector.Length > 0; + } + catch + { + return false; + } + } + + private static bool TryExtractEmbedding(string responseBody, out float[]? embedding) + { + embedding = null; + + try + { + using JsonDocument json = JsonDocument.Parse(responseBody); + + if (json.RootElement.ValueKind is JsonValueKind.Array) + { + return TryReadEmbeddingArray(json.RootElement, out embedding); + } + + if (json.RootElement.ValueKind is JsonValueKind.Object) + { + if (json.RootElement.TryGetProperty("embedding", out JsonElement directEmbedding) + && TryReadEmbeddingArray(directEmbedding, out embedding)) + { + return true; + } + + if (json.RootElement.TryGetProperty("data", out JsonElement data) + && data.ValueKind is JsonValueKind.Array + && data.GetArrayLength() > 0) + { + JsonElement first = data[0]; + if (first.ValueKind is JsonValueKind.Object + && first.TryGetProperty("embedding", out JsonElement nestedEmbedding) + && TryReadEmbeddingArray(nestedEmbedding, out embedding)) + { + return true; + } + } + } + + return false; + } + catch + { + return false; + } + } + + private static bool TryReadEmbeddingArray(JsonElement array, out float[]? embedding) + { + embedding = null; + + if (array.ValueKind is not JsonValueKind.Array) + { + return false; + } + + List values = []; + foreach (JsonElement item in array.EnumerateArray()) + { + if (!item.TryGetSingle(out float current)) + { + return false; + } + + values.Add(current); + } + + embedding = values.ToArray(); + return embedding.Length > 0; + } + + private static async Task CreateConnectionMultiplexerAsync(string connectionString) + { + ConfigurationOptions options = ConfigurationOptions.Parse(connectionString); + + if (Startup.ShouldUseEntraAuthForRedis(options)) + { + options = await options.ConfigureForAzureWithTokenCredentialAsync(new DefaultAzureCredential()); + } + + return await ConnectionMultiplexer.ConnectAsync(options); + } + + private static async Task ResolveVectorFieldNameAsync(IDatabase db, string indexName, string redisIndexType) + { + try + { + RedisResult infoResult = await db.ExecuteAsync("FT.INFO", indexName); + if (infoResult.IsNull) + { + return DEFAULT_VECTOR_FIELD; + } + + RedisResult[]? infoItems = AsRedisArray(infoResult); + if (infoItems is not null) + { + for (int i = 0; i + 1 < infoItems.Length; i += 2) + { + string? key = infoItems[i].ToString(); + if (!string.Equals(key, "attributes", StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + RedisResult[]? attributes = AsRedisArray(infoItems[i + 1]); + if (attributes is null) + { + continue; + } + + foreach (RedisResult attribute in attributes) + { + RedisResult[]? tokens = AsRedisArray(attribute); + if (tokens is null) + { + continue; + } + + bool isVector = false; + string? identifier = null; + string? alias = null; + + for (int t = 0; t + 1 < tokens.Length; t += 2) + { + string? tokenKey = tokens[t].ToString(); + string? tokenValue = tokens[t + 1].ToString(); + + if (string.Equals(tokenKey, "type", StringComparison.OrdinalIgnoreCase) + && string.Equals(tokenValue, "VECTOR", StringComparison.OrdinalIgnoreCase)) + { + isVector = true; + } + else if (string.Equals(tokenKey, "identifier", StringComparison.OrdinalIgnoreCase)) + { + identifier = tokenValue; + } + else if (string.Equals(tokenKey, "attribute", StringComparison.OrdinalIgnoreCase)) + { + alias = tokenValue; + } + } + + if (isVector) + { + string? selected = string.Equals(redisIndexType, "json", StringComparison.OrdinalIgnoreCase) + ? alias ?? identifier + : identifier ?? alias; + + if (!string.IsNullOrWhiteSpace(selected)) + { + return selected.TrimStart('$', '.'); + } + } + } + } + } + } + catch + { + // Fall back to default vector field name; caller will throw a semantic index error if query fails. + } + + return DEFAULT_VECTOR_FIELD; + } + + private static IReadOnlyList ParseCandidates( + RedisResult rawResult, + string redisIndexType, + HashSet sourceColumns, + IReadOnlyList primaryKeyColumns, + double similarityThreshold) + { + RedisResult[]? rows = AsRedisArray(rawResult); + if (rawResult.IsNull || rows is null || rows.Length < 3) + { + return []; + } + + List results = []; + for (int i = 1; i + 1 < rows.Length; i += 2) + { + RedisResult[]? payload = AsRedisArray(rows[i + 1]); + if (payload is null) + { + continue; + } + + Dictionary extracted = ExtractDocumentFields(payload, redisIndexType); + double similarity = TryReadSimilarity(extracted); + + if (similarity < similarityThreshold) + { + continue; + } + + Dictionary sqlColumns = new(StringComparer.OrdinalIgnoreCase); + foreach ((string key, object? value) in extracted) + { + string normalized = NormalizeFieldName(key); + if (sourceColumns.Contains(normalized)) + { + sqlColumns[normalized] = value; + } + } + + if (sqlColumns.Count == 0) + { + continue; + } + + Dictionary primaryKeys = new(StringComparer.OrdinalIgnoreCase); + bool hasAllPrimaryKeys = true; + foreach (string pk in primaryKeyColumns) + { + if (!sqlColumns.TryGetValue(pk, out object? pkValue) || pkValue is null) + { + hasAllPrimaryKeys = false; + break; + } + + primaryKeys[pk] = pkValue; + } + + if (!hasAllPrimaryKeys) + { + continue; + } + + results.Add(new SemanticSearchCandidate(primaryKeys, sqlColumns, similarity)); + } + + return results; + } + + private static RedisResult[]? AsRedisArray(RedisResult result) + { + try + { + RedisResult[]? value = (RedisResult[]?)result; + return value; + } + catch + { + return null; + } + } + + private static Dictionary ExtractDocumentFields(RedisResult[] payload, string redisIndexType) + { + Dictionary fields = new(StringComparer.OrdinalIgnoreCase); + + for (int i = 0; i + 1 < payload.Length; i += 2) + { + string? key = payload[i].ToString(); + string? value = payload[i + 1].ToString(); + if (string.IsNullOrWhiteSpace(key)) + { + continue; + } + + if (string.Equals(redisIndexType, "json", StringComparison.OrdinalIgnoreCase) + && string.Equals(key, "$", StringComparison.Ordinal)) + { + MergeJsonDocumentFields(fields, value); + continue; + } + + fields[key] = value; + } + + return fields; + } + + private static void MergeJsonDocumentFields(Dictionary fields, string? json) + { + if (string.IsNullOrWhiteSpace(json)) + { + return; + } + + try + { + using JsonDocument doc = JsonDocument.Parse(json); + if (doc.RootElement.ValueKind is not JsonValueKind.Object) + { + return; + } + + foreach (JsonProperty property in doc.RootElement.EnumerateObject()) + { + fields[property.Name] = property.Value.ValueKind switch + { + JsonValueKind.Null => null, + JsonValueKind.String => property.Value.GetString(), + JsonValueKind.Number => property.Value.ToString(), + JsonValueKind.True => true, + JsonValueKind.False => false, + _ => property.Value.ToString() + }; + } + } + catch + { + // Ignore malformed JSON payloads and continue with other fields. + } + } + + private static double TryReadSimilarity(Dictionary fields) + { + if (!fields.TryGetValue(VECTOR_SCORE_FIELD, out object? scoreObj) + || scoreObj is null + || !double.TryParse(scoreObj.ToString(), NumberStyles.Float, CultureInfo.InvariantCulture, out double rawScore)) + { + return 0.0; + } + + // Redis KNN score is distance; normalize to similarity for semantic-threshold semantics. + double similarity = 1.0 - rawScore; + if (similarity < 0.0) + { + similarity = 0.0; + } + + if (similarity > 1.0) + { + similarity = 1.0; + } + + return similarity; + } + + private static string NormalizeFieldName(string field) + { + string normalized = field.Trim(); + + if (normalized.StartsWith("$.", StringComparison.Ordinal)) + { + normalized = normalized[2..]; + } + + if (normalized.Contains('.', StringComparison.Ordinal)) + { + normalized = normalized[(normalized.LastIndexOf('.') + 1)..]; + } + + return normalized; + } +} \ No newline at end of file diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index 876b63378e..e40d08801a 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -28,11 +28,13 @@ using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Core.Telemetry; using Azure.DataApiBuilder.Mcp.Core; using Azure.DataApiBuilder.Service.Controllers; using Azure.DataApiBuilder.Service.Exceptions; using Azure.DataApiBuilder.Service.HealthCheck; +using Azure.DataApiBuilder.Service.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Telemetry; using Azure.DataApiBuilder.Service.Utilities; using Azure.Identity; @@ -315,6 +317,8 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddHttpClient(); + services.AddSingleton(); // ILogger explicit creation required for logger to use --LogLevel startup argument specified. services.AddSingleton>(implementationFactory: (serviceProvider) => From 8c3a221996889bcd5e9a73d25aa8f887e1863f15 Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Wed, 10 Jun 2026 07:45:06 -0700 Subject: [PATCH 04/12] More features and tests --- schemas/dab.draft.schema.json | 15 +++++ src/Cli.Tests/ConfigureOptionsTests.cs | 48 ++++++++++++++++ src/Cli/Commands/ConfigureOptions.cs | 10 ++++ src/Cli/ConfigGenerator.cs | 57 +++++++++++++++++++ src/Config/ObjectModel/RuntimeOptions.cs | 4 ++ .../RuntimeSemanticSearchOptions.cs | 31 ++++++++++ .../Configurations/RuntimeConfigValidator.cs | 21 +++++++ .../Queries/QueryBuilder.cs | 5 +- .../UnitTests/ConfigValidationUnitTests.cs | 42 +++++++++++++- .../RedisSemanticSearchService.cs | 10 ++-- 10 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 94b731bd56..da9455af1f 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -534,6 +534,21 @@ } } }, + "semantic-search": { + "type": "object", + "description": "Runtime semantic-search configuration.", + "additionalProperties": false, + "properties": { + "embedding-endpoint": { + "type": "string", + "description": "Endpoint used to generate embeddings from semantic-search input text." + }, + "embedding-api-key": { + "type": "string", + "description": "Optional API key used as bearer token for embedding endpoint calls." + } + } + }, "compression": { "type": "object", "description": "Configures HTTP response compression settings.", diff --git a/src/Cli.Tests/ConfigureOptionsTests.cs b/src/Cli.Tests/ConfigureOptionsTests.cs index f1155fb74d..35731cac71 100644 --- a/src/Cli.Tests/ConfigureOptionsTests.cs +++ b/src/Cli.Tests/ConfigureOptionsTests.cs @@ -548,6 +548,54 @@ public void TestUpdateTTLForCacheSettings(int updatedTtlValue) Assert.AreEqual(updatedTtlValue, runtimeConfig.Runtime.Cache.TtlSeconds); } + /// + /// Tests that running "dab configure --runtime.semantic-search.embedding-endpoint {value}" updates runtime semantic-search embedding endpoint. + /// + [TestMethod] + public void TestUpdateEmbeddingEndpointForSemanticSearchRuntimeSettings() + { + // Arrange -> all the setup which includes creating options. + SetupFileSystemWithInitialConfig(INITIAL_CONFIG); + string updatedEmbeddingEndpoint = "https://example.org/embed"; + + // Act: Attempts to update embedding endpoint value + ConfigureOptions options = new( + runtimeSemanticSearchEmbeddingEndpoint: updatedEmbeddingEndpoint, + config: TEST_RUNTIME_CONFIG_FILE + ); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert: Validate the embedding endpoint value is updated + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? runtimeConfig)); + Assert.AreEqual(updatedEmbeddingEndpoint, runtimeConfig.Runtime?.SemanticSearch?.EmbeddingEndpoint); + } + + /// + /// Tests that running "dab configure --runtime.semantic-search.embedding-api-key {value}" updates runtime semantic-search embedding API key. + /// + [TestMethod] + public void TestUpdateEmbeddingApiKeyForSemanticSearchRuntimeSettings() + { + // Arrange -> all the setup which includes creating options. + SetupFileSystemWithInitialConfig(INITIAL_CONFIG); + string updatedEmbeddingApiKey = "test-api-key"; + + // Act: Attempts to update embedding API key value + ConfigureOptions options = new( + runtimeSemanticSearchEmbeddingApiKey: updatedEmbeddingApiKey, + config: TEST_RUNTIME_CONFIG_FILE + ); + bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); + + // Assert: Validate the embedding API key value is updated + Assert.IsTrue(isSuccess); + string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); + Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? runtimeConfig)); + Assert.AreEqual(updatedEmbeddingApiKey, runtimeConfig.Runtime?.SemanticSearch?.EmbeddingApiKey); + } + /// /// Tests that running "dab configure --runtime.compression.level {value}" on a config with various values results /// in runtime config update. Takes in updated value for compression.level and diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index b64de19981..0b17ab4599 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -63,6 +63,8 @@ public ConfigureOptions( bool? runtimePaginationNextLinkRelative = null, string? runtimeCacheLevel2Provider = null, string? runtimeCacheLevel2ConnectionString = null, + string? runtimeSemanticSearchEmbeddingEndpoint = null, + string? runtimeSemanticSearchEmbeddingApiKey = null, CompressionLevel? runtimeCompressionLevel = null, bool? runtimeHealthEnabled = null, int? runtimeHealthCacheTtlSeconds = null, @@ -165,6 +167,8 @@ public ConfigureOptions( RuntimePaginationNextLinkRelative = runtimePaginationNextLinkRelative; RuntimeCacheLevel2Provider = runtimeCacheLevel2Provider; RuntimeCacheLevel2ConnectionString = runtimeCacheLevel2ConnectionString; + RuntimeSemanticSearchEmbeddingEndpoint = runtimeSemanticSearchEmbeddingEndpoint; + RuntimeSemanticSearchEmbeddingApiKey = runtimeSemanticSearchEmbeddingApiKey; // Compression RuntimeCompressionLevel = runtimeCompressionLevel; // Health @@ -352,6 +356,12 @@ public ConfigureOptions( [Option("runtime.cache.level-2.connection-string", Required = false, HelpText = "Set level-2 cache connection string.")] public string? RuntimeCacheLevel2ConnectionString { get; } + [Option("runtime.semantic-search.embedding-endpoint", Required = false, HelpText = "Set semantic-search embedding endpoint.")] + public string? RuntimeSemanticSearchEmbeddingEndpoint { get; } + + [Option("runtime.semantic-search.embedding-api-key", Required = false, HelpText = "Set semantic-search embedding API key.")] + public string? RuntimeSemanticSearchEmbeddingApiKey { get; } + [Option("runtime.compression.level", Required = false, HelpText = "Set the response compression level. Allowed values: optimal (default), fastest, none.")] public CompressionLevel? RuntimeCompressionLevel { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index c513f1ce30..a743825357 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -1060,6 +1060,22 @@ private static bool TryUpdateConfiguredRuntimeOptions( runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { Pagination = updatedPaginationOptions } }; } + // Semantic Search: Embedding endpoint and API key + if (options.RuntimeSemanticSearchEmbeddingEndpoint != null || + options.RuntimeSemanticSearchEmbeddingApiKey != null) + { + RuntimeSemanticSearchOptions? updatedSemanticSearchOptions = runtimeConfig?.Runtime?.SemanticSearch ?? new(); + bool status = TryUpdateConfiguredSemanticSearchValues(options, ref updatedSemanticSearchOptions); + if (status) + { + runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { SemanticSearch = updatedSemanticSearchOptions } }; + } + else + { + return false; + } + } + // Compression: Level if (options.RuntimeCompressionLevel != null) { @@ -1616,6 +1632,47 @@ private static bool TryUpdateConfiguredCacheValues( } } + /// + /// Attempts to update the Config parameters in the semantic-search runtime settings based on the provided value. + /// + /// options. + /// updatedSemanticSearchOptions. + /// True if updates succeed, else false. + private static bool TryUpdateConfiguredSemanticSearchValues( + ConfigureOptions options, + ref RuntimeSemanticSearchOptions? updatedSemanticSearchOptions) + { + try + { + if (options?.RuntimeSemanticSearchEmbeddingEndpoint is not null) + { + RuntimeSemanticSearchOptions current = updatedSemanticSearchOptions ?? new(); + updatedSemanticSearchOptions = current with + { + EmbeddingEndpoint = options.RuntimeSemanticSearchEmbeddingEndpoint + }; + _logger.LogInformation("Updated RuntimeConfig with runtime.semantic-search.embedding-endpoint."); + } + + if (options?.RuntimeSemanticSearchEmbeddingApiKey is not null) + { + RuntimeSemanticSearchOptions current = updatedSemanticSearchOptions ?? new(); + updatedSemanticSearchOptions = current with + { + EmbeddingApiKey = options.RuntimeSemanticSearchEmbeddingApiKey + }; + _logger.LogInformation("Updated RuntimeConfig with runtime.semantic-search.embedding-api-key."); + } + + return true; + } + catch (Exception ex) + { + _logger.LogError("Failed to update RuntimeConfig.SemanticSearch with exception message: {exceptionMessage}.", ex.Message); + return false; + } + } + /// /// Attempts to update the Config parameters in the Compression runtime settings based on the provided value. /// Validates user-provided parameters and then returns true if the updated Compression options diff --git a/src/Config/ObjectModel/RuntimeOptions.cs b/src/Config/ObjectModel/RuntimeOptions.cs index 1d5ad86db0..1b3c681fac 100644 --- a/src/Config/ObjectModel/RuntimeOptions.cs +++ b/src/Config/ObjectModel/RuntimeOptions.cs @@ -16,6 +16,8 @@ public record RuntimeOptions public string? BaseRoute { get; init; } public TelemetryOptions? Telemetry { get; init; } public RuntimeCacheOptions? Cache { get; init; } + [JsonPropertyName("semantic-search")] + public RuntimeSemanticSearchOptions? SemanticSearch { get; init; } public PaginationOptions? Pagination { get; init; } public RuntimeHealthCheckConfig? Health { get; init; } public EmbeddingsOptions? Embeddings { get; init; } @@ -30,6 +32,7 @@ public RuntimeOptions( string? BaseRoute = null, TelemetryOptions? Telemetry = null, RuntimeCacheOptions? Cache = null, + RuntimeSemanticSearchOptions? SemanticSearch = null, PaginationOptions? Pagination = null, RuntimeHealthCheckConfig? Health = null, EmbeddingsOptions? Embeddings = null, @@ -42,6 +45,7 @@ public RuntimeOptions( this.BaseRoute = BaseRoute; this.Telemetry = Telemetry; this.Cache = Cache; + this.SemanticSearch = SemanticSearch; this.Pagination = Pagination; this.Health = Health; this.Embeddings = Embeddings; diff --git a/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs b/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs new file mode 100644 index 0000000000..1c72667de2 --- /dev/null +++ b/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Runtime semantic search configuration. +/// +public record RuntimeSemanticSearchOptions +{ + /// + /// Endpoint used to generate embeddings from semantic-search input text. + /// + [JsonPropertyName("embedding-endpoint")] + public string? EmbeddingEndpoint { get; init; } = null; + + /// + /// Optional API key used as a bearer token for embedding endpoint calls. + /// + [JsonPropertyName("embedding-api-key")] + public string? EmbeddingApiKey { get; init; } = null; + + [JsonConstructor] + public RuntimeSemanticSearchOptions(string? EmbeddingEndpoint = null, string? EmbeddingApiKey = null) + { + this.EmbeddingEndpoint = EmbeddingEndpoint; + this.EmbeddingApiKey = EmbeddingApiKey; + } +} \ No newline at end of file diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index 70b559ceca..7d1265836b 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -69,6 +69,9 @@ public class RuntimeConfigValidator : IConfigValidator private const string SEMANTIC_SEARCH_REDIS_REQUIREMENT_ERR_MSG = "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured."; + private const string SEMANTIC_SEARCH_EMBEDDING_REQUIREMENT_ERR_MSG = + "Semantic search requires runtime.semantic-search.embedding-endpoint to be configured."; + private static readonly HashSet _reservedSemanticRestNames = [ "semantic_search", @@ -1069,6 +1072,14 @@ private void ValidateSemanticSearchConfiguration(RuntimeConfig runtimeConfig, st subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); } + if (!HasValidSemanticEmbeddingConfiguration(runtimeConfig)) + { + HandleOrRecordException(new DataApiBuilderException( + message: SEMANTIC_SEARCH_EMBEDDING_REQUIREMENT_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + if (semantic.RedisIndexMultiplier < 1 || semantic.RedisIndexMultiplier > 10) { HandleOrRecordException(new DataApiBuilderException( @@ -1099,6 +1110,16 @@ private static bool HasValidSemanticRedisConfiguration(RuntimeConfig runtimeConf && !string.IsNullOrWhiteSpace(level2.ConnectionString); } + /// + /// Validates embedding endpoint prerequisites for semantic-search. + /// + private static bool HasValidSemanticEmbeddingConfiguration(RuntimeConfig runtimeConfig) + { + RuntimeSemanticSearchOptions? semanticSearch = runtimeConfig.Runtime?.SemanticSearch; + return semanticSearch is not null + && !string.IsNullOrWhiteSpace(semanticSearch.EmbeddingEndpoint); + } + /// /// Ensures configured field names do not collide with semantic reserved names. /// diff --git a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs index 100759522f..69d57c86cd 100644 --- a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs @@ -30,8 +30,9 @@ public static class QueryBuilder public const string GROUP_BY_AGGREGATE_FIELD_ARG_NAME = "field"; public const string GROUP_BY_AGGREGATE_FIELD_DISTINCT_NAME = "distinct"; public const string GROUP_BY_AGGREGATE_FIELD_HAVING_NAME = "having"; - public const string SEMANTIC_SEARCH_ARGUMENT_NAME = SemanticSearchConstants.GRAPHQL_SEARCH_ARGUMENT; - public const string SEMANTIC_THRESHOLD_ARGUMENZT_NAME = SemanticSearchConstants.GRAPHQL_THRESHOLD_ARGUMENT; + public const string SEMANTIC_SEARCH_ARGUMENT_NAME = "semanticSearch"; + public const string SEMANTIC_THRESHOLD_ARGUMENT_NAME = "semanticThreshold"; + public const string SEMANTIC_DISTANCE_FIELD_NAME = "semanticDistance"; // Define the enabled database types for aggregation public static readonly HashSet AggregationEnabledDatabaseTypes = new() diff --git a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs index ef1864115d..ee3cf2dbdf 100644 --- a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs +++ b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs @@ -3652,6 +3652,45 @@ public void ValidateSemanticSearchRequiresRedisIndexNameWhenEnabled() ex.Message); } + [TestMethod] + public void ValidateSemanticSearchRequiresEmbeddingEndpointConfiguration() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Semantic search requires runtime.semantic-search.embedding-endpoint to be configured.", + ex.Message); + } + [TestMethod] public void ValidateSemanticSearchRejectsReservedFieldName() { @@ -3682,7 +3721,8 @@ public void ValidateSemanticSearchRejectsReservedFieldName() Cache: new RuntimeCacheOptions { Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") - }), + }, + SemanticSearch: new RuntimeSemanticSearchOptions(EmbeddingEndpoint: "https://example.org/embed")), Entities: new(new Dictionary { ["Product"] = entity })); DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs index 726137a53f..0b67657187 100644 --- a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -31,8 +31,6 @@ namespace Azure.DataApiBuilder.Service.Services.SemanticSearch; /// public sealed class RedisSemanticSearchService : ISemanticSearchService { - private const string EMBED_ENDPOINT_ENV = "DAB_SEMANTIC_EMBED_ENDPOINT"; - private const string EMBED_API_KEY_ENV = "DAB_SEMANTIC_EMBED_API_KEY"; private const string VECTOR_SCORE_FIELD = "__vector_score"; private const string DEFAULT_VECTOR_FIELD = "embedding"; @@ -71,7 +69,7 @@ public async Task> GetCandidatesAsync( return []; } - float[] embedding = await GetEmbeddingAsync(semanticSearchValue); + float[] embedding = await GetEmbeddingAsync(runtimeConfig, semanticSearchValue); if (embedding.Length == 0) { return []; @@ -124,21 +122,21 @@ private static byte[] ToRedisVectorBytes(float[] vector) return bytes; } - private async Task GetEmbeddingAsync(string semanticSearchValue) + private async Task GetEmbeddingAsync(RuntimeConfig runtimeConfig, string semanticSearchValue) { if (TryParseVectorText(semanticSearchValue, out float[]? parsedVector) && parsedVector is not null) { return parsedVector; } - string? endpoint = Environment.GetEnvironmentVariable(EMBED_ENDPOINT_ENV); + string? endpoint = runtimeConfig.Runtime?.SemanticSearch?.EmbeddingEndpoint; if (string.IsNullOrWhiteSpace(endpoint)) { return []; } HttpClient client = _httpClientFactory.CreateClient(); - string? apiKey = Environment.GetEnvironmentVariable(EMBED_API_KEY_ENV); + string? apiKey = runtimeConfig.Runtime?.SemanticSearch?.EmbeddingApiKey; if (!string.IsNullOrWhiteSpace(apiKey)) { client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); From 6fb288cff1292e43662a0f5bf4437993b8bd73cb Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Wed, 10 Jun 2026 12:38:31 -0700 Subject: [PATCH 05/12] Fix duplicate semantic search helper after merge --- src/Cli/Utils.cs | 88 ------------------------------------------------ 1 file changed, 88 deletions(-) diff --git a/src/Cli/Utils.cs b/src/Cli/Utils.cs index db5a699ae0..f15bb8d55a 100644 --- a/src/Cli/Utils.cs +++ b/src/Cli/Utils.cs @@ -1044,94 +1044,6 @@ public static EntityGraphQLOptions ConstructGraphQLTypeDetails(string? graphQL, }; } - /// - /// Constructs the EntitySemanticSearchOptions for Add/Update. - /// - /// EntitySemanticSearchOptions when at least one semantic option is provided, null otherwise. - public static EntitySemanticSearchOptions? ConstructSemanticSearchOptions( - string? semanticSearchEnabled, - string? semanticSearchRedisIndexName, - string? semanticSearchRedisIndexType, - string? semanticSearchRedisIndexMultiplier, - string? semanticSearchSimilarityThreshold, - string? semanticSearchInputDescription, - string? semanticSearchOutputDescription) - { - if (semanticSearchEnabled is null - && semanticSearchRedisIndexName is null - && semanticSearchRedisIndexType is null - && semanticSearchRedisIndexMultiplier is null - && semanticSearchSimilarityThreshold is null - && semanticSearchInputDescription is null - && semanticSearchOutputDescription is null) - { - return null; - } - - bool enabled = false; - if (semanticSearchEnabled is not null && !bool.TryParse(semanticSearchEnabled, out enabled)) - { - _logger.LogError("Invalid format for --semantic-search.enabled. Accepted values are true/false."); - return null; - } - - string redisIndexType = EntitySemanticSearchOptions.DEFAULT_REDIS_INDEX_TYPE; - if (!string.IsNullOrWhiteSpace(semanticSearchRedisIndexType)) - { - if (!string.Equals(semanticSearchRedisIndexType, "hash", StringComparison.OrdinalIgnoreCase) - && !string.Equals(semanticSearchRedisIndexType, "json", StringComparison.OrdinalIgnoreCase)) - { - _logger.LogError("Invalid format for --semantic-search.redis-index-type. Accepted values are hash/json."); - return null; - } - - redisIndexType = semanticSearchRedisIndexType.ToLowerInvariant(); - } - - int redisIndexMultiplier = EntitySemanticSearchOptions.DEFAULT_REDIS_INDEX_MULTIPLIER; - if (!string.IsNullOrWhiteSpace(semanticSearchRedisIndexMultiplier) - && !int.TryParse(semanticSearchRedisIndexMultiplier, out redisIndexMultiplier)) - { - _logger.LogError("Invalid format for --semantic-search.redis-index-multiplier. Accepted values are integer values in range [1,10]."); - return null; - } - - if (redisIndexMultiplier < 1 || redisIndexMultiplier > 10) - { - _logger.LogError("Invalid value for --semantic-search.redis-index-multiplier. Accepted values are integer values in range [1,10]."); - return null; - } - - double similarityThreshold = EntitySemanticSearchOptions.DEFAULT_SIMILARITY_THRESHOLD; - if (!string.IsNullOrWhiteSpace(semanticSearchSimilarityThreshold) - && !double.TryParse(semanticSearchSimilarityThreshold, out similarityThreshold)) - { - _logger.LogError("Invalid format for --semantic-search.similarity-threshold. Accepted values are decimal values in range [0.0,1.0]."); - return null; - } - - if (similarityThreshold < 0.0 || similarityThreshold > 1.0) - { - _logger.LogError("Invalid value for --semantic-search.similarity-threshold. Accepted values are decimal values in range [0.0,1.0]."); - return null; - } - - return new EntitySemanticSearchOptions - { - Enabled = enabled, - RedisIndexName = semanticSearchRedisIndexName, - RedisIndexType = redisIndexType, - RedisIndexMultiplier = redisIndexMultiplier, - SimilarityThreshold = similarityThreshold, - InputDescription = string.IsNullOrWhiteSpace(semanticSearchInputDescription) - ? EntitySemanticSearchOptions.DEFAULT_INPUT_DESCRIPTION - : semanticSearchInputDescription, - OutputDescription = string.IsNullOrWhiteSpace(semanticSearchOutputDescription) - ? EntitySemanticSearchOptions.DEFAULT_OUTPUT_DESCRIPTION - : semanticSearchOutputDescription - }; - } - /// /// Constructs the EntityMcpOptions for Add/Update. /// From eede53faf4c0842ea83f754dadb655873939172a Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Thu, 11 Jun 2026 03:43:02 -0700 Subject: [PATCH 06/12] Remove separate embeddings config within semantic search --- schemas/dab.draft.schema.json | 31 +- src/Cli.Tests/ConfigureOptionsTests.cs | 48 --- src/Cli/Commands/ConfigureOptions.cs | 10 - src/Cli/ConfigGenerator.cs | 57 ---- src/Config/ObjectModel/RuntimeOptions.cs | 4 - .../RuntimeSemanticSearchOptions.cs | 31 -- .../Configurations/RuntimeConfigValidator.cs | 9 +- .../UnitTests/ConfigValidationUnitTests.cs | 10 +- .../UnitTests/RequestParserUnitTests.cs | 66 +++++ .../UnitTests/SemanticSearchTextFlowTests.cs | 276 ++++++++++++++++++ .../RedisSemanticSearchService.cs | 105 +------ 11 files changed, 364 insertions(+), 283 deletions(-) delete mode 100644 src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs create mode 100644 src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 129356d95b..5be2a04b36 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -534,36 +534,7 @@ } } }, - "semantic-search": { - "type": "object", - "description": "Runtime semantic-search configuration.", - "additionalProperties": false, - "properties": { - "embedding-endpoint": { - "type": "string", - "description": "Endpoint used to generate embeddings from semantic-search input text." - }, - "embedding-api-key": { - "type": "string", - "description": "Optional API key used as bearer token for embedding endpoint calls." - } - } - }, - "semantic-search": { - "type": "object", - "description": "Runtime semantic-search configuration.", - "additionalProperties": false, - "properties": { - "embedding-endpoint": { - "type": "string", - "description": "Endpoint used to generate embeddings from semantic-search input text." - }, - "embedding-api-key": { - "type": "string", - "description": "Optional API key used as bearer token for embedding endpoint calls." - } - } - }, + "compression": { "type": "object", "description": "Configures HTTP response compression settings.", diff --git a/src/Cli.Tests/ConfigureOptionsTests.cs b/src/Cli.Tests/ConfigureOptionsTests.cs index 35731cac71..f1155fb74d 100644 --- a/src/Cli.Tests/ConfigureOptionsTests.cs +++ b/src/Cli.Tests/ConfigureOptionsTests.cs @@ -548,54 +548,6 @@ public void TestUpdateTTLForCacheSettings(int updatedTtlValue) Assert.AreEqual(updatedTtlValue, runtimeConfig.Runtime.Cache.TtlSeconds); } - /// - /// Tests that running "dab configure --runtime.semantic-search.embedding-endpoint {value}" updates runtime semantic-search embedding endpoint. - /// - [TestMethod] - public void TestUpdateEmbeddingEndpointForSemanticSearchRuntimeSettings() - { - // Arrange -> all the setup which includes creating options. - SetupFileSystemWithInitialConfig(INITIAL_CONFIG); - string updatedEmbeddingEndpoint = "https://example.org/embed"; - - // Act: Attempts to update embedding endpoint value - ConfigureOptions options = new( - runtimeSemanticSearchEmbeddingEndpoint: updatedEmbeddingEndpoint, - config: TEST_RUNTIME_CONFIG_FILE - ); - bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); - - // Assert: Validate the embedding endpoint value is updated - Assert.IsTrue(isSuccess); - string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); - Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? runtimeConfig)); - Assert.AreEqual(updatedEmbeddingEndpoint, runtimeConfig.Runtime?.SemanticSearch?.EmbeddingEndpoint); - } - - /// - /// Tests that running "dab configure --runtime.semantic-search.embedding-api-key {value}" updates runtime semantic-search embedding API key. - /// - [TestMethod] - public void TestUpdateEmbeddingApiKeyForSemanticSearchRuntimeSettings() - { - // Arrange -> all the setup which includes creating options. - SetupFileSystemWithInitialConfig(INITIAL_CONFIG); - string updatedEmbeddingApiKey = "test-api-key"; - - // Act: Attempts to update embedding API key value - ConfigureOptions options = new( - runtimeSemanticSearchEmbeddingApiKey: updatedEmbeddingApiKey, - config: TEST_RUNTIME_CONFIG_FILE - ); - bool isSuccess = TryConfigureSettings(options, _runtimeConfigLoader!, _fileSystem!); - - // Assert: Validate the embedding API key value is updated - Assert.IsTrue(isSuccess); - string updatedConfig = _fileSystem!.File.ReadAllText(TEST_RUNTIME_CONFIG_FILE); - Assert.IsTrue(RuntimeConfigLoader.TryParseConfig(updatedConfig, out RuntimeConfig? runtimeConfig)); - Assert.AreEqual(updatedEmbeddingApiKey, runtimeConfig.Runtime?.SemanticSearch?.EmbeddingApiKey); - } - /// /// Tests that running "dab configure --runtime.compression.level {value}" on a config with various values results /// in runtime config update. Takes in updated value for compression.level and diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 0b17ab4599..b64de19981 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -63,8 +63,6 @@ public ConfigureOptions( bool? runtimePaginationNextLinkRelative = null, string? runtimeCacheLevel2Provider = null, string? runtimeCacheLevel2ConnectionString = null, - string? runtimeSemanticSearchEmbeddingEndpoint = null, - string? runtimeSemanticSearchEmbeddingApiKey = null, CompressionLevel? runtimeCompressionLevel = null, bool? runtimeHealthEnabled = null, int? runtimeHealthCacheTtlSeconds = null, @@ -167,8 +165,6 @@ public ConfigureOptions( RuntimePaginationNextLinkRelative = runtimePaginationNextLinkRelative; RuntimeCacheLevel2Provider = runtimeCacheLevel2Provider; RuntimeCacheLevel2ConnectionString = runtimeCacheLevel2ConnectionString; - RuntimeSemanticSearchEmbeddingEndpoint = runtimeSemanticSearchEmbeddingEndpoint; - RuntimeSemanticSearchEmbeddingApiKey = runtimeSemanticSearchEmbeddingApiKey; // Compression RuntimeCompressionLevel = runtimeCompressionLevel; // Health @@ -356,12 +352,6 @@ public ConfigureOptions( [Option("runtime.cache.level-2.connection-string", Required = false, HelpText = "Set level-2 cache connection string.")] public string? RuntimeCacheLevel2ConnectionString { get; } - [Option("runtime.semantic-search.embedding-endpoint", Required = false, HelpText = "Set semantic-search embedding endpoint.")] - public string? RuntimeSemanticSearchEmbeddingEndpoint { get; } - - [Option("runtime.semantic-search.embedding-api-key", Required = false, HelpText = "Set semantic-search embedding API key.")] - public string? RuntimeSemanticSearchEmbeddingApiKey { get; } - [Option("runtime.compression.level", Required = false, HelpText = "Set the response compression level. Allowed values: optimal (default), fastest, none.")] public CompressionLevel? RuntimeCompressionLevel { get; } diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index a743825357..c513f1ce30 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -1060,22 +1060,6 @@ private static bool TryUpdateConfiguredRuntimeOptions( runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { Pagination = updatedPaginationOptions } }; } - // Semantic Search: Embedding endpoint and API key - if (options.RuntimeSemanticSearchEmbeddingEndpoint != null || - options.RuntimeSemanticSearchEmbeddingApiKey != null) - { - RuntimeSemanticSearchOptions? updatedSemanticSearchOptions = runtimeConfig?.Runtime?.SemanticSearch ?? new(); - bool status = TryUpdateConfiguredSemanticSearchValues(options, ref updatedSemanticSearchOptions); - if (status) - { - runtimeConfig = runtimeConfig! with { Runtime = runtimeConfig.Runtime! with { SemanticSearch = updatedSemanticSearchOptions } }; - } - else - { - return false; - } - } - // Compression: Level if (options.RuntimeCompressionLevel != null) { @@ -1632,47 +1616,6 @@ private static bool TryUpdateConfiguredCacheValues( } } - /// - /// Attempts to update the Config parameters in the semantic-search runtime settings based on the provided value. - /// - /// options. - /// updatedSemanticSearchOptions. - /// True if updates succeed, else false. - private static bool TryUpdateConfiguredSemanticSearchValues( - ConfigureOptions options, - ref RuntimeSemanticSearchOptions? updatedSemanticSearchOptions) - { - try - { - if (options?.RuntimeSemanticSearchEmbeddingEndpoint is not null) - { - RuntimeSemanticSearchOptions current = updatedSemanticSearchOptions ?? new(); - updatedSemanticSearchOptions = current with - { - EmbeddingEndpoint = options.RuntimeSemanticSearchEmbeddingEndpoint - }; - _logger.LogInformation("Updated RuntimeConfig with runtime.semantic-search.embedding-endpoint."); - } - - if (options?.RuntimeSemanticSearchEmbeddingApiKey is not null) - { - RuntimeSemanticSearchOptions current = updatedSemanticSearchOptions ?? new(); - updatedSemanticSearchOptions = current with - { - EmbeddingApiKey = options.RuntimeSemanticSearchEmbeddingApiKey - }; - _logger.LogInformation("Updated RuntimeConfig with runtime.semantic-search.embedding-api-key."); - } - - return true; - } - catch (Exception ex) - { - _logger.LogError("Failed to update RuntimeConfig.SemanticSearch with exception message: {exceptionMessage}.", ex.Message); - return false; - } - } - /// /// Attempts to update the Config parameters in the Compression runtime settings based on the provided value. /// Validates user-provided parameters and then returns true if the updated Compression options diff --git a/src/Config/ObjectModel/RuntimeOptions.cs b/src/Config/ObjectModel/RuntimeOptions.cs index 1b3c681fac..1d5ad86db0 100644 --- a/src/Config/ObjectModel/RuntimeOptions.cs +++ b/src/Config/ObjectModel/RuntimeOptions.cs @@ -16,8 +16,6 @@ public record RuntimeOptions public string? BaseRoute { get; init; } public TelemetryOptions? Telemetry { get; init; } public RuntimeCacheOptions? Cache { get; init; } - [JsonPropertyName("semantic-search")] - public RuntimeSemanticSearchOptions? SemanticSearch { get; init; } public PaginationOptions? Pagination { get; init; } public RuntimeHealthCheckConfig? Health { get; init; } public EmbeddingsOptions? Embeddings { get; init; } @@ -32,7 +30,6 @@ public RuntimeOptions( string? BaseRoute = null, TelemetryOptions? Telemetry = null, RuntimeCacheOptions? Cache = null, - RuntimeSemanticSearchOptions? SemanticSearch = null, PaginationOptions? Pagination = null, RuntimeHealthCheckConfig? Health = null, EmbeddingsOptions? Embeddings = null, @@ -45,7 +42,6 @@ public RuntimeOptions( this.BaseRoute = BaseRoute; this.Telemetry = Telemetry; this.Cache = Cache; - this.SemanticSearch = SemanticSearch; this.Pagination = Pagination; this.Health = Health; this.Embeddings = Embeddings; diff --git a/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs b/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs deleted file mode 100644 index 1c72667de2..0000000000 --- a/src/Config/ObjectModel/RuntimeSemanticSearchOptions.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Text.Json.Serialization; - -namespace Azure.DataApiBuilder.Config.ObjectModel; - -/// -/// Runtime semantic search configuration. -/// -public record RuntimeSemanticSearchOptions -{ - /// - /// Endpoint used to generate embeddings from semantic-search input text. - /// - [JsonPropertyName("embedding-endpoint")] - public string? EmbeddingEndpoint { get; init; } = null; - - /// - /// Optional API key used as a bearer token for embedding endpoint calls. - /// - [JsonPropertyName("embedding-api-key")] - public string? EmbeddingApiKey { get; init; } = null; - - [JsonConstructor] - public RuntimeSemanticSearchOptions(string? EmbeddingEndpoint = null, string? EmbeddingApiKey = null) - { - this.EmbeddingEndpoint = EmbeddingEndpoint; - this.EmbeddingApiKey = EmbeddingApiKey; - } -} \ No newline at end of file diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index 7d1265836b..bb210b1d9f 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -70,7 +70,7 @@ public class RuntimeConfigValidator : IConfigValidator "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured."; private const string SEMANTIC_SEARCH_EMBEDDING_REQUIREMENT_ERR_MSG = - "Semantic search requires runtime.semantic-search.embedding-endpoint to be configured."; + "Semantic search requires runtime.embeddings to be configured and enabled."; private static readonly HashSet _reservedSemanticRestNames = [ @@ -1111,13 +1111,12 @@ private static bool HasValidSemanticRedisConfiguration(RuntimeConfig runtimeConf } /// - /// Validates embedding endpoint prerequisites for semantic-search. + /// Validates embeddings prerequisites for semantic-search. /// private static bool HasValidSemanticEmbeddingConfiguration(RuntimeConfig runtimeConfig) { - RuntimeSemanticSearchOptions? semanticSearch = runtimeConfig.Runtime?.SemanticSearch; - return semanticSearch is not null - && !string.IsNullOrWhiteSpace(semanticSearch.EmbeddingEndpoint); + EmbeddingsOptions? embeddings = runtimeConfig.Runtime?.Embeddings; + return embeddings is not null && embeddings.Enabled; } /// diff --git a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs index ee3cf2dbdf..6546f41821 100644 --- a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs +++ b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs @@ -3653,7 +3653,7 @@ public void ValidateSemanticSearchRequiresRedisIndexNameWhenEnabled() } [TestMethod] - public void ValidateSemanticSearchRequiresEmbeddingEndpointConfiguration() + public void ValidateSemanticSearchRequiresEmbeddingsConfiguration() { RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); @@ -3687,7 +3687,7 @@ public void ValidateSemanticSearchRequiresEmbeddingEndpointConfiguration() DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); Assert.AreEqual( - "Semantic search requires runtime.semantic-search.embedding-endpoint to be configured.", + "Semantic search requires runtime.embeddings to be configured and enabled.", ex.Message); } @@ -3722,7 +3722,11 @@ public void ValidateSemanticSearchRejectsReservedFieldName() { Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") }, - SemanticSearch: new RuntimeSemanticSearchOptions(EmbeddingEndpoint: "https://example.org/embed")), + Embeddings: new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true)), Entities: new(new Dictionary { ["Product"] = entity })); DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); diff --git a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs index 0aeb309c61..8796e953bb 100644 --- a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs @@ -1,8 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System; using Azure.DataApiBuilder.Core.Parsers; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Services; +using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; namespace Azure.DataApiBuilder.Service.Tests.UnitTests { @@ -14,6 +20,9 @@ namespace Azure.DataApiBuilder.Service.Tests.UnitTests [TestClass] public class RequestParserUnitTests { + private const string DEFAULT_ENTITY = "Book"; + private const string DEFAULT_SCHEMA = "dbo"; + /// /// Tests that ExtractRawQueryParameter correctly extracts URL-encoded /// parameter values, preserving special characters like ampersand (&). @@ -76,5 +85,62 @@ public void ExtractRawQueryParameter_HandlesEdgeCases(string queryString, string Assert.AreEqual(expectedValue, result, $"Expected '{expectedValue}' but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); } + + [TestMethod] + public void ParseQueryString_SetsSemanticInputs_ForUserRequest() + { + FindRequestContext context = new( + entityName: DEFAULT_ENTITY, + dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), + isList: true); + + context.RawQueryString = "?$semantic_search=wireless%20headphones&$semantic_threshold=0.83"; + context.ParsedQueryString.Add(RequestParser.SEMANTIC_SEARCH_URL, "wireless headphones"); + context.ParsedQueryString.Add(RequestParser.SEMANTIC_THRESHOLD_URL, "0.83"); + + RequestParser.ParseQueryString(context, new Mock().Object); + + Assert.AreEqual("wireless headphones", context.SemanticSearch); + Assert.AreEqual(0.83, context.SemanticThreshold); + } + + [DataTestMethod] + [DataRow("-0.01")] + [DataRow("1.01")] + [DataRow("not-a-number")] + public void ParseQueryString_RejectsInvalidSemanticThreshold_ForUserRequest(string threshold) + { + FindRequestContext context = new( + entityName: DEFAULT_ENTITY, + dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), + isList: true); + + context.RawQueryString = $"?$semantic_threshold={threshold}"; + context.ParsedQueryString.Add(RequestParser.SEMANTIC_THRESHOLD_URL, threshold); + + DataApiBuilderException ex = Assert.ThrowsException( + () => RequestParser.ParseQueryString(context, new Mock().Object)); + + StringAssert.Contains(ex.Message, "semantic_threshold must be a decimal value between 0.0 and 1.0."); + } + + [DataTestMethod] + [DataRow("?$orderby=semantic_distance%20asc")] + [DataRow("?$orderby=Semantic_Distance%20desc")] + public void ParseQueryString_RejectsOrderBySemanticDistance_ForUserRequest(string rawQuery) + { + FindRequestContext context = new( + entityName: DEFAULT_ENTITY, + dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), + isList: true); + + context.RawQueryString = rawQuery; + context.ParsedQueryString.Add(RequestParser.SORT_URL, rawQuery.Contains("desc", StringComparison.OrdinalIgnoreCase) ? "Semantic_Distance desc" : "semantic_distance asc"); + + DataApiBuilderException ex = Assert.ThrowsException( + () => RequestParser.ParseQueryString(context, new Mock().Object)); + + StringAssert.Contains(ex.Message, "semantic_distance cannot be used in orderBy."); + } } } diff --git a/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs b/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs new file mode 100644 index 0000000000..228be7113e --- /dev/null +++ b/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Threading; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; + +#nullable enable + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests; + +/// +/// Integration tests for the semantic search TEXT-to-EMBEDDING flow. +/// Tests that customers provide TEXT input (not embeddings) to $semantic_search, +/// and DAB properly converts it through the embedding API (with cache checks) +/// then performs vector search and database queries. +/// +[TestClass] +public class SemanticSearchTextFlowTests +{ + /// + /// FLOW VALIDATION 1: Confirm TEXT input is accepted (not just embeddings) + /// Expected: DAB should accept plain text like "laptop" in $semantic_search + /// Current Implementation: RedisSemanticSearchService.GetEmbeddingAsync() tries TryParseVectorText + /// If that fails (plain text), it calls EmbeddingService.TryEmbedAsync(text) + /// + [TestMethod] + public void SemanticSearchFlow_ValidatesTextAcceptance() + { + // The implementation validates that: + // 1. USER provides TEXT: "laptop computers" (not [0.8, 0.1, 0.05, 0.05]) + // 2. DAB extracts this text from $semantic_search parameter + // 3. RedisSemanticSearchService receives the text + // 4. Calls GetEmbeddingAsync("laptop computers") + // 5. TryParseVectorText("laptop computers") → returns false (not JSON array) + // 6. Falls back to EmbeddingService.TryEmbedAsync("laptop computers") + + string textInput = "laptop computers"; + + // This validates that text input is NOT treated as embeddings + Assert.IsFalse(textInput.StartsWith("["), "User input should not start with ["); + Assert.IsTrue(textInput.Contains("laptop"), "Should contain user text"); + } + + /// + /// FLOW VALIDATION 2: Confirm EmbedAsync is called for TEXT input + /// Expected: When GetEmbeddingAsync receives TEXT, it calls EmbedAsync (not direct vector use) + /// Current Implementation: RedisSemanticSearchService.GetEmbeddingAsync() line 131 calls + /// _embeddingService.TryEmbedAsync(semanticSearchValue) + /// + [TestMethod] + public void SemanticSearchFlow_CallsEmbedAsyncForText() + { + // Implementation confirms: + // In RedisSemanticSearchService.GetEmbeddingAsync(): + // 1. Line 126: TryParseVectorText(semanticSearchValue, out float[]? parsedVector) + // 2. Line 127: if succeeds, return parsedVector (shortcut for direct vectors) + // 3. Line 131: else, call _embeddingService.TryEmbedAsync(semanticSearchValue) + // 4. Line 132-134: Get result from embedding service + + // This means: + // TEXT input → TryEmbedAsync → embedding API + // VECTOR input (JSON array) → direct use (shortcut) + + var mockEmbeddingService = new Mock(); + float[] expectedEmbedding = [0.8f, 0.1f, 0.05f, 0.05f]; + + mockEmbeddingService + .Setup(s => s.TryEmbedAsync("laptop", It.IsAny())) + .ReturnsAsync(new EmbeddingResult(Success: true, Embedding: expectedEmbedding, ErrorMessage: null)); + + // This validates the concept - EmbedAsync is set up to be called for TEXT + Assert.IsNotNull(mockEmbeddingService); + Assert.AreEqual(4, expectedEmbedding.Length); + } + + /// + /// FLOW VALIDATION 3: Confirm EmbedAsync checks caches (L1 & L2) + /// Expected: EmbeddingService.EmbedAsync() checks L1 cache, then L2 Redis cache, + /// then calls API only if not in either cache + /// Current Implementation: EmbeddingService.EmbedWithCacheInfoAsync() (line 433) + /// calls _cache.GetOrSetAsync() which uses FusionCache with L1 & L2 support + /// + [TestMethod] + public void SemanticSearchFlow_EmbedAsyncChecksCaches() + { + // Implementation confirms: + // In EmbeddingService.EmbedWithCacheInfoAsync() (line 433): + // 1. Line 436: string cacheKey = CreateCacheKey(text) + // - Creates SHA256 hash of text to ensure deterministic cache key + // - Format: "embedding:{provider}:{model}:{sha256_hash}" + // 2. Line 438: _cache.GetOrSetAsync(key: cacheKey, async (ctx, ct) => { ... }) + // - FusionCache.GetOrSetAsync checks L1 (in-memory) cache first + // - If L1 miss and L2 is configured, checks L2 (Redis) + // - If both miss, executes the lambda to call embedding API + // 3. Line 449: _cache.GetOrSetAsync() returns cached or newly computed result + + // Cache flow: + // First call with "laptop": Check L1 → miss → Check L2 → miss → Call API → Cache in L1 & L2 + // Second call with "laptop": Check L1 → hit → return cached embedding (no API call!) + + const string textInput = "high performance laptop"; + + // CreateCacheKey format verification + // The cache key will be: "embedding:{provider}:{model}:{sha256_hash_of_text}" + // This ensures same text always produces same cache key + + Assert.IsFalse(string.IsNullOrEmpty(textInput), "Text input should not be empty"); + } + + /// + /// FLOW VALIDATION 4: Confirm Redis vector search uses embeddings + /// Expected: After getting embeddings from EmbedAsync, use them for Redis FT.SEARCH + /// Current Implementation: RedisSemanticSearchService.GetCandidatesAsync() line 74 + /// gets embedding, then line 81 converts to bytes, then line 83 executes FT.SEARCH + /// + [TestMethod] + public void SemanticSearchFlow_UsesEmbeddingsForRedisVectorSearch() + { + // Implementation flow: + // 1. Line 74: float[] embedding = await GetEmbeddingAsync(semanticSearchValue) + // - Takes TEXT input "laptop" + // - Returns embedding vector [0.8, 0.1, 0.05, 0.05] + // 2. Line 81: byte[] vectorBytes = ToRedisVectorBytes(embedding) + // - Converts float array to bytes for Redis + // 3. Line 83-92: db.ExecuteAsync("FT.SEARCH", ...) + // - Executes Redis FT.SEARCH with KNN (k-nearest neighbors) + // - Uses the embedding vector for similarity matching + // - Returns matching documents with similarity scores + + float[] embedding = [0.8f, 0.1f, 0.05f, 0.05f]; + + // Validate byte conversion + byte[] vectorBytes = new byte[embedding.Length * sizeof(float)]; + System.Buffer.BlockCopy(embedding, 0, vectorBytes, 0, vectorBytes.Length); + + Assert.AreEqual(16, vectorBytes.Length); // 4 floats * 4 bytes each = 16 bytes + } + + /// + /// FLOW VALIDATION 5: Confirm Redis results are converted to SemanticSearchCandidate + /// Expected: Redis returns docs, ParseCandidates extracts primary keys and column values + /// Current Implementation: RedisSemanticSearchService.ParseCandidates() (line 270) + /// + [TestMethod] + public void SemanticSearchFlow_ParsesCandidatesFromRedisResults() + { + // Implementation confirms: + // In RedisSemanticSearchService.ParseCandidates() (line 270): + // Iterates through Redis results and creates SemanticSearchCandidate objects + // + // SemanticSearchCandidate record (line 10): + // public record SemanticSearchCandidate( + // IReadOnlyDictionary PrimaryKeyValues, // {id: 1} + // IReadOnlyDictionary ColumnValues, // {name: "Laptop"} + // double Distance); // 0.92 + + var mockCandidate = new SemanticSearchCandidate( + new Dictionary { { "id", 1 } }.AsReadOnly(), + new Dictionary { { "name", "Laptop" } }.AsReadOnly(), + 0.92 + ); + + Assert.IsNotNull(mockCandidate.PrimaryKeyValues); + Assert.IsNotNull(mockCandidate.ColumnValues); + Assert.AreEqual(0.92, mockCandidate.Distance); + Assert.IsTrue(mockCandidate.PrimaryKeyValues.ContainsKey("id")); + Assert.IsTrue(mockCandidate.ColumnValues.ContainsKey("name")); + } + + /// + /// FLOW VALIDATION 6: Confirm database query uses semantic candidates + /// Expected: SqlQueryEngine.ApplySemanticCandidates() adds WHERE clause with primary keys + /// Current Implementation: SqlQueryEngine.ApplySemanticCandidates() applies candidates + /// + [TestMethod] + public void SemanticSearchFlow_DatabaseQueryUsesCandidates() + { + // Implementation flow in SqlQueryEngine (line 337): + // 1. structure.ApplySemanticCandidates(narrowedCandidates) + // - Adds SemanticSearchCandidate objects to SqlQueryStructure + // 2. This modifies the query to add WHERE clause: + // SELECT id, name, description, semantic_distance + // FROM Products + // WHERE id IN (1, 2, 3, ...) // primary keys from Redis results + // 3. Database is queried only for rows that matched semantic search + // 4. Result rows are enriched with semantic_distance from stored distances + + var candidate1 = new SemanticSearchCandidate( + new Dictionary { { "id", 1 } }.AsReadOnly(), + new Dictionary { { "name", "Laptop" } }.AsReadOnly(), + 0.92 + ); + + var candidate2 = new SemanticSearchCandidate( + new Dictionary { { "id", 2 } }.AsReadOnly(), + new Dictionary { { "name", "Desktop" } }.AsReadOnly(), + 0.85 + ); + + List candidates = [candidate1, candidate2]; + + Assert.AreEqual(2, candidates.Count); + Assert.AreEqual(1, candidate1.PrimaryKeyValues["id"]); + Assert.AreEqual(2, candidate2.PrimaryKeyValues["id"]); + } + + /// + /// COMPLETE FLOW SUMMARY: TEXT → EMBEDDING → REDIS → DB → RESPONSE + /// + /// Request: GET /api/Products?$semantic_search=laptop&$semantic_threshold=0.7 + /// + /// Step 1: RequestParser.ParseQueryString() + /// - Extracts $semantic_search="laptop" (TEXT, not embeddings) + /// - Extracts $semantic_threshold=0.7 + /// - Stores in FindRequestContext + /// + /// Step 2: RestService.ExecuteAsync(FindRequestContext) + /// - Calls SqlQueryEngine.TryPopulateSemanticSearchInformation() + /// + /// Step 3: SqlQueryEngine.TryPopulateSemanticSearchInformation() + /// - Calls RedisSemanticSearchService.GetCandidatesAsync("laptop", 0.7) + /// + /// Step 4: RedisSemanticSearchService.GetCandidatesAsync() + /// - Calls GetEmbeddingAsync("laptop") + /// - TryParseVectorText("laptop") → fails (plain text) + /// - Calls EmbeddingService.TryEmbedAsync("laptop") + /// - EmbeddingService checks L1 cache → miss + /// - EmbeddingService checks L2 Redis cache → miss + /// - EmbeddingService calls embedding API → [0.8, 0.1, 0.05, 0.05] + /// - Stores in L1 cache + /// - Stores in L2 Redis cache + /// - Returns embedding to GetCandidatesAsync + /// - Converts embedding to bytes + /// - Executes Redis FT.SEARCH with KNN vector search + /// - Returns SemanticSearchCandidate objects + /// + /// Step 5: SqlQueryEngine.ApplySemanticCandidates() + /// - Adds candidates to query structure + /// - Modifies query to WHERE id IN (primary_keys_from_redis) + /// - Stores semantic_distance for each candidate + /// + /// Step 6: Database Query Execution + /// - SELECT id, name, description FROM Products WHERE id IN (...) + /// - Returns rows that matched semantic search + /// + /// Step 7: Response Enrichment + /// - Adds semantic_distance field to each row from stored distances + /// - Optionally orders by semantic_distance + /// + /// Response: [ + /// { "id": 1, "name": "Laptop", "description": "...", "semantic_distance": 0.92 } + /// ] + /// + [TestMethod] + public void SemanticSearchFlow_CompleteEndToEndFlow() + { + // This test validates the CONCEPT of the complete flow + // The actual integration test runs the full flow with real Redis + DB + + string userTextInput = "laptop computers"; + double threshold = 0.7; + + // Simulate the flow steps + Assert.IsFalse(userTextInput.StartsWith("["), "TEXT input (not embeddings)"); + Assert.AreEqual("laptop computers", userTextInput, "TEXT is preserved"); + Assert.AreEqual(0.7, threshold, "Threshold is preserved"); + + // After embedding, vector search, DB query, response building: + // Response would contain results with semantic_distance field + } +} + diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs index 0b67657187..ffd30b68f3 100644 --- a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -4,20 +4,17 @@ using System; using System.Collections.Generic; using System.Globalization; -using System.Net.Http; -using System.Net.Http.Headers; -using System.Text; using System.Text.Json; using System.Threading.Tasks; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Exceptions; using Azure.Identity; -using Microsoft.Azure.StackExchangeRedis; using Microsoft.Extensions.Logging; using StackExchange.Redis; @@ -36,18 +33,18 @@ public sealed class RedisSemanticSearchService : ISemanticSearchService private readonly RuntimeConfigProvider _runtimeConfigProvider; private readonly IMetadataProviderFactory _metadataProviderFactory; - private readonly IHttpClientFactory _httpClientFactory; + private readonly IEmbeddingService? _embeddingService; private readonly ILogger _logger; public RedisSemanticSearchService( RuntimeConfigProvider runtimeConfigProvider, IMetadataProviderFactory metadataProviderFactory, - IHttpClientFactory httpClientFactory, + IEmbeddingService? embeddingService, ILogger logger) { _runtimeConfigProvider = runtimeConfigProvider; _metadataProviderFactory = metadataProviderFactory; - _httpClientFactory = httpClientFactory; + _embeddingService = embeddingService; _logger = logger; } @@ -69,7 +66,7 @@ public async Task> GetCandidatesAsync( return []; } - float[] embedding = await GetEmbeddingAsync(runtimeConfig, semanticSearchValue); + float[] embedding = await GetEmbeddingAsync(semanticSearchValue); if (embedding.Length == 0) { return []; @@ -122,40 +119,25 @@ private static byte[] ToRedisVectorBytes(float[] vector) return bytes; } - private async Task GetEmbeddingAsync(RuntimeConfig runtimeConfig, string semanticSearchValue) + private async Task GetEmbeddingAsync(string semanticSearchValue) { if (TryParseVectorText(semanticSearchValue, out float[]? parsedVector) && parsedVector is not null) { return parsedVector; } - string? endpoint = runtimeConfig.Runtime?.SemanticSearch?.EmbeddingEndpoint; - if (string.IsNullOrWhiteSpace(endpoint)) + if (_embeddingService is null || !_embeddingService.IsEnabled) { return []; } - HttpClient client = _httpClientFactory.CreateClient(); - string? apiKey = runtimeConfig.Runtime?.SemanticSearch?.EmbeddingApiKey; - if (!string.IsNullOrWhiteSpace(apiKey)) - { - client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); - } - - string payload = JsonSerializer.Serialize(new Dictionary { { "input", semanticSearchValue } }); - using HttpRequestMessage request = new(HttpMethod.Post, endpoint) - { - Content = new StringContent(payload, Encoding.UTF8, "application/json") - }; - - using HttpResponseMessage response = await client.SendAsync(request); - if (!response.IsSuccessStatusCode) + EmbeddingResult result = await _embeddingService.TryEmbedAsync(semanticSearchValue); + if (!result.Success || result.Embedding is null) { return []; } - string body = await response.Content.ReadAsStringAsync(); - return TryExtractEmbedding(body, out float[]? embedding) && embedding is not null ? embedding : []; + return result.Embedding; } private static bool TryParseVectorText(string text, out float[]? vector) @@ -190,73 +172,6 @@ private static bool TryParseVectorText(string text, out float[]? vector) } } - private static bool TryExtractEmbedding(string responseBody, out float[]? embedding) - { - embedding = null; - - try - { - using JsonDocument json = JsonDocument.Parse(responseBody); - - if (json.RootElement.ValueKind is JsonValueKind.Array) - { - return TryReadEmbeddingArray(json.RootElement, out embedding); - } - - if (json.RootElement.ValueKind is JsonValueKind.Object) - { - if (json.RootElement.TryGetProperty("embedding", out JsonElement directEmbedding) - && TryReadEmbeddingArray(directEmbedding, out embedding)) - { - return true; - } - - if (json.RootElement.TryGetProperty("data", out JsonElement data) - && data.ValueKind is JsonValueKind.Array - && data.GetArrayLength() > 0) - { - JsonElement first = data[0]; - if (first.ValueKind is JsonValueKind.Object - && first.TryGetProperty("embedding", out JsonElement nestedEmbedding) - && TryReadEmbeddingArray(nestedEmbedding, out embedding)) - { - return true; - } - } - } - - return false; - } - catch - { - return false; - } - } - - private static bool TryReadEmbeddingArray(JsonElement array, out float[]? embedding) - { - embedding = null; - - if (array.ValueKind is not JsonValueKind.Array) - { - return false; - } - - List values = []; - foreach (JsonElement item in array.EnumerateArray()) - { - if (!item.TryGetSingle(out float current)) - { - return false; - } - - values.Add(current); - } - - embedding = values.ToArray(); - return embedding.Length > 0; - } - private static async Task CreateConnectionMultiplexerAsync(string connectionString) { ConfigurationOptions options = ConfigurationOptions.Parse(connectionString); From 3cbca26066fcd57bda385ffc243ab072f8b58732 Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Thu, 11 Jun 2026 04:35:39 -0700 Subject: [PATCH 07/12] Fix semantic search flow tests dictionary readonly usage --- .../UnitTests/SemanticSearchTextFlowTests.cs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs b/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs index 228be7113e..5aae7326ea 100644 --- a/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs +++ b/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs @@ -159,8 +159,8 @@ public void SemanticSearchFlow_ParsesCandidatesFromRedisResults() // double Distance); // 0.92 var mockCandidate = new SemanticSearchCandidate( - new Dictionary { { "id", 1 } }.AsReadOnly(), - new Dictionary { { "name", "Laptop" } }.AsReadOnly(), + new Dictionary { { "id", 1 } }, + new Dictionary { { "name", "Laptop" } }, 0.92 ); @@ -190,14 +190,14 @@ public void SemanticSearchFlow_DatabaseQueryUsesCandidates() // 4. Result rows are enriched with semantic_distance from stored distances var candidate1 = new SemanticSearchCandidate( - new Dictionary { { "id", 1 } }.AsReadOnly(), - new Dictionary { { "name", "Laptop" } }.AsReadOnly(), + new Dictionary { { "id", 1 } }, + new Dictionary { { "name", "Laptop" } }, 0.92 ); var candidate2 = new SemanticSearchCandidate( - new Dictionary { { "id", 2 } }.AsReadOnly(), - new Dictionary { { "name", "Desktop" } }.AsReadOnly(), + new Dictionary { { "id", 2 } }, + new Dictionary { { "name", "Desktop" } }, 0.85 ); From 601d837739c751edba0641e794798ae06fa6f406 Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Thu, 11 Jun 2026 06:42:19 -0700 Subject: [PATCH 08/12] Fix test failure --- .../Services/SemanticSearch/RedisSemanticSearchService.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs index ffd30b68f3..bc02aeb259 100644 --- a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Globalization; +using System.Linq; using System.Text.Json; using System.Threading.Tasks; using Azure.DataApiBuilder.Config.DatabasePrimitives; @@ -39,12 +40,12 @@ public sealed class RedisSemanticSearchService : ISemanticSearchService public RedisSemanticSearchService( RuntimeConfigProvider runtimeConfigProvider, IMetadataProviderFactory metadataProviderFactory, - IEmbeddingService? embeddingService, + IEnumerable embeddingServices, ILogger logger) { _runtimeConfigProvider = runtimeConfigProvider; _metadataProviderFactory = metadataProviderFactory; - _embeddingService = embeddingService; + _embeddingService = embeddingServices.FirstOrDefault(); _logger = logger; } From f3521330b1125176718499676962212548a8808b Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Thu, 11 Jun 2026 07:10:44 -0700 Subject: [PATCH 09/12] Fix PR comments Part 1 --- src/Core/Models/SemanticSearchCandidate.cs | 2 +- src/Core/Parsers/RequestParser.cs | 24 +- src/Core/Resolvers/SqlQueryEngine.cs | 6 +- src/Core/Services/RestService.cs | 6 - .../UnitTests/RequestParserUnitTests.cs | 17 ++ .../UnitTests/SemanticSearchTextFlowTests.cs | 276 ------------------ .../RedisSemanticSearchService.cs | 47 ++- 7 files changed, 89 insertions(+), 289 deletions(-) delete mode 100644 src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs diff --git a/src/Core/Models/SemanticSearchCandidate.cs b/src/Core/Models/SemanticSearchCandidate.cs index 65251ffcfd..0c8497364d 100644 --- a/src/Core/Models/SemanticSearchCandidate.cs +++ b/src/Core/Models/SemanticSearchCandidate.cs @@ -10,4 +10,4 @@ namespace Azure.DataApiBuilder.Core.Models; public record SemanticSearchCandidate( IReadOnlyDictionary PrimaryKeyValues, IReadOnlyDictionary ColumnValues, - double Distance); + double Similarity); diff --git a/src/Core/Parsers/RequestParser.cs b/src/Core/Parsers/RequestParser.cs index 90d3eb3c58..0870077dc9 100644 --- a/src/Core/Parsers/RequestParser.cs +++ b/src/Core/Parsers/RequestParser.cs @@ -147,7 +147,7 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); } - if (rawSortValue.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase)) + if (ContainsSemanticDistanceOrderByToken(rawSortValue)) { throw new DataApiBuilderException( message: "semantic_distance cannot be used in orderBy.", @@ -188,6 +188,28 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv } } + private static bool ContainsSemanticDistanceOrderByToken(string rawSortValue) + { + string decoded = Uri.UnescapeDataString(rawSortValue); + + foreach (string part in decoded.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) + { + string[] tokens = part.Split(' ', StringSplitOptions.RemoveEmptyEntries); + if (tokens.Length == 0) + { + continue; + } + + string columnToken = tokens[0].Trim('\'', '"'); + if (string.Equals(columnToken, SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + + return false; + } + /// /// Create List of OrderByColumn from an OrderByClause Abstract Syntax Tree /// and return that list as List since OrderByColumn is a Column. diff --git a/src/Core/Resolvers/SqlQueryEngine.cs b/src/Core/Resolvers/SqlQueryEngine.cs index 12d804328f..47b7daaaa1 100644 --- a/src/Core/Resolvers/SqlQueryEngine.cs +++ b/src/Core/Resolvers/SqlQueryEngine.cs @@ -319,12 +319,12 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), return (true, emptyResponse); } - // De-duplicate candidate keys while keeping the highest distance. + // De-duplicate candidate keys while keeping the highest similarity. Dictionary deduped = new(StringComparer.Ordinal); foreach (SemanticSearchCandidate candidate in candidates) { string signature = BuildPrimaryKeySignatureFromValues(primaryKeyColumns, candidate.PrimaryKeyValues); - if (!deduped.TryGetValue(signature, out SemanticSearchCandidate? existing) || candidate.Distance > existing.Distance) + if (!deduped.TryGetValue(signature, out SemanticSearchCandidate? existing) || candidate.Similarity > existing.Similarity) { deduped[signature] = candidate; } @@ -334,7 +334,7 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), structure.ApplySemanticCandidates(narrowedCandidates); foreach (KeyValuePair kvp in deduped) { - structure.SemanticDistanceByPrimaryKeySignature[kvp.Key] = kvp.Value.Distance; + structure.SemanticDistanceByPrimaryKeySignature[kvp.Key] = kvp.Value.Similarity; } return (false, null); diff --git a/src/Core/Services/RestService.cs b/src/Core/Services/RestService.cs index 5c6705f0d5..c70d9b619a 100644 --- a/src/Core/Services/RestService.cs +++ b/src/Core/Services/RestService.cs @@ -236,12 +236,6 @@ private async Task DispatchQuery(RestRequestContext context, Data if (context is FindRequestContext findRequestContext) { - if (findRequestContext.IncludeSemanticDistanceInResponse - && !findRequestContext.FieldsToBeReturned.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD)) - { - findRequestContext.FieldsToBeReturned.Add(SemanticSearchConstants.REST_DISTANCE_FIELD); - } - using JsonDocument? restApiResponse = await queryEngine.ExecuteAsync(findRequestContext); return restApiResponse is null ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), findRequestContext, _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName), _runtimeConfigProvider.GetConfig(), GetHttpContext()) : SqlResponseHelpers.FormatFindResult(restApiResponse.RootElement.Clone(), findRequestContext, _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName), _runtimeConfigProvider.GetConfig(), GetHttpContext()); diff --git a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs index 8796e953bb..5508939d19 100644 --- a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Reflection; using Azure.DataApiBuilder.Core.Parsers; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Core.Models; @@ -142,5 +143,21 @@ public void ParseQueryString_RejectsOrderBySemanticDistance_ForUserRequest(strin StringAssert.Contains(ex.Message, "semantic_distance cannot be used in orderBy."); } + + [DataTestMethod] + [DataRow("semantic_distance asc", true)] + [DataRow("Semantic_Distance desc", true)] + [DataRow("semantic_distance_score asc", false)] + [DataRow("title asc,semantic_distance_score desc", false)] + public void ContainsSemanticDistanceOrderByToken_UsesExactColumnMatch(string rawSortValue, bool expected) + { + MethodInfo method = typeof(RequestParser).GetMethod( + "ContainsSemanticDistanceOrderByToken", + BindingFlags.NonPublic | BindingFlags.Static); + + Assert.IsNotNull(method, "Expected private helper method to exist."); + bool actual = (bool)method!.Invoke(null, new object[] { rawSortValue })!; + Assert.AreEqual(expected, actual); + } } } diff --git a/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs b/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs deleted file mode 100644 index 5aae7326ea..0000000000 --- a/src/Service.Tests/UnitTests/SemanticSearchTextFlowTests.cs +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Collections.Generic; -using System.Threading; -using Azure.DataApiBuilder.Core.Models; -using Azure.DataApiBuilder.Core.Services.Embeddings; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using Moq; - -#nullable enable - -namespace Azure.DataApiBuilder.Service.Tests.UnitTests; - -/// -/// Integration tests for the semantic search TEXT-to-EMBEDDING flow. -/// Tests that customers provide TEXT input (not embeddings) to $semantic_search, -/// and DAB properly converts it through the embedding API (with cache checks) -/// then performs vector search and database queries. -/// -[TestClass] -public class SemanticSearchTextFlowTests -{ - /// - /// FLOW VALIDATION 1: Confirm TEXT input is accepted (not just embeddings) - /// Expected: DAB should accept plain text like "laptop" in $semantic_search - /// Current Implementation: RedisSemanticSearchService.GetEmbeddingAsync() tries TryParseVectorText - /// If that fails (plain text), it calls EmbeddingService.TryEmbedAsync(text) - /// - [TestMethod] - public void SemanticSearchFlow_ValidatesTextAcceptance() - { - // The implementation validates that: - // 1. USER provides TEXT: "laptop computers" (not [0.8, 0.1, 0.05, 0.05]) - // 2. DAB extracts this text from $semantic_search parameter - // 3. RedisSemanticSearchService receives the text - // 4. Calls GetEmbeddingAsync("laptop computers") - // 5. TryParseVectorText("laptop computers") → returns false (not JSON array) - // 6. Falls back to EmbeddingService.TryEmbedAsync("laptop computers") - - string textInput = "laptop computers"; - - // This validates that text input is NOT treated as embeddings - Assert.IsFalse(textInput.StartsWith("["), "User input should not start with ["); - Assert.IsTrue(textInput.Contains("laptop"), "Should contain user text"); - } - - /// - /// FLOW VALIDATION 2: Confirm EmbedAsync is called for TEXT input - /// Expected: When GetEmbeddingAsync receives TEXT, it calls EmbedAsync (not direct vector use) - /// Current Implementation: RedisSemanticSearchService.GetEmbeddingAsync() line 131 calls - /// _embeddingService.TryEmbedAsync(semanticSearchValue) - /// - [TestMethod] - public void SemanticSearchFlow_CallsEmbedAsyncForText() - { - // Implementation confirms: - // In RedisSemanticSearchService.GetEmbeddingAsync(): - // 1. Line 126: TryParseVectorText(semanticSearchValue, out float[]? parsedVector) - // 2. Line 127: if succeeds, return parsedVector (shortcut for direct vectors) - // 3. Line 131: else, call _embeddingService.TryEmbedAsync(semanticSearchValue) - // 4. Line 132-134: Get result from embedding service - - // This means: - // TEXT input → TryEmbedAsync → embedding API - // VECTOR input (JSON array) → direct use (shortcut) - - var mockEmbeddingService = new Mock(); - float[] expectedEmbedding = [0.8f, 0.1f, 0.05f, 0.05f]; - - mockEmbeddingService - .Setup(s => s.TryEmbedAsync("laptop", It.IsAny())) - .ReturnsAsync(new EmbeddingResult(Success: true, Embedding: expectedEmbedding, ErrorMessage: null)); - - // This validates the concept - EmbedAsync is set up to be called for TEXT - Assert.IsNotNull(mockEmbeddingService); - Assert.AreEqual(4, expectedEmbedding.Length); - } - - /// - /// FLOW VALIDATION 3: Confirm EmbedAsync checks caches (L1 & L2) - /// Expected: EmbeddingService.EmbedAsync() checks L1 cache, then L2 Redis cache, - /// then calls API only if not in either cache - /// Current Implementation: EmbeddingService.EmbedWithCacheInfoAsync() (line 433) - /// calls _cache.GetOrSetAsync() which uses FusionCache with L1 & L2 support - /// - [TestMethod] - public void SemanticSearchFlow_EmbedAsyncChecksCaches() - { - // Implementation confirms: - // In EmbeddingService.EmbedWithCacheInfoAsync() (line 433): - // 1. Line 436: string cacheKey = CreateCacheKey(text) - // - Creates SHA256 hash of text to ensure deterministic cache key - // - Format: "embedding:{provider}:{model}:{sha256_hash}" - // 2. Line 438: _cache.GetOrSetAsync(key: cacheKey, async (ctx, ct) => { ... }) - // - FusionCache.GetOrSetAsync checks L1 (in-memory) cache first - // - If L1 miss and L2 is configured, checks L2 (Redis) - // - If both miss, executes the lambda to call embedding API - // 3. Line 449: _cache.GetOrSetAsync() returns cached or newly computed result - - // Cache flow: - // First call with "laptop": Check L1 → miss → Check L2 → miss → Call API → Cache in L1 & L2 - // Second call with "laptop": Check L1 → hit → return cached embedding (no API call!) - - const string textInput = "high performance laptop"; - - // CreateCacheKey format verification - // The cache key will be: "embedding:{provider}:{model}:{sha256_hash_of_text}" - // This ensures same text always produces same cache key - - Assert.IsFalse(string.IsNullOrEmpty(textInput), "Text input should not be empty"); - } - - /// - /// FLOW VALIDATION 4: Confirm Redis vector search uses embeddings - /// Expected: After getting embeddings from EmbedAsync, use them for Redis FT.SEARCH - /// Current Implementation: RedisSemanticSearchService.GetCandidatesAsync() line 74 - /// gets embedding, then line 81 converts to bytes, then line 83 executes FT.SEARCH - /// - [TestMethod] - public void SemanticSearchFlow_UsesEmbeddingsForRedisVectorSearch() - { - // Implementation flow: - // 1. Line 74: float[] embedding = await GetEmbeddingAsync(semanticSearchValue) - // - Takes TEXT input "laptop" - // - Returns embedding vector [0.8, 0.1, 0.05, 0.05] - // 2. Line 81: byte[] vectorBytes = ToRedisVectorBytes(embedding) - // - Converts float array to bytes for Redis - // 3. Line 83-92: db.ExecuteAsync("FT.SEARCH", ...) - // - Executes Redis FT.SEARCH with KNN (k-nearest neighbors) - // - Uses the embedding vector for similarity matching - // - Returns matching documents with similarity scores - - float[] embedding = [0.8f, 0.1f, 0.05f, 0.05f]; - - // Validate byte conversion - byte[] vectorBytes = new byte[embedding.Length * sizeof(float)]; - System.Buffer.BlockCopy(embedding, 0, vectorBytes, 0, vectorBytes.Length); - - Assert.AreEqual(16, vectorBytes.Length); // 4 floats * 4 bytes each = 16 bytes - } - - /// - /// FLOW VALIDATION 5: Confirm Redis results are converted to SemanticSearchCandidate - /// Expected: Redis returns docs, ParseCandidates extracts primary keys and column values - /// Current Implementation: RedisSemanticSearchService.ParseCandidates() (line 270) - /// - [TestMethod] - public void SemanticSearchFlow_ParsesCandidatesFromRedisResults() - { - // Implementation confirms: - // In RedisSemanticSearchService.ParseCandidates() (line 270): - // Iterates through Redis results and creates SemanticSearchCandidate objects - // - // SemanticSearchCandidate record (line 10): - // public record SemanticSearchCandidate( - // IReadOnlyDictionary PrimaryKeyValues, // {id: 1} - // IReadOnlyDictionary ColumnValues, // {name: "Laptop"} - // double Distance); // 0.92 - - var mockCandidate = new SemanticSearchCandidate( - new Dictionary { { "id", 1 } }, - new Dictionary { { "name", "Laptop" } }, - 0.92 - ); - - Assert.IsNotNull(mockCandidate.PrimaryKeyValues); - Assert.IsNotNull(mockCandidate.ColumnValues); - Assert.AreEqual(0.92, mockCandidate.Distance); - Assert.IsTrue(mockCandidate.PrimaryKeyValues.ContainsKey("id")); - Assert.IsTrue(mockCandidate.ColumnValues.ContainsKey("name")); - } - - /// - /// FLOW VALIDATION 6: Confirm database query uses semantic candidates - /// Expected: SqlQueryEngine.ApplySemanticCandidates() adds WHERE clause with primary keys - /// Current Implementation: SqlQueryEngine.ApplySemanticCandidates() applies candidates - /// - [TestMethod] - public void SemanticSearchFlow_DatabaseQueryUsesCandidates() - { - // Implementation flow in SqlQueryEngine (line 337): - // 1. structure.ApplySemanticCandidates(narrowedCandidates) - // - Adds SemanticSearchCandidate objects to SqlQueryStructure - // 2. This modifies the query to add WHERE clause: - // SELECT id, name, description, semantic_distance - // FROM Products - // WHERE id IN (1, 2, 3, ...) // primary keys from Redis results - // 3. Database is queried only for rows that matched semantic search - // 4. Result rows are enriched with semantic_distance from stored distances - - var candidate1 = new SemanticSearchCandidate( - new Dictionary { { "id", 1 } }, - new Dictionary { { "name", "Laptop" } }, - 0.92 - ); - - var candidate2 = new SemanticSearchCandidate( - new Dictionary { { "id", 2 } }, - new Dictionary { { "name", "Desktop" } }, - 0.85 - ); - - List candidates = [candidate1, candidate2]; - - Assert.AreEqual(2, candidates.Count); - Assert.AreEqual(1, candidate1.PrimaryKeyValues["id"]); - Assert.AreEqual(2, candidate2.PrimaryKeyValues["id"]); - } - - /// - /// COMPLETE FLOW SUMMARY: TEXT → EMBEDDING → REDIS → DB → RESPONSE - /// - /// Request: GET /api/Products?$semantic_search=laptop&$semantic_threshold=0.7 - /// - /// Step 1: RequestParser.ParseQueryString() - /// - Extracts $semantic_search="laptop" (TEXT, not embeddings) - /// - Extracts $semantic_threshold=0.7 - /// - Stores in FindRequestContext - /// - /// Step 2: RestService.ExecuteAsync(FindRequestContext) - /// - Calls SqlQueryEngine.TryPopulateSemanticSearchInformation() - /// - /// Step 3: SqlQueryEngine.TryPopulateSemanticSearchInformation() - /// - Calls RedisSemanticSearchService.GetCandidatesAsync("laptop", 0.7) - /// - /// Step 4: RedisSemanticSearchService.GetCandidatesAsync() - /// - Calls GetEmbeddingAsync("laptop") - /// - TryParseVectorText("laptop") → fails (plain text) - /// - Calls EmbeddingService.TryEmbedAsync("laptop") - /// - EmbeddingService checks L1 cache → miss - /// - EmbeddingService checks L2 Redis cache → miss - /// - EmbeddingService calls embedding API → [0.8, 0.1, 0.05, 0.05] - /// - Stores in L1 cache - /// - Stores in L2 Redis cache - /// - Returns embedding to GetCandidatesAsync - /// - Converts embedding to bytes - /// - Executes Redis FT.SEARCH with KNN vector search - /// - Returns SemanticSearchCandidate objects - /// - /// Step 5: SqlQueryEngine.ApplySemanticCandidates() - /// - Adds candidates to query structure - /// - Modifies query to WHERE id IN (primary_keys_from_redis) - /// - Stores semantic_distance for each candidate - /// - /// Step 6: Database Query Execution - /// - SELECT id, name, description FROM Products WHERE id IN (...) - /// - Returns rows that matched semantic search - /// - /// Step 7: Response Enrichment - /// - Adds semantic_distance field to each row from stored distances - /// - Optionally orders by semantic_distance - /// - /// Response: [ - /// { "id": 1, "name": "Laptop", "description": "...", "semantic_distance": 0.92 } - /// ] - /// - [TestMethod] - public void SemanticSearchFlow_CompleteEndToEndFlow() - { - // This test validates the CONCEPT of the complete flow - // The actual integration test runs the full flow with real Redis + DB - - string userTextInput = "laptop computers"; - double threshold = 0.7; - - // Simulate the flow steps - Assert.IsFalse(userTextInput.StartsWith("["), "TEXT input (not embeddings)"); - Assert.AreEqual("laptop computers", userTextInput, "TEXT is preserved"); - Assert.AreEqual(0.7, threshold, "Threshold is preserved"); - - // After embedding, vector search, DB query, response building: - // Response would contain results with semantic_distance field - } -} - diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs index bc02aeb259..f3ac8efd97 100644 --- a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -6,6 +6,7 @@ using System.Globalization; using System.Linq; using System.Text.Json; +using System.Threading; using System.Threading.Tasks; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; @@ -27,7 +28,7 @@ namespace Azure.DataApiBuilder.Service.Services.SemanticSearch; /// 2) performing a Redis FT.SEARCH KNN query, /// 3) extracting SQL column/value pairs from Redis hash/json documents. /// -public sealed class RedisSemanticSearchService : ISemanticSearchService +public sealed class RedisSemanticSearchService : ISemanticSearchService, IDisposable { private const string VECTOR_SCORE_FIELD = "__vector_score"; private const string DEFAULT_VECTOR_FIELD = "embedding"; @@ -36,6 +37,10 @@ public sealed class RedisSemanticSearchService : ISemanticSearchService private readonly IMetadataProviderFactory _metadataProviderFactory; private readonly IEmbeddingService? _embeddingService; private readonly ILogger _logger; + private readonly SemaphoreSlim _redisConnectionGate = new(1, 1); + + private string? _cachedRedisConnectionString; + private IConnectionMultiplexer? _cachedRedisMultiplexer; public RedisSemanticSearchService( RuntimeConfigProvider runtimeConfigProvider, @@ -79,7 +84,7 @@ public async Task> GetCandidatesAsync( try { - using IConnectionMultiplexer multiplexer = await CreateConnectionMultiplexerAsync(connectionString); + IConnectionMultiplexer multiplexer = await GetConnectionMultiplexerAsync(connectionString); IDatabase db = multiplexer.GetDatabase(); string vectorFieldName = await ResolveVectorFieldNameAsync(db, options.RedisIndexName!, options.RedisIndexType); @@ -185,6 +190,44 @@ private static async Task CreateConnectionMultiplexerAsy return await ConnectionMultiplexer.ConnectAsync(options); } + private async Task GetConnectionMultiplexerAsync(string connectionString) + { + if (_cachedRedisMultiplexer is not null + && string.Equals(_cachedRedisConnectionString, connectionString, StringComparison.Ordinal)) + { + return _cachedRedisMultiplexer; + } + + await _redisConnectionGate.WaitAsync(); + try + { + if (_cachedRedisMultiplexer is not null + && string.Equals(_cachedRedisConnectionString, connectionString, StringComparison.Ordinal)) + { + return _cachedRedisMultiplexer; + } + + IConnectionMultiplexer multiplexer = await CreateConnectionMultiplexerAsync(connectionString); + IConnectionMultiplexer? previous = _cachedRedisMultiplexer; + + _cachedRedisMultiplexer = multiplexer; + _cachedRedisConnectionString = connectionString; + + previous?.Dispose(); + return _cachedRedisMultiplexer; + } + finally + { + _redisConnectionGate.Release(); + } + } + + public void Dispose() + { + _cachedRedisMultiplexer?.Dispose(); + _redisConnectionGate.Dispose(); + } + private static async Task ResolveVectorFieldNameAsync(IDatabase db, string indexName, string redisIndexType) { try From 6afa92038492c04f0f2a307ba550e77192b2346f Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Thu, 11 Jun 2026 07:16:06 -0700 Subject: [PATCH 10/12] Fix linting issues --- src/Core/Resolvers/SqlQueryEngine.cs | 10 +++------- .../Queries/InputTypeBuilder.cs | 1 - .../UnitTests/RequestParserUnitTests.cs | 18 ++++++++++++------ 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/Core/Resolvers/SqlQueryEngine.cs b/src/Core/Resolvers/SqlQueryEngine.cs index 47b7daaaa1..a5c26febb1 100644 --- a/src/Core/Resolvers/SqlQueryEngine.cs +++ b/src/Core/Resolvers/SqlQueryEngine.cs @@ -224,7 +224,7 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), } JsonDocument? response = await ExecuteAsync(structure, dataSourceName); - return ApplySemanticDistanceAndOrderingIfNeeded(response, structure, dataSourceName, includeRestField: true, includeGraphQlField: false, context); + return ApplySemanticDistanceAndOrderingIfNeeded(response, structure, dataSourceName, includeRestField: true, includeGraphQlField: false); } private async Task<(bool shouldReturnEmpty, JsonDocument? emptyResponse)> TryApplySemanticNarrowingAsync( @@ -234,8 +234,6 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), string entityName, FindRequestContext? restContext = null) { - JsonDocument? emptyResponse = null; - string? semanticSearchText = restContext?.SemanticSearch; if (semanticSearchText is null && graphQLParameters is not null @@ -315,8 +313,7 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), if (candidates.Count == 0) { - emptyResponse = JsonDocument.Parse("[]"); - return (true, emptyResponse); + return (true, JsonDocument.Parse("[]")); } // De-duplicate candidate keys while keeping the highest similarity. @@ -345,8 +342,7 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), SqlQueryStructure structure, string dataSourceName, bool includeRestField, - bool includeGraphQlField, - FindRequestContext? restContext = null) + bool includeGraphQlField) { if (response is null || structure.SemanticDistanceByPrimaryKeySignature.Count == 0) { diff --git a/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs b/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs index 69a16d4a7d..13b79d994a 100644 --- a/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System; using Azure.DataApiBuilder.Service.GraphQLBuilder.Directives; using Azure.DataApiBuilder.Service.GraphQLBuilder.GraphQLTypes; using HotChocolate.Language; diff --git a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs index 5508939d19..a91496d805 100644 --- a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs @@ -93,9 +93,11 @@ public void ParseQueryString_SetsSemanticInputs_ForUserRequest() FindRequestContext context = new( entityName: DEFAULT_ENTITY, dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), - isList: true); + isList: true) + { + RawQueryString = "?$semantic_search=wireless%20headphones&$semantic_threshold=0.83" + }; - context.RawQueryString = "?$semantic_search=wireless%20headphones&$semantic_threshold=0.83"; context.ParsedQueryString.Add(RequestParser.SEMANTIC_SEARCH_URL, "wireless headphones"); context.ParsedQueryString.Add(RequestParser.SEMANTIC_THRESHOLD_URL, "0.83"); @@ -114,9 +116,11 @@ public void ParseQueryString_RejectsInvalidSemanticThreshold_ForUserRequest(stri FindRequestContext context = new( entityName: DEFAULT_ENTITY, dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), - isList: true); + isList: true) + { + RawQueryString = $"?$semantic_threshold={threshold}" + }; - context.RawQueryString = $"?$semantic_threshold={threshold}"; context.ParsedQueryString.Add(RequestParser.SEMANTIC_THRESHOLD_URL, threshold); DataApiBuilderException ex = Assert.ThrowsException( @@ -133,9 +137,11 @@ public void ParseQueryString_RejectsOrderBySemanticDistance_ForUserRequest(strin FindRequestContext context = new( entityName: DEFAULT_ENTITY, dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), - isList: true); + isList: true) + { + RawQueryString = rawQuery + }; - context.RawQueryString = rawQuery; context.ParsedQueryString.Add(RequestParser.SORT_URL, rawQuery.Contains("desc", StringComparison.OrdinalIgnoreCase) ? "Semantic_Distance desc" : "semantic_distance asc"); DataApiBuilderException ex = Assert.ThrowsException( From 6f2c78657d466929334a1d17807f7415723e3cd4 Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Thu, 11 Jun 2026 07:29:26 -0700 Subject: [PATCH 11/12] Linting issues part 2 --- src/Core/Models/SemanticSearchCandidate.cs | 6 +++--- src/Core/Resolvers/SqlQueryEngine.cs | 3 +-- src/Service.Tests/UnitTests/RequestParserUnitTests.cs | 2 +- .../Services/SemanticSearch/RedisSemanticSearchService.cs | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/Core/Models/SemanticSearchCandidate.cs b/src/Core/Models/SemanticSearchCandidate.cs index 0c8497364d..046520b63c 100644 --- a/src/Core/Models/SemanticSearchCandidate.cs +++ b/src/Core/Models/SemanticSearchCandidate.cs @@ -8,6 +8,6 @@ namespace Azure.DataApiBuilder.Core.Models; /// the semantic document and primary key values used for dedupe/output mapping. /// public record SemanticSearchCandidate( - IReadOnlyDictionary PrimaryKeyValues, - IReadOnlyDictionary ColumnValues, - double Similarity); + IReadOnlyDictionary PrimaryKeyValues, + IReadOnlyDictionary ColumnValues, + double Similarity); diff --git a/src/Core/Resolvers/SqlQueryEngine.cs b/src/Core/Resolvers/SqlQueryEngine.cs index a5c26febb1..26a3a91985 100644 --- a/src/Core/Resolvers/SqlQueryEngine.cs +++ b/src/Core/Resolvers/SqlQueryEngine.cs @@ -4,8 +4,8 @@ using System.Text; using System.Text.Json; using System.Text.Json.Nodes; -using System.Linq; using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; @@ -15,7 +15,6 @@ using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Exceptions; -using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Service.GraphQLBuilder; using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; using HotChocolate.Resolvers; diff --git a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs index a91496d805..609c633166 100644 --- a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs @@ -3,9 +3,9 @@ using System; using System.Reflection; -using Azure.DataApiBuilder.Core.Parsers; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Parsers; using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.VisualStudio.TestTools.UnitTesting; diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs index f3ac8efd97..019bf31e7f 100644 --- a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -492,4 +492,4 @@ private static string NormalizeFieldName(string field) return normalized; } -} \ No newline at end of file +} From 938b0be76aeb2a3259bf4891aa245f5dc4b566a0 Mon Sep 17 00:00:00 2001 From: AJ Tiwari Date: Thu, 11 Jun 2026 12:22:41 -0700 Subject: [PATCH 12/12] Fix tests failures --- src/Cli/Commands/AddOptions.cs | 18 +++++++++--------- src/Cli/Commands/UpdateOptions.cs | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/Cli/Commands/AddOptions.cs b/src/Cli/Commands/AddOptions.cs index 0bf66d539f..e00053fce0 100644 --- a/src/Cli/Commands/AddOptions.cs +++ b/src/Cli/Commands/AddOptions.cs @@ -36,15 +36,6 @@ public AddOptions( string? cacheTtlSeconds, string? cacheLevel, string? healthEnabled, - string? description, - IEnumerable? parametersNameCollection, - IEnumerable? parametersDescriptionCollection, - IEnumerable? parametersRequiredCollection, - IEnumerable? parametersDefaultCollection, - IEnumerable? fieldsNameCollection, - IEnumerable? fieldsAliasCollection, - IEnumerable? fieldsDescriptionCollection, - IEnumerable? fieldsPrimaryKeyCollection, string? semanticSearchEnabled = null, string? semanticSearchRedisIndexName = null, string? semanticSearchRedisIndexType = null, @@ -52,6 +43,15 @@ public AddOptions( string? semanticSearchSimilarityThreshold = null, string? semanticSearchInputDescription = null, string? semanticSearchOutputDescription = null, + string? description = null, + IEnumerable? parametersNameCollection = null, + IEnumerable? parametersDescriptionCollection = null, + IEnumerable? parametersRequiredCollection = null, + IEnumerable? parametersDefaultCollection = null, + IEnumerable? fieldsNameCollection = null, + IEnumerable? fieldsAliasCollection = null, + IEnumerable? fieldsDescriptionCollection = null, + IEnumerable? fieldsPrimaryKeyCollection = null, string? mcpDmlTools = null, string? mcpCustomTool = null, string? config = null diff --git a/src/Cli/Commands/UpdateOptions.cs b/src/Cli/Commands/UpdateOptions.cs index d5c63e8875..28ad4afdb5 100644 --- a/src/Cli/Commands/UpdateOptions.cs +++ b/src/Cli/Commands/UpdateOptions.cs @@ -44,15 +44,6 @@ public UpdateOptions( string? cacheTtlSeconds, string? cacheLevel, string? healthEnabled, - string? description, - IEnumerable? parametersNameCollection, - IEnumerable? parametersDescriptionCollection, - IEnumerable? parametersRequiredCollection, - IEnumerable? parametersDefaultCollection, - IEnumerable? fieldsNameCollection, - IEnumerable? fieldsAliasCollection, - IEnumerable? fieldsDescriptionCollection, - IEnumerable? fieldsPrimaryKeyCollection, string? semanticSearchEnabled = null, string? semanticSearchRedisIndexName = null, string? semanticSearchRedisIndexType = null, @@ -60,6 +51,15 @@ public UpdateOptions( string? semanticSearchSimilarityThreshold = null, string? semanticSearchInputDescription = null, string? semanticSearchOutputDescription = null, + string? description = null, + IEnumerable? parametersNameCollection = null, + IEnumerable? parametersDescriptionCollection = null, + IEnumerable? parametersRequiredCollection = null, + IEnumerable? parametersDefaultCollection = null, + IEnumerable? fieldsNameCollection = null, + IEnumerable? fieldsAliasCollection = null, + IEnumerable? fieldsDescriptionCollection = null, + IEnumerable? fieldsPrimaryKeyCollection = null, string? mcpDmlTools = null, string? mcpCustomTool = null, string? config = null)