Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,20 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline":
"""
return self._append(stages.Sort(*orders))

def search(self, options: stages.SearchOptions) -> "_BasePipeline":
"""
Adds a search stage to the pipeline.
This stage filters documents based on the provided query expression.
Args:
options: A SearchOptions instance configuring the search.
Returns:
A new Pipeline object with this stage appended to the stage list
"""
return self._append(stages.Search(options))

def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline":
"""
Performs a pseudo-random sampling of the documents from the previous stage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,53 @@ def less_than_or_equal(
[self, self._cast_to_expr_or_convert_to_constant(other)],
)

@expose_as_static
def between(
self, lower: Expression | float, upper: Expression | float
) -> "BooleanExpression":
"""Evaluates if the result of this expression is between
the lower bound (inclusive) and upper bound (inclusive).

This is functionally equivalent to performing an `And` operation with
`greater_than_or_equal` and `less_than_or_equal`.

Example:
>>> # Check if the 'age' field is between 18 and 65
>>> Field.of("age").between(18, 65)

Args:
lower: Lower bound (inclusive) of the range.
upper: Upper bound (inclusive) of the range.

Returns:
A new `BooleanExpression` representing the between comparison.
"""
return And(
self.greater_than_or_equal(lower),
self.less_than_or_equal(upper),
)

@expose_as_static
def geo_distance(self, other: Expression | GeoPoint) -> "FunctionExpression":
"""Evaluates to the distance in meters between the location in the specified
field and the query location.

Note: This Expression can only be used within a `Search` stage.

Example:
>>> # Calculate distance between the 'location' field and a target GeoPoint
>>> Field.of("location").geo_distance(target_point)

Args:
other: Compute distance to this GeoPoint expression or constant value.

Returns:
A new `FunctionExpression` representing the distance.
"""
return FunctionExpression(
"geo_distance", [self, self._cast_to_expr_or_convert_to_constant(other)]
)

