2121# limitations under the License.
2222# -----------------------------------------------------------------------------
2323from __future__ import annotations
24+
2425from typing import Callable , Iterator
25- from ...encoding import Name , BinaryStr , FormalName , NonStrictName , Component
26+
27+ from ...encoding import BinaryStr , Component , FormalName , Name , NonStrictName
2628from ...security import Keychain
2729from ..security_v2 import parse_certificate
2830from . import binary as bny
2931from .compiler import top_order
3032
3133
32- __all__ = [' UserFn' , ' LvsModelError' , ' Checker' , ' DEFAULT_USER_FNS' ]
34+ __all__ = [" UserFn" , " LvsModelError" , " Checker" , " DEFAULT_USER_FNS" ]
3335
3436
3537UserFn = Callable [[BinaryStr , list [BinaryStr ]], bool ]
@@ -44,6 +46,7 @@ class LvsModelError(Exception):
4446 """
4547 Raised when the input LVS model is malformed.
4648 """
49+
4750 pass
4851
4952
@@ -71,8 +74,11 @@ def __init__(self, model: bny.LvsModel, user_fns: dict[str, UserFn]):
7174
7275 def _sanity_check (self ):
7376 """Basic sanity check. Also collect info for other testing."""
74- if self .model .version > bny .VERSION :
75- raise LvsModelError (f'Unrecognized LVS model version { self .model .version } ' )
77+ if (
78+ self .model .version is None
79+ or not bny .MIN_SUPPORTED_VERSION <= self .model .version <= bny .VERSION
80+ ):
81+ raise LvsModelError (f"Unsupported LVS model version { self .model .version } " )
7682 self ._model_fns = set ()
7783 self ._trust_roots = set ()
7884 in_deg_nodes = set ()
@@ -81,39 +87,50 @@ def _sanity_check(self):
8187
8288 def dfs (cur , par ):
8389 if cur >= len (self .model .nodes ):
84- raise LvsModelError (f' Non-existing node id { cur } ' )
90+ raise LvsModelError (f" Non-existing node id { cur } " )
8591 node = self .model .nodes [cur ]
8692 if node .id != cur :
87- raise LvsModelError (f' Malformed node id { cur } ' )
93+ raise LvsModelError (f" Malformed node id { cur } " )
8894 if par and node .parent != par :
89- raise LvsModelError (f' Node { cur } has a wrong parent' )
95+ raise LvsModelError (f" Node { cur } has a wrong parent" )
9096 for ve in node .v_edges :
9197 if ve .dest is None or not ve .value :
92- raise LvsModelError (f' Node { cur } has a malformed edge' )
98+ raise LvsModelError (f" Node { cur } has a malformed edge" )
9399 dfs (ve .dest , cur )
94100 for pe in node .p_edges :
95101 if pe .dest is None or pe .tag is None :
96- raise LvsModelError (f' Node { cur } has a malformed edge' )
102+ raise LvsModelError (f" Node { cur } has a malformed edge" )
97103 dfs (pe .dest , cur )
98104 for cons in pe .cons_sets :
99105 for op in cons .options :
100- branch = [not not op .value , op .tag is not None , op .fn is not None ].count (True )
106+ branch = [
107+ not not op .value ,
108+ op .tag is not None ,
109+ op .fn is not None ,
110+ ].count (True )
101111 if branch != 1 :
102- raise LvsModelError (f'Edge { cur } ->{ pe .dest } has a malformed condition' )
112+ raise LvsModelError (
113+ f"Edge { cur } ->{ pe .dest } has a malformed condition"
114+ )
103115 if op .fn is not None :
104116 if not op .fn .fn_id :
105- raise LvsModelError (f'Edge { cur } ->{ pe .dest } has a malformed condition' )
117+ raise LvsModelError (
118+ f"Edge { cur } ->{ pe .dest } has a malformed condition"
119+ )
106120 self ._model_fns .add (op .fn .fn_id )
107121 for key_node_id in node .sign_cons :
108122 if key_node_id >= len (self .model .nodes ):
109- raise LvsModelError (f'Node { cur } is signed by a non-existing key { key_node_id } ' )
123+ raise LvsModelError (
124+ f"Node { cur } is signed by a non-existing key { key_node_id } "
125+ )
110126 in_deg_nodes .add (key_node_id )
111127 adj_lst [cur ].append (key_node_id )
112128
113129 dfs (self .model .start_id , None )
114130 top_order (nodes_id_lst , adj_lst )
115- self ._trust_roots = {n for n in in_deg_nodes
116- if not self .model .nodes [n ].sign_cons }
131+ self ._trust_roots = {
132+ n for n in in_deg_nodes if not self .model .nodes [n ].sign_cons
133+ }
117134
118135 def validate_user_fns (self ) -> bool :
119136 """Check if all user functions required by the model is defined."""
@@ -131,7 +148,7 @@ def root_of_trust(self) -> set[str]:
131148 if node .rule_name :
132149 ret = ret | set (node .rule_name )
133150 else :
134- ret = ret | {'#_' + str (cur )}
151+ ret = ret | {"#_" + str (cur )}
135152 return ret
136153
137154 def save (self ) -> bytes :
@@ -152,12 +169,22 @@ def load(binary_model: BinaryStr, user_fns: dict[str, UserFn]):
152169 return Checker (model , user_fns )
153170
154171 def _context_to_name (self , context : dict [int , BinaryStr ]) -> dict [str , BinaryStr ]:
155- named_tag = {self ._symbols [tag ]: val for tag , val in context .items () if tag in self ._symbols }
156- annon_tag = {str (tag ): val for tag , val in context .items () if tag not in self ._symbols }
172+ named_tag = {
173+ self ._symbols [tag ]: val
174+ for tag , val in context .items ()
175+ if tag in self ._symbols
176+ }
177+ annon_tag = {
178+ str (tag ): val for tag , val in context .items () if tag not in self ._symbols
179+ }
157180 return named_tag | annon_tag
158181
159- def _check_cons (self , value : BinaryStr , context : dict [int , BinaryStr ],
160- cons_set : list [bny .PatternConstraint ]) -> bool :
182+ def _check_cons (
183+ self ,
184+ value : BinaryStr ,
185+ context : dict [int , BinaryStr ],
186+ cons_set : list [bny .PatternConstraint ],
187+ ) -> bool :
161188 for cons in cons_set :
162189 satisfied = False
163190 for op in cons .options :
@@ -172,7 +199,7 @@ def _check_cons(self, value: BinaryStr, context: dict[int, BinaryStr],
172199 else :
173200 fn_id = op .fn .fn_id
174201 if fn_id not in self .user_fns :
175- raise LvsModelError (f' User function { fn_id } is undefined' )
202+ raise LvsModelError (f" User function { fn_id } is undefined" )
176203 args = [context .get (arg .tag , arg .value ) for arg in op .fn .args ]
177204 if self .user_fns [fn_id ](value , args ):
178205 satisfied = True
@@ -181,7 +208,9 @@ def _check_cons(self, value: BinaryStr, context: dict[int, BinaryStr],
181208 return False
182209 return True
183210
184- def _match (self , name : FormalName , context : dict [int , BinaryStr ]) -> Iterator [tuple [int , dict [int , BinaryStr ]]]:
211+ def _match (
212+ self , name : FormalName , context : dict [int , BinaryStr ]
213+ ) -> Iterator [tuple [int , dict [int , BinaryStr ]]]:
185214 cur = self .model .start_id
186215 edge_index = - 1
187216 edge_indices = []
@@ -239,7 +268,9 @@ def _match(self, name: FormalName, context: dict[int, BinaryStr]) -> Iterator[tu
239268 del context [last_tag ]
240269 cur = node .parent
241270
242- def match (self , name : NonStrictName ) -> Iterator [tuple [list [str ], dict [str , BinaryStr ]]]:
271+ def match (
272+ self , name : NonStrictName
273+ ) -> Iterator [tuple [list [str ], dict [str , BinaryStr ]]]:
243274 """
244275 Iterate all matches of a given name.
245276
@@ -257,7 +288,7 @@ def match(self, name: NonStrictName) -> Iterator[tuple[list[str], dict[str, Bina
257288 if node .rule_name :
258289 rule_name = node .rule_name
259290 else :
260- rule_name = ['#_' + str (node_id )]
291+ rule_name = ["#_" + str (node_id )]
261292 yield rule_name , self ._context_to_name (context )
262293
263294 def check (self , pkt_name : NonStrictName , key_name : NonStrictName ) -> bool :
@@ -302,14 +333,19 @@ def suggest(self, pkt_name: NonStrictName, keychain: Keychain) -> FormalName:
302333 if self .check (pkt_name , cert_name ):
303334 cert = parse_certificate (key [cert_name ].data )
304335 # This is to avoid self-signed certificate
305- if (not cert .signature_info or not cert .signature_info .key_locator
306- or not cert .signature_info .key_locator .name ):
336+ if (
337+ not cert .signature_info
338+ or not cert .signature_info .key_locator
339+ or not cert .signature_info .key_locator .name
340+ ):
307341 continue
308342 if self .check (cert_name , cert .signature_info .key_locator .name ):
309343 return cert_name
310344
311345
312346DEFAULT_USER_FNS = {
313- '$eq' : lambda c , args : all (x == c for x in args ),
314- '$eq_type' : lambda c , args : all (Component .get_type (x ) == Component .get_type (c ) for x in args ),
347+ "$eq" : lambda c , args : all (x == c for x in args ),
348+ "$eq_type" : lambda c , args : all (
349+ Component .get_type (x ) == Component .get_type (c ) for x in args
350+ ),
315351}
0 commit comments