Skip to content

Commit cf70971

Browse files
committed
fix: precommit check
1 parent 1f403ad commit cf70971

5 files changed

Lines changed: 66 additions & 40 deletions

File tree

.pylintrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[MASTER]
2+
extension-pkg-allow-list=numpy,mkl_random.mklrand
3+
4+
[TYPECHECK]
5+
generated-members=RandomState,min,max

mkl_random/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,17 @@
9393
test = PytestTester(__name__)
9494
del PytestTester
9595

96-
from ._patch import monkey_patch, use_in_numpy, restore, is_patched, patched_names, mkl_random
9796
from mkl_random import interfaces
9897

98+
from ._patch import (
99+
is_patched,
100+
mkl_random,
101+
monkey_patch,
102+
patched_names,
103+
restore,
104+
use_in_numpy,
105+
)
106+
99107
__all__ = [
100108
"MKLRandomState",
101109
"RandomState",

mkl_random/src/_patch.pyx

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ replace NumPy's `Generator`/`default_rng()` unless mkl_random provides fully
3535
compatible replacements.
3636
"""
3737

38-
from threading import Lock, local
3938
from contextlib import ContextDecorator
39+
from threading import Lock, local
4040

4141
import numpy as _np
42+
4243
from . import mklrand as _mr
4344

4445

@@ -106,7 +107,9 @@ class _GlobalPatch:
106107

107108
def _validate_module(self, numpy_module):
108109
if not hasattr(numpy_module, "random"):
109-
raise TypeError("Expected a numpy-like module with a `.random` attribute.")
110+
raise TypeError(
111+
"Expected a numpy-like module with a `.random` attribute."
112+
)
110113

111114
def _apply_patch(self, numpy_module, names, strict):
112115
np_random = numpy_module.random
@@ -125,7 +128,8 @@ class _GlobalPatch:
125128
for name, value in originals.items():
126129
setattr(np_random, name, value)
127130
raise AttributeError(
128-
"Could not patch these names (missing on numpy.random or mkl_random.mklrand): "
131+
"Could not patch these names (missing on numpy.random or "
132+
"mkl_random.mklrand): "
129133
+ ", ".join([str(x) for x in missing])
130134
)
131135

@@ -134,7 +138,13 @@ class _GlobalPatch:
134138
self._originals = originals
135139
self._patched = tuple(patched)
136140

137-
def do_patch(self, numpy_module=None, names=None, strict=False, verbose=False):
141+
def do_patch(
142+
self,
143+
numpy_module=None,
144+
names=None,
145+
strict=False,
146+
verbose=False,
147+
):
138148
if numpy_module is None:
139149
numpy_module = _np
140150
names = self._normalize_names(names)
@@ -148,11 +158,13 @@ class _GlobalPatch:
148158
else:
149159
if self._numpy_module is not numpy_module:
150160
raise RuntimeError(
151-
"Already patched a different numpy module; call restore() first."
161+
"Already patched a different numpy module; "
162+
"call restore() first."
152163
)
153164
if names != self._requested_names:
154165
raise RuntimeError(
155-
"Already patched with a different names set; call restore() first."
166+
"Already patched with a different names set; "
167+
"call restore() first."
156168
)
157169
self._patch_count += 1
158170
self._tls.local_count = local_count + 1
@@ -163,7 +175,8 @@ class _GlobalPatch:
163175
if local_count <= 0:
164176
if verbose:
165177
print(
166-
"Warning: restore called more times than monkey_patch in this thread."
178+
"Warning: restore called more times than monkey_patch "
179+
"in this thread."
167180
)
168181
return
169182

@@ -192,7 +205,8 @@ _patch = _GlobalPatch()
192205

193206
def monkey_patch(numpy_module=None, names=None, strict=False, verbose=False):
194207
"""
195-
Enables using mkl_random in the given NumPy module by patching `numpy.random`.
208+
Enables using mkl_random in the given NumPy module by patching
209+
`numpy.random`.
196210
197211
Examples
198212
--------
@@ -229,7 +243,8 @@ def use_in_numpy(numpy_module=None, names=None, strict=False, verbose=False):
229243

230244
def restore(verbose=False):
231245
"""
232-
Disables using mkl_random in NumPy by restoring the original `numpy.random` symbols.
246+
Disables using mkl_random in NumPy by restoring the original
247+
`numpy.random` symbols.
233248
"""
234249
_patch.do_restore(verbose=bool(verbose))
235250

@@ -265,7 +280,11 @@ class mkl_random(ContextDecorator):
265280
self._strict = strict
266281

267282
def __enter__(self):
268-
monkey_patch(numpy_module=self._numpy_module, names=self._names, strict=self._strict)
283+
monkey_patch(
284+
numpy_module=self._numpy_module,
285+
names=self._names,
286+
strict=self._strict,
287+
)
269288
return self
270289

271290
def __exit__(self, *exc):

mkl_random/tests/test_patch.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,22 @@
2424
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2525

2626
import numpy as np
27-
import mkl_random
2827
import pytest
2928

29+
import mkl_random
30+
31+
3032
def test_is_patched():
31-
"""
32-
Test that is_patched() returns correct status.
33-
"""
33+
"""Test that is_patched() returns correct status."""
3434
assert not mkl_random.is_patched()
3535
mkl_random.monkey_patch(np)
3636
assert mkl_random.is_patched()
3737
mkl_random.restore()
3838
assert not mkl_random.is_patched()
3939

40+
4041
def test_monkey_patch_and_restore():
41-
"""
42-
Test that monkey_patch replaces and restore brings back original functions.
43-
"""
42+
"""Test monkey_patch replacement and restore of original functions."""
4443
# Store original functions
4544
orig_normal = np.random.normal
4645
orig_randint = np.random.randint
@@ -55,8 +54,8 @@ def test_monkey_patch_and_restore():
5554
assert np.random.RandomState is not orig_RandomState
5655

5756
# Check that they are from mkl_random
58-
assert np.random.normal is mkl_random.mklrand.normal
59-
assert np.random.RandomState is mkl_random.mklrand.RandomState
57+
assert np.random.normal is mkl_random.normal
58+
assert np.random.RandomState is mkl_random.RandomState
6059

6160
finally:
6261
mkl_random.restore()
@@ -67,10 +66,9 @@ def test_monkey_patch_and_restore():
6766
assert np.random.randint is orig_randint
6867
assert np.random.RandomState is orig_RandomState
6968

69+
7070
def test_context_manager():
71-
"""
72-
Test that the context manager patches and automatically restores.
73-
"""
71+
"""Test context manager patching and automatic restoration."""
7472
orig_uniform = np.random.uniform
7573
assert not mkl_random.is_patched()
7674

@@ -84,10 +82,9 @@ def test_context_manager():
8482
assert not mkl_random.is_patched()
8583
assert np.random.uniform is orig_uniform
8684

85+
8786
def test_patched_functions_callable():
88-
"""
89-
Smoke test to ensure some patched functions can be called without error.
90-
"""
87+
"""Smoke test that patched functions are callable without errors."""
9188
mkl_random.monkey_patch(np)
9289
try:
9390
# These calls should now be routed to mkl_random's implementations
@@ -105,10 +102,9 @@ def test_patched_functions_callable():
105102
finally:
106103
mkl_random.restore()
107104

105+
108106
def test_patched_names():
109-
"""
110-
Test that patched_names() returns a list of patched symbols.
111-
"""
107+
"""Test that patched_names() returns patched symbol names."""
112108
try:
113109
mkl_random.monkey_patch(np)
114110
names = mkl_random.patched_names()
@@ -119,21 +115,20 @@ def test_patched_names():
119115
finally:
120116
mkl_random.restore()
121117

118+
122119
def test_monkey_patch_strict_raises_attribute_error():
123-
"""
124-
Test that strict mode raises AttributeError when patching non-existent names.
125-
"""
120+
"""Test strict mode raises AttributeError for missing patch names."""
126121
# Attempt to patch a clearly non-existent symbol in strict mode.
127122
with pytest.raises(AttributeError):
128123
mkl_random.monkey_patch(np, strict=True, names=["nonexistent_symbol"])
129124

125+
130126
def test_use_in_numpy_is_alias_for_monkey_patch():
131-
"""
132-
Test that use_in_numpy is a backward-compatible alias for monkey_patch.
133-
"""
127+
"""Test use_in_numpy remains a backward-compatible alias."""
134128
assert hasattr(mkl_random, "use_in_numpy")
135129
assert mkl_random.use_in_numpy is mkl_random.monkey_patch
136130

131+
137132
def test_patch_redundant_patching():
138133
orig_normal = np.random.normal
139134
assert not mkl_random.is_patched()
@@ -142,11 +137,11 @@ def test_patch_redundant_patching():
142137
mkl_random.monkey_patch(np)
143138

144139
assert mkl_random.is_patched()
145-
assert np.random.normal is mkl_random.mklrand.normal
140+
assert np.random.normal is mkl_random.normal
146141

147142
mkl_random.restore()
148143
assert mkl_random.is_patched()
149-
assert np.random.normal is mkl_random.mklrand.normal
144+
assert np.random.normal is mkl_random.normal
150145

151146
mkl_random.restore()
152147
assert not mkl_random.is_patched()

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,15 @@ def extensions():
9191
library_dirs=lib_dirs,
9292
extra_compile_args=eca,
9393
define_macros=defs + [("NDEBUG", None)],
94-
language="c++"
94+
language="c++",
9595
),
96-
9796
Extension(
9897
"mkl_random._patch",
9998
sources=[join("mkl_random", "src", "_patch.pyx")],
10099
include_dirs=[np.get_include()],
101100
define_macros=defs + [("NDEBUG", None)],
102101
language="c",
103-
)
102+
),
104103
]
105104

106105
return exts

0 commit comments

Comments
 (0)