Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
Expression,
Field,
Selectable,
FunctionExpression,
_PipelineValueExpression,
)
from google.cloud.firestore_v1.types.pipeline import (
StructuredPipeline as StructuredPipeline_pb,
Expand Down Expand Up @@ -90,6 +92,51 @@ def _to_pb(self, **options) -> StructuredPipeline_pb:
options=options,
)

def to_array_expression(self) -> Expression:
"""
Converts this Pipeline into an expression that evaluates to an array of results.
Used for embedding 1:N subqueries into stages like `addFields`.

Example:
>>> # Get a list of all reviewer names for each book
>>> db.pipeline().collection("books").define(Field.of("id").as_("book_id")).add_fields(
... db.pipeline().collection("reviews")
... .where(Field.of("book_id").equal(Variable("book_id")))
... .select(Field.of("reviewer").as_("name"))
... .to_array_expression().as_("reviewers")
... )

Returns:
An :class:`Expression` representing the execution of this pipeline.
"""
return FunctionExpression("array", [_PipelineValueExpression(self)])

def to_scalar_expression(self) -> Expression:
"""
Converts this Pipeline into an expression that evaluates to a single scalar result.
Used for 1:1 lookups or Aggregations when the subquery is expected to return a single value or object.

**Result Unwrapping:**
For simpler access, scalar subqueries producing a single field automatically unwrap that value to the
top level, ignoring the inner alias. If the subquery returns multiple fields, they are preserved as a map.

Example:
>>> # Calculate average rating for each restaurant using a subquery
>>> db.pipeline().collection("restaurants").define(Field.of("id").as_("rid")).add_fields(
... db.pipeline().collection("reviews")
... .where(Field.of("restaurant_id").equal(Variable("rid")))
... .aggregate(AggregateFunction.average("rating").as_("value"))
... .to_scalar_expression().as_("average_rating")
... )

Raises:
RuntimeError: If the result set contains more than one item. If the pipeline has zero results, it evaluates to `null` instead of raising an error.

Returns:
An :class:`Expression` representing the execution of this pipeline.
"""
return FunctionExpression("scalar", [_PipelineValueExpression(self)])

def _append(self, new_stage):
"""
Create a new Pipeline object with a new stage appended
Expand Down Expand Up @@ -610,3 +657,28 @@ def distinct(self, *fields: str | Selectable) -> "_BasePipeline":
A new Pipeline object with this stage appended to the stage list
"""
return self._append(stages.Distinct(*fields))

def define(self, *aliased_expressions: AliasedExpression) -> "_BasePipeline":
"""
Binds one or more expressions to Variables that can be accessed in subsequent stages
or inner subqueries using `Variable`.

Each Variable is defined using an :class:`AliasedExpression`, which pairs an expression with
a name (alias).

Example:
>>> db.pipeline().collection("products").define(
... Field.of("price").multiply(0.9).as_("discountedPrice"),
... Field.of("stock").add(10).as_("newStock")
... ).where(
... Variable("discountedPrice").less_than(100)
... ).select(Field.of("name"), Variable("newStock"))

Args:
*aliased_expressions: One or more :class:`AliasedExpression` defining the Variable names and values.

Returns:
A new Pipeline object with this stage appended to the stage list.
"""
return self._append(stages.Define(*aliased_expressions))

Original file line number Diff line number Diff line change
Expand Up @@ -2634,3 +2634,53 @@ class Rand(FunctionExpression):

def __init__(self):
super().__init__("rand", [], use_infix_repr=False)


class Variable(Expression):
"""
Creates an expression that retrieves the value of a variable bound via `Pipeline.define`.

Example:
>>> # Define a variable "discountedPrice" and use it in a filter
>>> db.pipeline().collection("products").define(
... Field.of("price").multiply(0.9).as_("discountedPrice")
... ).where(Variable("discountedPrice").less_than(100))

Args:
name: The name of the variable to retrieve.
"""

def __init__(self, name: str):
self.name = name

