Skip to content

Commit 2bdc933

Browse files
committed
feat: implement Context redesign, Annotations extension, and ACL condition handlers
Context: - Add ContextKey<T> typed accessor with get/set/delete/exists/scoped - Define built-in context keys (_apcore.mw.* convention) - Migrate middleware from raw string keys to typed ContextKey - Add Context.serialize()/deserialize() with _context_version:1 Annotations: - Add extra: dict extension field to ModuleAnnotations (frozen dataclass) - Change pagination_style from Literal to str for extensibility - Add DEFAULT_ANNOTATIONS constant and from_dict() classmethod - Add __post_init__ for list→tuple coercion and extra copy ACL: - Add ACLConditionHandler protocol (sync + async) - Add register_condition() class-level handler registry - Extract identity_types/roles/max_call_depth into handler classes - Add $or and $not compound condition operators - Add async_check() method alongside sync check() - Replace hardcoded if/else with handler dispatch (fail-closed) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: tercel <tercel.yi@gmail.com>
1 parent c83947a commit 2bdc933

23 files changed

Lines changed: 1375 additions & 57 deletions

src/apcore/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818
)
1919
from apcore.cancel import CancelToken, ExecutionCancelledError
2020
from apcore.context import Context, ContextFactory, Identity
21+
from apcore.context_key import ContextKey
22+
from apcore.context_keys import (
23+
LOGGING_START,
24+
METRICS_STARTS,
25+
REDACTED_OUTPUT,
26+
RETRY_COUNT_BASE,
27+
TRACING_SAMPLED,
28+
TRACING_SPANS,
29+
)
2130
from apcore.registry import Registry
2231
from apcore.client import APCore
2332
from apcore.registry.registry import (
@@ -33,6 +42,7 @@
3342

3443
# Module types
3544
from apcore.module import (
45+
DEFAULT_ANNOTATIONS,
3646
Module,
3747
ModuleAnnotations,
3848
ModuleExample,
@@ -322,6 +332,13 @@ def enable(module_id: str, reason: str = "Enabled via APCore client") -> dict[st
322332
"ExecutionCancelledError",
323333
"Context",
324334
"ContextFactory",
335+
"ContextKey",
336+
"TRACING_SPANS",
337+
"TRACING_SAMPLED",
338+
"METRICS_STARTS",
339+
"LOGGING_START",
340+
"REDACTED_OUTPUT",
341+
"RETRY_COUNT_BASE",
325342
"Identity",
326343
"Registry",
327344
"Executor",
@@ -350,6 +367,7 @@ def enable(module_id: str, reason: str = "Enabled via APCore client") -> dict[st
350367
"AutoApproveHandler",
351368
"CallbackApprovalHandler",
352369
# Module types
370+
"DEFAULT_ANNOTATIONS",
353371
"Module",
354372
"ModuleAnnotations",
355373
"ModuleExample",

src/apcore/acl.py

Lines changed: 168 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,32 @@
66

77
from __future__ import annotations
88

9+
import inspect
910
import logging
1011
import os
1112
import threading
1213
from dataclasses import dataclass, field
1314
from datetime import datetime, timezone
14-
from typing import Any, Callable
15+
from typing import Any, Callable, ClassVar
1516

1617
import yaml
1718

19+
from apcore.acl_handlers import (
20+
ACLConditionHandler,
21+
_IdentityTypesHandler,
22+
_MaxCallDepthHandler,
23+
_NotHandler,
24+
_OrHandler,
25+
_RolesHandler,
26+
)
1827
from apcore.context import Context
1928
from apcore.errors import ACLRuleError, ConfigNotFoundError
2029
from apcore.utils.pattern import match_pattern
2130

2231
__all__ = ["ACLRule", "AuditEntry", "ACL"]
2332

33+
_logger = logging.getLogger(__name__)
34+
2435

2536
@dataclass
2637
class ACLRule:
@@ -64,6 +75,59 @@ class ACL:
6475
remove_rule, reload) are safe to call concurrently.
6576
"""
6677

78+
_condition_handlers: ClassVar[dict[str, ACLConditionHandler]] = {}
79+
80+
@classmethod
81+
def register_condition(cls, key: str, handler: ACLConditionHandler) -> None:
82+
"""Register a condition handler. Replaces existing handler for same key."""
83+
cls._condition_handlers[key] = handler
84+
85+
@classmethod
86+
def _evaluate_conditions(
87+
cls, conditions: dict[str, Any], context: Context,
88+
) -> bool:
89+
"""Evaluate all conditions with AND logic. Fail-closed on unknown."""
90+
for key, value in conditions.items():
91+
handler = cls._condition_handlers.get(key)
92+
if handler is None:
93+
_logger.warning("Unknown ACL condition %r — treated as unsatisfied", key)
94+
return False
95+
try:
96+
result = handler.evaluate(value, context)
97+
except Exception:
98+
_logger.exception("Handler for condition %r raised — treated as unsatisfied", key)
99+
return False
100+
if inspect.isawaitable(result):
101+
result.close() # prevent "coroutine never awaited" warning
102+
_logger.warning(
103+
"Async condition %r in sync context — treated as unsatisfied. Use async_check().", key,
104+
)
105+
return False
106+
if not result:
107+
return False
108+
return True
109+
110+
@classmethod
111+
async def _evaluate_conditions_async(
112+
cls, conditions: dict[str, Any], context: Context,
113+
) -> bool:
114+
"""Async variant. Awaits async handlers, calls sync handlers directly."""
115+
for key, value in conditions.items():
116+
handler = cls._condition_handlers.get(key)
117+
if handler is None:
118+
_logger.warning("Unknown ACL condition %r — treated as unsatisfied", key)
119+
return False
120+
try:
121+
result = handler.evaluate(value, context)
122+
if inspect.isawaitable(result):
123+
result = await result
124+
except Exception:
125+
_logger.exception("Handler for condition %r raised — treated as unsatisfied", key)
126+
return False
127+
if not result:
128+
return False
129+
return True
130+
67131
def __init__(
68132
self,
69133
rules: list[ACLRule],
@@ -224,6 +288,97 @@ def check(
224288
audit_logger(entry)
225289
return default_decision
226290

291+
async def async_check(
292+
self,
293+
caller_id: str | None,
294+
target_id: str,
295+
context: Context | None = None,
296+
) -> bool:
297+
"""Async ACL check. Supports both sync and async condition handlers.
298+
299+
Args:
300+
caller_id: The calling module ID, or None for external calls.
301+
target_id: The target module ID being called.
302+
context: Optional execution context for conditional rules.
303+
304+
Returns:
305+
True if the call is allowed, False if denied.
306+
"""
307+
effective_caller = "@external" if caller_id is None else caller_id
308+
309+
with self._lock:
310+
rules = list(self._rules)
311+
default_effect = self._default_effect
312+
audit_logger = self._audit_logger
313+
314+
for idx, rule in enumerate(rules):
315+
if await self._matches_rule_async(rule, effective_caller, target_id, context):
316+
decision = rule.effect == "allow"
317+
self._logger.debug(
318+
"ACL async_check: caller=%s target=%s decision=%s rule=%s",
319+
caller_id,
320+
target_id,
321+
"allow" if decision else "deny",
322+
rule.description or "(no description)",
323+
)
324+
if audit_logger is not None:
325+
entry = self._build_audit_entry(
326+
caller_id=effective_caller,
327+
target_id=target_id,
328+
decision="allow" if decision else "deny",
329+
reason="rule_match",
330+
matched_rule=rule,
331+
matched_rule_index=idx,
332+
context=context,
333+
)
334+
audit_logger(entry)
335+
return decision
336+
337+
default_decision = default_effect == "allow"
338+
self._logger.debug(
339+
"ACL async_check: caller=%s target=%s decision=%s rule=default",
340+
caller_id,
341+
target_id,
342+
"allow" if default_decision else "deny",
343+
)
344+
if audit_logger is not None:
345+
reason = "no_rules" if not rules else "default_effect"
346+
entry = self._build_audit_entry(
347+
caller_id=effective_caller,
348+
target_id=target_id,
349+
decision="allow" if default_decision else "deny",
350+
reason=reason,
351+
matched_rule=None,
352+
matched_rule_index=None,
353+
context=context,
354+
)
355+
audit_logger(entry)
356+
return default_decision
357+
358+
async def _matches_rule_async(
359+
self,
360+
rule: ACLRule,
361+
caller: str,
362+
target: str,
363+
context: Context | None,
364+
) -> bool:
365+
"""Async version of _matches_rule that awaits async condition handlers."""
366+
caller_match = any(self._match_pattern(p, caller, context) for p in rule.callers)
367+
if not caller_match:
368+
return False
369+
370+
target_match = any(self._match_pattern(p, target, context) for p in rule.targets)
371+
if not target_match:
372+
return False
373+
374+
if rule.conditions is not None:
375+
if context is None:
376+
return False
377+
if not await self._evaluate_conditions_async(rule.conditions, context):
378+
return False
379+
380+
return True
381+
227382
def _build_audit_entry(
228383
self,
229384
*,
@@ -309,22 +464,7 @@ def _check_conditions(self, conditions: dict[str, Any], context: Context | None)
309464
"""
310465
if context is None:
311466
return False
312-
313-
if "identity_types" in conditions:
314-
if context.identity is None or context.identity.type not in conditions["identity_types"]:
315-
return False
316-
317-
if "roles" in conditions:
318-
if context.identity is None:
319-
return False
320-
if not set(context.identity.roles) & set(conditions["roles"]):
321-
return False
322-
323-
if "max_call_depth" in conditions:
324-
if len(context.call_chain) > conditions["max_call_depth"]:
325-
return False
326-
327-
return True
467+
return self._evaluate_conditions(conditions, context)
328468

329469
def add_rule(self, rule: ACLRule) -> None:
330470
"""Add a rule at position 0 (highest priority).
@@ -366,3 +506,14 @@ def reload(self) -> None:
366506
with self._lock:
367507
self._rules = reloaded._rules
368508
self._default_effect = reloaded._default_effect
509+
510+
511+
# ---------------------------------------------------------------------------
512+
# Auto-register built-in handlers at module load time
513+
# ---------------------------------------------------------------------------
514+
515+
ACL.register_condition("identity_types", _IdentityTypesHandler())
516+
ACL.register_condition("roles", _RolesHandler())
517+
ACL.register_condition("max_call_depth", _MaxCallDepthHandler())
518+
ACL.register_condition("$or", _OrHandler(ACL._evaluate_conditions))
519+
ACL.register_condition("$not", _NotHandler(ACL._evaluate_conditions))

src/apcore/acl_handlers.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""Built-in ACL condition handlers and handler protocols.
2+
3+
Defines the ACLConditionHandler protocol (sync and async variants),
4+
three basic handlers (identity_types, roles, max_call_depth), and
5+
two compound operators ($or, $not).
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from typing import Any, Callable, Protocol, Union, runtime_checkable
11+
12+
from apcore.context import Context
13+
14+
__all__ = [
15+
"SyncACLConditionHandler",
16+
"AsyncACLConditionHandler",
17+
"ACLConditionHandler",
18+
]
19+
20+
21+
@runtime_checkable
22+
class SyncACLConditionHandler(Protocol):
23+
"""Sync condition handler protocol."""
24+
25+
def evaluate(self, value: Any, context: Context) -> bool: ...
26+
27+
28+
@runtime_checkable
29+
class AsyncACLConditionHandler(Protocol):
30+
"""Async condition handler protocol."""
31+
32+
async def evaluate(self, value: Any, context: Context) -> bool: ...
33+
34+
35+
ACLConditionHandler = Union[SyncACLConditionHandler, AsyncACLConditionHandler]
36+
37+
# Type alias for the recursive evaluation function used by compound handlers.
38+
_EvalFn = Callable[[dict[str, Any], Context], bool]
39+
40+
41+
# ---------------------------------------------------------------------------
42+
# Basic handlers
43+
# ---------------------------------------------------------------------------
44+
45+
46+
class _IdentityTypesHandler:
47+
"""Check context.identity.type is in the allowed list."""
48+
49+
def evaluate(self, value: Any, context: Context) -> bool:
50+
if not isinstance(value, list) or context.identity is None:
51+
return False
52+
return context.identity.type in value
53+
54+
55+
class _RolesHandler:
56+
"""Check at least one role overlaps between identity and required roles."""
57+
58+
def evaluate(self, value: Any, context: Context) -> bool:
59+
if not isinstance(value, list) or context.identity is None:
60+
return False
61+
return bool(set(context.identity.roles) & set(value))
62+
63+
64+
class _MaxCallDepthHandler:
65+
"""Check call chain length does not exceed threshold."""
66+
67+
def evaluate(self, value: Any, context: Context) -> bool:
68+
if not isinstance(value, int):
69+
return False
70+
return len(context.call_chain) <= value
71+
72+
73+
# ---------------------------------------------------------------------------
74+
# Compound handlers
75+
# ---------------------------------------------------------------------------
76+
77+
78+
class _OrHandler:
79+
"""$or: list of condition dicts. Returns True if ANY sub-set passes."""
80+
81+
def __init__(self, evaluate_fn: _EvalFn) -> None:
82+
self._evaluate = evaluate_fn
83+
84+
def evaluate(self, value: Any, context: Context) -> bool:
85+
if not isinstance(value, list):
86+
return False
87+
for sub in value:
88+
if not isinstance(sub, dict):
89+
continue
90+
if self._evaluate(sub, context):
91+
return True
92+
return False
93+
94+
95+
class _NotHandler:
96+
"""$not: single condition dict. Returns True if the sub-set FAILS."""
97+
98+
def __init__(self, evaluate_fn: _EvalFn) -> None:
99+
self._evaluate = evaluate_fn
100+
101+
def evaluate(self, value: Any, context: Context) -> bool:
102+
if not isinstance(value, dict):
103+
return False
104+
return not self._evaluate(value, context)

0 commit comments

Comments
 (0)