Skip to content

Commit ea83a73

Browse files
committed
feat: extract sensitive field redaction logic into a dedicated utility module
1 parent b8dce5a commit ea83a73

5 files changed

Lines changed: 135 additions & 108 deletions

File tree

src/apcore/builtin_steps.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,10 @@ async def execute(self, ctx: PipelineContext) -> StepResult:
262262
result = await self._handler.check_approval(token)
263263
elif self._executor is not None and hasattr(self._executor, "_check_approval_async"):
264264
await self._executor._check_approval_async(
265-
module, ctx.module_id, ctx.inputs, ctx.context,
265+
module,
266+
ctx.module_id,
267+
ctx.inputs,
268+
ctx.context,
266269
)
267270
return StepResult(action="continue")
268271
else:
@@ -317,11 +320,15 @@ async def execute(self, ctx: PipelineContext) -> StepResult:
317320
try:
318321
if hasattr(self._middleware_manager, "execute_before_async"):
319322
inputs, executed = await self._middleware_manager.execute_before_async(
320-
ctx.module_id, ctx.inputs, ctx.context,
323+
ctx.module_id,
324+
ctx.inputs,
325+
ctx.context,
321326
)
322327
else:
323328
inputs, executed = self._middleware_manager.execute_before(
324-
ctx.module_id, ctx.inputs, ctx.context,
329+
ctx.module_id,
330+
ctx.inputs,
331+
ctx.context,
325332
)
326333
ctx.inputs = inputs
327334
ctx.executed_middlewares = list(executed)
@@ -333,11 +340,19 @@ async def execute(self, ctx: PipelineContext) -> StepResult:
333340
# on_error recovery for non-chain errors
334341
if hasattr(self._middleware_manager, "execute_on_error_async"):
335342
recovery = await self._middleware_manager.execute_on_error_async(
336-
ctx.module_id, ctx.inputs, exc, ctx.context, ctx.executed_middlewares,
343+
ctx.module_id,
344+
ctx.inputs,
345+
exc,
346+
ctx.context,
347+
ctx.executed_middlewares,
337348
)
338349
elif hasattr(self._middleware_manager, "execute_on_error"):
339350
recovery = self._middleware_manager.execute_on_error(
340-
ctx.module_id, ctx.inputs, exc, ctx.context, ctx.executed_middlewares,
351+
ctx.module_id,
352+
ctx.inputs,
353+
exc,
354+
ctx.context,
355+
ctx.executed_middlewares,
341356
)
342357
else:
343358
recovery = None
@@ -415,7 +430,7 @@ async def execute(self, ctx: PipelineContext) -> StepResult:
415430
# Redact sensitive fields after successful validation
416431
schema_dict_fn = getattr(input_schema, "model_json_schema", None)
417432
if schema_dict_fn is not None and callable(schema_dict_fn):
418-
from apcore.executor import redact_sensitive
433+
from apcore.utils.redaction import redact_sensitive
419434

420435
schema = schema_dict_fn()
421436
redacted = redact_sensitive(ctx.inputs, schema)
@@ -512,7 +527,8 @@ async def execute(self, ctx: PipelineContext) -> StepResult:
512527
except asyncio.TimeoutError:
513528
timeout_ms = int((timeout_s or 0) * 1000)
514529
raise ModuleTimeoutError(
515-
module_id=ctx.module_id, timeout_ms=timeout_ms,
530+
module_id=ctx.module_id,
531+
timeout_ms=timeout_ms,
516532
) from None
517533
except (ExecutionCancelledError, ModuleTimeoutError, InvalidInputError):
518534
raise
@@ -568,7 +584,7 @@ async def execute(self, ctx: PipelineContext) -> StepResult:
568584
if ctx.context is not None and hasattr(ctx.context, "data"):
569585
schema_dict_fn = getattr(output_schema, "model_json_schema", None)
570586
if schema_dict_fn is not None and callable(schema_dict_fn):
571-
from apcore.executor import redact_sensitive
587+
from apcore.utils.redaction import redact_sensitive
572588

573589
schema = schema_dict_fn()
574590
redacted = redact_sensitive(ctx.output, schema)
@@ -605,11 +621,17 @@ async def execute(self, ctx: PipelineContext) -> StepResult:
605621
if self._middleware_manager is not None:
606622
if hasattr(self._middleware_manager, "execute_after_async"):
607623
output = await self._middleware_manager.execute_after_async(
608-
ctx.module_id, ctx.inputs, ctx.output or {}, ctx.context,
624+
ctx.module_id,
625+
ctx.inputs,
626+
ctx.output or {},
627+
ctx.context,
609628
)
610629
else:
611630
output = self._middleware_manager.execute_after(
612-
ctx.module_id, ctx.inputs, ctx.output or {}, ctx.context,
631+
ctx.module_id,
632+
ctx.inputs,
633+
ctx.output or {},
634+
ctx.context,
613635
)
614636
ctx.output = output
615637
return StepResult(action="continue")

src/apcore/executor.py

Lines changed: 23 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,15 @@
88
from __future__ import annotations
99

