Skip to content

Commit ff68c39

Browse files
committed
refactor: decouple ContextKey from Context using a structural protocol and reformat code for consistency
Signed-off-by: tercel <tercel.yi@gmail.com>
1 parent 8f34971 commit ff68c39

10 files changed

Lines changed: 89 additions & 87 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "apcore"
7-
version = "0.15.1"
7+
version = "0.16.0"
88
description = "Schema-driven module standard for AI-perceivable interfaces"
99
readme = "README.md"
1010
requires-python = ">=3.11"

src/apcore/acl.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def register_condition(cls, key: str, handler: ACLConditionHandler) -> None:
8484

8585
@classmethod
8686
def _evaluate_conditions(
87-
cls, conditions: dict[str, Any], context: Context,
87+
cls,
88+
conditions: dict[str, Any],
89+
context: Context,
8890
) -> bool:
8991
"""Evaluate all conditions with AND logic. Fail-closed on unknown."""
9092
for key, value in conditions.items():
@@ -100,7 +102,8 @@ def _evaluate_conditions(
100102
if inspect.isawaitable(result):
101103
result.close() # prevent "coroutine never awaited" warning
102104
_logger.warning(
103-
"Async condition %r in sync context — treated as unsatisfied. Use async_check().", key,
105+
"Async condition %r in sync context — treated as unsatisfied. Use async_check().",
106+
key,
104107
)
105108
return False
106109
if not result:
@@ -109,7 +112,9 @@ def _evaluate_conditions(
109112

110113
@classmethod
111114
async def _evaluate_conditions_async(
112-
cls, conditions: dict[str, Any], context: Context,
115+
cls,
116+
conditions: dict[str, Any],
117+
context: Context,
113118
) -> bool:
114119
"""Async variant. Awaits async handlers, calls sync handlers directly."""
115120
for key, value in conditions.items():

src/apcore/builtin_steps.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,6 @@
1717
import pydantic
1818

1919
from apcore.context import Context
20-
from apcore.errors import (
21-
ACLDeniedError,
22-
ModuleNotFoundError,
23-
SchemaValidationError,
24-
)
25-
from apcore.executor import REDACTED_VALUE, redact_sensitive
2620
from apcore.pipeline import (
2721
BaseStep,
2822
ExecutionStrategy,
@@ -363,9 +357,7 @@ async def execute(self, ctx: PipelineContext) -> StepResult:
363357
output = await module.execute(inputs, ctx.context)
364358
else:
365359
loop = asyncio.get_event_loop()
366-
output = await loop.run_in_executor(
367-
None, module.execute, inputs, ctx.context
368-
)
360+
output = await loop.run_in_executor(None, module.execute, inputs, ctx.context)
369361
ctx.output = output
370362
except Exception as exc:
371363
return StepResult(action="abort", explanation=f"Execution error: {exc}")

src/apcore/context.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def serialize(self) -> dict[str, Any]:
143143
result["identity"] = None
144144
if self.redacted_inputs is not None:
145145
result["redacted_inputs"] = dict(self.redacted_inputs)
146-
result["data"] = {
147-
k: v for k, v in self.data.items() if not k.startswith("_")
148-
}
146+
result["data"] = {k: v for k, v in self.data.items() if not k.startswith("_")}
149147
return result
150148

151149
@classmethod
@@ -162,8 +160,7 @@ def deserialize(cls, data: dict[str, Any]) -> Context:
162160
version = data.get("_context_version", 1)
163161
if version > 1:
164162
_logger.warning(
165-
"Unknown _context_version %d (expected 1). "
166-
"Proceeding with best-effort deserialization.",
163+
"Unknown _context_version %d (expected 1). " "Proceeding with best-effort deserialization.",
167164
version,
168165
)
169166

src/apcore/context_key.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6-
from typing import Generic, TypeVar
7-
8-
from apcore.context import Context
6+
from typing import Any, Generic, Protocol, TypeVar
97

108
T = TypeVar("T")
119

1210
_MISSING = object()
1311

1412

13+
class _ContextLike(Protocol):
14+
"""Structural protocol for any object with a ``data`` mapping."""
15+
16+
data: dict[str, Any]
17+
18+
1519
@dataclass(frozen=True)
1620
class ContextKey(Generic[T]):
1721
"""Typed key for context.data with namespace isolation.
@@ -22,20 +26,20 @@ class ContextKey(Generic[T]):
2226

2327
name: str
2428

25-
def get(self, ctx: Context, default: T | None = None) -> T | None: # type: ignore[type-var]
29+
def get(self, ctx: _ContextLike, default: T | None = None) -> T | None: # type: ignore[type-var]
2630
"""Return the value for this key, or *default* if absent."""
2731
value = ctx.data.get(self.name, _MISSING)
2832
return default if value is _MISSING else value # type: ignore[return-value]
2933

30-
def set(self, ctx: Context, value: T) -> None:
34+
def set(self, ctx: _ContextLike, value: T) -> None:
3135
"""Store *value* under this key in context.data."""
3236
ctx.data[self.name] = value
3337

34-
def delete(self, ctx: Context) -> None:
38+
def delete(self, ctx: _ContextLike) -> None:
3539
"""Remove this key from context.data (no-op if absent)."""
3640
ctx.data.pop(self.name, None)
3741

38-
def exists(self, ctx: Context) -> bool:
42+
def exists(self, ctx: _ContextLike) -> bool:
3943
"""Return True if this key is present in context.data."""
4044
return self.name in ctx.data
4145

src/apcore/executor.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def __init__(
216216
)
217217
if strategy is None:
218218
from apcore.builtin_steps import build_standard_strategy
219+
219220
self._strategy = build_standard_strategy(**strategy_kwargs)
220221
elif isinstance(strategy, str):
221222
self._strategy = self._resolve_strategy_name(strategy, **strategy_kwargs)
@@ -416,9 +417,7 @@ def call(
416417
errors=_convert_validation_errors(e),
417418
) from e
418419

419-
REDACTED_OUTPUT.set(ctx, redact_sensitive(
420-
output, module.output_schema.model_json_schema()
421-
))
420+
REDACTED_OUTPUT.set(ctx, redact_sensitive(output, module.output_schema.model_json_schema()))
422421

423422
# Step 10 -- Middleware After
424423
output = self._middleware_manager.execute_after(module_id, inputs, output, ctx)
@@ -922,9 +921,7 @@ async def call_async(
922921
errors=_convert_validation_errors(e),
923922
) from e
924923

925-
REDACTED_OUTPUT.set(ctx, redact_sensitive(
926-
output, module.output_schema.model_json_schema()
927-
))
924+
REDACTED_OUTPUT.set(ctx, redact_sensitive(output, module.output_schema.model_json_schema()))
928925

929926
# Step 10 -- Middleware After (async-aware)
930927
output = await self._middleware_manager.execute_after_async(module_id, inputs, output, ctx)
@@ -1048,9 +1045,7 @@ async def stream(
10481045
errors=_convert_validation_errors(e),
10491046
) from e
10501047

1051-
REDACTED_OUTPUT.set(ctx, redact_sensitive(
1052-
output, module.output_schema.model_json_schema()
1053-
))
1048+
REDACTED_OUTPUT.set(ctx, redact_sensitive(output, module.output_schema.model_json_schema()))
10541049

10551050
# Step 10 -- Middleware After (async-aware)
10561051
output = await self._middleware_manager.execute_after_async(module_id, effective_inputs, output, ctx)
@@ -1073,9 +1068,7 @@ async def stream(
10731068
errors=_convert_validation_errors(e),
10741069
) from e
10751070

1076-
REDACTED_OUTPUT.set(ctx, redact_sensitive(
1077-
accumulated, module.output_schema.model_json_schema()
1078-
))
1071+
REDACTED_OUTPUT.set(ctx, redact_sensitive(accumulated, module.output_schema.model_json_schema()))
10791072

10801073
# Step 10 -- Middleware After on accumulated result (async-aware)
10811074
accumulated = await self._middleware_manager.execute_after_async(
@@ -1226,9 +1219,7 @@ def call_with_trace(
12261219

12271220
if loop is None:
12281221
return asyncio.run(engine.run(effective_strategy, pipe_ctx))
1229-
return self._run_in_new_thread(
1230-
engine.run(effective_strategy, pipe_ctx), module_id, None
1231-
)
1222+
return self._run_in_new_thread(engine.run(effective_strategy, pipe_ctx), module_id, None)
12321223

12331224
async def call_async_with_trace(
12341225
self,
@@ -1263,7 +1254,8 @@ async def call_async_with_trace(
12631254
return await engine.run(effective_strategy, pipe_ctx)
12641255

12651256
def _effective_strategy(
1266-
self, strategy: ExecutionStrategy | str | None,
1257+
self,
1258+
strategy: ExecutionStrategy | str | None,
12671259
) -> ExecutionStrategy:
12681260
"""Return the strategy to use for a call, resolving strings."""
12691261
if strategy is None:

src/apcore/pipeline.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ class PipelineEngine:
182182
"""Executes a pipeline strategy step by step."""
183183

184184
async def run(
185-
self, strategy: ExecutionStrategy, ctx: PipelineContext,
185+
self,
186+
strategy: ExecutionStrategy,
187+
ctx: PipelineContext,
186188
) -> tuple[Any, PipelineTrace]:
187189
"""Run all steps in the strategy, returning the final output and trace."""
188190
trace = PipelineTrace(
@@ -200,13 +202,15 @@ async def run(
200202
try:
201203
result = await step.execute(ctx)
202204
except Exception as exc:
203-
trace.steps.append(StepTrace(
204-
name=step.name,
205-
duration_ms=(time.monotonic() - step_start) * 1000,
206-
result=StepResult(action="abort", explanation=str(exc)),
207-
skipped=False,
208-
decision_point=False,
209-
))
205+
trace.steps.append(
206+
StepTrace(
207+
name=step.name,
208+
duration_ms=(time.monotonic() - step_start) * 1000,
209+
result=StepResult(action="abort", explanation=str(exc)),
210+
skipped=False,
211+
decision_point=False,
212+
)
213+
)
210214
trace.total_duration_ms = (time.monotonic() - start) * 1000
211215
raise
212216

@@ -235,13 +239,15 @@ async def run(
235239
target_idx = j
236240
break
237241
# Record skipped steps in trace
238-
trace.steps.append(StepTrace(
239-
name=steps[j].name,
240-
duration_ms=0,
241-
result=StepResult(action="continue"),
242-
skipped=True,
243-
decision_point=False,
244-
))
242+
trace.steps.append(
243+
StepTrace(
244+
name=steps[j].name,
245+
duration_ms=0,
246+
result=StepResult(action="continue"),
247+
skipped=True,
248+
decision_point=False,
249+
)
250+
)
245251
if target_idx is None:
246252
raise StepNotFoundError(
247253
f"skip_to target '{target}' not found",
@@ -253,9 +259,7 @@ async def run(
253259

254260
trace.success = True
255261
trace.total_duration_ms = (time.monotonic() - start) * 1000
256-
final_output = (
257-
ctx.validated_output if ctx.validated_output is not None else ctx.output
258-
)
262+
final_output = ctx.validated_output if ctx.validated_output is not None else ctx.output
259263
return final_output, trace
260264

261265

tests/test_acl_conditions.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,9 @@
1010

1111
from apcore.acl import ACL, ACLRule
1212
from apcore.acl_handlers import (
13-
ACLConditionHandler,
14-
AsyncACLConditionHandler,
1513
SyncACLConditionHandler,
1614
_IdentityTypesHandler,
1715
_MaxCallDepthHandler,
18-
_NotHandler,
19-
_OrHandler,
2016
_RolesHandler,
2117
)
2218
from apcore.context import Context, Identity
@@ -190,18 +186,24 @@ class TestOrHandler:
190186
def test_or_passes_when_any_match(self) -> None:
191187
"""AC-011: $or evaluates with OR logic."""
192188
ctx = _make_context(identity_type="user", roles=["admin"])
193-
acl = _make_acl_with_condition("$or", [
194-
{"roles": ["admin"]},
195-
{"identity_types": ["service"]},
196-
])
189+
acl = _make_acl_with_condition(
190+
"$or",
191+
[
192+
{"roles": ["admin"]},
193+
{"identity_types": ["service"]},
194+
],
195+
)
197196
assert acl.check("caller", "target", context=ctx) is True
198197

199198
def test_or_fails_when_none_match(self) -> None:
200199
ctx = _make_context(identity_type="user", roles=["viewer"])
201-
acl = _make_acl_with_condition("$or", [
202-
{"roles": ["admin"]},
203-
{"identity_types": ["service"]},
204-
])
200+
acl = _make_acl_with_condition(
201+
"$or",
202+
[
203+
{"roles": ["admin"]},
204+
{"identity_types": ["service"]},
205+
],
206+
)
205207
assert acl.check("caller", "target", context=ctx) is False
206208

207209
def test_or_empty_list_returns_false(self) -> None:
@@ -217,10 +219,13 @@ def test_or_non_list_returns_false(self) -> None:
217219

218220
def test_or_skips_non_dict_elements(self) -> None:
219221
ctx = _make_context(roles=["admin"])
220-
acl = _make_acl_with_condition("$or", [
221-
"invalid_string",
222-
{"roles": ["admin"]},
223-
])
222+
acl = _make_acl_with_condition(
223+
"$or",
224+
[
225+
"invalid_string",
226+
{"roles": ["admin"]},
227+
],
228+
)
224229
assert acl.check("caller", "target", context=ctx) is True
225230

226231

@@ -244,18 +249,24 @@ class TestNestedCompound:
244249
def test_nested_or_with_and(self) -> None:
245250
"""AC-032: Nested compound conditions."""
246251
ctx = _make_context(identity_type="service", call_chain=["a", "b"])
247-
acl = _make_acl_with_condition("$or", [
248-
{"roles": ["admin"]},
249-
{"identity_types": ["service"], "max_call_depth": 5},
250-
])
252+
acl = _make_acl_with_condition(
253+
"$or",
254+
[
255+
{"roles": ["admin"]},
256+
{"identity_types": ["service"], "max_call_depth": 5},
257+
],
258+
)
251259
assert acl.check("caller", "target", context=ctx) is True
252260

253261
def test_nested_or_with_and_fails_when_depth_exceeded(self) -> None:
254262
ctx = _make_context(identity_type="service", call_chain=["a"] * 10)
255-
acl = _make_acl_with_condition("$or", [
256-
{"roles": ["admin"]},
257-
{"identity_types": ["service"], "max_call_depth": 5},
258-
])
263+
acl = _make_acl_with_condition(
264+
"$or",
265+
[
266+
{"roles": ["admin"]},
267+
{"identity_types": ["service"], "max_call_depth": 5},
268+
],
269+
)
259270
assert acl.check("caller", "target", context=ctx) is False
260271

261272

tests/test_builtin_steps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
BaseStep,
2626
ExecutionStrategy,
2727
PipelineContext,
28-
StepResult,
2928
)
3029

3130

tests/test_context_serialization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ class TestContextDeserialize:
9898
def test_deserialize_roundtrip(self) -> None:
9999
"""Serialize then deserialize preserves fields."""
100100
ctx = Context.create(executor=None)
101-
ctx.identity = Identity(
102-
id="user-1", type="user", roles=("admin",), attrs={"org": "acme"}
103-
)
101+
ctx.identity = Identity(id="user-1", type="user", roles=("admin",), attrs={"org": "acme"})
104102
ctx.data["app.counter"] = 42
105103
serialized = ctx.serialize()
106104
restored = Context.deserialize(serialized)

0 commit comments

Comments
 (0)