diff --git a/schemas/dab.draft.schema.json b/schemas/dab.draft.schema.json index 6aa8f02641..5be2a04b36 100644 --- a/schemas/dab.draft.schema.json +++ b/schemas/dab.draft.schema.json @@ -534,6 +534,7 @@ } } }, + "compression": { "type": "object", "description": "Configures HTTP response compression settings.", @@ -1487,6 +1488,67 @@ } } }, + "semantic-search": { + "type": "object", + "description": "Semantic search configuration for this entity.", + "additionalProperties": false, + "properties": { + "enabled": { + "$ref": "#/$defs/boolean-or-string", + "description": "Enables semantic search for this entity.", + "default": false + }, + "redis-index-name": { + "type": "string", + "description": "Name of the Redis vector index used for semantic search for this entity." + }, + "redis-index-type": { + "type": "string", + "description": "Redis index storage type used by the semantic search index.", + "enum": ["hash", "json"], + "default": "hash" + }, + "redis-index-multiplier": { + "type": "integer", + "description": "Multiplier applied to requested result count when querying Redis.", + "default": 2, + "minimum": 1, + "maximum": 10 + }, + "similarity-threshold": { + "type": "number", + "description": "Minimum Redis similarity value required for a semantic match.", + "default": 0.8, + "minimum": 0.0, + "maximum": 1.0 + }, + "input-description": { + "type": "string", + "description": "Description surfaced in API metadata for semantic search input.", + "default": "Natural language value used for semantic search." + }, + "output-description": { + "type": "string", + "description": "Description surfaced in API metadata for semantic distance output.", + "default": "Semantic distance score returned by semantic search." + } + }, + "allOf": [ + { + "if": { + "properties": { + "enabled": { + "const": true + } + }, + "required": ["enabled"] + }, + "then": { + "required": ["redis-index-name"] + } + } + ] + }, "mcp": { "oneOf": [ { diff --git a/src/Cli/Commands/AddOptions.cs b/src/Cli/Commands/AddOptions.cs index 8fab1937d7..e00053fce0 100644 --- a/src/Cli/Commands/AddOptions.cs +++ b/src/Cli/Commands/AddOptions.cs @@ -36,15 +36,22 @@ public AddOptions( string? cacheTtlSeconds, string? cacheLevel, string? healthEnabled, - string? description, - IEnumerable? parametersNameCollection, - IEnumerable? parametersDescriptionCollection, - IEnumerable? parametersRequiredCollection, - IEnumerable? parametersDefaultCollection, - IEnumerable? fieldsNameCollection, - IEnumerable? fieldsAliasCollection, - IEnumerable? fieldsDescriptionCollection, - IEnumerable? fieldsPrimaryKeyCollection, + string? semanticSearchEnabled = null, + string? semanticSearchRedisIndexName = null, + string? semanticSearchRedisIndexType = null, + string? semanticSearchRedisIndexMultiplier = null, + string? semanticSearchSimilarityThreshold = null, + string? semanticSearchInputDescription = null, + string? semanticSearchOutputDescription = null, + string? description = null, + IEnumerable? parametersNameCollection = null, + IEnumerable? parametersDescriptionCollection = null, + IEnumerable? parametersRequiredCollection = null, + IEnumerable? parametersDefaultCollection = null, + IEnumerable? fieldsNameCollection = null, + IEnumerable? fieldsAliasCollection = null, + IEnumerable? fieldsDescriptionCollection = null, + IEnumerable? fieldsPrimaryKeyCollection = null, string? mcpDmlTools = null, string? mcpCustomTool = null, string? config = null @@ -66,6 +73,13 @@ public AddOptions( cacheTtlSeconds, cacheLevel, healthEnabled, + semanticSearchEnabled, + semanticSearchRedisIndexName, + semanticSearchRedisIndexType, + semanticSearchRedisIndexMultiplier, + semanticSearchSimilarityThreshold, + semanticSearchInputDescription, + semanticSearchOutputDescription, description, parametersNameCollection, parametersDescriptionCollection, diff --git a/src/Cli/Commands/ConfigureOptions.cs b/src/Cli/Commands/ConfigureOptions.cs index 12f5f55455..b64de19981 100644 --- a/src/Cli/Commands/ConfigureOptions.cs +++ b/src/Cli/Commands/ConfigureOptions.cs @@ -61,6 +61,8 @@ public ConfigureOptions( int? runtimePaginationMaxPageSize = null, int? runtimePaginationDefaultPageSize = null, bool? runtimePaginationNextLinkRelative = null, + string? runtimeCacheLevel2Provider = null, + string? runtimeCacheLevel2ConnectionString = null, CompressionLevel? runtimeCompressionLevel = null, bool? runtimeHealthEnabled = null, int? runtimeHealthCacheTtlSeconds = null, @@ -161,6 +163,8 @@ public ConfigureOptions( RuntimePaginationMaxPageSize = runtimePaginationMaxPageSize; RuntimePaginationDefaultPageSize = runtimePaginationDefaultPageSize; RuntimePaginationNextLinkRelative = runtimePaginationNextLinkRelative; + RuntimeCacheLevel2Provider = runtimeCacheLevel2Provider; + RuntimeCacheLevel2ConnectionString = runtimeCacheLevel2ConnectionString; // Compression RuntimeCompressionLevel = runtimeCompressionLevel; // Health @@ -342,6 +346,12 @@ public ConfigureOptions( [Option("runtime.pagination.next-link-relative", Required = false, HelpText = "Use relative URLs for pagination next links. Default: false (boolean).")] public bool? RuntimePaginationNextLinkRelative { get; } + [Option("runtime.cache.level-2.provider", Required = false, HelpText = "Set level-2 cache provider. Allowed values: redis.")] + public string? RuntimeCacheLevel2Provider { get; } + + [Option("runtime.cache.level-2.connection-string", Required = false, HelpText = "Set level-2 cache connection string.")] + public string? RuntimeCacheLevel2ConnectionString { get; } + [Option("runtime.compression.level", Required = false, HelpText = "Set the response compression level. Allowed values: optimal (default), fastest, none.")] public CompressionLevel? RuntimeCompressionLevel { get; } diff --git a/src/Cli/Commands/EntityOptions.cs b/src/Cli/Commands/EntityOptions.cs index 76988d4c75..a0da509448 100644 --- a/src/Cli/Commands/EntityOptions.cs +++ b/src/Cli/Commands/EntityOptions.cs @@ -27,6 +27,13 @@ public EntityOptions( string? cacheTtlSeconds, string? cacheLevel, string? healthEnabled, + string? semanticSearchEnabled, + string? semanticSearchRedisIndexName, + string? semanticSearchRedisIndexType, + string? semanticSearchRedisIndexMultiplier, + string? semanticSearchSimilarityThreshold, + string? semanticSearchInputDescription, + string? semanticSearchOutputDescription, string? description, IEnumerable? parametersNameCollection, IEnumerable? parametersDescriptionCollection, @@ -58,6 +65,13 @@ public EntityOptions( CacheTtlSeconds = cacheTtlSeconds; CacheLevel = cacheLevel; HealthEnabled = healthEnabled; + SemanticSearchEnabled = semanticSearchEnabled; + SemanticSearchRedisIndexName = semanticSearchRedisIndexName; + SemanticSearchRedisIndexType = semanticSearchRedisIndexType; + SemanticSearchRedisIndexMultiplier = semanticSearchRedisIndexMultiplier; + SemanticSearchSimilarityThreshold = semanticSearchSimilarityThreshold; + SemanticSearchInputDescription = semanticSearchInputDescription; + SemanticSearchOutputDescription = semanticSearchOutputDescription; Description = description; ParametersNameCollection = parametersNameCollection; ParametersDescriptionCollection = parametersDescriptionCollection; @@ -120,6 +134,27 @@ public EntityOptions( [Option("health.enabled", Required = false, HelpText = "Enable health checks for this entity. Default: true (boolean).")] public string? HealthEnabled { get; } + [Option("semantic-search.enabled", Required = false, HelpText = "Enable semantic search for this entity. Accepted values are true/false.")] + public string? SemanticSearchEnabled { get; } + + [Option("semantic-search.redis-index-name", Required = false, HelpText = "Name of the Redis vector index to use for semantic search.")] + public string? SemanticSearchRedisIndexName { get; } + + [Option("semantic-search.redis-index-type", Required = false, HelpText = "Redis index type for semantic search. Allowed values: hash, json.")] + public string? SemanticSearchRedisIndexType { get; } + + [Option("semantic-search.redis-index-multiplier", Required = false, HelpText = "Multiplier for Redis candidate retrieval. Allowed values: 1-10.")] + public string? SemanticSearchRedisIndexMultiplier { get; } + + [Option("semantic-search.similarity-threshold", Required = false, HelpText = "Default semantic similarity threshold. Allowed values: 0.0-1.0.")] + public string? SemanticSearchSimilarityThreshold { get; } + + [Option("semantic-search.input-description", Required = false, HelpText = "Description for semantic search input metadata.")] + public string? SemanticSearchInputDescription { get; } + + [Option("semantic-search.output-description", Required = false, HelpText = "Description for semantic distance output metadata.")] + public string? SemanticSearchOutputDescription { get; } + [Option("description", Required = false, HelpText = "Description of the entity.")] public string? Description { get; } diff --git a/src/Cli/Commands/UpdateOptions.cs b/src/Cli/Commands/UpdateOptions.cs index 206bcd704d..28ad4afdb5 100644 --- a/src/Cli/Commands/UpdateOptions.cs +++ b/src/Cli/Commands/UpdateOptions.cs @@ -44,15 +44,22 @@ public UpdateOptions( string? cacheTtlSeconds, string? cacheLevel, string? healthEnabled, - string? description, - IEnumerable? parametersNameCollection, - IEnumerable? parametersDescriptionCollection, - IEnumerable? parametersRequiredCollection, - IEnumerable? parametersDefaultCollection, - IEnumerable? fieldsNameCollection, - IEnumerable? fieldsAliasCollection, - IEnumerable? fieldsDescriptionCollection, - IEnumerable? fieldsPrimaryKeyCollection, + string? semanticSearchEnabled = null, + string? semanticSearchRedisIndexName = null, + string? semanticSearchRedisIndexType = null, + string? semanticSearchRedisIndexMultiplier = null, + string? semanticSearchSimilarityThreshold = null, + string? semanticSearchInputDescription = null, + string? semanticSearchOutputDescription = null, + string? description = null, + IEnumerable? parametersNameCollection = null, + IEnumerable? parametersDescriptionCollection = null, + IEnumerable? parametersRequiredCollection = null, + IEnumerable? parametersDefaultCollection = null, + IEnumerable? fieldsNameCollection = null, + IEnumerable? fieldsAliasCollection = null, + IEnumerable? fieldsDescriptionCollection = null, + IEnumerable? fieldsPrimaryKeyCollection = null, string? mcpDmlTools = null, string? mcpCustomTool = null, string? config = null) @@ -72,6 +79,13 @@ public UpdateOptions( cacheTtlSeconds, cacheLevel, healthEnabled, + semanticSearchEnabled, + semanticSearchRedisIndexName, + semanticSearchRedisIndexType, + semanticSearchRedisIndexMultiplier, + semanticSearchSimilarityThreshold, + semanticSearchInputDescription, + semanticSearchOutputDescription, description, parametersNameCollection, parametersDescriptionCollection, diff --git a/src/Cli/ConfigGenerator.cs b/src/Cli/ConfigGenerator.cs index ae3087debb..c513f1ce30 100644 --- a/src/Cli/ConfigGenerator.cs +++ b/src/Cli/ConfigGenerator.cs @@ -472,6 +472,14 @@ public static bool TryAddNewEntity(AddOptions options, RuntimeConfig initialRunt EntityGraphQLOptions graphqlOptions = ConstructGraphQLTypeDetails(options.GraphQLType, graphQLOperationsForStoredProcedures); EntityCacheOptions? cacheOptions = ConstructCacheOptions(options.CacheEnabled, options.CacheTtlSeconds, options.CacheLevel); EntityHealthCheckConfig? entityHealthOptions = ConstructEntityHealthOptions(options.HealthEnabled); + EntitySemanticSearchOptions? semanticSearchOptions = ConstructSemanticSearchOptions( + options.SemanticSearchEnabled, + options.SemanticSearchRedisIndexName, + options.SemanticSearchRedisIndexType, + options.SemanticSearchRedisIndexMultiplier, + options.SemanticSearchSimilarityThreshold, + options.SemanticSearchInputDescription, + options.SemanticSearchOutputDescription); EntityMcpOptions? mcpOptions = null; if (options.McpDmlTools is not null || options.McpCustomTool is not null) @@ -495,6 +503,7 @@ public static bool TryAddNewEntity(AddOptions options, RuntimeConfig initialRunt Relationships: null, Mappings: null, Cache: cacheOptions, + SemanticSearch: semanticSearchOptions, Description: string.IsNullOrWhiteSpace(options.Description) ? null : options.Description, Mcp: mcpOptions, Health: entityHealthOptions); @@ -1002,7 +1011,9 @@ private static bool TryUpdateConfiguredRuntimeOptions( // Cache: Enabled and TTL if (options.RuntimeCacheEnabled != null || - options.RuntimeCacheTTL != null) + options.RuntimeCacheTTL != null || + options.RuntimeCacheLevel2Provider != null || + options.RuntimeCacheLevel2ConnectionString != null) { RuntimeCacheOptions? updatedCacheOptions = runtimeConfig?.Runtime?.Cache ?? new(); bool status = TryUpdateConfiguredCacheValues(options, ref updatedCacheOptions); @@ -1559,6 +1570,43 @@ private static bool TryUpdateConfiguredCacheValues( } } + // Runtime.Cache.level-2.provider + updatedValue = options?.RuntimeCacheLevel2Provider; + if (updatedValue != null) + { + string provider = ((string)updatedValue).Trim(); + if (!string.Equals(provider, "redis", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogError("Failed to update Runtime.Cache.level-2.provider as '{updatedValue}'. Supported value is 'redis'.", updatedValue); + return false; + } + + RuntimeCacheLevel2Options currentLevel2 = updatedCacheOptions?.Level2 ?? new(); + updatedCacheOptions = updatedCacheOptions! with + { + Level2 = currentLevel2 with + { + Provider = provider.ToLowerInvariant() + } + }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Cache.level-2.provider as '{updatedValue}'", updatedValue); + } + + // Runtime.Cache.level-2.connection-string + updatedValue = options?.RuntimeCacheLevel2ConnectionString; + if (updatedValue != null) + { + RuntimeCacheLevel2Options currentLevel2 = updatedCacheOptions?.Level2 ?? new(); + updatedCacheOptions = updatedCacheOptions! with + { + Level2 = currentLevel2 with + { + ConnectionString = (string)updatedValue + } + }; + _logger.LogInformation("Updated RuntimeConfig with Runtime.Cache.level-2.connection-string."); + } + return true; } catch (Exception ex) @@ -2327,6 +2375,14 @@ public static bool TryUpdateExistingEntity(UpdateOptions options, RuntimeConfig EntityActionFields? updatedFields = GetFieldsForOperation(options.FieldsToInclude, options.FieldsToExclude); EntityCacheOptions? updatedCacheOptions = ConstructCacheOptions(options.CacheEnabled, options.CacheTtlSeconds, options.CacheLevel) ?? entity.Cache; EntityHealthCheckConfig? updatedEntityHealthOptions = ConstructEntityHealthOptions(options.HealthEnabled) ?? entity.Health; + EntitySemanticSearchOptions? updatedSemanticSearchOptions = ConstructSemanticSearchOptions( + options.SemanticSearchEnabled, + options.SemanticSearchRedisIndexName, + options.SemanticSearchRedisIndexType, + options.SemanticSearchRedisIndexMultiplier, + options.SemanticSearchSimilarityThreshold, + options.SemanticSearchInputDescription, + options.SemanticSearchOutputDescription); // Determine if the entity is or will be a stored procedure bool isStoredProcedureAfterUpdate = doOptionsRepresentStoredProcedure || (isCurrentEntityStoredProcedure && options.SourceType is null); @@ -2588,6 +2644,7 @@ public static bool TryUpdateExistingEntity(UpdateOptions options, RuntimeConfig Relationships: updatedRelationships, Mappings: updatedMappings, Cache: updatedCacheOptions, + SemanticSearch: updatedSemanticSearchOptions ?? entity.SemanticSearch, Description: string.IsNullOrWhiteSpace(options.Description) ? entity.Description : options.Description, Mcp: updatedMcpOptions, Health: updatedEntityHealthOptions diff --git a/src/Cli/Utils.cs b/src/Cli/Utils.cs index 8f309407a8..f15bb8d55a 100644 --- a/src/Cli/Utils.cs +++ b/src/Cli/Utils.cs @@ -956,6 +956,94 @@ public static EntityGraphQLOptions ConstructGraphQLTypeDetails(string? graphQL, return new EntityHealthCheckConfig(enabled: isEnabled); } + /// + /// Constructs the EntitySemanticSearchOptions for Add/Update. + /// + /// EntitySemanticSearchOptions when at least one semantic option is provided, null otherwise. + public static EntitySemanticSearchOptions? ConstructSemanticSearchOptions( + string? semanticSearchEnabled, + string? semanticSearchRedisIndexName, + string? semanticSearchRedisIndexType, + string? semanticSearchRedisIndexMultiplier, + string? semanticSearchSimilarityThreshold, + string? semanticSearchInputDescription, + string? semanticSearchOutputDescription) + { + if (semanticSearchEnabled is null + && semanticSearchRedisIndexName is null + && semanticSearchRedisIndexType is null + && semanticSearchRedisIndexMultiplier is null + && semanticSearchSimilarityThreshold is null + && semanticSearchInputDescription is null + && semanticSearchOutputDescription is null) + { + return null; + } + + bool enabled = false; + if (semanticSearchEnabled is not null && !bool.TryParse(semanticSearchEnabled, out enabled)) + { + _logger.LogError("Invalid format for --semantic-search.enabled. Accepted values are true/false."); + return null; + } + + string redisIndexType = EntitySemanticSearchOptions.DEFAULT_REDIS_INDEX_TYPE; + if (!string.IsNullOrWhiteSpace(semanticSearchRedisIndexType)) + { + if (!string.Equals(semanticSearchRedisIndexType, "hash", StringComparison.OrdinalIgnoreCase) + && !string.Equals(semanticSearchRedisIndexType, "json", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogError("Invalid format for --semantic-search.redis-index-type. Accepted values are hash/json."); + return null; + } + + redisIndexType = semanticSearchRedisIndexType.ToLowerInvariant(); + } + + int redisIndexMultiplier = EntitySemanticSearchOptions.DEFAULT_REDIS_INDEX_MULTIPLIER; + if (!string.IsNullOrWhiteSpace(semanticSearchRedisIndexMultiplier) + && !int.TryParse(semanticSearchRedisIndexMultiplier, out redisIndexMultiplier)) + { + _logger.LogError("Invalid format for --semantic-search.redis-index-multiplier. Accepted values are integer values in range [1,10]."); + return null; + } + + if (redisIndexMultiplier < 1 || redisIndexMultiplier > 10) + { + _logger.LogError("Invalid value for --semantic-search.redis-index-multiplier. Accepted values are integer values in range [1,10]."); + return null; + } + + double similarityThreshold = EntitySemanticSearchOptions.DEFAULT_SIMILARITY_THRESHOLD; + if (!string.IsNullOrWhiteSpace(semanticSearchSimilarityThreshold) + && !double.TryParse(semanticSearchSimilarityThreshold, out similarityThreshold)) + { + _logger.LogError("Invalid format for --semantic-search.similarity-threshold. Accepted values are decimal values in range [0.0,1.0]."); + return null; + } + + if (similarityThreshold < 0.0 || similarityThreshold > 1.0) + { + _logger.LogError("Invalid value for --semantic-search.similarity-threshold. Accepted values are decimal values in range [0.0,1.0]."); + return null; + } + + return new EntitySemanticSearchOptions + { + Enabled = enabled, + RedisIndexName = semanticSearchRedisIndexName, + RedisIndexType = redisIndexType, + RedisIndexMultiplier = redisIndexMultiplier, + SimilarityThreshold = similarityThreshold, + InputDescription = string.IsNullOrWhiteSpace(semanticSearchInputDescription) + ? EntitySemanticSearchOptions.DEFAULT_INPUT_DESCRIPTION + : semanticSearchInputDescription, + OutputDescription = string.IsNullOrWhiteSpace(semanticSearchOutputDescription) + ? EntitySemanticSearchOptions.DEFAULT_OUTPUT_DESCRIPTION + : semanticSearchOutputDescription + }; + } + /// /// Constructs the EntityMcpOptions for Add/Update. /// diff --git a/src/Config/ObjectModel/Entity.cs b/src/Config/ObjectModel/Entity.cs index 9c9ba17f2c..637c219744 100644 --- a/src/Config/ObjectModel/Entity.cs +++ b/src/Config/ObjectModel/Entity.cs @@ -37,6 +37,8 @@ public record Entity public Dictionary? Mappings { get; init; } public Dictionary? Relationships { get; init; } public EntityCacheOptions? Cache { get; init; } + [JsonPropertyName("semantic-search")] + public EntitySemanticSearchOptions? SemanticSearch { get; init; } public EntityHealthCheckConfig? Health { get; init; } [JsonConverter(typeof(EntityMcpOptionsConverterFactory))] @@ -62,7 +64,8 @@ public Entity( EntityHealthCheckConfig? Health = null, string? Description = null, EntityMcpOptions? Mcp = null, - bool IsAutoentity = false) + bool IsAutoentity = false, + EntitySemanticSearchOptions? SemanticSearch = null) { this.Health = Health; this.Source = Source; @@ -73,6 +76,7 @@ public Entity( this.Mappings = Mappings; this.Relationships = Relationships; this.Cache = Cache; + this.SemanticSearch = SemanticSearch; this.IsLinkingEntity = IsLinkingEntity; this.Description = Description; this.Mcp = Mcp; diff --git a/src/Config/ObjectModel/EntitySemanticSearchOptions.cs b/src/Config/ObjectModel/EntitySemanticSearchOptions.cs new file mode 100644 index 0000000000..1175330501 --- /dev/null +++ b/src/Config/ObjectModel/EntitySemanticSearchOptions.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.DataApiBuilder.Config.ObjectModel; + +/// +/// Entity specific semantic search configuration. +/// +public record EntitySemanticSearchOptions +{ + public const int DEFAULT_REDIS_INDEX_MULTIPLIER = 2; + public const double DEFAULT_SIMILARITY_THRESHOLD = 0.8; + public const string DEFAULT_REDIS_INDEX_TYPE = "hash"; + public const string DEFAULT_INPUT_DESCRIPTION = "Natural language value used for semantic search."; + public const string DEFAULT_OUTPUT_DESCRIPTION = "Semantic distance score returned by semantic search."; + + [JsonPropertyName("enabled")] + public bool Enabled { get; init; } = false; + + [JsonPropertyName("redis-index-name")] + public string? RedisIndexName { get; init; } + + [JsonPropertyName("redis-index-type")] + public string RedisIndexType { get; init; } = DEFAULT_REDIS_INDEX_TYPE; + + [JsonPropertyName("redis-index-multiplier")] + public int RedisIndexMultiplier { get; init; } = DEFAULT_REDIS_INDEX_MULTIPLIER; + + [JsonPropertyName("similarity-threshold")] + public double SimilarityThreshold { get; init; } = DEFAULT_SIMILARITY_THRESHOLD; + + [JsonPropertyName("input-description")] + public string InputDescription { get; init; } = DEFAULT_INPUT_DESCRIPTION; + + [JsonPropertyName("output-description")] + public string OutputDescription { get; init; } = DEFAULT_OUTPUT_DESCRIPTION; +} diff --git a/src/Core/Configurations/RuntimeConfigValidator.cs b/src/Core/Configurations/RuntimeConfigValidator.cs index 33f60d7e48..bb210b1d9f 100644 --- a/src/Core/Configurations/RuntimeConfigValidator.cs +++ b/src/Core/Configurations/RuntimeConfigValidator.cs @@ -66,6 +66,26 @@ public class RuntimeConfigValidator : IConfigValidator // Error messages. public const string INVALID_CLAIMS_IN_POLICY_ERR_MSG = "One or more claim types supplied in the database policy are not supported."; + private const string SEMANTIC_SEARCH_REDIS_REQUIREMENT_ERR_MSG = + "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured."; + + private const string SEMANTIC_SEARCH_EMBEDDING_REQUIREMENT_ERR_MSG = + "Semantic search requires runtime.embeddings to be configured and enabled."; + + private static readonly HashSet _reservedSemanticRestNames = + [ + "semantic_search", + "semantic_threshold", + "semantic_distance" + ]; + + private static readonly HashSet _reservedSemanticGraphQLNames = + [ + "semanticSearch", + "semanticThreshold", + "semanticDistance" + ]; + public RuntimeConfigValidator( RuntimeConfigProvider runtimeConfigProvider, IFileSystem fileSystem, @@ -1011,6 +1031,153 @@ public void ValidateEntityConfiguration(RuntimeConfig runtimeConfig) ValidateNameRequirements(entity.GraphQL.Singular); ValidateNameRequirements(entity.GraphQL.Plural); } + + ValidateSemanticSearchConfiguration(runtimeConfig, entityName, entity); + } + } + + /// + /// Validates semantic-search settings configured for an entity. + /// + private void ValidateSemanticSearchConfiguration(RuntimeConfig runtimeConfig, string entityName, Entity entity) + { + EntitySemanticSearchOptions? semantic = entity.SemanticSearch; + + if (semantic is null || !semantic.Enabled) + { + return; + } + + if (entity.Source.Type is not EntitySourceType.Table and not EntitySourceType.View) + { + HandleOrRecordException(new DataApiBuilderException( + message: "Semantic search is only supported for entities with source type 'table' or 'view'.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (string.IsNullOrWhiteSpace(semantic.RedisIndexName)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"semantic-search.redis-index-name is required when semantic-search.enabled is true for entity '{entityName}'.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (!HasValidSemanticRedisConfiguration(runtimeConfig)) + { + HandleOrRecordException(new DataApiBuilderException( + message: SEMANTIC_SEARCH_REDIS_REQUIREMENT_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (!HasValidSemanticEmbeddingConfiguration(runtimeConfig)) + { + HandleOrRecordException(new DataApiBuilderException( + message: SEMANTIC_SEARCH_EMBEDDING_REQUIREMENT_ERR_MSG, + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (semantic.RedisIndexMultiplier < 1 || semantic.RedisIndexMultiplier > 10) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"semantic-search.redis-index-multiplier for entity '{entityName}' must be an integer between 1 and 10.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + if (semantic.SimilarityThreshold < 0.0 || semantic.SimilarityThreshold > 1.0) + { + HandleOrRecordException(new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + + ValidateSemanticReservedNames(entityName, entity); + } + + /// + /// Validates Redis provider/connection-string prerequisites for semantic-search. + /// + private static bool HasValidSemanticRedisConfiguration(RuntimeConfig runtimeConfig) + { + RuntimeCacheLevel2Options? level2 = runtimeConfig.Runtime?.Cache?.Level2; + return level2 is not null + && string.Equals(level2.Provider, EntityCacheOptions.L2_CACHE_PROVIDER, StringComparison.OrdinalIgnoreCase) + && !string.IsNullOrWhiteSpace(level2.ConnectionString); + } + + /// + /// Validates embeddings prerequisites for semantic-search. + /// + private static bool HasValidSemanticEmbeddingConfiguration(RuntimeConfig runtimeConfig) + { + EmbeddingsOptions? embeddings = runtimeConfig.Runtime?.Embeddings; + return embeddings is not null && embeddings.Enabled; + } + + /// + /// Ensures configured field names do not collide with semantic reserved names. + /// + private void ValidateSemanticReservedNames(string entityName, Entity entity) + { + HashSet exposedFieldNames = new(StringComparer.OrdinalIgnoreCase); + + if (entity.Fields is not null) + { + foreach (FieldMetadata field in entity.Fields) + { + if (!string.IsNullOrWhiteSpace(field.Name)) + { + exposedFieldNames.Add(field.Name); + } + + if (!string.IsNullOrWhiteSpace(field.Alias)) + { + exposedFieldNames.Add(field.Alias); + } + } + } + + if (entity.Mappings is not null) + { + foreach (KeyValuePair mapping in entity.Mappings) + { + if (!string.IsNullOrWhiteSpace(mapping.Key)) + { + exposedFieldNames.Add(mapping.Key); + } + + if (!string.IsNullOrWhiteSpace(mapping.Value)) + { + exposedFieldNames.Add(mapping.Value); + } + } + } + + foreach (string reserved in _reservedSemanticRestNames) + { + if (exposedFieldNames.Contains(reserved)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Entity '{entityName}' cannot enable semantic search because field name '{reserved}' is reserved.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } + } + + foreach (string reserved in _reservedSemanticGraphQLNames) + { + if (exposedFieldNames.Contains(reserved)) + { + HandleOrRecordException(new DataApiBuilderException( + message: $"Entity '{entityName}' cannot enable semantic search because field name '{reserved}' is reserved.", + statusCode: HttpStatusCode.ServiceUnavailable, + subStatusCode: DataApiBuilderException.SubStatusCodes.ConfigValidationError)); + } } } diff --git a/src/Core/Models/RestRequestContexts/RestRequestContext.cs b/src/Core/Models/RestRequestContexts/RestRequestContext.cs index e9987730a0..716f9cd96b 100644 --- a/src/Core/Models/RestRequestContexts/RestRequestContext.cs +++ b/src/Core/Models/RestRequestContexts/RestRequestContext.cs @@ -95,6 +95,31 @@ protected RestRequestContext(string entityName, DatabaseObject dbo) /// public int? First { get; set; } + + /// + /// Optional semantic search input provided by the caller. + /// + public string? SemanticSearch { get; set; } + + /// + /// Optional semantic threshold override provided by the caller. + /// + public double? SemanticThreshold { get; set; } + + /// + /// True when semantic distance should be included in REST output. + /// + public bool IncludeSemanticDistanceInResponse { get; set; } + + /// + /// Semantic distance keyed by a stable primary-key signature generated from result rows. + /// + public Dictionary? SemanticDistanceByPrimaryKeySignature { get; set; } + + /// + /// True when user explicitly requested semantic distance in the projection. + /// + public bool IsSemanticDistanceExplicitlySelected { get; set; } /// /// Is the result supposed to be multiple values or not. /// diff --git a/src/Core/Models/SemanticSearchCandidate.cs b/src/Core/Models/SemanticSearchCandidate.cs new file mode 100644 index 0000000000..046520b63c --- /dev/null +++ b/src/Core/Models/SemanticSearchCandidate.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Models; + +/// +/// Represents one semantic candidate with SQL column values extracted from +/// the semantic document and primary key values used for dedupe/output mapping. +/// +public record SemanticSearchCandidate( + IReadOnlyDictionary PrimaryKeyValues, + IReadOnlyDictionary ColumnValues, + double Similarity); diff --git a/src/Core/Models/SemanticSearchConstants.cs b/src/Core/Models/SemanticSearchConstants.cs new file mode 100644 index 0000000000..dbca8d9426 --- /dev/null +++ b/src/Core/Models/SemanticSearchConstants.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.DataApiBuilder.Core.Models; + +public static class SemanticSearchConstants +{ + public const string REST_SEARCH_QUERY_PARAM = "$semantic_search"; + public const string REST_THRESHOLD_QUERY_PARAM = "$semantic_threshold"; + public const string REST_DISTANCE_FIELD = "semantic_distance"; + + public const string GRAPHQL_SEARCH_ARGUMENT = "semanticSearch"; + public const string GRAPHQL_THRESHOLD_ARGUMENT = "semanticThreshold"; + public const string GRAPHQL_DISTANCE_FIELD = "semanticDistance"; +} diff --git a/src/Core/Parsers/RequestParser.cs b/src/Core/Parsers/RequestParser.cs index 081018e820..0870077dc9 100644 --- a/src/Core/Parsers/RequestParser.cs +++ b/src/Core/Parsers/RequestParser.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Globalization; using System.Net; using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Services; @@ -35,6 +36,14 @@ public class RequestParser /// Prefix used for specifying paging in the query string of the URL. /// public const string AFTER_URL = "$after"; + /// + /// Prefix used for specifying semantic search text in the query string of the URL. + /// + public const string SEMANTIC_SEARCH_URL = SemanticSearchConstants.REST_SEARCH_QUERY_PARAM; + /// + /// Prefix used for specifying semantic similarity threshold in the query string of the URL. + /// + public const string SEMANTIC_THRESHOLD_URL = SemanticSearchConstants.REST_THRESHOLD_QUERY_PARAM; /// /// Parses the primary key string to identify the field names composing the key @@ -138,6 +147,14 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); } + if (ContainsSemanticDistanceOrderByToken(rawSortValue)) + { + throw new DataApiBuilderException( + message: "semantic_distance cannot be used in orderBy.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + (context.OrderByClauseInUrl, context.OrderByClauseOfBackingColumns) = GenerateOrderByLists(context, sqlMetadataProvider, $"?{SORT_URL}={rawSortValue}"); break; case AFTER_URL: @@ -146,6 +163,22 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv case FIRST_URL: context.First = RequestValidator.CheckFirstValidity(context.ParsedQueryString[key]!); break; + case SEMANTIC_SEARCH_URL: + context.SemanticSearch = context.ParsedQueryString[key]; + break; + case SEMANTIC_THRESHOLD_URL: + if (!double.TryParse(context.ParsedQueryString[key], NumberStyles.Float, CultureInfo.InvariantCulture, out double semanticThreshold) + || semanticThreshold < 0.0 + || semanticThreshold > 1.0) + { + throw new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + context.SemanticThreshold = semanticThreshold; + break; default: throw new DataApiBuilderException( message: $"Invalid Query Parameter: {key}", @@ -155,6 +188,28 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv } } + private static bool ContainsSemanticDistanceOrderByToken(string rawSortValue) + { + string decoded = Uri.UnescapeDataString(rawSortValue); + + foreach (string part in decoded.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) + { + string[] tokens = part.Split(' ', StringSplitOptions.RemoveEmptyEntries); + if (tokens.Length == 0) + { + continue; + } + + string columnToken = tokens[0].Trim('\'', '"'); + if (string.Equals(columnToken, SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + + return false; + } + /// /// Create List of OrderByColumn from an OrderByClause Abstract Syntax Tree /// and return that list as List since OrderByColumn is a Column. diff --git a/src/Core/Resolvers/Factories/QueryEngineFactory.cs b/src/Core/Resolvers/Factories/QueryEngineFactory.cs index 1d2ae2935d..1260db344a 100644 --- a/src/Core/Resolvers/Factories/QueryEngineFactory.cs +++ b/src/Core/Resolvers/Factories/QueryEngineFactory.cs @@ -9,6 +9,7 @@ using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; @@ -32,6 +33,7 @@ public class QueryEngineFactory : IQueryEngineFactory private readonly IAuthorizationResolver _authorizationResolver; private readonly GQLFilterParser _gQLFilterParser; private readonly DabCacheService _cache; + private readonly ISemanticSearchService _semanticSearchService; private readonly ILogger _logger; /// @@ -42,6 +44,7 @@ public QueryEngineFactory(RuntimeConfigProvider runtimeConfigProvider, IHttpContextAccessor contextAccessor, IAuthorizationResolver authorizationResolver, GQLFilterParser gQLFilterParser, + ISemanticSearchService semanticSearchService, ILogger logger, DabCacheService cache, HotReloadEventHandler? handler) @@ -55,6 +58,7 @@ public QueryEngineFactory(RuntimeConfigProvider runtimeConfigProvider, _contextAccessor = contextAccessor; _authorizationResolver = authorizationResolver; _gQLFilterParser = gQLFilterParser; + _semanticSearchService = semanticSearchService; _cache = cache; _logger = logger; @@ -75,7 +79,8 @@ public void ConfigureQueryEngines() _gQLFilterParser, _logger, _runtimeConfigProvider, - _cache); + _cache, + _semanticSearchService); _queryEngines.Add(DatabaseType.MSSQL, queryEngine); _queryEngines.Add(DatabaseType.MySQL, queryEngine); _queryEngines.Add(DatabaseType.PostgreSQL, queryEngine); diff --git a/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs b/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs index 430b58c54f..af15fecf1c 100644 --- a/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs +++ b/src/Core/Resolvers/Sql Query Structures/SqlQueryStructure.cs @@ -4,6 +4,7 @@ using System.Data; using System.Net; using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; @@ -92,6 +93,80 @@ public class SqlQueryStructure : BaseSqlQueryStructure /// public GroupByMetadata GroupByMetadata { get; private set; } + /// + /// True when semantic search arguments are present on the request. + /// + public bool SemanticSearchRequested { get; private set; } + + /// + /// True when a user supplied ordering expression. + /// + public bool HasUserSpecifiedOrdering { get; private set; } + + /// + /// Primary-key signature to semantic distance map. + /// + public Dictionary SemanticDistanceByPrimaryKeySignature { get; } + + /// + /// Applies a semantic narrowing predicate of the form: + /// (col1 = ... AND col2 = ...) OR (...) + /// + public void ApplySemanticCandidates(IReadOnlyList candidates) + { + if (candidates.Count == 0) + { + Predicates.Add(Predicate.MakeFalsePredicate()); + return; + } + + Predicate? combinedOr = null; + + foreach (SemanticSearchCandidate candidate in candidates) + { + Predicate? andPredicate = null; + foreach ((string columnName, object? value) in candidate.ColumnValues) + { + if (value is null) + { + continue; + } + + string parameterName = MakeDbConnectionParam( + GetParamAsSystemType(value.ToString()!, columnName, GetColumnSystemType(columnName)), + columnName); + + Predicate equalsPredicate = new( + new PredicateOperand(new Column(DatabaseObject.SchemaName, DatabaseObject.Name, columnName, SourceAlias)), + PredicateOperation.Equal, + new PredicateOperand(parameterName), + addParenthesis: true); + + andPredicate = andPredicate is null + ? equalsPredicate + : new Predicate(new PredicateOperand(andPredicate), PredicateOperation.AND, new PredicateOperand(equalsPredicate), addParenthesis: true); + } + + if (andPredicate is null) + { + continue; + } + + combinedOr = combinedOr is null + ? andPredicate + : new Predicate(new PredicateOperand(combinedOr), PredicateOperation.OR, new PredicateOperand(andPredicate), addParenthesis: true); + } + + if (combinedOr is null) + { + Predicates.Add(Predicate.MakeFalsePredicate()); + } + else + { + Predicates.Add(combinedOr); + } + } + public virtual string? CacheControlOption { get; set; } public const string CACHE_CONTROL = "Cache-Control"; @@ -280,6 +355,7 @@ public SqlQueryStructure( IsListQuery = context.IsMany; SourceAlias = $"{DatabaseObject.SchemaName}_{DatabaseObject.Name}"; AddFields(context, sqlMetadataProvider); + SemanticSearchRequested = !string.IsNullOrWhiteSpace(context.SemanticSearch); foreach (KeyValuePair predicate in context.PrimaryKeyValuePairs) { sqlMetadataProvider.TryGetBackingColumn(EntityName, predicate.Key, out string? backingColumn); @@ -293,6 +369,7 @@ public SqlQueryStructure( // to only Find, we populate the SourceAlias in this constructor where we know we have a Find operation. OrderByColumns = context.OrderByClauseOfBackingColumns is not null ? context.OrderByClauseOfBackingColumns : PrimaryKeyAsOrderByColumns(); + HasUserSpecifiedOrdering = context.OrderByClauseOfBackingColumns is not null; foreach (OrderByColumn column in OrderByColumns) { @@ -444,6 +521,18 @@ private SqlQueryStructure( EntityName = sqlMetadataProvider.GetDatabaseType() == DatabaseType.DWSQL ? GraphQLUtils.GetEntityNameFromContext(ctx) : _underlyingFieldType.Name; bool isGroupByQuery = queryField?.Name.Value == QueryBuilder.GROUP_BY_FIELD_NAME; + SemanticSearchRequested = queryParams.TryGetValue(QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME, out object? semanticValue) + && semanticValue is string semanticSearchText + && !string.IsNullOrWhiteSpace(semanticSearchText); + + if (SemanticSearchRequested && isGroupByQuery) + { + throw new DataApiBuilderException( + message: "Semantic search is not supported for aggregate queries.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + if (GraphQLUtils.TryExtractGraphQLFieldModelName(_underlyingFieldType.Directives, out string? modelName)) { EntityName = modelName; @@ -520,6 +609,7 @@ private SqlQueryStructure( if (orderByObject is not null) { OrderByColumns = ProcessGqlOrderByArg((List)orderByObject, queryArgumentSchemas[QueryBuilder.ORDER_BY_FIELD_NAME], isGroupByQuery); + HasUserSpecifiedOrdering = true; } } @@ -575,6 +665,7 @@ private SqlQueryStructure( PaginationMetadata = new(this); GroupByMetadata = new(); ColumnLabelToParam = new(); + SemanticDistanceByPrimaryKeySignature = new(StringComparer.Ordinal); FilterPredicates = string.Empty; OrderByColumns = new(); AddCacheControlOptions(httpRequestHeaders); @@ -733,6 +824,9 @@ void ProcessPaginationFields(IReadOnlyList paginationSelections) // TODO : This is inefficient and could lead to errors. we should rewrite this to use the ISelection API. private void AddGraphQLFields(IReadOnlyList selections, RuntimeConfigProvider runtimeConfigProvider) { + bool entitySemanticEnabled = runtimeConfigProvider.GetConfig().Entities.TryGetValue(EntityName, out Entity? entity) + && entity.SemanticSearch?.Enabled is true; + foreach (ISelectionNode node in selections) { if (node.Kind == SyntaxKind.FragmentSpread) @@ -784,6 +878,12 @@ private void AddGraphQLFields(IReadOnlyList selections, RuntimeC if (field.SelectionSet is null) { + if (entitySemanticEnabled + && string.Equals(fieldName, SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD, StringComparison.Ordinal)) + { + continue; + } + if (MetadataProvider.TryGetBackingColumn(EntityName, fieldName, out string? name) && !string.IsNullOrWhiteSpace(name)) { @@ -796,6 +896,14 @@ private void AddGraphQLFields(IReadOnlyList selections, RuntimeC } else { + if (SemanticSearchRequested) + { + throw new DataApiBuilderException( + message: "Semantic search is not supported with relationship expansion.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + ObjectField? subschemaField = _underlyingFieldType.Fields[fieldName]; if (_ctx == null) diff --git a/src/Core/Resolvers/SqlQueryEngine.cs b/src/Core/Resolvers/SqlQueryEngine.cs index f567251771..26a3a91985 100644 --- a/src/Core/Resolvers/SqlQueryEngine.cs +++ b/src/Core/Resolvers/SqlQueryEngine.cs @@ -5,6 +5,7 @@ using System.Text.Json; using System.Text.Json.Nodes; using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Configurations; using Azure.DataApiBuilder.Core.Models; @@ -12,6 +13,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Exceptions; using Azure.DataApiBuilder.Service.GraphQLBuilder; using Azure.DataApiBuilder.Service.GraphQLBuilder.Queries; @@ -37,6 +39,7 @@ public class SqlQueryEngine : IQueryEngine private readonly RuntimeConfigProvider _runtimeConfigProvider; private readonly GQLFilterParser _gQLFilterParser; private readonly DabCacheService _cache; + private readonly ISemanticSearchService _semanticSearchService; // // Constructor. @@ -49,7 +52,8 @@ public SqlQueryEngine( GQLFilterParser gQLFilterParser, ILogger logger, RuntimeConfigProvider runtimeConfigProvider, - DabCacheService cache) + DabCacheService cache, + ISemanticSearchService? semanticSearchService = null) { _queryFactory = queryFactory; _sqlMetadataProviderFactory = sqlMetadataProviderFactory; @@ -59,6 +63,7 @@ public SqlQueryEngine( _logger = logger; _runtimeConfigProvider = runtimeConfigProvider; _cache = cache; + _semanticSearchService = semanticSearchService ?? NoOpSemanticSearchService.Instance; } /// @@ -79,16 +84,32 @@ public SqlQueryEngine( _runtimeConfigProvider, _gQLFilterParser); + (bool shouldReturnEmpty, JsonDocument? emptySemanticResponse) = await TryApplySemanticNarrowingAsync( + structure, + dataSourceName, + parameters, + GraphQLUtils.GetEntityNameFromContext(context)); + if (shouldReturnEmpty) + { + return new Tuple(emptySemanticResponse, structure.PaginationMetadata); + } + if (structure.PaginationMetadata.IsPaginated) { + JsonDocument? queryResult = await ExecuteAsync(structure, dataSourceName); + JsonDocument? processedResult = ApplySemanticDistanceAndOrderingIfNeeded(queryResult, structure, dataSourceName, includeRestField: false, includeGraphQlField: true); + return new Tuple( - SqlPaginationUtil.CreatePaginationConnectionFromJsonDocument(await ExecuteAsync(structure, dataSourceName), structure.PaginationMetadata, structure.GroupByMetadata), + SqlPaginationUtil.CreatePaginationConnectionFromJsonDocument(processedResult, structure.PaginationMetadata, structure.GroupByMetadata), structure.PaginationMetadata); } else { + JsonDocument? queryResult = await ExecuteAsync(structure, dataSourceName); + JsonDocument? processedResult = ApplySemanticDistanceAndOrderingIfNeeded(queryResult, structure, dataSourceName, includeRestField: false, includeGraphQlField: true); + return new Tuple( - await ExecuteAsync(structure, dataSourceName), + processedResult, structure.PaginationMetadata); } } @@ -189,7 +210,230 @@ await ExecuteAsync(structure, dataSourceName, isMultipleCreateOperation: true), _runtimeConfigProvider, _gQLFilterParser, _httpContextAccessor.HttpContext!); - return await ExecuteAsync(structure, dataSourceName); + + (bool shouldReturnEmpty, JsonDocument? emptySemanticResponse) = await TryApplySemanticNarrowingAsync( + structure, + dataSourceName, + graphQLParameters: null, + context.EntityName, + context); + if (shouldReturnEmpty) + { + return emptySemanticResponse; + } + + JsonDocument? response = await ExecuteAsync(structure, dataSourceName); + return ApplySemanticDistanceAndOrderingIfNeeded(response, structure, dataSourceName, includeRestField: true, includeGraphQlField: false); + } + + private async Task<(bool shouldReturnEmpty, JsonDocument? emptyResponse)> TryApplySemanticNarrowingAsync( + SqlQueryStructure structure, + string dataSourceName, + IDictionary? graphQLParameters, + string entityName, + FindRequestContext? restContext = null) + { + string? semanticSearchText = restContext?.SemanticSearch; + if (semanticSearchText is null + && graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME, out object? graphQlSearchValue) + && graphQlSearchValue is string graphQlSearchText) + { + semanticSearchText = graphQlSearchText; + } + + bool hasSemanticThresholdOnly = restContext?.SemanticThreshold is not null + || (graphQLParameters is not null && graphQLParameters.ContainsKey(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME)); + + bool semanticRequested = !string.IsNullOrWhiteSpace(semanticSearchText) || hasSemanticThresholdOnly; + + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + if (!runtimeConfig.Entities.TryGetValue(entityName, out Entity? entity) + || entity.SemanticSearch is null + || !entity.SemanticSearch.Enabled) + { + if (semanticRequested) + { + throw new DataApiBuilderException( + message: $"Semantic search is not enabled for entity '{entityName}'.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + return (false, null); + } + + if (string.IsNullOrWhiteSpace(semanticSearchText)) + { + return (false, null); + } + + if (graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.PAGINATION_TOKEN_ARGUMENT_NAME, out object? afterArg) + && afterArg is string afterToken + && !string.IsNullOrWhiteSpace(afterToken)) + { + throw new DataApiBuilderException( + message: "Pagination continuation tokens are not supported when semantic search is used.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (graphQLParameters is not null + && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME, out object? graphQlThreshold) + && graphQlThreshold is double graphQlThresholdDouble + && (graphQlThresholdDouble < 0.0 || graphQlThresholdDouble > 1.0)) + { + throw new DataApiBuilderException( + message: "semantic_threshold must be a decimal value between 0.0 and 1.0.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + double effectiveThreshold = restContext?.SemanticThreshold + ?? (graphQLParameters is not null && graphQLParameters.TryGetValue(QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME, out object? gqlThresholdObj) && gqlThresholdObj is double gqlThresholdVal + ? gqlThresholdVal + : entity.SemanticSearch.SimilarityThreshold); + + int first = restContext?.First + ?? (graphQLParameters is not null && graphQLParameters.TryGetValue(QueryBuilder.PAGE_START_ARGUMENT_NAME, out object? firstObj) && firstObj is int gqlFirst ? gqlFirst : (int)runtimeConfig.DefaultPageSize()); + int redisTop = first * entity.SemanticSearch.RedisIndexMultiplier; + + SourceDefinition sourceDefinition = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(entityName); + IReadOnlyList primaryKeyColumns = sourceDefinition.PrimaryKey; + + IReadOnlyList candidates = await _semanticSearchService.GetCandidatesAsync( + entityName, + entity.SemanticSearch, + primaryKeyColumns, + semanticSearchText, + effectiveThreshold, + redisTop); + + if (candidates.Count == 0) + { + return (true, JsonDocument.Parse("[]")); + } + + // De-duplicate candidate keys while keeping the highest similarity. + Dictionary deduped = new(StringComparer.Ordinal); + foreach (SemanticSearchCandidate candidate in candidates) + { + string signature = BuildPrimaryKeySignatureFromValues(primaryKeyColumns, candidate.PrimaryKeyValues); + if (!deduped.TryGetValue(signature, out SemanticSearchCandidate? existing) || candidate.Similarity > existing.Similarity) + { + deduped[signature] = candidate; + } + } + + IReadOnlyList narrowedCandidates = deduped.Values.ToList(); + structure.ApplySemanticCandidates(narrowedCandidates); + foreach (KeyValuePair kvp in deduped) + { + structure.SemanticDistanceByPrimaryKeySignature[kvp.Key] = kvp.Value.Similarity; + } + + return (false, null); + } + + private JsonDocument? ApplySemanticDistanceAndOrderingIfNeeded( + JsonDocument? response, + SqlQueryStructure structure, + string dataSourceName, + bool includeRestField, + bool includeGraphQlField) + { + if (response is null || structure.SemanticDistanceByPrimaryKeySignature.Count == 0) + { + return response; + } + + SourceDefinition sourceDefinition = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(structure.EntityName); + if (response.RootElement.ValueKind is not JsonValueKind.Array) + { + return response; + } + + List<(double? distance, JsonObject row)> rows = new(); + foreach (JsonElement element in response.RootElement.EnumerateArray()) + { + JsonObject? row = JsonObject.Create(element); + if (row is null) + { + continue; + } + + string signature = BuildPrimaryKeySignatureFromRow(sourceDefinition.PrimaryKey, structure.EntityName, dataSourceName, element); + double? distance = null; + if (structure.SemanticDistanceByPrimaryKeySignature.TryGetValue(signature, out double mappedDistance)) + { + distance = mappedDistance; + } + + if (includeRestField) + { + row[SemanticSearchConstants.REST_DISTANCE_FIELD] = distance.HasValue ? JsonValue.Create(distance.Value) : null; + } + + if (includeGraphQlField) + { + row[SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD] = distance.HasValue ? JsonValue.Create(distance.Value) : null; + } + + rows.Add((distance, row)); + } + + if (!structure.HasUserSpecifiedOrdering) + { + rows = rows.OrderByDescending(r => r.distance ?? double.MinValue).ToList(); + } + + JsonElement updated = JsonSerializer.SerializeToElement(rows.Select(r => r.row).ToList()); + + response.Dispose(); + return JsonDocument.Parse(updated.GetRawText()); + } + + private string BuildPrimaryKeySignatureFromRow( + IReadOnlyList primaryKeyColumns, + string entityName, + string dataSourceName, + JsonElement row) + { + ISqlMetadataProvider metadataProvider = _sqlMetadataProviderFactory.GetMetadataProvider(dataSourceName); + Dictionary values = new(StringComparer.Ordinal); + foreach (string primaryKeyColumn in primaryKeyColumns) + { + if (!metadataProvider.TryGetExposedColumnName(entityName, primaryKeyColumn, out string? exposedName) + || string.IsNullOrWhiteSpace(exposedName) + || !row.TryGetProperty(exposedName, out JsonElement valueElement)) + { + values[primaryKeyColumn] = null; + continue; + } + + values[primaryKeyColumn] = valueElement.ValueKind switch + { + JsonValueKind.Null => null, + JsonValueKind.String => valueElement.GetString(), + _ => valueElement.ToString() + }; + } + + return BuildPrimaryKeySignatureFromValues(primaryKeyColumns, values); + } + + private static string BuildPrimaryKeySignatureFromValues( + IReadOnlyList primaryKeyColumns, + IReadOnlyDictionary values) + { + return string.Join("|", primaryKeyColumns.Select(pk => + { + string serializedValue = values.TryGetValue(pk, out object? value) + ? value?.ToString() ?? "null" + : "null"; + return $"{pk}={serializedValue}"; + })); } /// diff --git a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs index 979da52eb6..ae527195f4 100644 --- a/src/Core/Services/OpenAPI/OpenApiDocumentor.cs +++ b/src/Core/Services/OpenAPI/OpenApiDocumentor.cs @@ -12,6 +12,7 @@ using Azure.DataApiBuilder.Config.ObjectModel; using Azure.DataApiBuilder.Core.Authorization; using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Parsers; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; @@ -374,6 +375,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour // Operations including primary key Dictionary pkOperations = CreateOperations( entityName: entityName, + entity: entity, sourceDefinition: sourceDefinition, includePrimaryKeyPathComponent: true, configuredRestOperations: configuredRestOperations, @@ -398,6 +400,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour // Operations excluding primary key Dictionary operations = CreateOperations( entityName: entityName, + entity: entity, sourceDefinition: sourceDefinition, includePrimaryKeyPathComponent: false, configuredRestOperations: configuredRestOperations, @@ -434,6 +437,7 @@ private OpenApiPaths BuildPaths(RuntimeEntities entities, string defaultDataSour /// Collection of operation types and associated definitions. private Dictionary CreateOperations( string entityName, + Entity entity, SourceDefinition sourceDefinition, bool includePrimaryKeyPathComponent, Dictionary configuredRestOperations, @@ -447,7 +451,7 @@ private Dictionary CreateOperations( if (configuredRestOperations[OperationType.Get]) { OpenApiOperation getOperation = CreateBaseOperation(description: GETONE_DESCRIPTION, tags: tags); - AddQueryParameters(getOperation.Parameters); + AddQueryParameters(getOperation.Parameters, entity); getOperation.Responses.Add(HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName)); openApiPathItemOperations.Add(OperationType.Get, getOperation); } @@ -492,7 +496,7 @@ private Dictionary CreateOperations( if (configuredRestOperations[OperationType.Get]) { OpenApiOperation getAllOperation = CreateBaseOperation(description: GETALL_DESCRIPTION, tags: tags); - AddQueryParameters(getAllOperation.Parameters); + AddQueryParameters(getAllOperation.Parameters, entity); getAllOperation.Responses.Add( HttpStatusCode.OK.ToString("D"), CreateOpenApiResponse(description: nameof(HttpStatusCode.OK), responseObjectSchemaName: entityName, includeNextLink: true)); @@ -542,12 +546,27 @@ private Dictionary CreateOperations( /// Helper method to add query parameters like $select, $first, $orderby etc. to get and getAll operations for tables/views. /// /// List of parameters for the operation. - private static void AddQueryParameters(IList parameters) + private static void AddQueryParameters(IList parameters, Entity entity) { foreach (OpenApiParameter openApiParameter in _tableAndViewQueryParameters) { parameters.Add(openApiParameter); } + + if (entity.SemanticSearch?.Enabled is true) + { + parameters.Add(GetOpenApiQueryParameter( + name: RequestParser.SEMANTIC_SEARCH_URL, + description: entity.SemanticSearch.InputDescription, + required: false, + type: "string")); + + parameters.Add(GetOpenApiQueryParameter( + name: RequestParser.SEMANTIC_THRESHOLD_URL, + description: "Minimum semantic similarity threshold between 0.0 and 1.0.", + required: false, + type: "number")); + } } /// @@ -1507,6 +1526,19 @@ private static OpenApiSchema CreateComponentSchema( } } + if (!isRequestBodySchema + && entityConfig?.SemanticSearch?.Enabled is true + && !properties.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD)) + { + properties.Add(SemanticSearchConstants.REST_DISTANCE_FIELD, new OpenApiSchema() + { + Type = "number", + Description = entityConfig.SemanticSearch.OutputDescription, + ReadOnly = true, + Nullable = true + }); + } + OpenApiSchema schema = new() { Type = SCHEMA_OBJECT_TYPE, diff --git a/src/Core/Services/RequestValidator.cs b/src/Core/Services/RequestValidator.cs index 0b6b0734ad..2a05bbd0c5 100644 --- a/src/Core/Services/RequestValidator.cs +++ b/src/Core/Services/RequestValidator.cs @@ -272,6 +272,15 @@ public void ValidateInsertRequestContext(InsertRequestContext insertRequestCtx) { ISqlMetadataProvider sqlMetadataProvider = GetSqlMetadataProvider(insertRequestCtx.EntityName); + if (insertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD) + || insertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD)) + { + throw new DataApiBuilderException( + message: "semantic_distance is read-only.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + IEnumerable fieldsInRequestBody = insertRequestCtx.FieldValuePairsInBody.Keys; SourceDefinition sourceDefinition = TryGetSourceDefinition(insertRequestCtx.EntityName, sqlMetadataProvider); @@ -356,6 +365,16 @@ public void ValidateInsertRequestContext(InsertRequestContext insertRequestCtx) public void ValidateUpsertRequestContext(UpsertRequestContext upsertRequestCtx, bool isPrimaryKeyInUrl = true) { ISqlMetadataProvider sqlMetadataProvider = GetSqlMetadataProvider(upsertRequestCtx.EntityName); + + if (upsertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.REST_DISTANCE_FIELD) + || upsertRequestCtx.FieldValuePairsInBody.ContainsKey(SemanticSearchConstants.GRAPHQL_DISTANCE_FIELD)) + { + throw new DataApiBuilderException( + message: "semantic_distance is read-only.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + IEnumerable fieldsInRequestBody = upsertRequestCtx.FieldValuePairsInBody.Keys; bool isRequestBodyStrict = IsRequestBodyStrict(); SourceDefinition sourceDefinition = TryGetSourceDefinition(upsertRequestCtx.EntityName, sqlMetadataProvider); diff --git a/src/Core/Services/RestService.cs b/src/Core/Services/RestService.cs index 5014d942bb..c70d9b619a 100644 --- a/src/Core/Services/RestService.cs +++ b/src/Core/Services/RestService.cs @@ -194,6 +194,7 @@ RequestValidator requestValidator context.RawQueryString = queryString; context.ParsedQueryString = HttpUtility.ParseQueryString(queryString); RequestParser.ParseQueryString(context, sqlMetadataProvider); + ValidateSemanticRestRequest(context, operationType, _runtimeConfigProvider.GetConfig()); } } @@ -356,6 +357,71 @@ private bool IsHttpMethodAllowedForStoredProcedure(string entityName) return false; } + private static void ValidateSemanticRestRequest(RestRequestContext context, EntityActionOperation operationType, RuntimeConfig runtimeConfig) + { + bool hasSemanticSearch = !string.IsNullOrWhiteSpace(context.SemanticSearch); + bool hasSemanticThreshold = context.SemanticThreshold.HasValue; + bool isSemanticRequest = hasSemanticSearch || hasSemanticThreshold; + + if (context.FieldsToBeReturned.Contains(SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparer.OrdinalIgnoreCase)) + { + context.FieldsToBeReturned = context.FieldsToBeReturned + .Where(field => !string.Equals(field, SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase)) + .ToList(); + context.IsSemanticDistanceExplicitlySelected = true; + } + + if (!isSemanticRequest) + { + if (context.IsSemanticDistanceExplicitlySelected) + { + throw new DataApiBuilderException( + message: "semantic_distance can only be selected when semantic search is used.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + return; + } + + if (operationType is not EntityActionOperation.Read) + { + throw new DataApiBuilderException( + message: "Semantic search is only supported for read operations.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (!runtimeConfig.Entities.TryGetValue(context.EntityName, out Entity? entity) + || entity.SemanticSearch is null + || !entity.SemanticSearch.Enabled) + { + throw new DataApiBuilderException( + message: $"Semantic search is not enabled for entity '{context.EntityName}'.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (!string.IsNullOrWhiteSpace(context.After)) + { + throw new DataApiBuilderException( + message: "Pagination continuation tokens are not supported when semantic search is used.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + if (context.OrderByClauseInUrl is not null + && context.OrderByClauseInUrl.Any(col => string.Equals(col.ColumnName, SemanticSearchConstants.REST_DISTANCE_FIELD, StringComparison.OrdinalIgnoreCase))) + { + throw new DataApiBuilderException( + message: "semantic_distance cannot be used in orderBy.", + statusCode: HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + + context.IncludeSemanticDistanceInResponse = context.FieldsToBeReturned.Count == 0 || context.IsSemanticDistanceExplicitlySelected; + } + /// /// Gets the list of HTTP methods defined for entities representing stored procedures. /// When no explicit REST method configuration is present for a stored procedure entity, diff --git a/src/Core/Services/SemanticSearch/ISemanticSearchService.cs b/src/Core/Services/SemanticSearch/ISemanticSearchService.cs new file mode 100644 index 0000000000..f3c124edfd --- /dev/null +++ b/src/Core/Services/SemanticSearch/ISemanticSearchService.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Models; + +namespace Azure.DataApiBuilder.Core.Services.SemanticSearch; + +public interface ISemanticSearchService +{ + Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top); +} diff --git a/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs b/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs new file mode 100644 index 0000000000..dfabecb133 --- /dev/null +++ b/src/Core/Services/SemanticSearch/NoOpSemanticSearchService.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Models; + +namespace Azure.DataApiBuilder.Core.Services.SemanticSearch; + +/// +/// Placeholder semantic search service used when no embedding/vector integration is configured. +/// +public sealed class NoOpSemanticSearchService : ISemanticSearchService +{ + public static readonly NoOpSemanticSearchService Instance = new(); + + private NoOpSemanticSearchService() + { + } + + public Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top) + { + return Task.FromResult>([]); + } +} diff --git a/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs b/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs index 5441dcc766..13b79d994a 100644 --- a/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/InputTypeBuilder.cs @@ -53,7 +53,8 @@ private static List GenerateOrderByInputFieldsForBuilt // Skip scalar array fields (e.g., PostgreSQL int[], text[]) - they cannot be ordered. // Non-scalar list types (e.g., Cosmos nested object arrays) are not skipped // because they are handled as relationship navigations. - if (field.Type.IsListType() && IsBuiltInType(field.Type)) + if (string.Equals(field.Name.Value, QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, StringComparison.Ordinal) + || (field.Type.IsListType() && IsBuiltInType(field.Type))) { continue; } @@ -122,9 +123,10 @@ private static List GenerateFilterInputFieldsForBuiltI // which are read-only and cannot be filtered. Cosmos scalar arrays like // tags: [String] do NOT have @autoGenerated and remain filterable // (using ARRAY_CONTAINS). - if (field.Type.IsListType() - && IsBuiltInType(field.Type) - && field.Directives.Any(d => d.Name.Value == AutoGeneratedDirectiveType.DirectiveName)) + if (string.Equals(field.Name.Value, QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, StringComparison.Ordinal) + || (field.Type.IsListType() + && IsBuiltInType(field.Type) + && field.Directives.Any(d => d.Name.Value == AutoGeneratedDirectiveType.DirectiveName))) { continue; } diff --git a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs index a2cc63b2c2..69d57c86cd 100644 --- a/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs +++ b/src/Service.GraphQLBuilder/Queries/QueryBuilder.cs @@ -30,6 +30,9 @@ public static class QueryBuilder public const string GROUP_BY_AGGREGATE_FIELD_ARG_NAME = "field"; public const string GROUP_BY_AGGREGATE_FIELD_DISTINCT_NAME = "distinct"; public const string GROUP_BY_AGGREGATE_FIELD_HAVING_NAME = "having"; + public const string SEMANTIC_SEARCH_ARGUMENT_NAME = "semanticSearch"; + public const string SEMANTIC_THRESHOLD_ARGUMENT_NAME = "semanticThreshold"; + public const string SEMANTIC_DISTANCE_FIELD_NAME = "semanticDistance"; // Define the enabled database types for aggregation public static readonly HashSet AggregationEnabledDatabaseTypes = new() @@ -191,21 +194,29 @@ public static FieldDefinitionNode GenerateGetAllQuery( location: null, new NameNode(GenerateListQueryName(name.Value, entity)), new StringValueNode($"Get a list of all the {GetDefinedSingularName(name.Value, entity)} items from the database"), - QueryArgumentsForField(filterInputName, orderByInputName), + QueryArgumentsForField(filterInputName, orderByInputName, includeSemanticArguments: entity.SemanticSearch?.Enabled ?? false), new NonNullTypeNode(new NamedTypeNode(returnType.Name)), fieldDefinitionNodeDirectives ); } - public static List QueryArgumentsForField(string filterInputName, string orderByInputName) + public static List QueryArgumentsForField(string filterInputName, string orderByInputName, bool includeSemanticArguments = false) { - return new() + List args = new() { new(location: null, new NameNode(PAGE_START_ARGUMENT_NAME), description: new StringValueNode("The number of items to return from the page start point"), new IntType().ToTypeNode(), defaultValue: null, new List()), new(location: null, new NameNode(PAGINATION_TOKEN_ARGUMENT_NAME), new StringValueNode("A pagination token from a previous query to continue through a paginated list"), new StringType().ToTypeNode(), defaultValue: null, new List()), new(location: null, new NameNode(FILTER_FIELD_NAME), new StringValueNode("Filter options for query"), new NamedTypeNode(filterInputName), defaultValue: null, new List()), new(location: null, new NameNode(ORDER_BY_FIELD_NAME), new StringValueNode("Ordering options for query"), new NamedTypeNode(orderByInputName), defaultValue: null, new List()), }; + + if (includeSemanticArguments) + { + args.Add(new(location: null, new NameNode(SEMANTIC_SEARCH_ARGUMENT_NAME), new StringValueNode("Natural language value used for semantic search."), new StringType().ToTypeNode(), defaultValue: null, new List())); + args.Add(new(location: null, new NameNode(SEMANTIC_THRESHOLD_ARGUMENT_NAME), new StringValueNode("Minimum semantic similarity threshold between 0.0 and 1.0."), new FloatType().ToTypeNode(), defaultValue: null, new List())); + } + + return args; } public static ObjectTypeDefinitionNode AddQueryArgumentsForRelationships(ObjectTypeDefinitionNode node, Dictionary inputObjects) diff --git a/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs b/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs index 622376dc13..1da4e56197 100644 --- a/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs +++ b/src/Service.GraphQLBuilder/Sql/SchemaConverter.cs @@ -210,6 +210,20 @@ private static ObjectTypeDefinitionNode CreateObjectTypeDefinitionForTableOrView } } + if (configEntity.SemanticSearch?.Enabled is true) + { + List semanticDistanceDirectives = [new DirectiveNode(AutoGeneratedDirectiveType.DirectiveName)]; + FieldDefinitionNode semanticDistanceField = new( + location: null, + name: new NameNode(QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME), + description: new StringValueNode(configEntity.SemanticSearch.OutputDescription), + arguments: [], + type: new NamedTypeNode(new NameNode(FLOAT_TYPE)), + directives: semanticDistanceDirectives); + + fieldDefinitionNodes.TryAdd(QueryBuilder.SEMANTIC_DISTANCE_FIELD_NAME, semanticDistanceField); + } + // A linking entity is not exposed in the runtime config file but is used by DAB to support multiple mutations on entities with M:N relationship. // Hence we don't need to process relationships for the linking entity itself. if (!configEntity.IsLinkingEntity) diff --git a/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs b/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs index ffe636f6db..89a747df15 100644 --- a/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs +++ b/src/Service.Tests/GraphQLBuilder/MultipleMutationBuilderTests.cs @@ -16,6 +16,7 @@ using Azure.DataApiBuilder.Core.Services; using Azure.DataApiBuilder.Core.Services.Cache; using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Service.GraphQLBuilder; using Azure.DataApiBuilder.Service.GraphQLBuilder.Directives; using Azure.DataApiBuilder.Service.GraphQLBuilder.Mutations; @@ -439,6 +440,7 @@ private static async Task GetGQLSchemaCreator(RuntimeConfi contextAccessor: httpContextAccessor.Object, authorizationResolver: authorizationResolver, gQLFilterParser: graphQLFilterParser, + semanticSearchService: NoOpSemanticSearchService.Instance, logger: queryEngineLogger.Object, cache: cacheService, handler: null); diff --git a/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs b/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs index 93a3df9778..4eb727a59f 100644 --- a/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs +++ b/src/Service.Tests/GraphQLBuilder/QueryBuilderTests.cs @@ -229,6 +229,54 @@ type Foo @model(name:""Foo"") { Assert.AreEqual("Boolean", hasNextPageField.Type.NamedType().Name.Value, "hasNextPage should be Boolean type"); } + [TestMethod] + [TestCategory("Query Generation")] + [TestCategory("Collection access")] + public void CollectionQueryAddsSemanticArgumentsWhenEnabled() + { + string gql = + @" +type Foo @model(name:""Foo"") { + id: ID! +} + "; + + DocumentNode root = Utf8GraphQLParser.Parse(gql); + Dictionary entityNameToDatabaseType = new() + { + { "Foo", DatabaseType.CosmosDB_NoSQL } + }; + + RuntimeEntities entities = new(new Dictionary + { + { + "Foo", + new Entity( + Source: new("Foo", EntitySourceType.Table, null, null), + GraphQL: new("Foo", "Foos"), + Fields: null, + Rest: new(), + Permissions: [], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions { Enabled = true, RedisIndexName = "idx:foo" }) + } + }); + + DocumentNode queryRoot = QueryBuilder.Build( + root, + entityNameToDatabaseType, + entities, + inputTypes: new(), + entityPermissionsMap: _entityPermissions + ); + + ObjectTypeDefinitionNode query = GetQueryNode(queryRoot); + FieldDefinitionNode collectionField = query.Fields.First(f => f.Name.Value == "foos"); + Assert.IsTrue(collectionField.Arguments.Any(a => a.Name.Value == QueryBuilder.SEMANTIC_SEARCH_ARGUMENT_NAME)); + Assert.IsTrue(collectionField.Arguments.Any(a => a.Name.Value == QueryBuilder.SEMANTIC_THRESHOLD_ARGUMENT_NAME)); + } + [TestMethod] public void PrimaryKeyFieldAsQueryInput() { diff --git a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs index bbb4874d1a..6546f41821 100644 --- a/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs +++ b/src/Service.Tests/UnitTests/ConfigValidationUnitTests.cs @@ -3579,6 +3579,162 @@ public void ValidateEmbeddingsOptions_EndpointDisabled_SkipsValidation() configValidator.ValidateEmbeddingsOptions(runtimeConfig); } + [TestMethod] + public void ValidateSemanticSearchRequiresRedisLevel2Configuration() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null)), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Semantic search requires runtime.cache.level-2.provider to be 'redis' and runtime.cache.level-2.connection-string to be configured.", + ex.Message); + } + + [TestMethod] + public void ValidateSemanticSearchRequiresRedisIndexNameWhenEnabled() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "semantic-search.redis-index-name is required when semantic-search.enabled is true for entity 'Product'.", + ex.Message); + } + + [TestMethod] + public void ValidateSemanticSearchRequiresEmbeddingsConfiguration() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: null, + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Semantic search requires runtime.embeddings to be configured and enabled.", + ex.Message); + } + + [TestMethod] + public void ValidateSemanticSearchRejectsReservedFieldName() + { + RuntimeConfigValidator configValidator = InitializeRuntimeConfigValidator(); + + Entity entity = new( + Source: new EntitySource(Object: "dbo.Product", Type: EntitySourceType.Table, Parameters: null, KeyFields: null), + Fields: [new FieldMetadata { Name = "id" }, new FieldMetadata { Name = "semantic_distance" }], + GraphQL: new EntityGraphQLOptions("Product", "Products"), + Rest: new EntityRestOptions(), + Permissions: [new EntityPermission("anonymous", [new EntityAction(EntityActionOperation.Read, null, null)])], + Mappings: null, + Relationships: null, + SemanticSearch: new EntitySemanticSearchOptions + { + Enabled = true, + RedisIndexName = "idx:product-semantic" + }); + + RuntimeConfig runtimeConfig = new( + Schema: "UnitTestSchema", + DataSource: new DataSource(DatabaseType: DatabaseType.MSSQL, "Server=.", Options: null), + Runtime: new( + Rest: new(), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null), + Cache: new RuntimeCacheOptions + { + Level2 = new RuntimeCacheLevel2Options(Provider: "redis", ConnectionString: "localhost:6379") + }, + Embeddings: new EmbeddingsOptions( + Provider: EmbeddingProviderType.OpenAI, + BaseUrl: "https://api.openai.com", + ApiKey: "test-api-key", + Enabled: true)), + Entities: new(new Dictionary { ["Product"] = entity })); + + DataApiBuilderException ex = Assert.ThrowsException(() => configValidator.ValidateEntityConfiguration(runtimeConfig)); + Assert.AreEqual( + "Entity 'Product' cannot enable semantic search because field name 'semantic_distance' is reserved.", + ex.Message); + } + private static RuntimeConfigValidator InitializeRuntimeConfigValidator() { MockFileSystem fileSystem = new(); diff --git a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs index 0aeb309c61..609c633166 100644 --- a/src/Service.Tests/UnitTests/RequestParserUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestParserUnitTests.cs @@ -1,8 +1,15 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System; +using System.Reflection; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Core.Models; using Azure.DataApiBuilder.Core.Parsers; +using Azure.DataApiBuilder.Core.Services; +using Azure.DataApiBuilder.Service.Exceptions; using Microsoft.VisualStudio.TestTools.UnitTesting; +using Moq; namespace Azure.DataApiBuilder.Service.Tests.UnitTests { @@ -14,6 +21,9 @@ namespace Azure.DataApiBuilder.Service.Tests.UnitTests [TestClass] public class RequestParserUnitTests { + private const string DEFAULT_ENTITY = "Book"; + private const string DEFAULT_SCHEMA = "dbo"; + /// /// Tests that ExtractRawQueryParameter correctly extracts URL-encoded /// parameter values, preserving special characters like ampersand (&). @@ -76,5 +86,84 @@ public void ExtractRawQueryParameter_HandlesEdgeCases(string queryString, string Assert.AreEqual(expectedValue, result, $"Expected '{expectedValue}' but got '{result}' for parameter '{parameterName}' in query '{queryString}'"); } + + [TestMethod] + public void ParseQueryString_SetsSemanticInputs_ForUserRequest() + { + FindRequestContext context = new( + entityName: DEFAULT_ENTITY, + dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), + isList: true) + { + RawQueryString = "?$semantic_search=wireless%20headphones&$semantic_threshold=0.83" + }; + + context.ParsedQueryString.Add(RequestParser.SEMANTIC_SEARCH_URL, "wireless headphones"); + context.ParsedQueryString.Add(RequestParser.SEMANTIC_THRESHOLD_URL, "0.83"); + + RequestParser.ParseQueryString(context, new Mock().Object); + + Assert.AreEqual("wireless headphones", context.SemanticSearch); + Assert.AreEqual(0.83, context.SemanticThreshold); + } + + [DataTestMethod] + [DataRow("-0.01")] + [DataRow("1.01")] + [DataRow("not-a-number")] + public void ParseQueryString_RejectsInvalidSemanticThreshold_ForUserRequest(string threshold) + { + FindRequestContext context = new( + entityName: DEFAULT_ENTITY, + dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), + isList: true) + { + RawQueryString = $"?$semantic_threshold={threshold}" + }; + + context.ParsedQueryString.Add(RequestParser.SEMANTIC_THRESHOLD_URL, threshold); + + DataApiBuilderException ex = Assert.ThrowsException( + () => RequestParser.ParseQueryString(context, new Mock().Object)); + + StringAssert.Contains(ex.Message, "semantic_threshold must be a decimal value between 0.0 and 1.0."); + } + + [DataTestMethod] + [DataRow("?$orderby=semantic_distance%20asc")] + [DataRow("?$orderby=Semantic_Distance%20desc")] + public void ParseQueryString_RejectsOrderBySemanticDistance_ForUserRequest(string rawQuery) + { + FindRequestContext context = new( + entityName: DEFAULT_ENTITY, + dbo: new DatabaseTable(DEFAULT_SCHEMA, DEFAULT_ENTITY), + isList: true) + { + RawQueryString = rawQuery + }; + + context.ParsedQueryString.Add(RequestParser.SORT_URL, rawQuery.Contains("desc", StringComparison.OrdinalIgnoreCase) ? "Semantic_Distance desc" : "semantic_distance asc"); + + DataApiBuilderException ex = Assert.ThrowsException( + () => RequestParser.ParseQueryString(context, new Mock().Object)); + + StringAssert.Contains(ex.Message, "semantic_distance cannot be used in orderBy."); + } + + [DataTestMethod] + [DataRow("semantic_distance asc", true)] + [DataRow("Semantic_Distance desc", true)] + [DataRow("semantic_distance_score asc", false)] + [DataRow("title asc,semantic_distance_score desc", false)] + public void ContainsSemanticDistanceOrderByToken_UsesExactColumnMatch(string rawSortValue, bool expected) + { + MethodInfo method = typeof(RequestParser).GetMethod( + "ContainsSemanticDistanceOrderByToken", + BindingFlags.NonPublic | BindingFlags.Static); + + Assert.IsNotNull(method, "Expected private helper method to exist."); + bool actual = (bool)method!.Invoke(null, new object[] { rawSortValue })!; + Assert.AreEqual(expected, actual); + } } } diff --git a/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs b/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs index 35c86abc82..77f057cd0d 100644 --- a/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs +++ b/src/Service.Tests/UnitTests/RequestValidatorUnitTests.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Net; using System.Text; +using System.Text.Json; using Azure.DataApiBuilder.Config; using Azure.DataApiBuilder.Config.DatabasePrimitives; using Azure.DataApiBuilder.Config.ObjectModel; @@ -327,6 +328,58 @@ public void PrimaryKeyWithNoValueTest() primaryKeyRoute = "id//title/"; PerformRequestParserPrimaryKeyTest(findRequestContext, primaryKeyRoute, expectsException: true); } + + [TestMethod] + public void InsertRequestBodyCannotSetSemanticDistanceField() + { + RuntimeConfig mockConfig = new( + Schema: "", + DataSource: new(DatabaseType.PostgreSQL, "", new()), + Runtime: new( + Rest: new(Path: "/api"), + GraphQL: new(), + Mcp: new(), + Host: new(Cors: null, Authentication: null) + ), + Entities: new(new Dictionary + { + { + DEFAULT_NAME, + new Entity( + Source: new(DEFAULT_NAME, EntitySourceType.Table, null, null), + GraphQL: new(DEFAULT_NAME, DEFAULT_NAME), + Fields: null, + Rest: new(), + Permissions: [], + Mappings: null, + Relationships: null) + } + }) + ); + + MockFileSystem fileSystem = new(); + fileSystem.AddFile(FileSystemRuntimeConfigLoader.DEFAULT_CONFIG_FILE_NAME, new MockFileData(mockConfig.ToJson())); + FileSystemRuntimeConfigLoader loader = new(fileSystem); + RuntimeConfigProvider provider = new(loader); + + Mock metadataProviderFactory = new(); + metadataProviderFactory.Setup(x => x.GetMetadataProvider(It.IsAny())).Returns(_mockMetadataStore.Object); + RequestValidator requestValidator = new(metadataProviderFactory.Object, provider); + + using JsonDocument body = JsonDocument.Parse("{\"semantic_distance\":0.81}"); + InsertRequestContext context = new( + entityName: DEFAULT_NAME, + dbo: GetDbo(DEFAULT_SCHEMA, DEFAULT_NAME), + insertPayloadRoot: body.RootElement, + operationType: EntityActionOperation.Insert); + + DataApiBuilderException ex = Assert.ThrowsException( + () => requestValidator.ValidateInsertRequestContext(context)); + + Assert.AreEqual(HttpStatusCode.BadRequest, ex.StatusCode); + Assert.AreEqual(DataApiBuilderException.SubStatusCodes.BadRequest, ex.SubStatusCode); + StringAssert.Contains(ex.Message, "semantic_distance is read-only."); + } #endregion #region Helper Methods /// diff --git a/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs new file mode 100644 index 0000000000..019bf31e7f --- /dev/null +++ b/src/Service/Services/SemanticSearch/RedisSemanticSearchService.cs @@ -0,0 +1,495 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Services.Embeddings; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; +using Azure.DataApiBuilder.Service.Exceptions; +using Azure.Identity; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; + +namespace Azure.DataApiBuilder.Service.Services.SemanticSearch; + +/// +/// Resolves semantic candidates by: +/// 1) generating/retrieving an embedding vector for semantic_search input, +/// 2) performing a Redis FT.SEARCH KNN query, +/// 3) extracting SQL column/value pairs from Redis hash/json documents. +/// +public sealed class RedisSemanticSearchService : ISemanticSearchService, IDisposable +{ + private const string VECTOR_SCORE_FIELD = "__vector_score"; + private const string DEFAULT_VECTOR_FIELD = "embedding"; + + private readonly RuntimeConfigProvider _runtimeConfigProvider; + private readonly IMetadataProviderFactory _metadataProviderFactory; + private readonly IEmbeddingService? _embeddingService; + private readonly ILogger _logger; + private readonly SemaphoreSlim _redisConnectionGate = new(1, 1); + + private string? _cachedRedisConnectionString; + private IConnectionMultiplexer? _cachedRedisMultiplexer; + + public RedisSemanticSearchService( + RuntimeConfigProvider runtimeConfigProvider, + IMetadataProviderFactory metadataProviderFactory, + IEnumerable embeddingServices, + ILogger logger) + { + _runtimeConfigProvider = runtimeConfigProvider; + _metadataProviderFactory = metadataProviderFactory; + _embeddingService = embeddingServices.FirstOrDefault(); + _logger = logger; + } + + public async Task> GetCandidatesAsync( + string entityName, + EntitySemanticSearchOptions options, + IReadOnlyList primaryKeyColumns, + string semanticSearchValue, + double similarityThreshold, + int top) + { + RuntimeConfig runtimeConfig = _runtimeConfigProvider.GetConfig(); + string? connectionString = runtimeConfig.Runtime?.Cache?.Level2?.ConnectionString; + + if (string.IsNullOrWhiteSpace(options.RedisIndexName) + || string.IsNullOrWhiteSpace(connectionString) + || top <= 0) + { + return []; + } + + float[] embedding = await GetEmbeddingAsync(semanticSearchValue); + if (embedding.Length == 0) + { + return []; + } + + string dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName); + SourceDefinition sourceDefinition = _metadataProviderFactory.GetMetadataProvider(dataSourceName).GetSourceDefinition(entityName); + HashSet sourceColumns = new(sourceDefinition.Columns.Keys, StringComparer.OrdinalIgnoreCase); + + try + { + IConnectionMultiplexer multiplexer = await GetConnectionMultiplexerAsync(connectionString); + IDatabase db = multiplexer.GetDatabase(); + + string vectorFieldName = await ResolveVectorFieldNameAsync(db, options.RedisIndexName!, options.RedisIndexType); + byte[] vectorBytes = ToRedisVectorBytes(embedding); + + RedisResult rawResult = await db.ExecuteAsync( + "FT.SEARCH", + options.RedisIndexName!, + $"*=>[KNN {top} @{vectorFieldName} $vec AS {VECTOR_SCORE_FIELD}]", + "PARAMS", + "2", + "vec", + vectorBytes, + "SORTBY", + VECTOR_SCORE_FIELD, + "DIALECT", + "2", + "LIMIT", + "0", + top.ToString(CultureInfo.InvariantCulture)); + + return ParseCandidates(rawResult, options.RedisIndexType, sourceColumns, primaryKeyColumns, similarityThreshold); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Semantic search query failed for entity {entityName} and index {indexName}.", entityName, options.RedisIndexName); + throw new DataApiBuilderException( + message: $"Semantic search index '{options.RedisIndexName}' for entity '{entityName}' was not found or could not be queried.", + statusCode: System.Net.HttpStatusCode.BadRequest, + subStatusCode: DataApiBuilderException.SubStatusCodes.BadRequest); + } + } + + private static byte[] ToRedisVectorBytes(float[] vector) + { + byte[] bytes = new byte[vector.Length * sizeof(float)]; + Buffer.BlockCopy(vector, 0, bytes, 0, bytes.Length); + return bytes; + } + + private async Task GetEmbeddingAsync(string semanticSearchValue) + { + if (TryParseVectorText(semanticSearchValue, out float[]? parsedVector) && parsedVector is not null) + { + return parsedVector; + } + + if (_embeddingService is null || !_embeddingService.IsEnabled) + { + return []; + } + + EmbeddingResult result = await _embeddingService.TryEmbedAsync(semanticSearchValue); + if (!result.Success || result.Embedding is null) + { + return []; + } + + return result.Embedding; + } + + private static bool TryParseVectorText(string text, out float[]? vector) + { + vector = null; + + try + { + using JsonDocument json = JsonDocument.Parse(text); + if (json.RootElement.ValueKind is not JsonValueKind.Array) + { + return false; + } + + List values = []; + foreach (JsonElement element in json.RootElement.EnumerateArray()) + { + if (!element.TryGetSingle(out float current)) + { + return false; + } + + values.Add(current); + } + + vector = values.ToArray(); + return vector.Length > 0; + } + catch + { + return false; + } + } + + private static async Task CreateConnectionMultiplexerAsync(string connectionString) + { + ConfigurationOptions options = ConfigurationOptions.Parse(connectionString); + + if (Startup.ShouldUseEntraAuthForRedis(options)) + { + options = await options.ConfigureForAzureWithTokenCredentialAsync(new DefaultAzureCredential()); + } + + return await ConnectionMultiplexer.ConnectAsync(options); + } + + private async Task GetConnectionMultiplexerAsync(string connectionString) + { + if (_cachedRedisMultiplexer is not null + && string.Equals(_cachedRedisConnectionString, connectionString, StringComparison.Ordinal)) + { + return _cachedRedisMultiplexer; + } + + await _redisConnectionGate.WaitAsync(); + try + { + if (_cachedRedisMultiplexer is not null + && string.Equals(_cachedRedisConnectionString, connectionString, StringComparison.Ordinal)) + { + return _cachedRedisMultiplexer; + } + + IConnectionMultiplexer multiplexer = await CreateConnectionMultiplexerAsync(connectionString); + IConnectionMultiplexer? previous = _cachedRedisMultiplexer; + + _cachedRedisMultiplexer = multiplexer; + _cachedRedisConnectionString = connectionString; + + previous?.Dispose(); + return _cachedRedisMultiplexer; + } + finally + { + _redisConnectionGate.Release(); + } + } + + public void Dispose() + { + _cachedRedisMultiplexer?.Dispose(); + _redisConnectionGate.Dispose(); + } + + private static async Task ResolveVectorFieldNameAsync(IDatabase db, string indexName, string redisIndexType) + { + try + { + RedisResult infoResult = await db.ExecuteAsync("FT.INFO", indexName); + if (infoResult.IsNull) + { + return DEFAULT_VECTOR_FIELD; + } + + RedisResult[]? infoItems = AsRedisArray(infoResult); + if (infoItems is not null) + { + for (int i = 0; i + 1 < infoItems.Length; i += 2) + { + string? key = infoItems[i].ToString(); + if (!string.Equals(key, "attributes", StringComparison.OrdinalIgnoreCase)) + { + continue; + } + + RedisResult[]? attributes = AsRedisArray(infoItems[i + 1]); + if (attributes is null) + { + continue; + } + + foreach (RedisResult attribute in attributes) + { + RedisResult[]? tokens = AsRedisArray(attribute); + if (tokens is null) + { + continue; + } + + bool isVector = false; + string? identifier = null; + string? alias = null; + + for (int t = 0; t + 1 < tokens.Length; t += 2) + { + string? tokenKey = tokens[t].ToString(); + string? tokenValue = tokens[t + 1].ToString(); + + if (string.Equals(tokenKey, "type", StringComparison.OrdinalIgnoreCase) + && string.Equals(tokenValue, "VECTOR", StringComparison.OrdinalIgnoreCase)) + { + isVector = true; + } + else if (string.Equals(tokenKey, "identifier", StringComparison.OrdinalIgnoreCase)) + { + identifier = tokenValue; + } + else if (string.Equals(tokenKey, "attribute", StringComparison.OrdinalIgnoreCase)) + { + alias = tokenValue; + } + } + + if (isVector) + { + string? selected = string.Equals(redisIndexType, "json", StringComparison.OrdinalIgnoreCase) + ? alias ?? identifier + : identifier ?? alias; + + if (!string.IsNullOrWhiteSpace(selected)) + { + return selected.TrimStart('$', '.'); + } + } + } + } + } + } + catch + { + // Fall back to default vector field name; caller will throw a semantic index error if query fails. + } + + return DEFAULT_VECTOR_FIELD; + } + + private static IReadOnlyList ParseCandidates( + RedisResult rawResult, + string redisIndexType, + HashSet sourceColumns, + IReadOnlyList primaryKeyColumns, + double similarityThreshold) + { + RedisResult[]? rows = AsRedisArray(rawResult); + if (rawResult.IsNull || rows is null || rows.Length < 3) + { + return []; + } + + List results = []; + for (int i = 1; i + 1 < rows.Length; i += 2) + { + RedisResult[]? payload = AsRedisArray(rows[i + 1]); + if (payload is null) + { + continue; + } + + Dictionary extracted = ExtractDocumentFields(payload, redisIndexType); + double similarity = TryReadSimilarity(extracted); + + if (similarity < similarityThreshold) + { + continue; + } + + Dictionary sqlColumns = new(StringComparer.OrdinalIgnoreCase); + foreach ((string key, object? value) in extracted) + { + string normalized = NormalizeFieldName(key); + if (sourceColumns.Contains(normalized)) + { + sqlColumns[normalized] = value; + } + } + + if (sqlColumns.Count == 0) + { + continue; + } + + Dictionary primaryKeys = new(StringComparer.OrdinalIgnoreCase); + bool hasAllPrimaryKeys = true; + foreach (string pk in primaryKeyColumns) + { + if (!sqlColumns.TryGetValue(pk, out object? pkValue) || pkValue is null) + { + hasAllPrimaryKeys = false; + break; + } + + primaryKeys[pk] = pkValue; + } + + if (!hasAllPrimaryKeys) + { + continue; + } + + results.Add(new SemanticSearchCandidate(primaryKeys, sqlColumns, similarity)); + } + + return results; + } + + private static RedisResult[]? AsRedisArray(RedisResult result) + { + try + { + RedisResult[]? value = (RedisResult[]?)result; + return value; + } + catch + { + return null; + } + } + + private static Dictionary ExtractDocumentFields(RedisResult[] payload, string redisIndexType) + { + Dictionary fields = new(StringComparer.OrdinalIgnoreCase); + + for (int i = 0; i + 1 < payload.Length; i += 2) + { + string? key = payload[i].ToString(); + string? value = payload[i + 1].ToString(); + if (string.IsNullOrWhiteSpace(key)) + { + continue; + } + + if (string.Equals(redisIndexType, "json", StringComparison.OrdinalIgnoreCase) + && string.Equals(key, "$", StringComparison.Ordinal)) + { + MergeJsonDocumentFields(fields, value); + continue; + } + + fields[key] = value; + } + + return fields; + } + + private static void MergeJsonDocumentFields(Dictionary fields, string? json) + { + if (string.IsNullOrWhiteSpace(json)) + { + return; + } + + try + { + using JsonDocument doc = JsonDocument.Parse(json); + if (doc.RootElement.ValueKind is not JsonValueKind.Object) + { + return; + } + + foreach (JsonProperty property in doc.RootElement.EnumerateObject()) + { + fields[property.Name] = property.Value.ValueKind switch + { + JsonValueKind.Null => null, + JsonValueKind.String => property.Value.GetString(), + JsonValueKind.Number => property.Value.ToString(), + JsonValueKind.True => true, + JsonValueKind.False => false, + _ => property.Value.ToString() + }; + } + } + catch + { + // Ignore malformed JSON payloads and continue with other fields. + } + } + + private static double TryReadSimilarity(Dictionary fields) + { + if (!fields.TryGetValue(VECTOR_SCORE_FIELD, out object? scoreObj) + || scoreObj is null + || !double.TryParse(scoreObj.ToString(), NumberStyles.Float, CultureInfo.InvariantCulture, out double rawScore)) + { + return 0.0; + } + + // Redis KNN score is distance; normalize to similarity for semantic-threshold semantics. + double similarity = 1.0 - rawScore; + if (similarity < 0.0) + { + similarity = 0.0; + } + + if (similarity > 1.0) + { + similarity = 1.0; + } + + return similarity; + } + + private static string NormalizeFieldName(string field) + { + string normalized = field.Trim(); + + if (normalized.StartsWith("$.", StringComparison.Ordinal)) + { + normalized = normalized[2..]; + } + + if (normalized.Contains('.', StringComparison.Ordinal)) + { + normalized = normalized[(normalized.LastIndexOf('.') + 1)..]; + } + + return normalized; + } +} diff --git a/src/Service/Startup.cs b/src/Service/Startup.cs index 876b63378e..e40d08801a 100644 --- a/src/Service/Startup.cs +++ b/src/Service/Startup.cs @@ -28,11 +28,13 @@ using Azure.DataApiBuilder.Core.Services.Embeddings; using Azure.DataApiBuilder.Core.Services.MetadataProviders; using Azure.DataApiBuilder.Core.Services.OpenAPI; +using Azure.DataApiBuilder.Core.Services.SemanticSearch; using Azure.DataApiBuilder.Core.Telemetry; using Azure.DataApiBuilder.Mcp.Core; using Azure.DataApiBuilder.Service.Controllers; using Azure.DataApiBuilder.Service.Exceptions; using Azure.DataApiBuilder.Service.HealthCheck; +using Azure.DataApiBuilder.Service.Services.SemanticSearch; using Azure.DataApiBuilder.Service.Telemetry; using Azure.DataApiBuilder.Service.Utilities; using Azure.Identity; @@ -315,6 +317,8 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); + services.AddHttpClient(); + services.AddSingleton(); // ILogger explicit creation required for logger to use --LogLevel startup argument specified. services.AddSingleton>(implementationFactory: (serviceProvider) =>