|
| 1 | +import pytest |
| 2 | +from pybaum.config import IS_JAX_INSTALLED |
| 3 | +from pybaum.registry import get_registry |
| 4 | +from pybaum.tree_util import leaf_names |
| 5 | +from pybaum.tree_util import tree_equal |
| 6 | +from pybaum.tree_util import tree_flatten |
| 7 | +from pybaum.tree_util import tree_just_flatten |
| 8 | + |
| 9 | +if IS_JAX_INSTALLED: |
| 10 | + import jax.numpy as jnp |
| 11 | +else: |
| 12 | + # run the tests with normal numpy instead |
| 13 | + import numpy as jnp |
| 14 | + |
| 15 | + |
| 16 | +@pytest.fixture |
| 17 | +def tree(): |
| 18 | + return {"a": {"b": jnp.arange(4).reshape(2, 2)}, "c": jnp.ones(2)} |
| 19 | + |
| 20 | + |
| 21 | +@pytest.fixture |
| 22 | +def flat(): |
| 23 | + return [0, 1, 2, 3, 1, 1] |
| 24 | + |
| 25 | + |
| 26 | +@pytest.fixture |
| 27 | +def registry(): |
| 28 | + return get_registry(types=["jax.numpy.ndarray", "numpy.ndarray"]) |
| 29 | + |
| 30 | + |
| 31 | +def test_tree_just_flatten_with_jax(tree, registry, flat): |
| 32 | + got = tree_just_flatten(tree, registry=registry) |
| 33 | + assert got == flat |
| 34 | + |
| 35 | + |
| 36 | +def test_tree_flatten_with_jax(tree, registry, flat): |
| 37 | + got_flat, got_treedef = tree_flatten(tree, registry=registry) |
| 38 | + assert got_flat == flat |
| 39 | + assert tree_equal(got_treedef, tree) |
| 40 | + |
| 41 | + |
| 42 | +def test_leaf_names_with_jax(tree, registry): |
| 43 | + got = leaf_names(tree, registry=registry) |
| 44 | + expected = ["a_b_0_0", "a_b_0_1", "a_b_1_0", "a_b_1_1", "c_0", "c_1"] |
| 45 | + assert got == expected |
0 commit comments