def _to_pb(self) -> Value:
return Value(variable_reference_value=self.name)


class _PipelineValueExpression(Expression):
"""Internal wrapper to represent a pipeline as an expression."""

def __init__(self, pipeline):
Comment thread
daniel-sanche marked this conversation as resolved.
Outdated
self.pipeline = pipeline

def _to_pb(self) -> Value:
return Value(pipeline_value=self.pipeline._to_pb())


class CurrentDocument(FunctionExpression):
"""
Creates an expression that represents the current document being processed.

This acts as a handle, allowing you to bind the entire document to a variable or pass the
document itself to a function or subquery.

Example:
>>> # Define the current document as a variable "doc"
>>> db.pipeline().collection("books").define(
... CurrentDocument().as_("doc")
... ).select(Variable("doc").get_field("title"))
"""

def __init__(self):
super().__init__("current_document", [])

Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,25 @@ def literals(
A new Pipeline object with this stage appended to the stage list.
"""
return self._create_pipeline(stages.Literals(*documents))

def subcollection(self, path: str) -> PipelineType:
"""
Creates a new Pipeline targeted at a subcollection relative to the current document context.

This is used inside stages like `addFields` to query physically nested subcollections
without manually joining on IDs.

Example:
>>> db.pipeline().collection("books").add_fields(
... db.pipeline().subcollection("reviews")
... .aggregate(AggregateFunction.average("rating").as_("avg_rating"))
... .to_scalar_expression().as_("average_rating")
... )

Args:
path: The path of the subcollection.

Returns:
A new :class:`Pipeline` instance scoped to the subcollection.
"""
return self._create_pipeline(stages.Subcollection(path))
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,26 @@ def __init__(self, condition: BooleanExpression):

def _pb_args(self):
return [self.condition._to_pb()]


class Define(Stage):
"""Binds one or more expressions to variables."""

def __init__(self, *expressions: AliasedExpression):
super().__init__("let")
self.expressions = list(expressions)

def _pb_args(self) -> list[Value]:
return [Selectable._to_value(self.expressions)]


class Subcollection(Stage):
"""Targets a subcollection relative to the current document."""

def __init__(self, path: str):
super().__init__("subcollection")
self.path = path

def _pb_args(self) -> list[Value]:
return [encode_value(self.path)]

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
tests:
- description: array_subquery_with_variable
pipeline:
- Collection: publishers
- Where:
- Field.of.equal:
- publisherId
- pub1
- Define:
- Field.of.as_:
- publisherId
- pub_id
- AddFields:
- Pipeline.to_array_expression:
- Collection: books
- Where:
- Field.of.equal:
- publisherId
- Variable: pub_id
as_: books
- Select:
- Field: name
- Field: books
assert_results:
- name: Publisher 1
books:
- title: The Hitchhiker's Guide to the Galaxy
author: Douglas Adams
- title: Pride and Prejudice
author: Jane Austen

- description: scalar_subquery_with_current_document
pipeline:
- Collection: books
- Where:
- Field.of.equal:
- title
- 1984
- Define:
- CurrentDocument.as_: doc
- AddFields:
- Pipeline.to_scalar_expression:
- Collection: reviews
- Where:
- Field.of.equal:
- bookId
- Variable.get_field:
- doc
- __name__
- Aggregate:
- AggregateFunction.average.as_:
- rating
- avg_rating
as_: average_rating
assert_results:
- title: 1984
average_rating: 4.5
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def _parse_expressions(client, yaml_element: Any):
# find Pipeline objects for Union expressions
other_ppl = yaml_element["Pipeline"]
return parse_pipeline(client, other_ppl)
elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline.to_array_expression":
other_ppl = yaml_element["Pipeline.to_array_expression"]
return parse_pipeline(client, other_ppl).to_array_expression()
elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline.to_scalar_expression":
other_ppl = yaml_element["Pipeline.to_scalar_expression"]
return parse_pipeline(client, other_ppl).to_scalar_expression()
else:
# otherwise, return dict
return {
Expand Down
Loading