1- from collections import namedtuple
1+ from pybaum .config import IS_JAX_INSTALLED
2+ from pybaum .config import IS_NUMPY_INSTALLED
3+
4+ if IS_JAX_INSTALLED :
5+ import jax .numpy as jnp
6+
7+ if IS_NUMPY_INSTALLED :
8+ import numpy as np
29
310
411def get_type (obj ):
5- """namdetuple aware type check.
12+ """Get type of candidate objects in a pytree.
13+
14+ This function allows us to reliably identify namedtuples, NamedTuples and jax arrays
15+ for which standard ``type`` function does not work.
16+
17+ Args:
18+ obj: The object to be checked
19+
20+ Returns:
21+ type or str: The type of the object or a string with the type name.
22+
23+ """
24+ if _is_namedtuple (obj ):
25+ out = "namedtuple"
26+ elif _is_jax_array (obj ):
27+ out = "jax.numpy.ndarray"
28+ else :
29+ out = type (obj )
30+ return out
31+
32+
33+ def _is_namedtuple (obj ):
34+ """Check if an object is a namedtuple.
635
736 As in JAX we treat collections.namedtuple and typing.NamedTuple both as
837 namedtuple but the exact type is preserved in the unflatten function.
@@ -24,8 +53,41 @@ def get_type(obj):
2453 bool
2554
2655 """
27- if isinstance (obj , tuple ) and hasattr (obj , "_fields" ) and hasattr (obj , "_replace" ):
28- out = namedtuple
56+ out = (
57+ isinstance (obj , tuple ) and hasattr (obj , "_fields" ) and hasattr (obj , "_replace" )
58+ )
59+ return out
60+
61+
62+ def _is_jax_array (obj ):
63+ """Check if an object is a jax array.
64+
65+ The exact type of jax arrays has changed over time and is an implementation detail.
66+
67+ Instead we rely on isinstance checks which will likely be more stable in the future.
68+ However, the behavior of isinstance for jax arrays has also changed over time. For
69+ jax versions before 0.2.21, standard numpy arrays were instances of jax arrays,
70+ now they are not.
71+
72+ Resources:
73+ ----------
74+
75+ - https://github.com/google/jax/issues/2115
76+ - https://github.com/google/jax/issues/2014
77+ - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0221-sept-23-2021
78+ - https://github.com/google/jax/blob/main/CHANGELOG.md#jax-0318-sep-26-2022
79+
80+ Args:
81+ obj: The object to be checked
82+
83+ Returns:
84+ bool
85+
86+ """
87+ if not IS_JAX_INSTALLED :
88+ out = False
89+ elif IS_NUMPY_INSTALLED :
90+ out = isinstance (obj , jnp .ndarray ) and not isinstance (obj , np .ndarray )
2991 else :
30- out = type (obj )
92+ out = isinstance (obj , jnp . ndarray )
3193 return out
0 commit comments