|
3 | 3 |
|
4 | 4 | from beanie.odm.enums import SortDirection |
5 | 5 | from beanie.odm.operators.find import BaseFindOperator |
6 | | -from beanie.operators import GT, LT, In |
| 6 | +from beanie.operators import GT, LT, NE, Eq, In |
7 | 7 | from monggregate import Pipeline, S |
8 | 8 |
|
9 | 9 | from app.db.docs import ExecutionDocument, SagaDocument |
|
14 | 14 | class SagaRepository: |
15 | 15 | def _filter_conditions(self, saga_filter: SagaFilter) -> list[BaseFindOperator]: |
16 | 16 | """Build Beanie query conditions from SagaFilter.""" |
17 | | - conditions = [ |
18 | | - SagaDocument.state == saga_filter.state if saga_filter.state else None, |
19 | | - In(SagaDocument.execution_id, saga_filter.execution_ids) if saga_filter.execution_ids else None, |
20 | | - SagaDocument.context_data["user_id"] == saga_filter.user_id if saga_filter.user_id else None, |
21 | | - SagaDocument.saga_name == saga_filter.saga_name if saga_filter.saga_name else None, |
22 | | - GT(SagaDocument.created_at, saga_filter.created_after) if saga_filter.created_after else None, |
23 | | - LT(SagaDocument.created_at, saga_filter.created_before) if saga_filter.created_before else None, |
24 | | - ] |
| 17 | + conditions: list[BaseFindOperator] = [] |
| 18 | + if saga_filter.state: |
| 19 | + conditions.append(Eq(SagaDocument.state, saga_filter.state)) |
| 20 | + if saga_filter.execution_ids: |
| 21 | + conditions.append(In(SagaDocument.execution_id, saga_filter.execution_ids)) |
| 22 | + if saga_filter.user_id: |
| 23 | + conditions.append(Eq(SagaDocument.context_data["user_id"], saga_filter.user_id)) |
| 24 | + if saga_filter.saga_name: |
| 25 | + conditions.append(Eq(SagaDocument.saga_name, saga_filter.saga_name)) |
| 26 | + if saga_filter.created_after: |
| 27 | + conditions.append(GT(SagaDocument.created_at, saga_filter.created_after)) |
| 28 | + if saga_filter.created_before: |
| 29 | + conditions.append(LT(SagaDocument.created_at, saga_filter.created_before)) |
25 | 30 | if saga_filter.error_status is True: |
26 | | - conditions.append(SagaDocument.error_message != None) # noqa: E711 |
| 31 | + conditions.append(NE(SagaDocument.error_message, None)) |
27 | 32 | elif saga_filter.error_status is False: |
28 | | - conditions.append(SagaDocument.error_message == None) # noqa: E711 |
29 | | - return [c for c in conditions if c is not None] |
| 33 | + conditions.append(Eq(SagaDocument.error_message, None)) |
| 34 | + return conditions |
30 | 35 |
|
31 | 36 | async def upsert_saga(self, saga: Saga) -> bool: |
32 | 37 | existing = await SagaDocument.find_one(SagaDocument.saga_id == saga.saga_id) |
@@ -55,11 +60,9 @@ async def get_saga(self, saga_id: str) -> Saga | None: |
55 | 60 | async def get_sagas_by_execution( |
56 | 61 | self, execution_id: str, state: SagaState | None = None, limit: int = 100, skip: int = 0 |
57 | 62 | ) -> SagaListResult: |
58 | | - conditions = [ |
59 | | - SagaDocument.execution_id == execution_id, |
60 | | - SagaDocument.state == state if state else None, |
61 | | - ] |
62 | | - conditions = [c for c in conditions if c is not None] |
| 63 | + conditions: list[BaseFindOperator] = [Eq(SagaDocument.execution_id, execution_id)] |
| 64 | + if state: |
| 65 | + conditions.append(Eq(SagaDocument.state, state)) |
63 | 66 |
|
64 | 67 | query = SagaDocument.find(*conditions) |
65 | 68 | total = await query.count() |
@@ -135,10 +138,10 @@ async def get_saga_statistics(self, saga_filter: SagaFilter | None = None) -> di |
135 | 138 | states[doc["_id"]] = doc["count"] |
136 | 139 |
|
137 | 140 | # Average duration for completed sagas |
138 | | - completed_conditions = [ |
| 141 | + completed_conditions: list[BaseFindOperator] = [ |
139 | 142 | *conditions, |
140 | | - SagaDocument.state == SagaState.COMPLETED, |
141 | | - SagaDocument.completed_at != None, # noqa: E711 |
| 143 | + Eq(SagaDocument.state, SagaState.COMPLETED), |
| 144 | + NE(SagaDocument.completed_at, None), |
142 | 145 | ] |
143 | 146 | duration_pipeline = ( |
144 | 147 | Pipeline() |
|
0 commit comments