2828HAS_NUMBA = find_spec ("numba" ) is not None
2929HAS_PYARROW = find_spec ("pyarrow" ) is not None
3030
31+ if HAS_NUMPY :
32+ import numpy as np
33+ if HAS_SCIPY :
34+ import scipy # type: ignore # noqa: PGH003
35+ if HAS_JAX :
36+ import jax # type: ignore # noqa: PGH003
37+ import jax .numpy as jnp # type: ignore # noqa: PGH003
38+ if HAS_XARRAY :
39+ import xarray # type: ignore # noqa: PGH003
40+ if HAS_TENSORFLOW :
41+ import tensorflow as tf # type: ignore # noqa: PGH003
42+ if HAS_SQLALCHEMY :
43+ import sqlalchemy # type: ignore # noqa: PGH003
44+ if HAS_PYARROW :
45+ import pyarrow as pa # type: ignore # noqa: PGH003
46+ if HAS_PANDAS :
47+ import pandas # noqa: ICN001
48+ if HAS_TORCH :
49+ import torch # type: ignore # noqa: PGH003
50+ if HAS_NUMBA :
51+ import numba # type: ignore # noqa: PGH003
52+ from numba .core .dispatcher import Dispatcher # type: ignore # noqa: PGH003
53+ from numba .typed import Dict as NumbaDict # type: ignore # noqa: PGH003
54+ from numba .typed import List as NumbaList # type: ignore # noqa: PGH003
55+ if HAS_PYRSISTENT :
56+ import pyrsistent # type: ignore # noqa: PGH003
57+
3158# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
3259# These paths vary between test runs but are logically equivalent
3360PYTEST_TEMP_PATH_PATTERN = re .compile (r"/tmp/pytest-of-[^/]+/pytest-\d+/" ) # noqa: S108
3663# Created by tempfile.mkdtemp() or tempfile.TemporaryDirectory()
3764PYTHON_TEMPFILE_PATTERN = re .compile (r"/tmp/tmp[a-zA-Z0-9_]+/" ) # noqa: S108
3865
66+ _DICT_KEYS_TYPE = type ({}.keys ())
67+ _DICT_VALUES_TYPE = type ({}.values ())
68+ _DICT_ITEMS_TYPE = type ({}.items ())
69+
70+ _EQUALITY_TYPES = (
71+ int ,
72+ bool ,
73+ complex ,
74+ type (None ),
75+ type (Ellipsis ),
76+ decimal .Decimal ,
77+ set ,
78+ bytes ,
79+ bytearray ,
80+ memoryview ,
81+ frozenset ,
82+ enum .Enum ,
83+ type ,
84+ range ,
85+ slice ,
86+ OrderedDict ,
87+ types .GenericAlias ,
88+ * ((_union_type ,) if (_union_type := getattr (types , "UnionType" , None )) else ()),
89+ )
90+
3991
4092def _normalize_temp_path (path : str ) -> str :
4193 """Normalize temporary file paths by replacing session-specific components.
@@ -145,27 +197,7 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
145197 return _normalize_temp_path (orig ) == _normalize_temp_path (new )
146198 return False
147199
148- _equality_types = (
149- int ,
150- bool ,
151- complex ,
152- type (None ),
153- type (Ellipsis ),
154- decimal .Decimal ,
155- set ,
156- bytes ,
157- bytearray ,
158- memoryview ,
159- frozenset ,
160- enum .Enum ,
161- type ,
162- range ,
163- slice ,
164- OrderedDict ,
165- types .GenericAlias ,
166- * ((_union_type ,) if (_union_type := getattr (types , "UnionType" , None )) else ()),
167- )
168- if isinstance (orig , _equality_types ):
200+ if isinstance (orig , _EQUALITY_TYPES ):
169201 return orig == new
170202 if isinstance (orig , float ):
171203 if math .isnan (orig ) and math .isnan (new ):
@@ -184,9 +216,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
184216 return comparator (orig_referent , new_referent , superset_obj )
185217
186218 if HAS_JAX :
187- import jax # type: ignore # noqa: PGH003
188- import jax .numpy as jnp # type: ignore # noqa: PGH003
189-
190219 # Handle JAX arrays first to avoid boolean context errors in other conditions
191220 if isinstance (orig , jax .Array ):
192221 if orig .dtype != new .dtype :
@@ -197,15 +226,11 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
197226
198227 # Handle xarray objects before numpy to avoid boolean context errors
199228 if HAS_XARRAY :
200- import xarray # type: ignore # noqa: PGH003
201-
202229 if isinstance (orig , (xarray .Dataset , xarray .DataArray )):
203230 return orig .identical (new )
204231
205232 # Handle TensorFlow objects early to avoid boolean context errors
206233 if HAS_TENSORFLOW :
207- import tensorflow as tf # type: ignore # noqa: PGH003
208-
209234 if isinstance (orig , tf .Tensor ):
210235 if orig .dtype != new .dtype :
211236 return False
@@ -231,7 +256,9 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
231256 if not comparator (orig .dense_shape .numpy (), new .dense_shape .numpy (), superset_obj ):
232257 return False
233258 return comparator (orig .indices .numpy (), new .indices .numpy (), superset_obj ) and comparator (
234- orig .values .numpy (), new .values .numpy (), superset_obj
259+ orig .values .numpy (), # noqa: PD011
260+ new .values .numpy (), # noqa: PD011
261+ superset_obj ,
235262 )
236263
237264 if isinstance (orig , tf .RaggedTensor ):
@@ -242,8 +269,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
242269 return comparator (orig .to_list (), new .to_list (), superset_obj )
243270
244271 if HAS_SQLALCHEMY :
245- import sqlalchemy # type: ignore # noqa: PGH003
246-
247272 try :
248273 insp = sqlalchemy .inspection .inspect (orig )
249274 insp = sqlalchemy .inspection .inspect (new )
@@ -259,8 +284,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
259284 except sqlalchemy .exc .NoInspectionAvailable :
260285 pass
261286
262- if HAS_SCIPY :
263- import scipy # type: ignore # noqa: PGH003
264287 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
265288 if isinstance (orig , dict ) and not (HAS_SCIPY and isinstance (orig , scipy .sparse .spmatrix )):
266289 if superset_obj :
@@ -279,21 +302,14 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
279302 return comparator (dict (orig ), dict (new ), superset_obj )
280303
281304 # Handle dict view types (dict_keys, dict_values, dict_items)
282- # Use type name checking since these are not directly importable types
283- type_name = type (orig ).__name__
284- if type_name == "dict_keys" :
285- # dict_keys can be compared as sets (order doesn't matter)
305+ if isinstance (orig , _DICT_KEYS_TYPE ):
286306 return comparator (set (orig ), set (new ))
287- if type_name == "dict_values" :
288- # dict_values need element-wise comparison (order matters)
307+ if isinstance (orig , _DICT_VALUES_TYPE ):
289308 return comparator (list (orig ), list (new ))
290- if type_name == "dict_items" :
291- # Convert to dict for order-insensitive comparison (handles unhashable values)
309+ if isinstance (orig , _DICT_ITEMS_TYPE ):
292310 return comparator (dict (orig ), dict (new ), superset_obj )
293311
294312 if HAS_NUMPY :
295- import numpy as np
296-
297313 if isinstance (orig , (np .datetime64 , np .timedelta64 )):
298314 # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
299315 if np .isnat (orig ) and np .isnat (new ):
@@ -355,8 +371,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
355371 return (orig != new ).nnz == 0
356372
357373 if HAS_PYARROW :
358- import pyarrow as pa # type: ignore # noqa: PGH003
359-
360374 if isinstance (orig , pa .Table ):
361375 if orig .schema != new .schema :
362376 return False
@@ -399,8 +413,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
399413 return bool (orig .equals (new ))
400414
401415 if HAS_PANDAS :
402- import pandas # noqa: ICN001
403-
404416 if isinstance (
405417 orig , (pandas .DataFrame , pandas .Series , pandas .Index , pandas .Categorical , pandas .arrays .SparseArray )
406418 ):
@@ -431,8 +443,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
431443 pass
432444
433445 if HAS_TORCH :
434- import torch # type: ignore # noqa: PGH003
435-
436446 if isinstance (orig , torch .Tensor ):
437447 if orig .dtype != new .dtype :
438448 return False
@@ -451,11 +461,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
451461 return orig == new
452462
453463 if HAS_NUMBA :
454- import numba
455- from numba .core .dispatcher import Dispatcher
456- from numba .typed import Dict as NumbaDict
457- from numba .typed import List as NumbaList
458-
459464 # Handle numba typed List
460465 if isinstance (orig , NumbaList ):
461466 if len (orig ) != len (new ):
@@ -487,8 +492,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
487492 return orig .py_func is new .py_func
488493
489494 if HAS_PYRSISTENT :
490- import pyrsistent # type: ignore # noqa: PGH003
491-
492495 if isinstance (
493496 orig ,
494497 (
@@ -534,7 +537,7 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
534537 # If the object passed has a user defined __eq__ method, use that
535538 # This could fail if the user defined __eq__ is defined with C-extensions
536539 try :
537- if hasattr (orig , "__eq__" ) and str ( type ( orig .__eq__ )) == "<class 'method'>" :
540+ if hasattr (orig , "__eq__" ) and isinstance ( orig .__eq__ , types . MethodType ) :
538541 return orig == new
539542 except Exception :
540543 pass
0 commit comments