Skip to content

Commit 1138ff2

Browse files
authored
Context manager for logger handlers in tests (#1293)
1 parent eb513cc commit 1138ff2

4 files changed

Lines changed: 114 additions & 103 deletions

File tree

tests/helpers/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import socket
66
import time
77
import uuid
8-
from collections.abc import Awaitable, Callable, Sequence
8+
from collections.abc import Awaitable, Callable, Iterator, Sequence
99
from contextlib import closing, contextmanager
1010
from dataclasses import dataclass
1111
from datetime import datetime, timedelta, timezone
@@ -440,3 +440,16 @@ def find(
440440
if pred(record):
441441
return record
442442
return None
443+
444+
445+
class LogHandler:
446+
@staticmethod
447+
@contextmanager
448+
def apply(logger: logging.Logger, handler: logging.Handler) -> Iterator[None]:
449+
level = logger.level
450+
logger.addHandler(handler)
451+
try:
452+
yield
453+
finally:
454+
logger.removeHandler(handler)
455+
logger.level = level

tests/test_runtime.py

Lines changed: 69 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from temporalio.worker import Worker
2424
from tests.helpers import (
25+
LogHandler,
2526
assert_eq_eventually,
2627
assert_eventually,
2728
find_free_port,
@@ -80,8 +81,8 @@ async def test_runtime_log_forwarding():
8081
# Create logger with record capture
8182
log_queue: queue.Queue[logging.LogRecord] = queue.Queue()
8283
log_queue_list = cast(list[logging.LogRecord], log_queue.queue)
84+
handler = logging.handlers.QueueHandler(log_queue)
8385
logger = logging.getLogger(f"log-{uuid.uuid4()}")
84-
logger.addHandler(logging.handlers.QueueHandler(log_queue))
8586

8687
async def log_queue_len() -> int:
8788
return len(log_queue_list)
@@ -96,49 +97,50 @@ async def log_queue_len() -> int:
9697
)
9798
)
9899

99-
# Set capture only info logs
100-
logger.setLevel(logging.INFO)
101-
# Write some logs
102-
runtime._core_runtime.write_test_info_log("info1", "extra1")
103-
runtime._core_runtime.write_test_debug_log("debug2", "extra2")
104-
runtime._core_runtime.write_test_info_log("info3", "extra3")
105-
106-
# Check the expected records
107-
await assert_eq_eventually(2, log_queue_len)
108-
assert log_queue_list[0].levelno == logging.INFO
109-
assert log_queue_list[0].message.startswith(
110-
"[sdk_core::temporal_sdk_bridge::runtime] info1"
111-
)
112-
assert (
113-
log_queue_list[0].name
114-
== f"{logger.name}-sdk_core::temporal_sdk_bridge::runtime"
115-
)
116-
assert log_queue_list[0].created == log_queue_list[0].temporal_log.time # type: ignore
117-
assert log_queue_list[0].temporal_log.fields == {"extra_data": "extra1"} # type: ignore
118-
assert log_queue_list[1].levelno == logging.INFO
119-
assert log_queue_list[1].message.startswith(
120-
"[sdk_core::temporal_sdk_bridge::runtime] info3"
121-
)
100+
with LogHandler.apply(logger, handler):
101+
# Set capture only info logs
102+
logger.setLevel(logging.INFO)
103+
# Write some logs
104+
runtime._core_runtime.write_test_info_log("info1", "extra1")
105+
runtime._core_runtime.write_test_debug_log("debug2", "extra2")
106+
runtime._core_runtime.write_test_info_log("info3", "extra3")
107+
108+
# Check the expected records
109+
await assert_eq_eventually(2, log_queue_len)
110+
assert log_queue_list[0].levelno == logging.INFO
111+
assert log_queue_list[0].message.startswith(
112+
"[sdk_core::temporal_sdk_bridge::runtime] info1"
113+
)
114+
assert (
115+
log_queue_list[0].name
116+
== f"{logger.name}-sdk_core::temporal_sdk_bridge::runtime"
117+
)
118+
assert log_queue_list[0].created == log_queue_list[0].temporal_log.time # type: ignore
119+
assert log_queue_list[0].temporal_log.fields == {"extra_data": "extra1"} # type: ignore
120+
assert log_queue_list[1].levelno == logging.INFO
121+
assert log_queue_list[1].message.startswith(
122+
"[sdk_core::temporal_sdk_bridge::runtime] info3"
123+
)
122124

123-
# Clear logs and enable debug and try again
124-
log_queue_list.clear()
125-
logger.setLevel(logging.DEBUG)
126-
runtime._core_runtime.write_test_info_log("info4", "extra4")
127-
runtime._core_runtime.write_test_debug_log("debug5", "extra5")
128-
runtime._core_runtime.write_test_info_log("info6", "extra6")
129-
await assert_eq_eventually(3, log_queue_len)
130-
assert log_queue_list[0].levelno == logging.INFO
131-
assert log_queue_list[0].message.startswith(
132-
"[sdk_core::temporal_sdk_bridge::runtime] info4"
133-
)
134-
assert log_queue_list[1].levelno == logging.DEBUG
135-
assert log_queue_list[1].message.startswith(
136-
"[sdk_core::temporal_sdk_bridge::runtime] debug5"
137-
)
138-
assert log_queue_list[2].levelno == logging.INFO
139-
assert log_queue_list[2].message.startswith(
140-
"[sdk_core::temporal_sdk_bridge::runtime] info6"
141-
)
125+
# Clear logs and enable debug and try again
126+
log_queue_list.clear()
127+
logger.setLevel(logging.DEBUG)
128+
runtime._core_runtime.write_test_info_log("info4", "extra4")
129+
runtime._core_runtime.write_test_debug_log("debug5", "extra5")
130+
runtime._core_runtime.write_test_info_log("info6", "extra6")
131+
await assert_eq_eventually(3, log_queue_len)
132+
assert log_queue_list[0].levelno == logging.INFO
133+
assert log_queue_list[0].message.startswith(
134+
"[sdk_core::temporal_sdk_bridge::runtime] info4"
135+
)
136+
assert log_queue_list[1].levelno == logging.DEBUG
137+
assert log_queue_list[1].message.startswith(
138+
"[sdk_core::temporal_sdk_bridge::runtime] debug5"
139+
)
140+
assert log_queue_list[2].levelno == logging.INFO
141+
assert log_queue_list[2].message.startswith(
142+
"[sdk_core::temporal_sdk_bridge::runtime] info6"
143+
)
142144

143145

144146
@workflow.defn
@@ -152,8 +154,8 @@ async def test_runtime_task_fail_log_forwarding(client: Client):
152154
# Client with lo capturing runtime
153155
log_queue: queue.Queue[logging.LogRecord] = queue.Queue()
154156
log_queue_list = cast(list[logging.LogRecord], log_queue.queue)
157+
handler = logging.handlers.QueueHandler(log_queue)
155158
logger = logging.getLogger(f"log-{uuid.uuid4()}")
156-
logger.addHandler(logging.handlers.QueueHandler(log_queue))
157159
logger.setLevel(logging.WARN)
158160
client = await Client.connect(
159161
client.service_client.config.target_host,
@@ -168,30 +170,32 @@ async def test_runtime_task_fail_log_forwarding(client: Client):
168170
),
169171
)
170172

