Conversation
4ce5903 to
94516da
Compare
6dfe919 to
6f03b12
Compare
ba61259 to
0e76b38
Compare
| This method must only be called once as each namespace must only be registered | ||
| one time. | ||
| """ | ||
| global _are_namespaces_registered # noqa: PLW0603 |
There was a problem hiding this comment.
Registering optree.register_pytree_node with the same namespace again raises an error so I added this global boolean variable so the optree namespace registration is only executed once for each namespace. Should I directly loop over the namespaces and register the nodes directly inside the file without wrapping it in a method?
84dddb6 to
651fa5e
Compare
Codecov Report✅ All modified and coverable lines are covered by tests.
... and 2 files with indirect coverage changes 🚀 New features to boost your workflow:
|
| if not is_fast_path: | ||
| x, params_treedef = tree_flatten(params, registry=registry) | ||
| x = np.array(x, dtype=np.float64) | ||
| params_leaves, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE) |
There was a problem hiding this comment.
Changed the name from x to params_leaves to fix type hint error.
|
|
||
| # generate parameter vectors at which func has to be evaluated as numpy arrays | ||
| evaluation_points = [] | ||
| evaluation_points: list[float | np.ndarray] = [] |
There was a problem hiding this comment.
Add this type hint because I was getting mypy type error.
| if not is_fast_path: | ||
| x, params_treedef = tree_flatten(params, registry=registry) | ||
| x = np.array(x, dtype=np.float64) | ||
| params_leaves, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE) |
There was a problem hiding this comment.
Changed the name from x to params_leaves to fix type hint error.
| registry = get_registry(extended=True) | ||
| flat = tree_just_flatten(value, registry=registry) | ||
| value_leaves = tree_leaves(value, namespace=VALUE_NAMESPACE) | ||
| flat = np.asarray(value_leaves, dtype=np.float64) |
There was a problem hiding this comment.
Converting to numpy array after flattening to fix mypy type error.
| return equality_checkers | ||
|
|
||
|
|
||
| def _check_namespace(namespace: str) -> None: |
There was a problem hiding this comment.
Should I update this method to change the namespace to DEFAULT_NAMESPACE and raise a warning if an unregistered namespace is used?
Summary
tree_unflatten,tree_flatten,tree_leaves(Previouslytree_just_flatten),tree_map,tree_equal,leaf_names)OPTREE_NAMESPACEStuple andDEFAULT_NAMESPACE) in typing.pyoptree.register_pytree_node, including support for numpy arrays, pandas DataFrame/Series, and JAX arrays.get_registrymethod. So now, when extending the set of types that are considered internal nodes is desired, one would pass a namespace that exists inOPTREE_NAMESPACESas an argument to the tree methods.Extending the set of types that are considered internal nodes for numpy arrays, pandas DataFrame/Series, and JAX arrays.
Previously
Now
When custom node method parsing is not desired.
Previously
Now
Why I added wrapper methods on top of optree tree methods instead of directly using them throughout the project.
OPTREE_NAMESPACESandDEFAULT_NAMESPACE. (The following tests in the screenshot below fail if the order is not maintained for a dictionary)tree_unflattenwrapper maintains backward compatibility for methods that use pybaum's tree structure.Remaining
Test plan