Skip to content

Commit f5e6fbb

Browse files
authored
Add FunctionInvocationId header propagation for client operations (#593)
* Add FunctionInvocationId header propagation for client operations Propagates the Azure Functions invocation ID to the Durable Functions host via the X-Azure-Functions-InvocationId HTTP header, enabling correlation between worker-side function invocations and host-side orchestration events. - Modified http_utils.py to accept optional function_invocation_id parameter - Updated DurableOrchestrationClient to pass invocation ID to HTTP calls - Added optional function_invocation_id parameter to DurableApp.client decorator - Added unit tests for header propagation Related to Azure/azure-functions-durable-extension#3317
1 parent d9cf4b8 commit f5e6fbb

6 files changed

Lines changed: 144 additions & 16 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## Unreleased
6+
7+
### Added
8+
9+
- Client operation correlation logging: `FunctionInvocationId` is now propagated via HTTP headers to the host for client operations, enabling correlation with host logs.
10+
511
## 1.0.0b6
612

713
- [Create timer](https://github.com/Azure/azure-functions-durable-python/issues/35) functionality available

azure/durable_functions/decorators/durable_app.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,14 @@ async def df_client_middleware(*args, **kwargs):
195195
# construct rich object from it,
196196
# and assign parameter to that rich object
197197
starter = kwargs[parameter_name]
198-
client = client_constructor(starter)
198+
199+
# Try to extract the function invocation ID from the context for correlation
200+
function_invocation_id = None
201+
context = kwargs.get('context')
202+
if context is not None and hasattr(context, 'invocation_id'):
203+
function_invocation_id = context.invocation_id
204+
205+
client = client_constructor(starter, function_invocation_id)
199206
kwargs[parameter_name] = client
200207

201208
# Invoke user code with rich DF Client binding

azure/durable_functions/models/DurableOrchestrationClient.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,16 @@ class DurableOrchestrationClient:
2626
orchestration instances.
2727
"""
2828

29-
def __init__(self, context: str):
29+
def __init__(self, context: str, function_invocation_id: Optional[str] = None):
30+
"""Initialize a DurableOrchestrationClient.
31+
32+
Parameters
33+
----------
34+
context : str
35+
The JSON-encoded client binding context.
36+
function_invocation_id : Optional[str]
37+
The function invocation ID for correlation with host-side logs.
38+
"""
3039
self.task_hub_name: str
3140
self._uniqueWebHookOrigins: List[str]
3241
self._event_name_placeholder: str = "{eventName}"
@@ -39,6 +48,7 @@ def __init__(self, context: str):
3948
self._show_history_query_key: str = "showHistory"
4049
self._show_history_output_query_key: str = "showHistoryOutput"
4150
self._show_input_query_key: str = "showInput"
51+
self._function_invocation_id: Optional[str] = function_invocation_id
4252
self._orchestration_bindings: DurableOrchestrationBindings = \
4353
DurableOrchestrationBindings.from_json(context)
4454
self._post_async_request = post_async_request
@@ -84,7 +94,8 @@ async def start_new(self,
8494
request_url,
8595
self._get_json_input(client_input),
8696
trace_parent,
87-
trace_state)
97+
trace_state,
98+
self._function_invocation_id)
8899

89100
status_code: int = response[0]
90101
if status_code <= 202 and response[1]:
@@ -256,7 +267,10 @@ async def raise_event(
256267
request_url = self._get_raise_event_url(
257268
instance_id, event_name, task_hub_name, connection_name)
258269

259-
response = await self._post_async_request(request_url, json.dumps(event_data))
270+
response = await self._post_async_request(
271+
request_url,
272+
json.dumps(event_data),
273+
function_invocation_id=self._function_invocation_id)
260274

261275
switch_statement = {
262276
202: lambda: None,
@@ -445,7 +459,10 @@ async def terminate(self, instance_id: str, reason: str) -> None:
445459
"""
446460
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
447461
f"terminate?reason={quote(reason)}"
448-
response = await self._post_async_request(request_url, None)
462+
response = await self._post_async_request(
463+
request_url,
464+
None,
465+
function_invocation_id=self._function_invocation_id)
449466
switch_statement = {
450467
202: lambda: None, # instance in progress
451468
410: lambda: None, # instance failed or terminated
@@ -564,7 +581,8 @@ async def signal_entity(self, entityId: EntityId, operation_name: str,
564581
request_url,
565582
json.dumps(operation_input) if operation_input else None,
566583
trace_parent,
567-
trace_state)
584+
trace_state,
585+
self._function_invocation_id)
568586

569587
switch_statement = {
570588
202: lambda: None # signal accepted
@@ -714,7 +732,10 @@ async def rewind(self,
714732
raise Exception("The Python SDK only supports RPC endpoints."
715733
+ "Please remove the `localRpcEnabled` setting from host.json")
716734

717-
response = await self._post_async_request(request_url, None)
735+
response = await self._post_async_request(
736+
request_url,
737+
None,
738+
function_invocation_id=self._function_invocation_id)
718739
status: int = response[0]
719740
ex_msg: str = ""
720741
if status == 200 or status == 202:
@@ -753,7 +774,10 @@ async def suspend(self, instance_id: str, reason: str) -> None:
753774
"""
754775
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
755776
f"suspend?reason={quote(reason)}"
756-
response = await self._post_async_request(request_url, None)
777+
response = await self._post_async_request(
778+
request_url,
779+
None,
780+
function_invocation_id=self._function_invocation_id)
757781
switch_statement = {
758782
202: lambda: None, # instance is suspended
759783
410: lambda: None, # instance completed
@@ -788,7 +812,10 @@ async def resume(self, instance_id: str, reason: str) -> None:
788812
"""
789813
request_url = f"{self._orchestration_bindings.rpc_base_url}instances/{instance_id}/" \
790814
f"resume?reason={quote(reason)}"
791-
response = await self._post_async_request(request_url, None)
815+
response = await self._post_async_request(
816+
request_url,
817+
None,
818+
function_invocation_id=self._function_invocation_id)
792819
switch_statement = {
793820
202: lambda: None, # instance is resumed
794821
410: lambda: None, # instance completed

azure/durable_functions/models/TaskOrchestrationExecutor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def get_orchestrator_state_str(self) -> str:
276276
message contains in it the string representation of the orchestration's
277277
state
278278
"""
279-
if(self.output is not None):
279+
if (self.output is not None):
280280
try:
281281
# Attempt to serialize the output. If serialization fails, raise an
282282
# error indicating that the orchestration output is not serializable,

azure/durable_functions/models/utils/http_utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ async def _close_session() -> None:
8080
async def post_async_request(url: str,
8181
data: Any = None,
8282
trace_parent: str = None,
83-
trace_state: str = None) -> List[Union[int, Any]]:
83+
trace_state: str = None,
84+
function_invocation_id: str = None) -> List[Union[int, Any]]:
8485
"""Post request with the data provided to the url provided.
8586
8687
Parameters
@@ -93,6 +94,8 @@ async def post_async_request(url: str,
9394
traceparent header to send with the request
9495
trace_state: str
9596
tracestate header to send with the request
97+
function_invocation_id: str
98+
function invocation ID header to send for correlation
9699
97100
Returns
98101
-------
@@ -105,6 +108,8 @@ async def post_async_request(url: str,
105108
headers["traceparent"] = trace_parent
106109
if trace_state:
107110
headers["tracestate"] = trace_state
111+
if function_invocation_id:
112+
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
108113

109114
try:
110115
async with session.post(url, json=data, headers=headers) as response:
@@ -120,23 +125,29 @@ async def post_async_request(url: str,
120125
raise
121126

122127

123-
async def get_async_request(url: str) -> List[Any]:
128+
async def get_async_request(url: str,
129+
function_invocation_id: str = None) -> List[Any]:
124130
"""Get the data from the url provided.
125131
126132
Parameters
127133
----------
128134
url: str
129135
url to get the data from
136+
function_invocation_id: str
137+
function invocation ID header to send for correlation
130138
131139
Returns
132140
-------
133141
[int, Any]
134142
Tuple with the Response status code and the data returned from the request
135143
"""
136144
session = await _get_session()
145+
headers = {}
146+
if function_invocation_id:
147+
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
137148

138149
try:
139-
async with session.get(url) as response:
150+
async with session.get(url, headers=headers) as response:
140151
data = await response.json(content_type=None)
141152
if data is None:
142153
data = ""
@@ -147,23 +158,29 @@ async def get_async_request(url: str) -> List[Any]:
147158
raise
148159

149160

150-
async def delete_async_request(url: str) -> List[Union[int, Any]]:
161+
async def delete_async_request(url: str,
162+
function_invocation_id: str = None) -> List[Union[int, Any]]:
151163
"""Delete the data from the url provided.
152164
153165
Parameters
154166
----------
155167
url: str
156168
url to delete the data from
169+
function_invocation_id: str
170+
function invocation ID header to send for correlation
157171
158172
Returns
159173
-------
160174
[int, Any]
161175
Tuple with the Response status code and the data returned from the request
162176
"""
163177
session = await _get_session()
178+
headers = {}
179+
if function_invocation_id:
180+
headers["X-Azure-Functions-InvocationId"] = function_invocation_id
164181

165182
try:
166-
async with session.delete(url) as response:
183+
async with session.delete(url, headers=headers) as response:
167184
data = await response.json(content_type=None)
168185
return [response.status, data]
169186
except (aiohttp.ClientError, asyncio.TimeoutError):

tests/models/test_DurableOrchestrationClient.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ async def delete(self, url: str):
6767
assert url == self._expected_url
6868
return self._response
6969

70-
async def post(self, url: str, data: Any = None, trace_parent: str = None, trace_state: str = None):
70+
async def post(self, url: str, data: Any = None, trace_parent: str = None,
71+
trace_state: str = None, function_invocation_id: str = None):
7172
assert url == self._expected_url
7273
return self._response
7374

@@ -739,3 +740,73 @@ async def test_post_500_resume(binding_string):
739740

740741
with pytest.raises(Exception):
741742
await client.resume(TEST_INSTANCE_ID, raw_reason)
743+
744+
745+
# Tests for function_invocation_id parameter
746+
def test_client_stores_function_invocation_id(binding_string):
747+
"""Test that the client stores the function_invocation_id parameter."""
748+
invocation_id = "test-invocation-123"
749+
client = DurableOrchestrationClient(binding_string, function_invocation_id=invocation_id)
750+
assert client._function_invocation_id == invocation_id
751+
752+
753+
def test_client_stores_none_when_no_invocation_id(binding_string):
754+
"""Test that the client stores None when no invocation ID is provided."""
755+
client = DurableOrchestrationClient(binding_string)
756+
assert client._function_invocation_id is None
757+
758+
759+
class MockRequestWithInvocationId:
760+
"""Mock request class that verifies function_invocation_id is passed."""
761+
762+
def __init__(self, expected_url: str, response: [int, any], expected_invocation_id: str = None):
763+
self._expected_url = expected_url
764+
self._response = response
765+
self._expected_invocation_id = expected_invocation_id
766+
self._received_invocation_id = None
767+
768+
@property
769+
def received_invocation_id(self):
770+
return self._received_invocation_id
771+
772+
async def post(self, url: str, data: Any = None, trace_parent: str = None,
773+
trace_state: str = None, function_invocation_id: str = None):
774+
assert url == self._expected_url
775+
self._received_invocation_id = function_invocation_id
776+
if self._expected_invocation_id is not None:
777+
assert function_invocation_id == self._expected_invocation_id
778+
return self._response
779+
780+
781+
@pytest.mark.asyncio
782+
async def test_start_new_passes_invocation_id(binding_string):
783+
"""Test that start_new passes the function_invocation_id to the HTTP request."""
784+
invocation_id = "test-invocation-456"
785+
function_name = "MyOrchestrator"
786+
787+
mock_request = MockRequestWithInvocationId(
788+
expected_url=f"{RPC_BASE_URL}orchestrators/{function_name}",
789+
response=[202, {"id": TEST_INSTANCE_ID}],
790+
expected_invocation_id=invocation_id)
791+
792+
client = DurableOrchestrationClient(binding_string, function_invocation_id=invocation_id)
793+
client._post_async_request = mock_request.post
794+
795+
await client.start_new(function_name)
796+
assert mock_request.received_invocation_id == invocation_id
797+
798+
799+
@pytest.mark.asyncio
800+
async def test_start_new_passes_none_when_no_invocation_id(binding_string):
801+
"""Test that start_new passes None when no invocation ID is provided."""
802+
function_name = "MyOrchestrator"
803+
804+
mock_request = MockRequestWithInvocationId(
805+
expected_url=f"{RPC_BASE_URL}orchestrators/{function_name}",
806+
response=[202, {"id": TEST_INSTANCE_ID}])
807+
808+
client = DurableOrchestrationClient(binding_string)
809+
client._post_async_request = mock_request.post
810+
811+
await client.start_new(function_name)
812+
assert mock_request.received_invocation_id is None

0 commit comments

Comments
 (0)