11import itertools
2- from functools import partial
2+ from collections import namedtuple
3+ from collections import OrderedDict
4+ from itertools import product
35
46from pybaum .config import IS_NUMPY_INSTALLED
57from pybaum .config import IS_PANDAS_INSTALLED
1113 import pandas as pd
1214
1315
16+ def _none ():
17+ """Create registry entry for NoneType."""
18+ entry = {
19+ type (None ): {
20+ "flatten" : lambda tree : ([], None ), # noqa: U100
21+ "unflatten" : lambda aux_data , children : None , # noqa: U100
22+ "names" : lambda tree : [], # noqa: U100
23+ }
24+ }
25+ return entry
26+
27+
1428def _list ():
29+ """Create registry entry for list."""
1530 entry = {
1631 list : {
1732 "flatten" : lambda tree : (tree , None ),
@@ -23,6 +38,7 @@ def _list():
2338
2439
2540def _dict ():
41+ """Create registry entry for dict."""
2642 entry = {
2743 dict : {
2844 "flatten" : lambda tree : (list (tree .values ()), list (tree )),
@@ -34,6 +50,7 @@ def _dict():
3450
3551
3652def _tuple ():
53+ """Create registry entry for tuple."""
3754 entry = {
3855 tuple : {
3956 "flatten" : lambda tree : (list (tree ), None ),
@@ -44,12 +61,41 @@ def _tuple():
4461 return entry
4562
4663
47- def _numpy_array ():
48- """Create a pytree declaration for numpy arrays.
64+ def _namedtuple ():
65+ """Create registry entry for namedtuple and NamedTuple."""
66+ entry = {
67+ namedtuple : {
68+ "flatten" : lambda tree : (list (tree ), tree ),
69+ "unflatten" : _unflatten_namedtuple ,
70+ "names" : lambda tree : list (tree ._fields ),
71+ },
72+ }
73+ return entry
74+
75+
76+ def _unflatten_namedtuple (aux_data , leaves ):
77+ replacements = dict (zip (aux_data ._fields , leaves ))
78+ out = aux_data ._replace (** replacements )
79+ return out
4980
50- To-Do: Add optional axis argument.
5181
52- """
82+ def _ordereddict ():
83+ """Create registry entry for OrderedDict."""
84+ entry = {
85+ OrderedDict : {
86+ "flatten" : lambda tree : (list (tree .values ()), list (tree )),
87+ "unflatten" : lambda aux_data , children : OrderedDict (
88+ zip (aux_data , children )
89+ ),
90+ "names" : lambda tree : list (map (str , list (tree ))),
91+ },
92+ }
93+ return entry
94+
95+
96+ def _numpy_array ():
97+ """Create registry entry for numpy.ndarray."""
98+
5399 if IS_NUMPY_INSTALLED :
54100 entry = {
55101 np .ndarray : {
@@ -72,6 +118,7 @@ def _array_element_names(arr):
72118
73119
74120def _pandas_series ():
121+ """Create registry entry for pandas.Series."""
75122 if IS_PANDAS_INSTALLED :
76123 entry = {
77124 pd .Series : {
@@ -88,69 +135,49 @@ def _pandas_series():
88135 return entry
89136
90137
91- def _pandas_dataframe (columns = None ):
138+ def _pandas_dataframe ():
139+ """Create registry entry for pandas.DataFrame."""
92140 if IS_PANDAS_INSTALLED :
93141 entry = {
94142 pd .DataFrame : {
95- "flatten" : partial ( _flatten_pandas_dataframe , columns = columns ) ,
96- "unflatten" : partial ( _unflatten_pandas_dataframe ) ,
97- "names" : partial ( _get_names_pandas_dataframe , columns = columns ) ,
143+ "flatten" : _flatten_pandas_dataframe ,
144+ "unflatten" : _unflatten_pandas_dataframe ,
145+ "names" : _get_names_pandas_dataframe ,
98146 }
99147 }
100148 else :
101149 entry = {}
102150 return entry
103151
104152
105- def _flatten_pandas_dataframe (df , columns ):
106- columns = _process_columns (df , columns )
107- flat = []
108- for col in columns :
109- flat += df [col ].tolist ()
110-
111- aux_data = (columns , df .drop (columns = columns ))
153+ def _flatten_pandas_dataframe (df ):
154+ flat = df .to_numpy ().flatten ().tolist ()
155+ aux_data = {"columns" : df .columns , "index" : df .index , "shape" : df .shape }
112156 return flat , aux_data
113157
114158
115159def _unflatten_pandas_dataframe (aux_data , leaves ):
116- columns , empty_df = aux_data
117- out = empty_df .copy ()
118- remaining_leaves = leaves
119- for col in columns :
120- out [col ] = leaves [: len (empty_df )]
121- remaining_leaves = remaining_leaves [len (empty_df ) :]
160+ out = pd .DataFrame (
161+ data = np .array (leaves ).reshape (aux_data ["shape" ]),
162+ columns = aux_data ["columns" ],
163+ index = aux_data ["index" ],
164+ )
122165 return out
123166
124167
125- def _get_names_pandas_dataframe (df , columns ):
126- columns = _process_columns (df , columns )
127- if len (columns ) == 1 :
128- out = list (df .index .map (_index_element_to_string ))
129- else :
130- out = []
131- for col in df .columns :
132- out += list (df .index .map (partial (_index_element_to_string , prefix = col )))
168+ def _get_names_pandas_dataframe (df ):
169+ index_strings = list (df .index .map (_index_element_to_string ))
170+ out = ["_" .join ([loc , col ]) for loc , col in product (index_strings , df .columns )]
133171 return out
134172
135173
136- def _process_columns (df , columns ):
137- if columns is None :
138- columns = df .columns
139- elif not isinstance (columns , list ):
140- columns = [columns ]
141- return columns
142-
143-
144- def _index_element_to_string (element , prefix = None ):
145- separator = "_"
174+ def _index_element_to_string (element ):
146175 if isinstance (element , (tuple , list )):
147176 as_strings = [str (entry ) for entry in element ]
148- res_string = separator .join (as_strings )
177+ res_string = "_" .join (as_strings )
149178 else :
150179 res_string = str (element )
151180
152- if prefix is not None :
153- res_string = separator .join ([prefix , res_string ])
154181 return res_string
155182
156183
@@ -161,4 +188,7 @@ def _index_element_to_string(element, prefix=None):
161188 "numpy.ndarray" : _numpy_array ,
162189 "pandas.Series" : _pandas_series ,
163190 "pandas.DataFrame" : _pandas_dataframe ,
191+ "None" : _none ,
192+ "namedtuple" : _namedtuple ,
193+ "OrderedDict" : _ordereddict ,
164194}
0 commit comments