77from collections .abc import Mapping
88from dataclasses import asdict , dataclass
99from datetime import datetime , timezone
10- from urllib .parse import urlsplit
10+ from urllib .parse import urlencode , urlsplit
1111
1212from predicate_contracts import ProofEvent
1313
@@ -22,6 +22,11 @@ class ControlPlaneClientConfig:
2222 max_retries : int = 2
2323 backoff_initial_s : float = 0.2
2424 fail_open : bool = True
25+ sync_enabled : bool = False
26+ sync_wait_timeout_s : float = 15.0
27+ sync_poll_interval_ms : int = 200
28+ sync_project_id : str | None = None
29+ sync_environment : str | None = None
2530
2631
2732@dataclass (frozen = True )
@@ -80,6 +85,89 @@ def authority_check(tenant_id: str, project_id: str, credits: int = 1) -> UsageC
8085 )
8186
8287
88+ @dataclass (frozen = True )
89+ class RemoteRevocation :
90+ revocation_id : str
91+ type : str
92+ principal_id : str | None = None
93+ intent_hash : str | None = None
94+ tags : tuple [str , ...] = ()
95+ reason : str | None = None
96+ created_at : str = ""
97+
98+
99+ @dataclass (frozen = True )
100+ class AuthoritySyncSnapshot :
101+ changed : bool
102+ sync_token : str
103+ tenant_id : str
104+ project_id : str | None = None
105+ environment : str | None = None
106+ policy_id : str | None = None
107+ policy_revision : int | None = None
108+ policy_document : dict [str , object ] | None = None
109+ revocations : tuple [RemoteRevocation , ...] = ()
110+
111+ @staticmethod
112+ def from_payload (payload : Mapping [str , object ]) -> AuthoritySyncSnapshot :
113+ revocations_payload = payload .get ("revocations" )
114+ parsed_revocations : list [RemoteRevocation ] = []
115+ if isinstance (revocations_payload , list ):
116+ for item in revocations_payload :
117+ if not isinstance (item , Mapping ):
118+ continue
119+ raw_tags = item .get ("tags" )
120+ tags : tuple [str , ...] = ()
121+ if isinstance (raw_tags , list ):
122+ tags = tuple (str (tag ) for tag in raw_tags if isinstance (tag , str ))
123+ parsed_revocations .append (
124+ RemoteRevocation (
125+ revocation_id = str (item .get ("revocation_id" , "" )),
126+ type = str (item .get ("type" , "" )),
127+ principal_id = (
128+ str (item ["principal_id" ])
129+ if isinstance (item .get ("principal_id" ), str )
130+ else None
131+ ),
132+ intent_hash = (
133+ str (item ["intent_hash" ])
134+ if isinstance (item .get ("intent_hash" ), str )
135+ else None
136+ ),
137+ tags = tags ,
138+ reason = str (item ["reason" ]) if isinstance (item .get ("reason" ), str ) else None ,
139+ created_at = str (item .get ("created_at" , "" )),
140+ )
141+ )
142+ policy_document = payload .get ("policy_document" )
143+ raw_policy_revision = payload .get ("policy_revision" )
144+ policy_revision : int | None = None
145+ if isinstance (raw_policy_revision , int ):
146+ policy_revision = raw_policy_revision
147+ elif isinstance (raw_policy_revision , str ) and raw_policy_revision .strip () != "" :
148+ try :
149+ policy_revision = int (raw_policy_revision )
150+ except ValueError :
151+ policy_revision = None
152+ return AuthoritySyncSnapshot (
153+ changed = bool (payload .get ("changed" , False )),
154+ sync_token = str (payload .get ("sync_token" , "" )),
155+ tenant_id = str (payload .get ("tenant_id" , "" )),
156+ project_id = (
157+ str (payload ["project_id" ]) if isinstance (payload .get ("project_id" ), str ) else None
158+ ),
159+ environment = (
160+ str (payload ["environment" ]) if isinstance (payload .get ("environment" ), str ) else None
161+ ),
162+ policy_id = (
163+ str (payload ["policy_id" ]) if isinstance (payload .get ("policy_id" ), str ) else None
164+ ),
165+ policy_revision = policy_revision ,
166+ policy_document = (dict (policy_document ) if isinstance (policy_document , dict ) else None ),
167+ revocations = tuple (parsed_revocations ),
168+ )
169+
170+
83171class ControlPlaneClient :
84172 def __init__ (self , config : ControlPlaneClientConfig ) -> None :
85173 self .config = config
@@ -100,6 +188,29 @@ def send_usage_records(self, records: tuple[UsageCreditRecord, ...]) -> bool:
100188 def send_audit_payload (self , payload : Mapping [str , object ]) -> bool :
101189 return self ._post_json ("/v1/audit/events:batch" , payload )
102190
191+ def poll_authority_updates (
192+ self ,
193+ current_token : str | None ,
194+ wait_timeout_s : float = 15.0 ,
195+ poll_interval_ms : int = 200 ,
196+ project_id : str | None = None ,
197+ environment : str | None = None ,
198+ ) -> AuthoritySyncSnapshot :
199+ query : dict [str , str | float | int ] = {
200+ "tenant_id" : self .config .tenant_id ,
201+ "wait_timeout_s" : max (0.0 , float (wait_timeout_s )),
202+ "poll_interval_ms" : max (50 , int (poll_interval_ms )),
203+ }
204+ if current_token is not None and current_token .strip () != "" :
205+ query ["current_token" ] = current_token
206+ if project_id is not None and project_id .strip () != "" :
207+ query ["project_id" ] = project_id
208+ if environment is not None and environment .strip () != "" :
209+ query ["environment" ] = environment
210+ path = "/v1/sync/authority-updates?" + urlencode (query )
211+ payload = self ._get_json (path )
212+ return AuthoritySyncSnapshot .from_payload (payload )
213+
103214 def _post_json (self , path : str , payload : Mapping [str , object ]) -> bool :
104215 attempts = self .config .max_retries + 1
105216 for attempt in range (attempts ):
@@ -115,6 +226,20 @@ def _post_json(self, path: str, payload: Mapping[str, object]) -> bool:
115226 time .sleep (self .config .backoff_initial_s * (2 ** attempt ))
116227 return False
117228
229+ def _get_json (self , path : str ) -> Mapping [str , object ]:
230+ attempts = self .config .max_retries + 1
231+ for attempt in range (attempts ):
232+ try :
233+ return self ._get_json_once (path )
234+ except Exception as exc :
235+ is_last_attempt = attempt == attempts - 1
236+ if is_last_attempt :
237+ if self .config .fail_open :
238+ return {}
239+ raise RuntimeError (f"control-plane request failed: { path } " ) from exc
240+ time .sleep (self .config .backoff_initial_s * (2 ** attempt ))
241+ return {}
242+
118243 def _post_json_once (self , path : str , payload : Mapping [str , object ]) -> None :
119244 target_path = path if path .startswith ("/" ) else f"/{ path } "
120245 connection = self ._new_connection ()
@@ -131,6 +256,25 @@ def _post_json_once(self, path: str, payload: Mapping[str, object]) -> None:
131256 if response .status >= 400 :
132257 raise RuntimeError (f"HTTP { response .status } : { content } " )
133258
259+ def _get_json_once (self , path : str ) -> Mapping [str , object ]:
260+ target_path = path if path .startswith ("/" ) else f"/{ path } "
261+ connection = self ._new_connection ()
262+ headers : dict [str , str ] = {}
263+ if self .config .auth_token :
264+ headers ["Authorization" ] = f"Bearer { self .config .auth_token } "
265+ try :
266+ connection .request ("GET" , target_path , headers = headers )
267+ response = connection .getresponse ()
268+ content = response .read ().decode ("utf-8" )
269+ finally :
270+ connection .close ()
271+ if response .status >= 400 :
272+ raise RuntimeError (f"HTTP { response .status } : { content } " )
273+ loaded = json .loads (content ) if content .strip () != "" else {}
274+ if not isinstance (loaded , dict ):
275+ raise RuntimeError ("Expected object JSON payload from control-plane GET response." )
276+ return loaded
277+
134278 def _new_connection (self ) -> http .client .HTTPConnection :
135279 if self ._base .scheme == "https" :
136280 return http .client .HTTPSConnection (self ._base .netloc , timeout = self .config .timeout_s )
0 commit comments