Skip to content

Commit 5f3fdab

Browse files
authored
Fix for new jax version. (#22)
1 parent aa0e218 commit 5f3fdab

6 files changed

Lines changed: 84 additions & 34 deletions

File tree

.pre-commit-config.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.3.0
3+
rev: v4.4.0
44
hooks:
55
- id: check-merge-conflict
66
- id: debug-statements
77
- id: end-of-file-fixer
88
- repo: https://github.com/asottile/reorder_python_imports
9-
rev: v3.1.0
9+
rev: v3.9.0
1010
hooks:
1111
- id: reorder-python-imports
1212
types: [python]
1313
- repo: https://github.com/pre-commit/pre-commit-hooks
14-
rev: v4.3.0
14+
rev: v4.4.0
1515
hooks:
1616
- id: check-added-large-files
1717
args: ['--maxkb=100']
@@ -45,12 +45,12 @@ repos:
4545
additional_dependencies: [black==22.3.0]
4646
types: [rst]
4747
- repo: https://github.com/psf/black
48-
rev: 22.3.0
48+
rev: 22.12.0
4949
hooks:
5050
- id: black
51-
language_version: python3.9
51+
language_version: python3.10
5252
- repo: https://github.com/PyCQA/flake8
53-
rev: 4.0.1
53+
rev: 5.0.4
5454
hooks:
5555
- id: flake8
5656
types: [python]
@@ -71,7 +71,7 @@ repos:
7171
Pygments,
7272
]
7373
- repo: https://github.com/PyCQA/doc8
74-
rev: 0.11.2
74+
rev: v1.0.0
7575
hooks:
7676
- id: doc8
7777
- repo: meta
@@ -80,15 +80,15 @@ repos:
8080
- id: check-useless-excludes
8181
# - id: identity # Prints all files passed to pre-commits. Debugging.
8282
- repo: https://github.com/mgedmin/check-manifest
83-
rev: "0.48"
83+
rev: "0.49"
8484
hooks:
8585
- id: check-manifest
8686
- repo: https://github.com/PyCQA/doc8
87-
rev: 0.11.2
87+
rev: v1.0.0
8888
hooks:
8989
- id: doc8
9090
- repo: https://github.com/asottile/setup-cfg-fmt
91-
rev: v1.20.1
91+
rev: v2.2.0
9292
hooks:
9393
- id: setup-cfg-fmt
9494
- repo: https://github.com/econchick/interrogate
@@ -98,11 +98,11 @@ repos:
9898
args: [-v, --fail-under=20]
9999
exclude: ^(tests|docs|setup\.py)
100100
- repo: https://github.com/codespell-project/codespell
101-
rev: v2.1.0
101+
rev: v2.2.2
102102
hooks:
103103
- id: codespell
104104
- repo: https://github.com/asottile/pyupgrade
105-
rev: v2.34.0
105+
rev: v3.3.1
106106
hooks:
107107
- id: pyupgrade
108108
args: [--py37-plus]

setup.cfg

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ classifiers =
1717
Operating System :: POSIX
1818
Programming Language :: Python :: 3
1919
Programming Language :: Python :: 3 :: Only
20-
Programming Language :: Python :: 3.7
21-
Programming Language :: Python :: 3.8
22-
Programming Language :: Python :: 3.9
23-
Programming Language :: Python :: 3.10
2420
Topic :: Scientific/Engineering
2521
Topic :: Utilities
2622

src/pybaum/equality.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@
1111
if IS_PANDAS_INSTALLED:
1212
import pandas as pd
1313

14-
if IS_JAX_INSTALLED:
15-
import jaxlib
16-
17-
1814
EQUALITY_CHECKERS = {}
1915

2016

@@ -28,6 +24,4 @@
2824

2925

3026
if IS_JAX_INSTALLED:
31-
EQUALITY_CHECKERS[jaxlib.xla_extension.DeviceArray] = lambda a, b: bool(
32-
(a == b).all()
33-
)
27+
EQUALITY_CHECKERS["jax.numpy.ndarray"] = lambda a, b: bool((a == b).all())

src/pybaum/registry_entries.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import itertools
2-
from collections import namedtuple
32
from collections import OrderedDict
43
from itertools import product
54

@@ -15,7 +14,6 @@
1514

