|
25 | 25 | CompleteWorkflowExecutionDecisionAttributes, Decision, DecisionType, RespondDecisionTaskCompletedResponse, \ |
26 | 26 | HistoryEvent, EventType, WorkflowType, ScheduleActivityTaskDecisionAttributes, \ |
27 | 27 | CancelWorkflowExecutionDecisionAttributes, StartTimerDecisionAttributes, TimerFiredEventAttributes, \ |
28 | | - FailWorkflowExecutionDecisionAttributes, RecordMarkerDecisionAttributes, Header |
29 | | -from cadence.conversions import json_to_args |
| 28 | + FailWorkflowExecutionDecisionAttributes, RecordMarkerDecisionAttributes, Header, WorkflowQuery, \ |
| 29 | + RespondQueryTaskCompletedRequest, QueryTaskCompletedType, QueryWorkflowResponse |
| 30 | +from cadence.conversions import json_to_args, args_to_json |
30 | 31 | from cadence.decisions import DecisionId, DecisionTarget |
31 | 32 | from cadence.exception_handling import serialize_exception, deserialize_exception |
32 | 33 | from cadence.exceptions import WorkflowTypeNotFound, NonDeterministicWorkflowException, ActivityTaskFailedException, \ |
33 | | - ActivityTaskTimeoutException, SignalNotFound, ActivityFailureException |
| 34 | + ActivityTaskTimeoutException, SignalNotFound, ActivityFailureException, QueryNotFound, QueryDidNotComplete |
34 | 35 | from cadence.state_machines import ActivityDecisionStateMachine, DecisionStateMachine, CompleteWorkflowStateMachine, \ |
35 | 36 | TimerDecisionStateMachine, MarkerDecisionStateMachine |
36 | 37 | from cadence.tchannel import TChannelException |
37 | 38 | from cadence.worker import Worker |
| 39 | +from cadence.workflow import QueryMethod |
38 | 40 | from cadence.workflowservice import WorkflowService |
39 | 41 |
|
40 | 42 | logger = logging.getLogger(__name__) |
@@ -254,6 +256,47 @@ def get_workflow_instance(self): |
254 | 256 | return self.workflow_instance |
255 | 257 |
|
256 | 258 |
|
| 259 | +@dataclass |
| 260 | +class QueryMethodTask(ITask): |
| 261 | + task_id: str = None |
| 262 | + workflow_instance: object = None |
| 263 | + query_name: str = None |
| 264 | + query_input: List = None |
| 265 | + exception_thrown: BaseException = None |
| 266 | + ret_value: object = None |
| 267 | + |
| 268 | + def start(self): |
| 269 | + logger.debug(f"[query-task-{self.task_id}-{self.query_name}] Created") |
| 270 | + self.task = asyncio.get_event_loop().create_task(self.query_main()) |
| 271 | + |
| 272 | + async def query_main(self): |
| 273 | + logger.debug(f"[query-task-{self.task_id}-{self.query_name}] Running") |
| 274 | + current_task.set(self) |
| 275 | + |
| 276 | + if not self.query_name in self.workflow_instance._query_methods: |
| 277 | + self.status = Status.DONE |
| 278 | + self.exception_thrown = QueryNotFound(self.query_name) |
| 279 | + logger.error(f"Query not found: {self.query_name}") |
| 280 | + return |
| 281 | + |
| 282 | + query_proc = self.workflow_instance._query_methods[self.query_name] |
| 283 | + self.status = Status.RUNNING |
| 284 | + |
| 285 | + try: |
| 286 | + logger.info(f"Invoking query {self.query_name}({str(self.query_input)[1:-1]})") |
| 287 | + self.ret_value = await query_proc(self.workflow_instance, *self.query_input) |
| 288 | + logger.info( |
| 289 | + f"Query {self.query_name}({str(self.query_input)[1:-1]}) returned {self.ret_value}") |
| 290 | + except CancelledError: |
| 291 | + logger.debug("Coroutine cancelled (expected)") |
| 292 | + except Exception as ex: |
| 293 | + logger.error( |
| 294 | + f"Query {self.query_name}({str(self.query_input)[1:-1]}) failed", exc_info=1) |
| 295 | + self.exception_thrown = ex |
| 296 | + finally: |
| 297 | + self.status = Status.DONE |
| 298 | + |
| 299 | + |
257 | 300 | @dataclass |
258 | 301 | class SignalMethodTask(ITask): |
259 | 302 | task_id: str = None |
@@ -718,6 +761,28 @@ def handle_marker_recorded(self, event: HistoryEvent): |
718 | 761 | def get_optional_decision_event(self, event_id: int) -> HistoryEvent: |
719 | 762 | return self.decision_events.get_optional_decision_event(event_id) |
720 | 763 |
|
| 764 | + def query(self, decision_task: PollForDecisionTaskResponse, query: WorkflowQuery) -> bytes: |
| 765 | + query_args = query.query_args |
| 766 | + if query_args is None: |
| 767 | + args = [] |
| 768 | + else: |
| 769 | + args = json_to_args(query_args) |
| 770 | + task = QueryMethodTask(task_id=self.execution_id, |
| 771 | + workflow_instance=self.workflow_task.workflow_instance, |
| 772 | + query_name=query.query_type, |
| 773 | + query_input=args, |
| 774 | + decider=self) |
| 775 | + self.tasks.append(task) |
| 776 | + task.start() |
| 777 | + self.event_loop.run_event_loop_once() |
| 778 | + if task.status == Status.DONE: |
| 779 | + if task.exception_thrown: |
| 780 | + raise task.exception_thrown |
| 781 | + else: # ret_value might be None, need to put it in else |
| 782 | + return task.ret_value |
| 783 | + else: |
| 784 | + raise QueryDidNotComplete(f"Query method {query.query_type} with args {query.query_args} did not complete") |
| 785 | + |
721 | 786 |
|
722 | 787 | # noinspection PyUnusedLocal |
723 | 788 | def noop(*args): |
@@ -774,8 +839,16 @@ def run(self): |
774 | 839 | decision_task: PollForDecisionTaskResponse = self.poll() |
775 | 840 | if not decision_task: |
776 | 841 | continue |
777 | | - decisions = self.process_task(decision_task) |
778 | | - self.respond_decisions(decision_task.task_token, decisions) |
| 842 | + if decision_task.query: |
| 843 | + try: |
| 844 | + result = self.process_query(decision_task) |
| 845 | + self.respond_query(decision_task.task_token, result, None) |
| 846 | + except Exception as ex: |
| 847 | + logger.error("Error") |
| 848 | + self.respond_query(decision_task.task_token, None, serialize_exception(ex)) |
| 849 | + else: |
| 850 | + decisions = self.process_task(decision_task) |
| 851 | + self.respond_decisions(decision_task.task_token, decisions) |
779 | 852 | finally: |
780 | 853 | # noinspection PyPep8,PyBroadException |
781 | 854 | try: |
@@ -815,6 +888,32 @@ def process_task(self, decision_task: PollForDecisionTaskResponse) -> List[Decis |
815 | 888 | decider.destroy() |
816 | 889 | return decisions |
817 | 890 |
|
| 891 | + def process_query(self, decision_task: PollForDecisionTaskResponse) -> bytes: |
| 892 | + execution_id = str(decision_task.workflow_execution) |
| 893 | + decider = ReplayDecider(execution_id, decision_task.workflow_type, self.worker) |
| 894 | + decider.decide(decision_task.history.events) |
| 895 | + try: |
| 896 | + result = decider.query(decision_task, decision_task.query) |
| 897 | + return json.dumps(result) |
| 898 | + finally: |
| 899 | + decider.destroy() |
| 900 | + |
| 901 | + def respond_query(self, task_token: bytes, result: bytes = None, error_message: str = None): |
| 902 | + service = self.service |
| 903 | + request = RespondQueryTaskCompletedRequest() |
| 904 | + request.task_token = task_token |
| 905 | + if result: |
| 906 | + request.query_result = result |
| 907 | + request.completed_type = QueryTaskCompletedType.COMPLETED |
| 908 | + else: |
| 909 | + request.error_message = error_message |
| 910 | + request.completed_type = QueryTaskCompletedType.FAILED |
| 911 | + _, err = service.respond_query_task_completed(request) |
| 912 | + if err: |
| 913 | + logger.error("Error invoking RespondDecisionTaskCompleted: %s", err) |
| 914 | + else: |
| 915 | + logger.debug("RespondQueryTaskCompleted successful") |
| 916 | + |
818 | 917 | def respond_decisions(self, task_token: bytes, decisions: List[Decision]): |
819 | 918 | service = self.service |
820 | 919 | request = RespondDecisionTaskCompletedRequest() |
|
0 commit comments