Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions azure/durable_functions/models/DurableOrchestrationClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,44 @@ async def suspend(self, instance_id: str, reason: str) -> None:
if error_message:
raise Exception(error_message)

async def restart(self, instance_id: str,
Comment thread
nytian marked this conversation as resolved.
restart_with_new_instance_id: bool = True) -> str:
"""Restart an orchestration instance with its original input.

Parameters
----------
instance_id : str
The ID of the orchestration instance to restart.
restart_with_new_instance_id : bool
If True, the restarted instance will use a new instance ID.
If False, the restarted instance will reuse the original instance ID.

Raises
------
Exception:
When the instance with the given ID is not found.

Returns
-------
str
The instance ID of the restarted orchestration.
"""
status = await self.get_status(instance_id, show_input=True)

if not status or status.name is None:
raise Exception(
f"An orchestration with the instanceId {instance_id} was not found.")
Comment thread
nytian marked this conversation as resolved.
Outdated

if restart_with_new_instance_id:
return await self.start_new(
orchestration_function_name=status.name,
client_input=status.input_)
else:
return await self.start_new(
orchestration_function_name=status.name,
instance_id=status.instance_id,
client_input=status.input_)

async def resume(self, instance_id: str, reason: str) -> None:
"""Resume the specified orchestration instance.

Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def get_binding_string():
"resumePostUri": f"{BASE_URL}/instances/INSTANCEID/resume?reason="
"{text}&taskHub="
f"{TASK_HUB_NAME}&connection=Storage&code={AUTH_CODE}",
"restartPostUri": f"{BASE_URL}/instances/INSTANCEID/restart?taskHub="
f"{TASK_HUB_NAME}&connection=Storage&code={AUTH_CODE}",
},
"rpcBaseUrl": RPC_BASE_URL
}
Expand Down
76 changes: 75 additions & 1 deletion tests/models/test_DurableOrchestrationClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ def test_create_check_status_response(binding_string):
"resumePostUri":
r"http://test_azure.net/runtime/webhooks/durabletask/instances/"
r"2e2568e7-a906-43bd-8364-c81733c5891e/resume"
r"?reason={text}&taskHub=TASK_HUB_NAME&connection=Storage&code=AUTH_CODE"
r"?reason={text}&taskHub=TASK_HUB_NAME&connection=Storage&code=AUTH_CODE",
"restartPostUri":
r"http://test_azure.net/runtime/webhooks/durabletask/instances/"
r"2e2568e7-a906-43bd-8364-c81733c5891e/restart"
r"?taskHub=TASK_HUB_NAME&connection=Storage&code=AUTH_CODE"
}
for key, _ in http_management_payload.items():
http_management_payload[key] = replace_stand_in_bits(http_management_payload[key])
Expand Down Expand Up @@ -742,6 +746,76 @@ async def test_post_500_resume(binding_string):
await client.resume(TEST_INSTANCE_ID, raw_reason)


@pytest.mark.asyncio
async def test_restart_with_new_instance_id(binding_string):
"""Test restart creates a new instance with a new ID by default."""
orchestrator_name = "MyOrchestrator"
original_input = {"key": "value"}
new_instance_id = "new-instance-id-1234"

get_mock = MockRequest(
expected_url=f"{RPC_BASE_URL}instances/{TEST_INSTANCE_ID}?showInput=True",
response=[200, dict(
name=orchestrator_name,
instanceId=TEST_INSTANCE_ID,
createdTime=TEST_CREATED_TIME,
lastUpdatedTime=TEST_LAST_UPDATED_TIME,
runtimeStatus="Completed",
input=original_input)])

post_mock = MockRequest(
expected_url=f"{RPC_BASE_URL}orchestrators/{orchestrator_name}",
response=[202, {"id": new_instance_id}])

client = DurableOrchestrationClient(binding_string)
client._get_async_request = get_mock.get
client._post_async_request = post_mock.post

result = await client.restart(TEST_INSTANCE_ID)
assert result == new_instance_id


@pytest.mark.asyncio
async def test_restart_with_same_instance_id(binding_string):
"""Test restart reuses the original instance ID when restartWithNewInstanceId is False."""
orchestrator_name = "MyOrchestrator"
original_input = {"key": "value"}

get_mock = MockRequest(
expected_url=f"{RPC_BASE_URL}instances/{TEST_INSTANCE_ID}?showInput=True",
response=[200, dict(
name=orchestrator_name,
instanceId=TEST_INSTANCE_ID,
createdTime=TEST_CREATED_TIME,
lastUpdatedTime=TEST_LAST_UPDATED_TIME,
runtimeStatus="Completed",
input=original_input)])

post_mock = MockRequest(
expected_url=f"{RPC_BASE_URL}orchestrators/{orchestrator_name}/{TEST_INSTANCE_ID}",
response=[202, {"id": TEST_INSTANCE_ID}])

client = DurableOrchestrationClient(binding_string)
client._get_async_request = get_mock.get
client._post_async_request = post_mock.post

result = await client.restart(TEST_INSTANCE_ID, restart_with_new_instance_id=False)
assert result == TEST_INSTANCE_ID


@pytest.mark.asyncio
async def test_restart_instance_not_found(binding_string):
"""Test restart raises exception when instance is not found."""
get_mock = MockRequest(
expected_url=f"{RPC_BASE_URL}instances/{TEST_INSTANCE_ID}?showInput=True",
response=[404, dict(createdTime=None, lastUpdatedTime=None)])

client = DurableOrchestrationClient(binding_string)
client._get_async_request = get_mock.get

with pytest.raises(Exception) as ex:
await client.restart(TEST_INSTANCE_ID)
assert f"instanceId {TEST_INSTANCE_ID} was not found" in str(ex.value)
# Tests for function_invocation_id parameter
def test_client_stores_function_invocation_id(binding_string):
"""Test that the client stores the function_invocation_id parameter."""
Expand Down