Skip to content

Commit aa0e218

Browse files
authored
Add JAX array to registry (#20)
1 parent 6ff956b commit aa0e218

11 files changed

Lines changed: 149 additions & 27 deletions

File tree

.github/workflows/main.yml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
strategy:
2323
fail-fast: false
2424
matrix:
25-
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
25+
os: ['ubuntu-latest', 'macos-latest']
2626
python-version: ['3.7', '3.8', '3.9', '3.10']
2727

2828
steps:
@@ -46,6 +46,32 @@ jobs:
4646
with:
4747
token: ${{ secrets.CODECOV_TOKEN }}
4848

49+
run-tests-windows:
50+
51+
name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }}
52+
runs-on: ${{ matrix.os }}
53+
54+
strategy:
55+
fail-fast: false
56+
matrix:
57+
os: ['windows-latest']
58+
python-version: ['3.7', '3.8', '3.9', '3.10']
59+
60+
steps:
61+
- uses: actions/checkout@v2
62+
- uses: conda-incubator/setup-miniconda@v2
63+
with:
64+
auto-update-conda: true
65+
python-version: ${{ matrix.python-version }}
66+
67+
- name: Install core dependencies.
68+
shell: bash -l {0}
69+
run: conda install -c conda-forge tox-conda
70+
71+
- name: Run pytest.
72+
shell: bash -l {0}
73+
run: tox -e pytest-windows -- -m "not slow"
74+
4975
docs:
5076

5177
name: Run documentation.

.pre-commit-config.yaml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.2.0
3+
rev: v4.3.0
44
hooks:
55
- id: check-merge-conflict
66
- id: debug-statements
@@ -11,7 +11,7 @@ repos:
1111
- id: reorder-python-imports
1212
types: [python]
1313
- repo: https://github.com/pre-commit/pre-commit-hooks
14-
rev: v4.2.0
14+
rev: v4.3.0
1515
hooks:
1616
- id: check-added-large-files
1717
args: ['--maxkb=100']
@@ -42,13 +42,13 @@ repos:
4242
rev: v1.12.1
4343
hooks:
4444
- id: blacken-docs
45-
additional_dependencies: [black]
45+
additional_dependencies: [black==22.3.0]
4646
types: [rst]
4747
- repo: https://github.com/psf/black
4848
rev: 22.3.0
4949
hooks:
5050
- id: black
51-
types: [python]
51+
language_version: python3.9
5252
- repo: https://github.com/PyCQA/flake8
5353
rev: 4.0.1
5454
hooks:
@@ -71,7 +71,7 @@ repos:
7171
Pygments,
7272
]
7373
- repo: https://github.com/PyCQA/doc8
74-
rev: 0.11.1
74+
rev: 0.11.2
7575
hooks:
7676
- id: doc8
7777
- repo: meta
@@ -84,7 +84,7 @@ repos:
8484
hooks:
8585
- id: check-manifest
8686
- repo: https://github.com/PyCQA/doc8
87-
rev: 0.11.1
87+
rev: 0.11.2
8888
hooks:
8989
- id: doc8
9090
- repo: https://github.com/asottile/setup-cfg-fmt
@@ -102,7 +102,7 @@ repos:
102102
hooks:
103103
- id: codespell
104104
- repo: https://github.com/asottile/pyupgrade
105-
rev: v2.32.1
105+
rev: v2.34.0
106106
hooks:
107107
- id: pyupgrade
108108
args: [--py37-plus]

environment.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,5 @@ dependencies:
3030
- pdbpp
3131
- numpy
3232
- pandas
33+
- jax
34+
- jaxlib

src/pybaum/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,12 @@
1212
IS_PANDAS_INSTALLED = False
1313
else:
1414
IS_PANDAS_INSTALLED = True
15+
16+
17+
try:
18+
import jax # noqa: F401
19+
import jaxlib # noqa: F401
20+
except ImportError:
21+
IS_JAX_INSTALLED = False
22+
else:
23+
IS_JAX_INSTALLED = True

