Skip to content

Commit 1675270

Browse files
authored
Align default containers with jax (#14)
1 parent 7108cf7 commit 1675270

7 files changed

Lines changed: 185 additions & 75 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ repos:
9696
hooks:
9797
- id: interrogate
9898
args: [-v, --fail-under=20]
99-
exclude: ^(docs|setup\.py)
99+
exclude: ^(tests|docs|setup\.py)
100100
- repo: https://github.com/codespell-project/codespell
101101
rev: v2.1.0
102102
hooks:

src/pybaum/registry.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pybaum.registry_entries import FUNC_DICT
22

33

4-
def get_registry(types=None, options=None, include_defaults=True):
4+
def get_registry(types=None, include_defaults=True):
55
"""Create a pytree registry.
66
77
Args:
@@ -11,14 +11,15 @@ def get_registry(types=None, options=None, include_defaults=True):
1111
- "tuple"
1212
- "dict"
1313
- "list"
14+
- :class:`collections.namedtuple` or :class:`typing.NamedTuple`
15+
- :obj:`None`
16+
- :class:`collections.OrderedDict`
1417
- "numpy.ndarray"
1518
- "pandas.Series"
1619
- "pandas.DataFrame"
17-
options (dict): Option dictionary where the keys are names of types and the
18-
values are keyword arguments that influence how containers are flattened
19-
and unflattened.
2020
include_defaults (bool): Whether the default pytree containers "tuple", "dict"
21-
and "list" should be included even if not specified in `types`.
21+
"list", "None", "namedtuple" and "OrderedDict" should be included even if
22+
not specified in `types`.
2223
2324
Returns:
2425
dict: A pytree registry.
@@ -27,13 +28,12 @@ def get_registry(types=None, options=None, include_defaults=True):
2728
types = [] if types is None else types
2829

2930
if include_defaults:
30-
types = list(set(types) | {"list", "tuple", "dict"})
31-
32-
options = {} if options is None else options
31+
default_types = {"list", "tuple", "dict", "None", "namedtuple", "OrderedDict"}
32+
types = list(set(types) | default_types)
3333

3434
registry = {}
3535
for typ in types:
36-
new_entry = FUNC_DICT[typ](**options.get(typ, {}))
36+
new_entry = FUNC_DICT[typ]()
3737
registry = {**registry, **new_entry}
3838

3939
return registry

src/pybaum/registry_entries.py

Lines changed: 73 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import itertools
2-
from functools import partial
2+
from collections import namedtuple
3+
from collections import OrderedDict
4+
from itertools import product
35

46
from pybaum.config import IS_NUMPY_INSTALLED
57
from pybaum.config import IS_PANDAS_INSTALLED
@@ -11,7 +13,20 @@
1113
import pandas as pd
1214

1315

16+
def _none():
17+
"""Create registry entry for NoneType."""
18+
entry = {
19+
type(None): {
20+
"flatten": lambda tree: ([], None), # noqa: U100
21+
"unflatten": lambda aux_data, children: None, # noqa: U100
22+
"names": lambda tree: [], # noqa: U100
23+
}
24+
}
25+
return entry
26+
27+
1428
def _list():
29+
"""Create registry entry for list."""
1530
entry = {
1631
list: {
1732
"flatten": lambda tree: (tree, None),
@@ -23,6 +38,7 @@ def _list():
2338

2439

2540
def _dict():
41+
"""Create registry entry for dict."""
2642
entry = {
2743
dict: {
2844
"flatten": lambda tree: (list(tree.values()), list(tree)),
@@ -34,6 +50,7 @@ def _dict():
3450

3551

3652
def _tuple():
53+
"""Create registry entry for tuple."""
3754
entry = {
3855
tuple: {
3956
"flatten": lambda tree: (list(tree), None),
@@ -44,12 +61,41 @@ def _tuple():
4461
return entry
4562

4663

47-
def _numpy_array():
48-
"""Create a pytree declaration for numpy arrays.
64+
def _namedtuple():
65+
"""Create registry entry for namedtuple and NamedTuple."""
66+
entry = {
67+
namedtuple: {
68+
"flatten": lambda tree: (list(tree), tree),
69+
"unflatten": _unflatten_namedtuple,
70+
"names": lambda tree: list(tree._fields),
71+
},
72+
}
73+
return entry
74+
75+
76+
def _unflatten_namedtuple(aux_data, leaves):
77+
replacements = dict(zip(aux_data._fields, leaves))
78+
out = aux_data._replace(**replacements)
79+
return out
4980

50-
To-Do: Add optional axis argument.
5181

52-
"""
82+
def _ordereddict():
83+
"""Create registry entry for OrderedDict."""
84+
entry = {
85+
OrderedDict: {
86+
"flatten": lambda tree: (list(tree.values()), list(tree)),
87+
"unflatten": lambda aux_data, children: OrderedDict(
88+
zip(aux_data, children)
89+
),
90+
"names": lambda tree: list(map(str, list(tree))),
91+
},
92+
}
93+
return entry
94+
95+
96+
def _numpy_array():
97+
"""Create registry entry for numpy.ndarray."""
98+
5399
if IS_NUMPY_INSTALLED:
54100
entry = {
55101
np.ndarray: {
@@ -72,6 +118,7 @@ def _array_element_names(arr):
72118

73119

74120
def _pandas_series():
121+
"""Create registry entry for pandas.Series."""
75122
if IS_PANDAS_INSTALLED:
76123
entry = {
77124
pd.Series: {
@@ -88,69 +135,49 @@ def _pandas_series():
88135
return entry
89136

90137

91-
def _pandas_dataframe(columns=None):
138+
def _pandas_dataframe():
139+
"""Create registry entry for pandas.DataFrame."""
92140
if IS_PANDAS_INSTALLED:
93141
entry = {
94142
pd.DataFrame: {
95-
"flatten": partial(_flatten_pandas_dataframe, columns=columns),
96-
"unflatten": partial(_unflatten_pandas_dataframe),
97-
"names": partial(_get_names_pandas_dataframe, columns=columns),
143+
"flatten": _flatten_pandas_dataframe,
144+
"unflatten": _unflatten_pandas_dataframe,
145+
"names": _get_names_pandas_dataframe,
98146
}
99147
}
100148
else:
101149
entry = {}
102150
return entry
103151

104152

105-
def _flatten_pandas_dataframe(df, columns):
106-
columns = _process_columns(df, columns)
107-
flat = []
108-
for col in columns:
109-
flat += df[col].tolist()
110-
111-
aux_data = (columns, df.drop(columns=columns))
153+
def _flatten_pandas_dataframe(df):
154+
flat = df.to_numpy().flatten().tolist()
155+
aux_data = {"columns": df.columns, "index": df.index, "shape": df.shape}
112156
return flat, aux_data
113157

114158

115159
def _unflatten_pandas_dataframe(aux_data, leaves):
116-
columns, empty_df = aux_data
117-
out = empty_df.copy()
118-
remaining_leaves = leaves
119-
for col in columns:
120-
out[col] = leaves[: len(empty_df)]
121-
remaining_leaves = remaining_leaves[len(empty_df) :]
160+
out = pd.DataFrame(
161+
data=np.array(leaves).reshape(aux_data["shape"]),
162+
columns=aux_data["columns"],
163+
index=aux_data["index"],
164+
)
122165
return out
123166

124167

125-
def _get_names_pandas_dataframe(df, columns):
126-
columns = _process_columns(df, columns)
127-
if len(columns) == 1:
128-
out = list(df.index.map(_index_element_to_string))
129-
else:
130-
out = []
131-
for col in df.columns:
132-
out += list(df.index.map(partial(_index_element_to_string, prefix=col)))
168+
def _get_names_pandas_dataframe(df):
169+
index_strings = list(df.index.map(_index_element_to_string))
170+
out = ["_".join([loc, col]) for loc, col in product(index_strings, df.columns)]
133171
return out
134172

135173

136-
def _process_columns(df, columns):
137-
if columns is None:
138-
columns = df.columns
139-
elif not isinstance(columns, list):
140-
columns = [columns]
141-
return columns
142-
143-
144-
def _index_element_to_string(element, prefix=None):
145-
separator = "_"
174+
def _index_element_to_string(element):
146175
if isinstance(element, (tuple, list)):
147176
as_strings = [str(entry) for entry in element]
148-
res_string = separator.join(as_strings)
177+
res_string = "_".join(as_strings)
149178
else:
150179
res_string = str(element)
151180

152-
if prefix is not None:
153-
res_string = separator.join([prefix, res_string])
154181
return res_string
155182

156183

@@ -161,4 +188,7 @@ def _index_element_to_string(element, prefix=None):
161188
"numpy.ndarray": _numpy_array,
162189
"pandas.Series": _pandas_series,
163190
"pandas.DataFrame": _pandas_dataframe,
191+
"None": _none,
192+
"namedtuple": _namedtuple,
193+
"OrderedDict": _ordereddict,
164194
}

src/pybaum/tree_util.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from pybaum.equality import EQUALITY_CHECKERS
1212
from pybaum.registry import get_registry
13+
from pybaum.typecheck import get_type
1314

1415

1516
def tree_flatten(tree, is_leaf=None, registry=None):
@@ -80,14 +81,14 @@ def tree_just_flatten(tree, is_leaf=None, registry=None):
8081

8182
def _tree_flatten(tree, is_leaf, registry):
8283
out = []
83-
tree_type = type(tree)
84+
tree_type = get_type(tree)
8485

8586
if tree_type not in registry or is_leaf(tree):
8687
out.append(tree)
8788
else:
8889
subtrees, _ = registry[tree_type]["flatten"](tree)
8990
for subtree in subtrees:
90-
if type(subtree) in registry:
91+
if get_type(subtree) in registry:
9192
out += _tree_flatten(subtree, is_leaf, registry)
9293
else:
9394
out.append(subtree)
@@ -161,14 +162,14 @@ def tree_just_yield(tree, is_leaf=None, registry=None):
161162

162163
def _tree_yield(tree, is_leaf, registry):
163164
out = []
164-
tree_type = type(tree)
165+
tree_type = get_type(tree)
165166

166167
if tree_type not in registry or is_leaf(tree):
167168
yield tree
168169
else:
169170
subtrees, _ = registry[tree_type]["flatten"](tree)
170171
for subtree in subtrees:
171-
if type(subtree) in registry:
172+
if get_type(subtree) in registry:
172173
yield from _tree_yield(subtree, is_leaf, registry)
173174
else:
174175
yield subtree
@@ -211,15 +212,15 @@ def tree_unflatten(treedef, leaves, is_leaf=None, registry=None):
211212

212213
def _tree_unflatten(treedef, leaves, is_leaf, registry):
213214
leaves = iter(leaves)
214-
tree_type = type(treedef)
215+
tree_type = get_type(treedef)
215216

216217
if tree_type not in registry or is_leaf(treedef):
217218
return next(leaves)
218219
else:
219220
items, info = registry[tree_type]["flatten"](treedef)
220221
unflattened_items = []
221222
for item in items:
222-
if type(item) in registry:
223+
if get_type(item) in registry:
223224
unflattened_items.append(
224225
_tree_unflatten(item, leaves, is_leaf=is_leaf, registry=registry)
225226
)
@@ -336,15 +337,15 @@ def leaf_names(tree, is_leaf=None, registry=None, separator="_"):
336337

337338
def _leaf_names(tree, is_leaf, registry, separator, prefix=None):
338339
out = []
339-
tree_type = type(tree)
340+
tree_type = get_type(tree)
340341

341342
if tree_type not in registry or is_leaf(tree):
342343
out.append(prefix)
343344
else:
344345
subtrees, info = registry[tree_type]["flatten"](tree)
345346
names = registry[tree_type]["names"](tree)
346347
for name, subtree in zip(names, subtrees):
347-
if type(subtree) in registry:
348+
if get_type(subtree) in registry:
348349
out += _leaf_names(
349350
subtree,
350351
is_leaf=is_leaf,
@@ -424,7 +425,7 @@ def tree_equal(tree, other, is_leaf=None, registry=None, equality_checkers=None)
424425

425426
if equal:
426427
for first, second in zip(first_flat, second_flat):
427-
check_func = equality_checkers.get(type(first), lambda a, b: a == b)
428+
check_func = equality_checkers.get(get_type(first), lambda a, b: a == b)
428429
equal = equal and check_func(first, second)
429430
if not equal:
430431
break

src/pybaum/typecheck.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from collections import namedtuple
2+
3+
4+
def get_type(obj):
5+
"""namdetuple aware type check.
6+
7+
As in JAX we treat collections.namedtuple and typing.NamedTuple both as
8+
namedtuple but the exact type is preserved in the unflatten function.
9+
10+
namedtuples are discovered by being instances of tuple and having a
11+
``_fields`` attribute as suggested by Raymond Hettinger
12+
`here <https://bugs.python.org/issue7796>`_.
13+
14+
Moreover we check for the presence of a ``_replace`` method because we need when
15+
unflattening pytrees.
16+
17+
This can produce false positives but in most cases would still result in desired
18+
behavior.
19+
20+
Args:
21+
obj: The object to be checked
22+
23+
Returns:
24+
bool
25+
26+
"""
27+
if isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_replace"):
28+
out = namedtuple
29+
else:
30+
out = type(obj)
31+
return out

0 commit comments

Comments
 (0)