Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 62 additions & 50 deletions pixi.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ test = [
"pytest-benchmark>=3.4.1,<4",
"pytest-asyncio<=0.24.0",
"py>=1.11.0",
"testcontainers==4.9.0",
"testcontainers>=4.15.0rc2",
"minio==7.2.11",
"python-keycloak==4.2.2",
"cryptography>=43.0",
Expand Down
305 changes: 298 additions & 7 deletions sdk/python/feast/infra/online_stores/mongodb_online_store/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
from __future__ import annotations

import time
from datetime import datetime
from logging import getLogger
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

try:
from pymongo import AsyncMongoClient, MongoClient, UpdateOne
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.operations import SearchIndexModel
except ImportError as e:
from feast.errors import FeastExtrasDependencyImportError

raise FeastExtrasDependencyImportError("mongodb", str(e))

import feast.version
from feast.batch_feature_view import BatchFeatureView
from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.key_encoding_utils import deserialize_entity_key, serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.online_stores.vector_store import VectorStoreConfig
from feast.infra.supported_async_methods import SupportedAsyncMethods
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.stream_feature_view import StreamFeatureView
from feast.type_map import (
feast_value_type_to_python_type,
python_values_to_proto_values,
Expand All @@ -33,7 +38,7 @@
DRIVER_METADATA = DriverInfo(name="Feast", version=feast.version.get_version())


class MongoDBOnlineStoreConfig(FeastConfigBaseModel):
class MongoDBOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
"""MongoDB configuration.

For a description of kwargs that may be passed to MongoClient,
Expand All @@ -48,6 +53,11 @@ class MongoDBOnlineStoreConfig(FeastConfigBaseModel):
)
collection_suffix: str = "latest"
client_kwargs: Dict[str, Any] = {}
# vector_enabled and similarity are inherited from VectorStoreConfig
vector_index_wait_timeout: int = 60
"""Seconds to wait for a newly created Atlas Search index to become READY."""
vector_index_wait_poll_interval: float = 1.0
"""Seconds between polls when waiting for an Atlas Search index to become READY."""


class MongoDBOnlineStore(OnlineStore):
Expand Down Expand Up @@ -213,11 +223,135 @@ def online_read(

return self._convert_raw_docs_to_proto(ids, docs, table)

def retrieve_online_documents_v2(
self,
config: RepoConfig,
table: FeatureView,
requested_features: List[str],
embedding: Optional[List[float]],
top_k: int,
distance_metric: Optional[str] = None,
query_string: Optional[str] = None,
include_feature_view_version_metadata: bool = False,
) -> List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
]:
"""Retrieve documents via MongoDB Atlas Vector Search ($vectorSearch).

Uses the ``$vectorSearch`` aggregation stage to find the *top_k*
documents closest to the provided *embedding* vector. The method
expects that an Atlas vector search index has already been created for
the relevant field (see ``update()``).

Returns a list of 3-tuples ``(event_timestamp, entity_key_proto,
feature_dict)`` where *feature_dict* includes the requested feature
values plus a synthetic ``distance`` key with the vector search score.
"""
if not config.online_store.vector_enabled:
raise ValueError(
"Vector search is not enabled in the online store config. "
"Set vector_enabled=True in MongoDBOnlineStoreConfig."
)
if embedding is None:
raise ValueError(
"An embedding vector must be provided for MongoDB vector search."
)

clxn = self._get_collection(config)

# Identify the vector field on this feature view
vector_fields = [f for f in table.features if f.vector_index]
if not vector_fields:
raise ValueError(
f"Feature view '{table.name}' has no fields with vector_index=True."
)
vector_field = vector_fields[0]
path = f"features.{table.name}.{vector_field.name}"
idx_name = self._vector_search_index_name(table.name, vector_field.name)

# BSON cannot encode numpy float types — ensure native Python floats.
query_vector = [float(v) for v in embedding]

num_candidates = max(top_k * 10, 100)
pipeline: List[Dict[str, Any]] = [
{
"$vectorSearch": {
"index": idx_name,
"path": path,
"queryVector": query_vector,
"numCandidates": num_candidates,
"limit": top_k,
}
},
{
"$addFields": {
"score": {"$meta": "vectorSearchScore"},
}
},
]

results: List[
Tuple[
Optional[datetime],
Optional[EntityKeyProto],
Optional[Dict[str, ValueProto]],
]
] = []

for doc in clxn.aggregate(pipeline):
# Deserialize entity key
entity_key_bin = doc.get("_id")
entity_key_proto = (
deserialize_entity_key(
entity_key_bin,
entity_key_serialization_version=config.entity_key_serialization_version,
)
if entity_key_bin
else None
)

# Event timestamp
event_ts = doc.get("event_timestamps", {}).get(table.name)

# Build feature dict from raw doc values
fv_features = doc.get("features", {}).get(table.name, {})

# Convert raw values → ValueProto for each requested feature
feature_dict: Dict[str, ValueProto] = {}
feature_type_map = {f.name: f.dtype.to_value_type() for f in table.features}

for feat_name in requested_features:
raw_val = fv_features.get(feat_name)
if raw_val is not None:
vtype = feature_type_map.get(feat_name)
if vtype is not None:
protos = python_values_to_proto_values(
[raw_val], feature_type=vtype
)
feature_dict[feat_name] = protos[0]
else:
# Fall back: try to store as-is
feature_dict[feat_name] = _python_value_to_proto(raw_val)

# Add distance (vector search score)
score = doc.get("score", 0.0)
feature_dict["distance"] = ValueProto(float_val=float(score))

results.append((event_ts, entity_key_proto, feature_dict))

return results

def update(
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
tables_to_keep: Sequence[
Union[BatchFeatureView, StreamFeatureView, FeatureView]
],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
Expand All @@ -238,20 +372,32 @@ def update(
}
We remove any feature views named in tables_to_delete.
The Entities are serialized in the _id. No schema needs be adjusted.

When ``vector_enabled`` is set in the online store config, Atlas Vector
Search indexes are automatically created for feature views containing
fields with ``vector_index=True`` and dropped for deleted feature views.
"""
if not isinstance(config.online_store, MongoDBOnlineStoreConfig):
raise RuntimeError(f"{config.online_store.type = }. It must be mongodb.")

online_config = config.online_store
clxn = self._get_collection(repo_config=config)

# --- Remove deleted feature views (data + vector search indexes) ---
if tables_to_delete:
unset_fields = {}
for fv in tables_to_delete:
unset_fields[f"features.{fv.name}"] = ""
unset_fields[f"event_timestamps.{fv.name}"] = ""

clxn.update_many({}, {"$unset": unset_fields})

if online_config.vector_enabled:
self._drop_vector_indexes_for_tables(clxn, tables_to_delete)

# --- Create vector search indexes for kept feature views ---
if online_config.vector_enabled:
self._ensure_vector_indexes(clxn, tables_to_keep, online_config)

# Note: entities_to_delete contains Entity definitions (metadata), not entity instances.
# Like other online stores, we don't need to do anything with entities_to_delete here.

Expand Down Expand Up @@ -286,6 +432,136 @@ async def close(self) -> None:
# Helpers
# ------------------------------------------------------------------

# -- Vector Search helpers ------------------------------------------

@staticmethod
def _vector_search_index_name(fv_name: str, field_name: str) -> str:
"""Canonical Atlas vector search index name for a (feature_view, field) pair."""
return f"{fv_name}__{field_name}__vs_index"

@staticmethod
def _vector_search_index_definition(
path: str,
num_dimensions: int,
similarity: str,
) -> dict:
"""Return a vector search index definition for ``SearchIndexModel``."""
return {
"fields": [
{
"type": "vector",
"path": path,
"numDimensions": num_dimensions,
"similarity": similarity,
}
]
}

@staticmethod
def _wait_for_index_ready(
collection: Collection,
index_name: str,
timeout: int,
poll_interval: float = 1.0,
) -> None:
"""Poll until the named Atlas Search index reaches READY status.

Raises ``TimeoutError`` if the index does not become queryable
within *timeout* seconds.
"""
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
for idx in collection.list_search_indexes(name=index_name):
if idx.get("status") == "READY" or idx.get("queryable") is True:
return
time.sleep(poll_interval)
raise TimeoutError(
f"Atlas search index '{index_name}' did not reach READY "
f"within {timeout}s. Increase vector_index_wait_timeout in "
f"MongoDBOnlineStoreConfig if the index needs more time."
)

def _drop_vector_indexes_for_tables(
self,
collection: Collection,
tables: Sequence[FeatureView],
) -> None:
"""Drop all Atlas vector search indexes belonging to the given feature views."""
existing = {idx["name"] for idx in collection.list_search_indexes()}
for fv in tables:
for field in fv.features:
idx_name = self._vector_search_index_name(fv.name, field.name)
if idx_name in existing:
logger.info("Dropping vector search index: %s", idx_name)
collection.drop_search_index(idx_name)

def _ensure_vector_indexes(
self,
collection: Collection,
tables: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]],
online_config: MongoDBOnlineStoreConfig,
) -> None:
"""Create Atlas vector search indexes for vector-indexed fields if they don't exist.

Currently creates one index per (feature_view, vector_field) pair.
A future optimization could consolidate all vector fields into a
single composite index with multiple field definitions, reducing
cluster-wide index count and memory overhead. See Atlas limits:
hard cap of 2,500 search indexes per cluster; smaller tiers (M10)
may degrade well before that.
"""
db = collection.database
if collection.name not in db.list_collection_names():
db.create_collection(collection.name)