src/pybaum/equality.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Functions to check equality of pytree leaves."""
2+
from pybaum.config import IS_JAX_INSTALLED
23
from pybaum.config import IS_NUMPY_INSTALLED
34
from pybaum.config import IS_PANDAS_INSTALLED
45

@@ -10,6 +11,9 @@
1011
if IS_PANDAS_INSTALLED:
1112
import pandas as pd
1213

14+
if IS_JAX_INSTALLED:
15+
import jaxlib
16+
1317

1418
EQUALITY_CHECKERS = {}
1519

@@ -21,3 +25,9 @@
2125
if IS_PANDAS_INSTALLED:
2226
EQUALITY_CHECKERS[pd.Series] = lambda a, b: a.equals(b)
2327
EQUALITY_CHECKERS[pd.DataFrame] = lambda a, b: a.equals(b)
28+
29+
30+
if IS_JAX_INSTALLED:
31+
EQUALITY_CHECKERS[jaxlib.xla_extension.DeviceArray] = lambda a, b: bool(
32+
(a == b).all()
33+
)

src/pybaum/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def get_registry(types=None, include_defaults=True):
1515
- :obj:`None`
1616
- :class:`collections.OrderedDict`
1717
- "numpy.ndarray"
18+
- "jax.numpy.ndarray"
1819
- "pandas.Series"
1920
- "pandas.DataFrame"
2021
include_defaults (bool): Whether the default pytree containers "tuple", "dict"

src/pybaum/registry_entries.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import OrderedDict
44
from itertools import product
55

6+
from pybaum.config import IS_JAX_INSTALLED
67
from pybaum.config import IS_NUMPY_INSTALLED
78
from pybaum.config import IS_PANDAS_INSTALLED
89

@@ -12,6 +13,10 @@
1213
if IS_PANDAS_INSTALLED:
1314
import pandas as pd
1415

16+
if IS_JAX_INSTALLED:
17+
import jax
18+
import jaxlib
19+
1520

1621
def _none():
1722
"""Create registry entry for NoneType."""
@@ -117,6 +122,22 @@ def _array_element_names(arr):
117122
return names
118123

119124

125+
def _jax_array():
126+
if IS_JAX_INSTALLED:
127+
entry = {
128+
jaxlib.xla_extension.DeviceArray: {
129+
"flatten": lambda arr: (arr.flatten().tolist(), arr.shape),
130+
"unflatten": lambda aux_data, leaves: jax.numpy.array(leaves).reshape(
131+
aux_data
132+
),
133+
"names": _array_element_names,
134+
},
135+
}
136+
else:
137+
entry = {}
138+
return entry
139+
140+
120141
def _pandas_series():
121142
"""Create registry entry for pandas.Series."""
122143
if IS_PANDAS_INSTALLED:
@@ -186,6 +207,7 @@ def _index_element_to_string(element):
186207
"tuple": _tuple,
187208
"dict": _dict,
188209
"numpy.ndarray": _numpy_array,
210+
"jax.numpy.ndarray": _jax_array,
189211
"pandas.Series": _pandas_series,
190212
"pandas.DataFrame": _pandas_dataframe,
191213
"None": _none,

src/pybaum/tree_util.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
- The treedef containing information to unflatten pytrees is implemented differently.
77
88
"""
9-
import itertools
10-
119
from pybaum.equality import EQUALITY_CHECKERS
1210
from pybaum.registry import get_registry
1311
from pybaum.typecheck import get_type
@@ -42,8 +40,8 @@ def tree_flatten(tree, is_leaf=None, registry=None):
4240
is_leaf = _process_is_leaf(is_leaf)
4341

4442
flat = _tree_flatten(tree, is_leaf=is_leaf, registry=registry)
45-
dummy_flat = ["*"] * len(flat)
46-
treedef = tree_unflatten(tree, dummy_flat, is_leaf=is_leaf, registry=registry)
43+
# unflatten the flat tree to make a copy
44+
treedef = tree_unflatten(tree, flat, is_leaf=is_leaf, registry=registry)
4745
return flat, treedef
4846

4947

@@ -124,9 +122,7 @@ def tree_yield(tree, is_leaf=None, registry=None):
124122
is_leaf = _process_is_leaf(is_leaf)
125123

126124
flat = _tree_yield(tree, is_leaf=is_leaf, registry=registry)
127-
dummy_flat = itertools.repeat("*")
128-
treedef = tree_unflatten(tree, dummy_flat, is_leaf=is_leaf, registry=registry)
129-
return flat, treedef
125+
return flat, tree
130126

131127

132128
def tree_just_yield(tree, is_leaf=None, registry=None):

tests/test_tree_util.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,13 @@ def example_flat():
3131

3232

3333
@pytest.fixture
34-
def example_treedef():
35-
return (["*", "*", {"a": "*", "b": "*"}], "*")
34+
def example_treedef(example_tree):
35+
return example_tree
3636

3737

3838
@pytest.fixture
39-
def extended_treedef():
40-
return (
41-
[
42-
"*",
43-
np.array(["*", "*"]),
44-
{"a": pd.Series(["*", "*"], index=["c", "d"]), "b": "*"},
45-
],
46-
"*",
47-
)
39+
def extended_treedef(example_tree):
40+
return example_tree
4841

4942

5043
@pytest.fixture
@@ -195,7 +188,7 @@ def test_flatten_df_all_columns():
195188
def test_tree_yield(example_tree, example_treedef, example_flat):
196189
generator, treedef = tree_yield(example_tree)
197190

198-
assert treedef == example_treedef
191+
assert tree_equal(treedef, example_treedef)
199192
assert inspect.isgenerator(generator)
200193
for a, b in zip(generator, example_flat):
201194
if isinstance(a, (np.ndarray, pd.Series)):

tests/test_with_jax.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import pytest
2+
from pybaum.config import IS_JAX_INSTALLED
3+
from pybaum.registry import get_registry
4+
from pybaum.tree_util import leaf_names
5+
from pybaum.tree_util import tree_equal
6+
from pybaum.tree_util import tree_flatten
7+
from pybaum.tree_util import tree_just_flatten
8+
9+
if IS_JAX_INSTALLED:
10+
import jax.numpy as jnp
11+
else:
12+
# run the tests with normal numpy instead
13+
import numpy as jnp
14+
15+
16+
@pytest.fixture
17+
def tree():
18+
return {"a": {"b": jnp.arange(4).reshape(2, 2)}, "c": jnp.ones(2)}
19+
20+
21+
@pytest.fixture
22+
def flat():
23+
return [0, 1, 2, 3, 1, 1]
24+
25+
26+
@pytest.fixture
27+
def registry():
28+
return get_registry(types=["jax.numpy.ndarray", "numpy.ndarray"])
29+
30+
31+
def test_tree_just_flatten_with_jax(tree, registry, flat):
32+
got = tree_just_flatten(tree, registry=registry)
33+
assert got == flat
34+
35+
36+
def test_tree_flatten_with_jax(tree, registry, flat):
37+
got_flat, got_treedef = tree_flatten(tree, registry=registry)
38+
assert got_flat == flat
39+
assert tree_equal(got_treedef, tree)
40+
41+
42+
def test_leaf_names_with_jax(tree, registry):
43+
got = leaf_names(tree, registry=registry)
44+
expected = ["a_b_0_0", "a_b_0_1", "a_b_1_0", "a_b_1_1", "c_0", "c_1"]
45+
assert got == expected

0 commit comments

Comments
 (0)