171-
# Start workflow
172-
task_queue = f"task-queue-{uuid.uuid4()}"
173-
async with Worker(client, task_queue=task_queue, workflows=[TaskFailWorkflow]):
174-
handle = await client.start_workflow(
175-
TaskFailWorkflow.run,
176-
id=f"workflow-{uuid.uuid4()}",
177-
task_queue=task_queue,
178-
)
179-
180-
# Wait for log to appear
181-
async def has_log() -> bool:
182-
return any(
183-
l for l in log_queue_list if "Failing workflow task" in l.message
173+
with LogHandler.apply(logger, handler):
174+
# Start workflow
175+
task_queue = f"task-queue-{uuid.uuid4()}"
176+
async with Worker(client, task_queue=task_queue, workflows=[TaskFailWorkflow]):
177+
handle = await client.start_workflow(
178+
TaskFailWorkflow.run,
179+
id=f"workflow-{uuid.uuid4()}",
180+
task_queue=task_queue,
184181
)
185182

186-
await assert_eq_eventually(True, has_log)
183+
# Wait for log to appear
184+
async def has_log() -> bool:
185+
return any(
186+
l for l in log_queue_list if "Failing workflow task" in l.message
187+
)
187188

188-
# Check record
189-
record = next(l for l in log_queue_list if "Failing workflow task" in l.message)
190-
assert record.levelno == logging.WARNING
191-
assert (
192-
record.name == f"{logger.name}-sdk_core::temporalio_sdk_core::worker::workflow"
193-
)
194-
assert record.temporal_log.fields["run_id"] == handle.result_run_id # type: ignore
189+
await assert_eq_eventually(True, has_log)
190+
191+
# Check record
192+
record = next(l for l in log_queue_list if "Failing workflow task" in l.message)
193+
assert record.levelno == logging.WARNING
194+
assert (
195+
record.name
196+
== f"{logger.name}-sdk_core::temporalio_sdk_core::worker::workflow"
197+
)
198+
assert record.temporal_log.fields["run_id"] == handle.result_run_id # type: ignore
195199

196200

197201
async def test_prometheus_histogram_bucket_overrides(client: Client):

tests/worker/test_activity.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
Worker,
4747
WorkerConfig,
4848
)
49+
from tests.helpers import LogHandler
4950
from tests.helpers.worker import (
5051
ExternalWorker,
5152
KSAction,
@@ -1019,20 +1020,15 @@ async def say_hello(name: str) -> str:
10191020

10201021
# Create a queue, add handler to logger, call normal activity, then check
10211022
handler = logging.handlers.QueueHandler(queue.Queue())
1022-
activity.logger.base_logger.addHandler(handler)
1023-
prev_level = activity.logger.base_logger.level
1024-
activity.logger.base_logger.setLevel(logging.INFO)
1025-
try:
1023+
with LogHandler.apply(activity.logger.base_logger, handler):
1024+
activity.logger.base_logger.setLevel(logging.INFO)
10261025
result = await _execute_workflow_with_activity(
10271026
client,
10281027
worker,
10291028
say_hello,
10301029
"Temporal",
10311030
shared_state_manager=shared_state_manager,
10321031
)
1033-
finally:
1034-
activity.logger.base_logger.removeHandler(handler)
1035-
activity.logger.base_logger.setLevel(prev_level)
10361032
assert result.result == "Hello, Temporal!"
10371033
records: list[logging.LogRecord] = list(handler.queue.queue) # type: ignore
10381034
assert len(records) > 0
@@ -1671,9 +1667,8 @@ async def raise_error():
16711667
raise RuntimeError("oh no!")
16721668

16731669
handler = CustomLogHandler()
1674-
activity.logger.base_logger.addHandler(handler)
16751670

1676-
try:
1671+
with LogHandler.apply(activity.logger.base_logger, handler):
16771672
with pytest.raises(WorkflowFailureError) as err:
16781673
await _execute_workflow_with_activity(
16791674
client,
@@ -1686,9 +1681,6 @@ async def raise_error():
16861681
)
16871682
assert handler._trace_identifiers == 1
16881683

1689-
finally:
1690-
activity.logger.base_logger.removeHandler(CustomLogHandler())
1691-
16921684

16931685
async def test_activity_heartbeat_context(
16941686
client: Client,

tests/worker/test_workflow.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
from tests import DEV_SERVER_DOWNLOAD_VERSION
123123
from tests.helpers import (
124124
LogCapturer,
125+
LogHandler,
125126
admitted_update_task,
126127
assert_eq_eventually,
127128
assert_eventually,
@@ -8445,41 +8446,42 @@ async def run(self):
84458446

84468447
class CustomLogHandler(logging.Handler):
84478448
def emit(self, record: logging.LogRecord) -> None:
8448-
import httpx # type: ignore[reportUnusedImport]
8449+
import httpx # type: ignore[reportUnusedImport] # noqa
84498450

84508451

84518452
async def test_disable_logger_sandbox(
84528453
client: Client,
84538454
):
84548455
logger = workflow.logger.logger
8455-
logger.addHandler(CustomLogHandler())
8456-
async with new_worker(
8457-
client,
8458-
DisableLoggerSandbox,
8459-
activities=[],
8460-
) as worker:
8461-
with pytest.raises(WorkflowFailureError):
8462-
await client.execute_workflow(
8463-
DisableLoggerSandbox.run,
8464-
id=f"workflow-{uuid.uuid4()}",
8465-
task_queue=worker.task_queue,
8466-
run_timeout=timedelta(seconds=1),
8467-
retry_policy=RetryPolicy(maximum_attempts=1),
8468-
)
8469-
workflow.logger.unsafe_disable_sandbox()
8470-
await client.execute_workflow(
8471-
DisableLoggerSandbox.run,
8472-
id=f"workflow-{uuid.uuid4()}",
8473-
task_queue=worker.task_queue,
8474-
run_timeout=timedelta(seconds=1),
8475-
retry_policy=RetryPolicy(maximum_attempts=1),
8476-
)
8477-
workflow.logger.unsafe_disable_sandbox(False)
8478-
with pytest.raises(WorkflowFailureError):
8456+
handler = CustomLogHandler()
8457+
with LogHandler.apply(logger, handler):
8458+
async with new_worker(
8459+
client,
8460+
DisableLoggerSandbox,
8461+
activities=[],
8462+
) as worker:
8463+
with pytest.raises(WorkflowFailureError):
8464+
await client.execute_workflow(
8465+
DisableLoggerSandbox.run,
8466+
id=f"workflow-{uuid.uuid4()}",
8467+
task_queue=worker.task_queue,
8468+
run_timeout=timedelta(seconds=1),
8469+
retry_policy=RetryPolicy(maximum_attempts=1),
8470+
)
8471+
workflow.logger.unsafe_disable_sandbox()
84798472
await client.execute_workflow(
84808473
DisableLoggerSandbox.run,
84818474
id=f"workflow-{uuid.uuid4()}",
84828475
task_queue=worker.task_queue,
84838476
run_timeout=timedelta(seconds=1),
84848477
retry_policy=RetryPolicy(maximum_attempts=1),
84858478
)
8479+
workflow.logger.unsafe_disable_sandbox(False)
8480+
with pytest.raises(WorkflowFailureError):
8481+
await client.execute_workflow(
8482+
DisableLoggerSandbox.run,
8483+
id=f"workflow-{uuid.uuid4()}",
8484+
task_queue=worker.task_queue,
8485+
run_timeout=timedelta(seconds=1),
8486+
retry_policy=RetryPolicy(maximum_attempts=1),
8487+
)

0 commit comments

Comments
 (0)