@expose_as_static
def equal_any(
self, array: Array | Sequence[Expression | CONSTANT_TYPE] | Expression
Expand Down Expand Up @@ -2889,3 +2936,44 @@ class Rand(FunctionExpression):

def __init__(self):
super().__init__("rand", [], use_infix_repr=False)


class Score(FunctionExpression):
"""Evaluates to the search score that reflects the topicality of the document
to all of the text predicates (`queryMatch`)
in the search query. If `SearchOptions.query` is not set or does not contain
any text predicates, then this topicality score will always be `0`.

Note: This Expression can only be used within a `Search` stage.

Returns:
A new `Expression` representing the score operation.
"""

def __init__(self):
super().__init__("score", [], use_infix_repr=False)


class DocumentMatches(BooleanExpression):
"""Creates a boolean expression for a document match query.

Note: This Expression can only be used within a `Search` stage.

Example:
>>> # Find documents matching the query string
>>> DocumentMatches("search query")

Args:
query: The search query string or expression.

Returns:
A new `BooleanExpression` representing the document match.
"""

def __init__(self, query: Expression | str):
super().__init__(
"document_matches",
[Expression._cast_to_expr_or_convert_to_constant(query)],
use_infix_repr=False,
)

Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AliasedExpression,
BooleanExpression,
CONSTANT_TYPE,
DocumentMatches,
Expression,
Field,
Ordering,
Expand Down Expand Up @@ -109,6 +110,81 @@ def percentage(value: float):
return SampleOptions(value, mode=SampleOptions.Mode.PERCENT)


class QueryEnhancement(Enum):
"""Define the query expansion behavior used by full-text search expressions."""
DISABLED = "disabled"
REQUIRED = "required"
PREFERRED = "preferred"


class SearchOptions:
"""Options for configuring the `Search` pipeline stage."""

def __init__(
self,
query: str | BooleanExpression,
limit: Optional[int] = None,
retrieval_depth: Optional[int] = None,
sort: Optional[Sequence[Ordering] | Ordering] = None,
add_fields: Optional[Sequence[Selectable]] = None,
select: Optional[Sequence[Selectable | str]] = None,
offset: Optional[int] = None,
query_enhancement: Optional[str | QueryEnhancement] = None,
language_code: Optional[str] = None,
):
"""
Initializes a SearchOptions instance.

Args:
query (str | BooleanExpression): Specifies the search query that will be used to query and score documents
by the search stage. The query can be expressed as an `Expression`, which will be used to score
and filter the results. Not all expressions supported by Pipelines are supported in the Search query.
The query can also be expressed as a string in the Search DSL.
limit (Optional[int]): The maximum number of documents to return from the Search stage.
retrieval_depth (Optional[int]): The maximum number of documents for the search stage to score. Documents
will be processed in the pre-sort order specified by the search index.
sort (Optional[Sequence[Ordering] | Ordering]): Orderings specify how the input documents are sorted.
add_fields (Optional[Sequence[Selectable]]): The fields to add to each document, specified as a `Selectable`.
select (Optional[Sequence[Selectable | str]]): The fields to keep or add to each document,
specified as an array of `Selectable` or strings.
offset (Optional[int]): The number of documents to skip.
query_enhancement (Optional[str | QueryEnhancement]): Define the query expansion behavior used by full-text search expressions
in this search stage.
language_code (Optional[str]): The BCP-47 language code of text in the search query, such as "en-US" or "sr-Latn".
"""
self.query = DocumentMatches(query) if isinstance(query, str) else query
self.limit = limit
self.retrieval_depth = retrieval_depth
self.sort = [sort] if isinstance(sort, Ordering) else sort
self.add_fields = add_fields
self.select = [Field(s) if isinstance(s, str) else s for s in select] if select is not None else None
self.offset = offset
self.query_enhancement = (
QueryEnhancement(query_enhancement.lower()) if isinstance(query_enhancement, str) else query_enhancement
)
self.language_code = language_code

def __repr__(self):
args = [f"query={self.query!r}"]
if self.limit is not None:
args.append(f"limit={self.limit}")
if self.retrieval_depth is not None:
args.append(f"retrieval_depth={self.retrieval_depth}")
if self.sort is not None:
args.append(f"sort={self.sort}")
if self.add_fields is not None:
args.append(f"add_fields={self.add_fields}")
if self.select is not None:
args.append(f"select={self.select}")
if self.offset is not None:
args.append(f"offset={self.offset}")
if self.query_enhancement is not None:
args.append(f"query_enhancement={self.query_enhancement!r}")
if self.language_code is not None:
args.append(f"language_code={self.language_code!r}")
return f"{self.__class__.__name__}({', '.join(args)})"


class UnnestOptions:
"""Options for configuring the `Unnest` pipeline stage.

Expand Down Expand Up @@ -423,6 +499,39 @@ def _pb_args(self):
]


class Search(Stage):
"""Search stage."""

def __init__(self, options: SearchOptions):
super().__init__("search")
self.options = options

def _pb_args(self) -> list[Value]:
return []

def _pb_options(self) -> dict[str, Value]:
options = {}
if self.options.query is not None:
options["query"] = self.options.query._to_pb()
Comment on lines +514 to +515
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The query parameter in SearchOptions.__init__ is not optional, and the initializer ensures self.options.query is always a BooleanExpression. Therefore, this if check for None is redundant and can be removed to simplify the code.

        options["query"] = self.options.query._to_pb()

if self.options.limit is not None:
options["limit"] = Value(integer_value=self.options.limit)
if self.options.retrieval_depth is not None:
options["retrieval_depth"] = Value(integer_value=self.options.retrieval_depth)
if self.options.sort is not None:
options["sort"] = Value(array_value={"values": [s._to_pb() for s in self.options.sort]})
if self.options.add_fields is not None:
options["add_fields"] = Selectable._to_value(self.options.add_fields)
if self.options.select is not None:
options["select"] = Selectable._to_value(self.options.select)
if self.options.offset is not None:
options["offset"] = Value(integer_value=self.options.offset)
if self.options.query_enhancement is not None:
options["query_enhancement"] = Value(string_value=self.options.query_enhancement.value)
if self.options.language_code is not None:
options["language_code"] = Value(string_value=self.options.language_code)
return options


class Select(Stage):
"""Selects or creates a set of fields."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,14 @@ data:
doc_with_nan:
value: "NaN"
doc_with_null:
value: null
value: null
geopoints:
loc1:
name: SF
location: GEOPOINT(37.7749,-122.4194)
loc2:
name: LA
location: GEOPOINT(34.0522,-118.2437)
loc3:
name: NY
location: GEOPOINT(40.7128,-74.0060)
Loading
Loading