Skip to content

Commit 69da43a

Browse files
committed
address code review
1 parent e13bb31 commit 69da43a

12 files changed

Lines changed: 314 additions & 69 deletions

predicate_authority/client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from dataclasses import dataclass
45

56
from predicate_authority.guard import ActionGuard
@@ -63,6 +64,27 @@ def from_policy_file(
6364
policy_file=policy_file,
6465
)
6566

67+
@classmethod
68+
def from_env(cls) -> LocalAuthorizationContext:
69+
policy_file = os.getenv("PREDICATE_AUTHORITY_POLICY_FILE")
70+
secret_key = os.getenv("PREDICATE_AUTHORITY_SIGNING_KEY")
71+
ttl_seconds_raw = os.getenv("PREDICATE_AUTHORITY_MANDATE_TTL_SECONDS", "300")
72+
if policy_file is None or policy_file.strip() == "":
73+
raise RuntimeError("PREDICATE_AUTHORITY_POLICY_FILE is required.")
74+
if secret_key is None or secret_key.strip() == "":
75+
raise RuntimeError("PREDICATE_AUTHORITY_SIGNING_KEY is required.")
76+
try:
77+
ttl_seconds = int(ttl_seconds_raw)
78+
except ValueError as exc:
79+
raise RuntimeError(
80+
"PREDICATE_AUTHORITY_MANDATE_TTL_SECONDS must be an integer."
81+
) from exc
82+
return cls.from_policy_file(
83+
policy_file=policy_file,
84+
secret_key=secret_key,
85+
ttl_seconds=ttl_seconds,
86+
)
87+
6688
def authorize(
6789
self,
6890
request: ActionRequest,

predicate_authority/control_plane.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from datetime import datetime, timezone
1010
from urllib.parse import urlsplit
1111

12-
from predicate_contracts import ProofEvent, TraceEmitter
12+
from predicate_contracts import ProofEvent
1313

1414

1515
@dataclass(frozen=True)
@@ -138,7 +138,7 @@ def _new_connection(self) -> http.client.HTTPConnection:
138138

139139

140140
@dataclass
141-
class ControlPlaneTraceEmitter(TraceEmitter):
141+
class ControlPlaneTraceEmitter:
142142
client: ControlPlaneClient
143143
trace_id: str | None = None
144144
emit_usage_credits: bool = True

predicate_authority/daemon.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class DaemonConfig:
5353
host: str = "127.0.0.1"
5454
port: int = 8787
5555
policy_poll_interval_s: float = 2.0
56+
max_request_body_bytes: int = 1_048_576
5657

5758

5859
@dataclass(frozen=True)
@@ -167,6 +168,7 @@ def do_POST(self) -> None: # noqa: N802
167168
"/policy/reload": self._handle_policy_reload,
168169
"/revoke/principal": self._handle_revoke_principal,
169170
"/revoke/intent": self._handle_revoke_intent,
171+
"/revoke/mandate": self._handle_revoke_mandate,
170172
"/identity/task": self._handle_identity_task,
171173
"/identity/revoke": self._handle_identity_revoke,
172174
"/ledger/flush-ack": self._handle_ledger_flush_ack,
@@ -201,6 +203,15 @@ def _handle_revoke_intent(self) -> None:
201203
self.server.daemon_ref.revoke_intent(intent_hash.strip()) # type: ignore[attr-defined]
202204
self._send_json(200, {"ok": True, "intent_hash": intent_hash.strip()})
203205

206+
def _handle_revoke_mandate(self) -> None:
207+
payload = self._read_json_body()
208+
mandate_id = payload.get("mandate_id")
209+
if not isinstance(mandate_id, str) or mandate_id.strip() == "":
210+
self._send_json(400, {"error": "mandate_id is required"})
211+
return
212+
self.server.daemon_ref.revoke_mandate(mandate_id.strip()) # type: ignore[attr-defined]
213+
self._send_json(200, {"ok": True, "mandate_id": mandate_id.strip()})
214+
204215
def _handle_identity_task(self) -> None:
205216
payload = self._read_json_body()
206217
principal_id = payload.get("principal_id")
@@ -273,6 +284,9 @@ def _read_json_body(self) -> dict[str, Any]:
273284
return {}
274285
if content_length <= 0:
275286
return {}
287+
max_body = self.server.daemon_ref.max_request_body_bytes() # type: ignore[attr-defined]
288+
if content_length > max_body:
289+
return {}
276290
payload = self.rfile.read(content_length).decode("utf-8")
277291
try:
278292
loaded = json.loads(payload)
@@ -388,6 +402,12 @@ def revoke_principal(self, principal_id: str) -> None:
388402
def revoke_intent(self, intent_hash: str) -> None:
389403
self._sidecar.revoke_intent_hash(intent_hash)
390404

405+
def revoke_mandate(self, mandate_id: str) -> None:
406+
self._sidecar.revoke_mandate_id(mandate_id)
407+
408+
def max_request_body_bytes(self) -> int:
409+
return max(0, int(self._config.max_request_body_bytes))
410+
391411
def issue_task_identity(
392412
self,
393413
principal_id: str,
@@ -592,6 +612,7 @@ def _build_default_sidecar(
592612
control_plane_config: ControlPlaneBootstrapConfig | None = None,
593613
local_identity_config: LocalIdentityBootstrapConfig | None = None,
594614
identity_bridge: ExchangeTokenBridge | None = None,
615+
mandate_signing_key: str | None = None,
595616
) -> PredicateAuthoritySidecar:
596617
policy_rules: tuple[PolicyRule, ...] = ()
597618
global_max_delegation_depth: int | None = None
@@ -647,7 +668,7 @@ def _build_default_sidecar(
647668

648669
guard = ActionGuard(
649670
policy_engine=policy_engine,
650-
mandate_signer=LocalMandateSigner(secret_key=secrets.token_hex(32)),
671+
mandate_signer=LocalMandateSigner(secret_key=mandate_signing_key or secrets.token_hex(32)),
651672
proof_ledger=proof_ledger,
652673
)
653674
return PredicateAuthoritySidecar(
@@ -709,6 +730,22 @@ def _build_identity_bridge_from_args(args: argparse.Namespace) -> ExchangeTokenB
709730
raise SystemExit(f"Unsupported identity mode: {mode}")
710731

711732

733+
def _resolve_mandate_signing_key(
734+
signing_key_file: str | None,
735+
signing_key_env: str,
736+
) -> str:
737+
if signing_key_file is not None and str(signing_key_file).strip() != "":
738+
key_path = Path(signing_key_file)
739+
if key_path.exists():
740+
loaded = key_path.read_text(encoding="utf-8").strip()
741+
if loaded != "":
742+
return loaded
743+
env_value = os.getenv(signing_key_env)
744+
if env_value is not None and env_value.strip() != "":
745+
return env_value.strip()
746+
return secrets.token_hex(32)
747+
748+
712749
def main() -> None:
713750
parser = argparse.ArgumentParser(description="predicate-authorityd sidecar daemon")
714751
parser.add_argument("--host", default="127.0.0.1")
@@ -816,6 +853,16 @@ def main() -> None:
816853
)
817854
parser.set_defaults(control_plane_fail_open=True)
818855
parser.add_argument("--control-plane-usage-credits-per-decision", type=int, default=1)
856+
parser.add_argument(
857+
"--mandate-signing-key-env",
858+
default="PREDICATE_AUTHORITY_SIGNING_KEY",
859+
help="Env var name for mandate signing key.",
860+
)
861+
parser.add_argument(
862+
"--mandate-signing-key-file",
863+
default=None,
864+
help="Optional file path containing mandate signing key.",
865+
)
819866
args = parser.parse_args()
820867

821868
mode = AuthorityMode(args.mode)
@@ -851,20 +898,26 @@ def main() -> None:
851898
default_ttl_seconds=max(1, int(args.local_identity_default_ttl_s)),
852899
)
853900
identity_bridge = _build_identity_bridge_from_args(args)
901+
mandate_signing_key = _resolve_mandate_signing_key(
902+
signing_key_file=args.mandate_signing_key_file,
903+
signing_key_env=args.mandate_signing_key_env,
904+
)
854905
sidecar = _build_default_sidecar(
855906
mode=mode,
856907
policy_file=args.policy_file,
857908
credential_store_file=args.credential_store_file,
858909
control_plane_config=control_plane_bootstrap,
859910
local_identity_config=local_identity_bootstrap,
860911
identity_bridge=identity_bridge,
912+
mandate_signing_key=mandate_signing_key,
861913
)
862914
daemon = PredicateAuthorityDaemon(
863915
sidecar=sidecar,
864916
config=DaemonConfig(
865917
host=args.host,
866918
port=args.port,
867919
policy_poll_interval_s=args.policy_poll_interval_s,
920+
max_request_body_bytes=1_048_576,
868921
),
869922
flush_worker=FlushWorkerConfig(
870923
enabled=bool(args.flush_worker_enabled),

predicate_authority/local_identity.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,12 @@ def quarantine_queue_item(self, queue_item_id: str, reason: str) -> bool:
259259
return True
260260

261261
def list_dead_letter_queue(self, limit: int | None = None) -> list[LedgerQueueItem]:
262-
return self.list_flush_queue(
262+
items = self.list_flush_queue(
263263
include_flushed=True,
264264
include_quarantined=True,
265265
limit=limit,
266266
)
267+
return [item for item in items if item.quarantined]
267268

268269
def requeue_item(self, queue_item_id: str, reset_attempts: bool = True) -> bool:
269270
with self._lock:
@@ -317,7 +318,10 @@ def _read_all_unlocked(self) -> dict[str, Any]:
317318
content = self._file_path.read_text(encoding="utf-8").strip()
318319
if content == "":
319320
return {"identities": {}, "flush_queue": {}}
320-
loaded = json.loads(content)
321+
try:
322+
loaded = json.loads(content)
323+
except json.JSONDecodeError:
324+
return {"identities": {}, "flush_queue": {}}
321325
if isinstance(loaded, dict):
322326
if "identities" not in loaded:
323327
loaded["identities"] = {}
@@ -356,7 +360,9 @@ def _parse_queue_item(self, raw: dict[str, Any]) -> LedgerQueueItem | None:
356360
return None
357361

358362
def _write_all_unlocked(self, payload: dict[str, Any]) -> None:
359-
self._file_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
363+
tmp_path = self._file_path.with_name(f"{self._file_path.name}.{uuid.uuid4().hex}.tmp")
364+
tmp_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
365+
os.replace(tmp_path, self._file_path)
360366
self._chmod_file_safe()
361367

362368
def _ensure_store_path(self) -> None:
@@ -366,10 +372,12 @@ def _ensure_store_path(self) -> None:
366372
except OSError:
367373
pass
368374
if not self._file_path.exists():
369-
self._file_path.write_text(
375+
tmp_path = self._file_path.with_name(f"{self._file_path.name}.{uuid.uuid4().hex}.tmp")
376+
tmp_path.write_text(
370377
json.dumps({"identities": {}, "flush_queue": {}}, indent=2),
371378
encoding="utf-8",
372379
)
380+
os.replace(tmp_path, self._file_path)
373381
self._chmod_file_safe()
374382

375383
def _chmod_file_safe(self) -> None:
@@ -380,7 +388,7 @@ def _chmod_file_safe(self) -> None:
380388

381389

382390
@dataclass
383-
class LocalLedgerQueueEmitter(TraceEmitter):
391+
class LocalLedgerQueueEmitter:
384392
registry: LocalIdentityRegistry
385393
source: str = "predicate-authorityd"
386394

@@ -389,7 +397,7 @@ def emit(self, event: ProofEvent) -> None:
389397

390398

391399
@dataclass
392-
class CompositeTraceEmitter(TraceEmitter):
400+
class CompositeTraceEmitter:
393401
emitters: tuple[TraceEmitter, ...]
394402

395403
def emit(self, event: ProofEvent) -> None:

predicate_authority/policy.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dataclasses import dataclass
44
from fnmatch import fnmatch
5+
from threading import Lock
56

67
from predicate_contracts import ActionRequest, AuthorizationReason, PolicyEffect, PolicyRule
78

@@ -22,15 +23,31 @@ def __init__(
2223
) -> None:
2324
self._rules = rules
2425
self._global_max_delegation_depth = global_max_delegation_depth
26+
self._lock = Lock()
2527

2628
def replace_rules(self, rules: tuple[PolicyRule, ...]) -> None:
27-
self._rules = rules
29+
with self._lock:
30+
self._rules = rules
2831

2932
def set_global_max_delegation_depth(self, max_depth: int | None) -> None:
30-
self._global_max_delegation_depth = max_depth
33+
with self._lock:
34+
self._global_max_delegation_depth = max_depth
35+
36+
def replace_policy(
37+
self,
38+
rules: tuple[PolicyRule, ...],
39+
global_max_delegation_depth: int | None,
40+
) -> None:
41+
with self._lock:
42+
self._rules = rules
43+
self._global_max_delegation_depth = global_max_delegation_depth
3144

3245
def evaluate(self, request: ActionRequest, delegation_depth: int = 0) -> PolicyMatchResult:
33-
matching_rules = [rule for rule in self._rules if self._matches_rule(rule, request)]
46+
with self._lock:
47+
rules = self._rules
48+
global_max_delegation_depth = self._global_max_delegation_depth
49+
50+
matching_rules = [rule for rule in rules if self._matches_rule(rule, request)]
3451
if not matching_rules:
3552
return PolicyMatchResult(
3653
allowed=False,
@@ -50,7 +67,10 @@ def evaluate(self, request: ActionRequest, delegation_depth: int = 0) -> PolicyM
5067
if rule.effect != PolicyEffect.ALLOW:
5168
continue
5269

53-
effective_max_depth = self._effective_max_delegation_depth(rule)
70+
effective_max_depth = self._effective_max_delegation_depth(
71+
global_max_delegation_depth,
72+
rule.max_delegation_depth,
73+
)
5474
if effective_max_depth is not None and delegation_depth > effective_max_depth:
5575
failure = PolicyMatchResult(
5676
allowed=False,
@@ -102,9 +122,11 @@ def _matches_rule(rule: PolicyRule, request: ActionRequest) -> bool:
102122
)
103123
return principal_ok and action_ok and resource_ok
104124

105-
def _effective_max_delegation_depth(self, rule: PolicyRule) -> int | None:
106-
global_max = self._global_max_delegation_depth
107-
rule_max = rule.max_delegation_depth
125+
@staticmethod
126+
def _effective_max_delegation_depth(
127+
global_max: int | None,
128+
rule_max: int | None,
129+
) -> int | None:
108130
if global_max is None:
109131
return rule_max
110132
if rule_max is None:

predicate_authority/proof.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import time
44
from dataclasses import dataclass, field
5+
from threading import Lock
56

67
from predicate_contracts import ActionRequest, AuthorizationDecision, ProofEvent, TraceEmitter
78

@@ -10,6 +11,7 @@
1011
class InMemoryProofLedger:
1112
trace_emitter: TraceEmitter | None = None
1213
events: list[ProofEvent] = field(default_factory=list)
14+
_lock: Lock = field(default_factory=Lock)
1315

1416
def record(self, decision: AuthorizationDecision, request: ActionRequest) -> ProofEvent:
1517
event = ProofEvent(
@@ -22,7 +24,13 @@ def record(self, decision: AuthorizationDecision, request: ActionRequest) -> Pro
2224
mandate_id=decision.mandate.claims.mandate_id if decision.mandate else None,
2325
emitted_at_epoch_s=int(time.time()),
2426
)
25-
self.events.append(event)
26-
if self.trace_emitter is not None:
27-
self.trace_emitter.emit(event)
27+
with self._lock:
28+
self.events.append(event)
29+
trace_emitter = self.trace_emitter
30+
if trace_emitter is not None:
31+
trace_emitter.emit(event)
2832
return event
33+
34+
def event_count(self) -> int:
35+
with self._lock:
36+
return len(self.events)

0 commit comments

Comments
 (0)