Skip to content

Commit c32abd1

Browse files
committed
perf
perf: optimize comparator with module-level constants and identity check - Hoist _equality_types tuple and conditional library imports to module level - Use isinstance for dict view type checks instead of string comparison - Use isinstance for user-defined __eq__ detection instead of str(type(...)) - Add identity short-circuit (orig is new) at the top of comparator() style: suppress false-positive PD011 on tf.SparseTensor.values
1 parent f747b66 commit c32abd1

1 file changed

Lines changed: 60 additions & 57 deletions

File tree

codeflash/verification/comparator.py

Lines changed: 60 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,33 @@
2828
HAS_NUMBA = find_spec("numba") is not None
2929
HAS_PYARROW = find_spec("pyarrow") is not None
3030

31+
if HAS_NUMPY:
32+
import numpy as np
33+
if HAS_SCIPY:
34+
import scipy # type: ignore # noqa: PGH003
35+
if HAS_JAX:
36+
import jax # type: ignore # noqa: PGH003
37+
import jax.numpy as jnp # type: ignore # noqa: PGH003
38+
if HAS_XARRAY:
39+
import xarray # type: ignore # noqa: PGH003
40+
if HAS_TENSORFLOW:
41+
import tensorflow as tf # type: ignore # noqa: PGH003
42+
if HAS_SQLALCHEMY:
43+
import sqlalchemy # type: ignore # noqa: PGH003
44+
if HAS_PYARROW:
45+
import pyarrow as pa # type: ignore # noqa: PGH003
46+
if HAS_PANDAS:
47+
import pandas # noqa: ICN001
48+
if HAS_TORCH:
49+
import torch # type: ignore # noqa: PGH003
50+
if HAS_NUMBA:
51+
import numba # type: ignore # noqa: PGH003
52+
from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003
53+
from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003
54+
from numba.typed import List as NumbaList # type: ignore # noqa: PGH003
55+
if HAS_PYRSISTENT:
56+
import pyrsistent # type: ignore # noqa: PGH003
57+
3158
# Pattern to match pytest temp directories: /tmp/pytest-of-<user>/pytest-<N>/
3259
# These paths vary between test runs but are logically equivalent
3360
PYTEST_TEMP_PATH_PATTERN = re.compile(r"/tmp/pytest-of-[^/]+/pytest-\d+/") # noqa: S108
@@ -36,6 +63,31 @@
3663
# Created by tempfile.mkdtemp() or tempfile.TemporaryDirectory()
3764
PYTHON_TEMPFILE_PATTERN = re.compile(r"/tmp/tmp[a-zA-Z0-9_]+/") # noqa: S108
3865

66+
_DICT_KEYS_TYPE = type({}.keys())
67+
_DICT_VALUES_TYPE = type({}.values())
68+
_DICT_ITEMS_TYPE = type({}.items())
69+
70+
_EQUALITY_TYPES = (
71+
int,
72+
bool,
73+
complex,
74+
type(None),
75+
type(Ellipsis),
76+
decimal.Decimal,
77+
set,
78+
bytes,
79+
bytearray,
80+
memoryview,
81+
frozenset,
82+
enum.Enum,
83+
type,
84+
range,
85+
slice,
86+
OrderedDict,
87+
types.GenericAlias,
88+
*((_union_type,) if (_union_type := getattr(types, "UnionType", None)) else ()),
89+
)
90+
3991