existing = {idx["name"] for idx in collection.list_search_indexes()}

for fv in tables:
vector_fields = [f for f in fv.features if f.vector_index]
for field in vector_fields:
idx_name = self._vector_search_index_name(fv.name, field.name)
if idx_name in existing:
logger.debug("Vector search index '%s' already exists", idx_name)
continue

path = f"features.{fv.name}.{field.name}"
num_dimensions = field.vector_length
if not num_dimensions:
raise ValueError(
f"Field '{field.name}' in feature view '{fv.name}' has "
f"vector_index=True but vector_length is not set."
)

similarity = (
field.vector_search_metric or online_config.similarity or "cosine"
)

definition = self._vector_search_index_definition(
path, num_dimensions, similarity
)
search_index_model = SearchIndexModel(
definition=definition,
name=idx_name,
type="vectorSearch",
)
logger.info(
"Creating Atlas vector search index '%s' on path '%s' "
"(dims=%d, similarity=%s)",
idx_name,
path,
num_dimensions,
similarity,
)
collection.create_search_index(model=search_index_model)
self._wait_for_index_ready(
collection,
idx_name,
online_config.vector_index_wait_timeout,
online_config.vector_index_wait_poll_interval,
)

# -- Connection helpers ---------------------------------------------

def _get_client(self, config: RepoConfig):
"""Returns a connection to the server."""
online_store_config = config.online_store
Expand Down Expand Up @@ -494,5 +770,20 @@ async def online_write_batch_async(
progress(len(data))


# TODO
# - Vector Search (requires atlas image in testcontainers or similar)
def _python_value_to_proto(value: Any) -> ValueProto:
"""Best-effort conversion of a single Python value to ValueProto."""
if isinstance(value, float):
return ValueProto(float_val=value)
elif isinstance(value, bool):
return ValueProto(bool_val=value)
elif isinstance(value, int):
return ValueProto(int64_val=value)
elif isinstance(value, str):
return ValueProto(string_val=value)
elif isinstance(value, bytes):
return ValueProto(bytes_val=value)
elif isinstance(value, list) and all(isinstance(v, float) for v in value):
proto = ValueProto()
proto.float_list_val.val.extend(value)
return proto
return ValueProto()
Loading
Loading