|
1 | 1 | from logging import getLogger |
2 | | -from typing import Any, List |
| 2 | +from typing import Any, List, Optional |
3 | 3 |
|
4 | 4 | import pydantic |
5 | 5 | from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult |
@@ -108,23 +108,28 @@ async def on_error( |
108 | 108 | return |
109 | 109 | if current_step_num == len(steps) - 1: |
110 | 110 | return |
111 | | - await self.fail_pipeline(steps[-1].task_id) |
| 111 | + await self.fail_pipeline(steps[-1].task_id, result.error) |
112 | 112 |
|
113 | | - async def fail_pipeline(self, last_task_id: str) -> None: |
| 113 | + async def fail_pipeline( |
| 114 | + self, |
| 115 | + last_task_id: str, |
| 116 | + abort: Optional[BaseException] = None, |
| 117 | + ) -> None: |
114 | 118 | """ |
115 | 119 | This function aborts pipeline. |
116 | 120 |
|
117 | 121 | This is done by setting error result for |
118 | 122 | the last task in the pipeline. |
119 | 123 |
|
120 | 124 | :param last_task_id: id of the last task. |
| 125 | + :param abort: caught earlier exception or default |
121 | 126 | """ |
122 | 127 | await self.broker.result_backend.set_result( |
123 | 128 | last_task_id, |
124 | 129 | TaskiqResult( |
125 | 130 | is_err=True, |
126 | 131 | return_value=None, # type: ignore |
127 | | - error=AbortPipeline("Execution aborted."), |
| 132 | + error=abort or AbortPipeline("Execution aborted."), |
128 | 133 | execution_time=0, |
129 | 134 | log="Error found while executing pipeline.", |
130 | 135 | ), |
|
0 commit comments