1010
import asyncio
11-
import copy
1211
import dataclasses
13-
import inspect
1412
import logging
1513
import threading
16-
import time
1714
from collections.abc import AsyncIterator
1815
from typing import Any, Callable
1916

2017
import pydantic
2118

2219
from apcore.acl import ACL
23-
from apcore.context_keys import REDACTED_OUTPUT
2420
from apcore.approval import ApprovalHandler, ApprovalRequest, ApprovalResult
2521
from apcore.cancel import ExecutionCancelledError
2622
from apcore.config import Config
@@ -51,11 +47,10 @@
5147
StrategyNotFoundError,
5248
)
5349
from apcore.registry import MODULE_ID_PATTERN, Registry
54-
from apcore.utils.call_chain import guard_call_chain
5550

56-
__all__ = ["redact_sensitive", "REDACTED_VALUE", "Executor"]
51+
from apcore.utils.redaction import REDACTED_VALUE, redact_sensitive
5752

58-
REDACTED_VALUE: str = "***REDACTED***"
53+
__all__ = ["redact_sensitive", "REDACTED_VALUE", "Executor"]
5954

6055
_logger = logging.getLogger(__name__)
6156

@@ -123,71 +118,8 @@ def _convert_validation_errors(error: pydantic.ValidationError) -> list[dict[str
123118
# =============================================================================
124119

125120

126-
def redact_sensitive(data: dict[str, Any], schema_dict: dict[str, Any]) -> dict[str, Any]:
127-
"""Redact fields marked with x-sensitive in the schema.
128-
129-
Implements Algorithm A13 from PROTOCOL_SPEC section 9.5.
130-
Returns a deep copy of data with sensitive values replaced by "***REDACTED***".
131-
Also redacts any keys starting with "_secret_" regardless of schema.
132-
133-
Args:
134-
data: The data dict to redact.
135-
schema_dict: A JSON Schema dict that may contain "x-sensitive": true
136-
on individual properties.
137-
138-
Returns:
139-
A new dict with sensitive values replaced. Original data is not modified.
140-
"""
141-
redacted = copy.deepcopy(data)
142-
_redact_fields(redacted, schema_dict)
143-
_redact_secret_prefix(redacted)
144-
return redacted
145-
146-
147-
def _redact_fields(data: dict[str, Any], schema_dict: dict[str, Any]) -> None:
148-
"""In-place redaction based on schema x-sensitive markers."""
149-
properties = schema_dict.get("properties")
150-
if not properties:
151-
return
152-
153-
for field_name, field_schema in properties.items():
154-
if field_name not in data:
155-
continue
156-
157-
value = data[field_name]
158-
159-
# x-sensitive: true on this property
160-
if field_schema.get("x-sensitive") is True:
161-
if value is not None:
162-
data[field_name] = REDACTED_VALUE
163-
continue
164-
165-
# Nested object: recurse
166-
if field_schema.get("type") == "object" and "properties" in field_schema and isinstance(value, dict):
167-
_redact_fields(value, field_schema)
168-
continue
169-
170-
# Array: redact items
171-
if field_schema.get("type") == "array" and "items" in field_schema and isinstance(value, list):
172-
items_schema = field_schema["items"]
173-
if items_schema.get("x-sensitive") is True:
174-
for i, item in enumerate(value):
175-
if item is not None:
176-
value[i] = REDACTED_VALUE
177-
elif items_schema.get("type") == "object" and "properties" in items_schema:
178-
for item in value:
179-
if isinstance(item, dict):
180-
_redact_fields(item, items_schema)
181-
182-
183-
def _redact_secret_prefix(data: dict[str, Any]) -> None:
184-
"""In-place redaction of keys starting with _secret_."""
185-
for key in data:
186-
value = data[key]
187-
if key.startswith("_secret_") and value is not None:
188-
data[key] = REDACTED_VALUE
189-
elif isinstance(value, dict):
190-
_redact_secret_prefix(value)
121+
# redact_sensitive and REDACTED_VALUE moved to apcore.utils.redaction in v0.17
122+
# Re-exported here for backward compatibility.
191123

192124

193125
# =============================================================================
@@ -480,12 +412,8 @@ def validate(
480412
if loop is None:
481413
if self._sync_loop is None or self._sync_loop.is_closed():
482414
self._sync_loop = asyncio.new_event_loop()
483-
return self._sync_loop.run_until_complete(
484-
self._validate_async(module_id, inputs, context)
485-
)
486-
return self._run_in_new_thread(
487-
self._validate_async(module_id, inputs, context), module_id, None
488-
)
415+
return self._sync_loop.run_until_complete(self._validate_async(module_id, inputs, context))
416+
return self._run_in_new_thread(self._validate_async(module_id, inputs, context), module_id, None)
489417

490418
async def _validate_async(
491419
self,
@@ -550,7 +478,11 @@ async def _validate_async(
550478
requires_approval = self._needs_approval(pipe_ctx.module)
551479

552480
# Module-level preflight (optional)
553-
if pipe_ctx.module is not None and hasattr(pipe_ctx.module, "preflight") and callable(pipe_ctx.module.preflight):
481+
if (
482+
pipe_ctx.module is not None
483+
and hasattr(pipe_ctx.module, "preflight")
484+
and callable(pipe_ctx.module.preflight)
485+
):
554486
try:
555487
preflight_warnings = pipe_ctx.module.preflight(inputs, pipe_ctx.context)
556488
if isinstance(preflight_warnings, list) and preflight_warnings:
@@ -719,6 +651,7 @@ def _translate_abort(self, abort: PipelineAbortError) -> ModuleError:
719651
if step == "execute":
720652
if "cancelled" in explanation.lower():
721653
from apcore.cancel import ExecutionCancelledError
654+
722655
return ExecutionCancelledError()
723656
if "deadline" in explanation.lower() or "timed out" in explanation.lower():
724657
return ModuleTimeoutError(module_id="", timeout_ms=0)
@@ -876,7 +809,10 @@ async def stream(
876809
wrapped = propagate_error(exc, module_id, ctx_obj) if ctx_obj else exc
877810
if pipe_ctx.executed_middlewares:
878811
recovery = await self._middleware_manager.execute_on_error_async(
879-
module_id, pipe_ctx.inputs, wrapped, ctx_obj,
812+
module_id,
813+
pipe_ctx.inputs,
814+
wrapped,
815+
ctx_obj,
880816
pipe_ctx.executed_middlewares,
881817
)
882818
if recovery is not None:
@@ -902,7 +838,10 @@ async def stream(
902838
wrapped = propagate_error(exc, module_id, ctx_obj) if ctx_obj else exc
903839
if pipe_ctx.executed_middlewares:
904840
recovery = await self._middleware_manager.execute_on_error_async(
905-
module_id, pipe_ctx.inputs, wrapped, ctx_obj,
841+
module_id,
842+
pipe_ctx.inputs,
843+
wrapped,
844+
ctx_obj,
906845
pipe_ctx.executed_middlewares,
907846
)
908847
if recovery is not None:
@@ -912,8 +851,9 @@ async def stream(
912851

913852
# Phase 3: Output validation + middleware_after on accumulated result
914853
pipe_ctx.output = accumulated
915-
post_steps = [s for s in self._strategy.steps
916-
if s.name in ("output_validation", "middleware_after", "return_result")]
854+
post_steps = [
855+
s for s in self._strategy.steps if s.name in ("output_validation", "middleware_after", "return_result")
856+
]
917857
if post_steps:
918858
post_strategy = ExecutionStrategy("post_stream", post_steps)
919859
try:

src/apcore/pipeline.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,7 @@ async def run(
226226
timeout_ms = getattr(step, "timeout_ms", 0)
227227

228228
# (1) match_modules filter
229-
if match_modules is not None and not _any_match(
230-
match_modules, ctx.module_id
231-
):
229+
if match_modules is not None and not _any_match(match_modules, ctx.module_id):
232230
trace.steps.append(
233231
StepTrace(
234232
name=step.name,
@@ -306,9 +304,7 @@ async def run(
306304
duration = (time.monotonic() - step_start) * 1000
307305
# (4) ignore_errors: log and continue
308306
if ignore_errors:
309-
_logger.warning(
310-
"Step '%s' failed (ignored): %s", step.name, exc
311-
)
307+
_logger.warning("Step '%s' failed (ignored): %s", step.name, exc)
312308
trace.steps.append(
313309
StepTrace(
314310
name=step.name,
@@ -327,9 +323,7 @@ async def run(
327323
StepTrace(
328324
name=step.name,
329325
duration_ms=duration,
330-
result=StepResult(
331-
action="abort", explanation=str(exc)
332-
),
326+
result=StepResult(action="abort", explanation=str(exc)),
333327
)
334328
)
335329
trace.total_duration_ms = (time.monotonic() - start) * 1000

src/apcore/pipeline_config.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ def _resolve_step(step_def: dict[str, Any]) -> BaseStep:
115115
def _import_step(handler_path: str, name: str, config: dict[str, Any]) -> BaseStep:
116116
"""Import a step class from a handler path like 'myapp.steps:RateLimitStep'."""
117117
if ":" not in handler_path:
118-
raise ValueError(
119-
f"Invalid handler path '{handler_path}'. Expected format: 'module.path:ClassName'"
120-
)
118+
raise ValueError(f"Invalid handler path '{handler_path}'. Expected format: 'module.path:ClassName'")
121119
module_path, class_name = handler_path.split(":", 1)
122120

123121
import importlib
@@ -193,9 +191,7 @@ def build_strategy_from_config(
193191
if hasattr(step, key):
194192
setattr(step, key, value)
195193
else:
196-
_logger.warning(
197-
"Step '%s' has no field '%s'", step_name, key
198-
)
194+
_logger.warning("Step '%s' has no field '%s'", step_name, key)
199195
break
200196

201197
# (3) Resolve and insert custom steps

0 commit comments

Comments
 (0)