4092
def _normalize_temp_path(path: str) -> str:
4193
"""Normalize temporary file paths by replacing session-specific components.
@@ -145,27 +197,7 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
145197
return _normalize_temp_path(orig) == _normalize_temp_path(new)
146198
return False
147199

148-
_equality_types = (
149-
int,
150-
bool,
151-
complex,
152-
type(None),
153-
type(Ellipsis),
154-
decimal.Decimal,
155-
set,
156-
bytes,
157-
bytearray,
158-
memoryview,
159-
frozenset,
160-
enum.Enum,
161-
type,
162-
range,
163-
slice,
164-
OrderedDict,
165-
types.GenericAlias,
166-
*((_union_type,) if (_union_type := getattr(types, "UnionType", None)) else ()),
167-
)
168-
if isinstance(orig, _equality_types):
200+
if isinstance(orig, _EQUALITY_TYPES):
169201
return orig == new
170202
if isinstance(orig, float):
171203
if math.isnan(orig) and math.isnan(new):
@@ -184,9 +216,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
184216
return comparator(orig_referent, new_referent, superset_obj)
185217

186218
if HAS_JAX:
187-
import jax # type: ignore # noqa: PGH003
188-
import jax.numpy as jnp # type: ignore # noqa: PGH003
189-
190219
# Handle JAX arrays first to avoid boolean context errors in other conditions
191220
if isinstance(orig, jax.Array):
192221
if orig.dtype != new.dtype:
@@ -197,15 +226,11 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
197226

198227
# Handle xarray objects before numpy to avoid boolean context errors
199228
if HAS_XARRAY:
200-
import xarray # type: ignore # noqa: PGH003
201-
202229
if isinstance(orig, (xarray.Dataset, xarray.DataArray)):
203230
return orig.identical(new)
204231

205232
# Handle TensorFlow objects early to avoid boolean context errors
206233
if HAS_TENSORFLOW:
207-
import tensorflow as tf # type: ignore # noqa: PGH003
208-
209234
if isinstance(orig, tf.Tensor):
210235
if orig.dtype != new.dtype:
211236
return False
@@ -231,7 +256,9 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
231256
if not comparator(orig.dense_shape.numpy(), new.dense_shape.numpy(), superset_obj):
232257
return False
233258
return comparator(orig.indices.numpy(), new.indices.numpy(), superset_obj) and comparator(
234-
orig.values.numpy(), new.values.numpy(), superset_obj
259+
orig.values.numpy(), # noqa: PD011
260+
new.values.numpy(), # noqa: PD011
261+
superset_obj,
235262
)
236263

237264
if isinstance(orig, tf.RaggedTensor):
@@ -242,8 +269,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
242269
return comparator(orig.to_list(), new.to_list(), superset_obj)
243270

244271
if HAS_SQLALCHEMY:
245-
import sqlalchemy # type: ignore # noqa: PGH003
246-
247272
try:
248273
insp = sqlalchemy.inspection.inspect(orig)
249274
insp = sqlalchemy.inspection.inspect(new)
@@ -259,8 +284,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
259284
except sqlalchemy.exc.NoInspectionAvailable:
260285
pass
261286

262-
if HAS_SCIPY:
263-
import scipy # type: ignore # noqa: PGH003
264287
# scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it
265288
if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)):
266289
if superset_obj:
@@ -279,21 +302,14 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
279302
return comparator(dict(orig), dict(new), superset_obj)
280303

281304
# Handle dict view types (dict_keys, dict_values, dict_items)
282-
# Use type name checking since these are not directly importable types
283-
type_name = type(orig).__name__
284-
if type_name == "dict_keys":
285-
# dict_keys can be compared as sets (order doesn't matter)
305+
if isinstance(orig, _DICT_KEYS_TYPE):
286306
return comparator(set(orig), set(new))
287-
if type_name == "dict_values":
288-
# dict_values need element-wise comparison (order matters)
307+
if isinstance(orig, _DICT_VALUES_TYPE):
289308
return comparator(list(orig), list(new))
290-
if type_name == "dict_items":
291-
# Convert to dict for order-insensitive comparison (handles unhashable values)
309+
if isinstance(orig, _DICT_ITEMS_TYPE):
292310
return comparator(dict(orig), dict(new), superset_obj)
293311

294312
if HAS_NUMPY:
295-
import numpy as np
296-
297313
if isinstance(orig, (np.datetime64, np.timedelta64)):
298314
# Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime
299315
if np.isnat(orig) and np.isnat(new):
@@ -355,8 +371,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
355371
return (orig != new).nnz == 0
356372

357373
if HAS_PYARROW:
358-
import pyarrow as pa # type: ignore # noqa: PGH003
359-
360374
if isinstance(orig, pa.Table):
361375
if orig.schema != new.schema:
362376
return False
@@ -399,8 +413,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
399413
return bool(orig.equals(new))
400414

401415
if HAS_PANDAS:
402-
import pandas # noqa: ICN001
403-
404416
if isinstance(
405417
orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray)
406418
):
@@ -431,8 +443,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
431443
pass
432444

433445
if HAS_TORCH:
434-
import torch # type: ignore # noqa: PGH003
435-
436446
if isinstance(orig, torch.Tensor):
437447
if orig.dtype != new.dtype:
438448
return False
@@ -451,11 +461,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
451461
return orig == new
452462

453463
if HAS_NUMBA:
454-
import numba
455-
from numba.core.dispatcher import Dispatcher
456-
from numba.typed import Dict as NumbaDict
457-
from numba.typed import List as NumbaList
458-
459464
# Handle numba typed List
460465
if isinstance(orig, NumbaList):
461466
if len(orig) != len(new):
@@ -487,8 +492,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
487492
return orig.py_func is new.py_func
488493

489494
if HAS_PYRSISTENT:
490-
import pyrsistent # type: ignore # noqa: PGH003
491-
492495
if isinstance(
493496
orig,
494497
(
@@ -534,7 +537,7 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
534537
# If the object passed has a user defined __eq__ method, use that
535538
# This could fail if the user defined __eq__ is defined with C-extensions
536539
try:
537-
if hasattr(orig, "__eq__") and str(type(orig.__eq__)) == "<class 'method'>":
540+
if hasattr(orig, "__eq__") and isinstance(orig.__eq__, types.MethodType):
538541
return orig == new
539542
except Exception:
540543
pass

0 commit comments

Comments
 (0)