1615
if IS_JAX_INSTALLED:
1716
import jax
18-
import jaxlib
1917

2018

2119
def _none():
@@ -69,7 +67,7 @@ def _tuple():
6967
def _namedtuple():
7068
"""Create registry entry for namedtuple and NamedTuple."""
7169
entry = {
72-
namedtuple: {
70+
"namedtuple": {
7371
"flatten": lambda tree: (list(tree), tree),
7472
"unflatten": _unflatten_namedtuple,
7573
"names": lambda tree: list(tree._fields),
@@ -125,7 +123,7 @@ def _array_element_names(arr):
125123
def _jax_array():
126124
if IS_JAX_INSTALLED:
127125
entry = {
128-
jaxlib.xla_extension.DeviceArray: {
126+
"jax.numpy.ndarray": {
129127
"flatten": lambda arr: (arr.flatten().tolist(), arr.shape),
130128
"unflatten": lambda aux_data, leaves: jax.numpy.array(leaves).reshape(
131129
aux_data

src/pybaum/typecheck.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,37 @@
1-
from collections import namedtuple
1+
from pybaum.config import IS_JAX_INSTALLED
2+
from pybaum.config import IS_NUMPY_INSTALLED
3+
4+
if IS_JAX_INSTALLED:
5+
import jax.numpy as jnp
6+
7+
if IS_NUMPY_INSTALLED:
8+
import numpy as np
29

310

411
def get_type(obj):
5-
"""namdetuple aware type check.
12+
"""Get type of candidate objects in a pytree.
13+
14+
This function allows us to reliably identify namedtuples, NamedTuples and jax arrays
15+
for which standard ``type`` function does not work.
16+
17+
Args:
18+
obj: The object to be checked
19+
20+
Returns:
21+
type or str: The type of the object or a string with the type name.
22+
23+
"""
24+
if _is_namedtuple(obj):
25+
out = "namedtuple"
26+
elif _is_jax_array(obj):
27+
out = "jax.numpy.ndarray"
28+
else:
29+
out = type(obj)
30+
return out
31+
32+
33+
def _is_namedtuple(obj):
34+
"""Check if an object is a namedtuple.
635
736
As in JAX we treat collections.namedtuple and typing.NamedTuple both as
837
namedtuple but the exact type is preserved in the unflatten function.
@@ -24,8 +53,41 @@ def get_type(obj):
2453
bool
2554
2655
"""
27-
if isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_replace"):
28-
out = namedtuple
56+
out = (
57+
isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_replace")
58+
)
59+
return out
60+
61+
62+
def _is_jax_array(obj):
63+
"""Check if an object is a jax array.
64+
65+
The exact type of jax arrays has changed over time and is an implementation detail.
66+
67+
Instead we rely on isinstance checks which will likely be more stable in the future.
68+
However, the behavior of isinstance for jax arrays has also changed over time. For
69+
jax versions before 0.2.21, standard numpy arrays were instances of jax arrays,
70+
now they are not.
71+
72+
Resources:
73+
----------
74+
75+
- https://github.com/google/jax/issues/2115
76+
- https://github.com/google/jax/issues/2014
77+
- https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0221-sept-23-2021
78+
- https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0318-sep-26-2022
79+
80+
Args:
81+
obj: The object to be checked
82+
83+
Returns:
84+
bool
85+
86+
"""
87+
if not IS_JAX_INSTALLED:
88+
out = False
89+
elif IS_NUMPY_INSTALLED:
90+
out = isinstance(obj, jnp.ndarray) and not isinstance(obj, np.ndarray)
2991
else:
30-
out = type(obj)
92+
out = isinstance(obj, jnp.ndarray)
3193
return out

tests/test_typecheck.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def test_namedtuple_is_discovered():
88
bla = namedtuple("bla", ["a", "b"])(1, 2)
9-
assert get_type(bla) == namedtuple
9+
assert get_type(bla) == "namedtuple"
1010

1111

1212
def test_typed_namedtuple_is_discovered():
@@ -15,7 +15,7 @@ class Blubb(NamedTuple):
1515
b: int
1616

1717
blubb = Blubb(1, 2)
18-
assert get_type(blubb) == namedtuple
18+
assert get_type(blubb) == "namedtuple"
1919

2020

2121
def test_standard_tuple_is_not_discovered():

0 commit comments

Comments
 (0)