Skip to content

Migrate pybaum to optree#679

Open
abelaba wants to merge 33 commits intooptimagic-dev:mainfrom
abelaba:migrate-pybaum-to-optree
Open

Migrate pybaum to optree#679
abelaba wants to merge 33 commits intooptimagic-dev:mainfrom
abelaba:migrate-pybaum-to-optree

Conversation

@abelaba
Copy link
Copy Markdown
Collaborator

@abelaba abelaba commented Apr 1, 2026

Summary

  • Added wrapper methods that use optree for all pytree operations used in optimagic (tree_unflatten, tree_flatten, tree_leaves(Previously tree_just_flatten), tree_map, tree_equal, leaf_names)
  • Removed pybaum as a dependency and added optree>=0.19.0.
  • Added namespace constants (OPTREE_NAMESPACES tuple and DEFAULT_NAMESPACE) in typing.py
  • tree_registry.py: Rewritten to use optree.register_pytree_node, including support for numpy arrays, pandas DataFrame/Series, and JAX arrays.
  • Removed the get_registry method. So now, when extending the set of types that are considered internal nodes is desired, one would pass a namespace that exists in OPTREE_NAMESPACES as an argument to the tree methods.

Extending the set of types that are considered internal nodes for numpy arrays, pandas DataFrame/Series, and JAX arrays.

Previously

registry = get_registry(extended=True) # data_col="value" is a default parameter inside this method
tree_flatten(tree, registry=registry)

Now

tree_flatten(tree, namespace=<namespace>) # <namespace> is a string in OPTREE_NAMESPACES

When custom node method parsing is not desired.

Previously

tree_flatten(tree)

Now

tree_flatten(tree)

Why I added wrapper methods on top of optree tree methods instead of directly using them throughout the project.

  1. Flattening: optree does not preserve insertion order of dictionaries by default (documentation). The wrapper ensures insertion order for dictionaries is maintained for the registered namespaces in OPTREE_NAMESPACES and DEFAULT_NAMESPACE. (The following tests in the screenshot below fail if the order is not maintained for a dictionary)
Screenshot 2026-04-02 at 18 26 42
  1. Unflattening: pybaum and optree have different tree structures. The tree_unflatten wrapper maintains backward compatibility for methods that use pybaum's tree structure.

Remaining

  • Flattening and unflattening for torch tensors.

Test plan

  • All existing fast tests pass with the new optree backend.
  • Added additional tests for verifying the outputs of the tree methods.
    • Check tree method outputs for default namespace and registered namespaces.
    • Verify warning is raised for an unregistered namespace.
    • Tree method outputs when tree is None.
    • Verify dictionary insertion ordering is respected for all registered namespaces and the default namespace.

@abelaba abelaba force-pushed the migrate-pybaum-to-optree branch from 4ce5903 to 94516da Compare April 2, 2026 12:30
@abelaba abelaba force-pushed the migrate-pybaum-to-optree branch from 6dfe919 to 6f03b12 Compare April 2, 2026 14:52
@abelaba abelaba changed the title Migrate pybaum to optree feat: migrate pybaum to optree Apr 2, 2026
@abelaba abelaba marked this pull request as ready for review April 2, 2026 17:33
@abelaba abelaba force-pushed the migrate-pybaum-to-optree branch from ba61259 to 0e76b38 Compare April 7, 2026 16:10
This method must only be called once as each namespace must only be registered
one time.
"""
global _are_namespaces_registered # noqa: PLW0603
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registering optree.register_pytree_node with the same namespace again raises an error so I added this global boolean variable so the optree namespace registration is only executed once for each namespace. Should I directly loop over the namespaces and register the nodes directly inside the file without wrapping it in a method?

@abelaba abelaba changed the title feat: migrate pybaum to optree Migrate pybaum to optree Apr 8, 2026
@abelaba abelaba force-pushed the migrate-pybaum-to-optree branch from 84dddb6 to 651fa5e Compare April 9, 2026 14:33
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 9, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/estimagic/bootstrap.py 99.00% <100.00%> (-0.06%) ⬇️
src/estimagic/estimate_msm.py 89.55% <100.00%> (-0.13%) ⬇️
src/estimagic/msm_weighting.py 98.24% <100.00%> (-0.04%) ⬇️
src/estimagic/shared_covs.py 99.35% <100.00%> (-0.02%) ⬇️
src/optimagic/benchmarking/run_benchmark.py 92.22% <100.00%> (-0.09%) ⬇️
src/optimagic/differentiation/derivatives.py 96.14% <100.00%> (-0.04%) ⬇️
src/optimagic/examples/criterion_functions.py 100.00% <100.00%> (ø)
src/optimagic/optimization/fun_value.py 97.81% <100.00%> (-0.02%) ⬇️
src/optimagic/optimization/history.py 93.54% <100.00%> (-0.09%) ⬇️
src/optimagic/parameters/block_trees.py 85.54% <100.00%> (-0.25%) ⬇️
... and 9 more

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

if not is_fast_path:
x, params_treedef = tree_flatten(params, registry=registry)
x = np.array(x, dtype=np.float64)
params_leaves, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the name from x to params_leaves to fix type hint error.


# generate parameter vectors at which func has to be evaluated as numpy arrays
evaluation_points = []
evaluation_points: list[float | np.ndarray] = []
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this type hint because I was getting mypy type error.

if not is_fast_path:
x, params_treedef = tree_flatten(params, registry=registry)
x = np.array(x, dtype=np.float64)
params_leaves, params_treedef = tree_flatten(params, namespace=VALUE_NAMESPACE)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the name from x to params_leaves to fix type hint error.

registry = get_registry(extended=True)
flat = tree_just_flatten(value, registry=registry)
value_leaves = tree_leaves(value, namespace=VALUE_NAMESPACE)
flat = np.asarray(value_leaves, dtype=np.float64)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converting to numpy array after flattening to fix mypy type error.

@abelaba abelaba requested a review from janosg April 10, 2026 10:06
return equality_checkers


def _check_namespace(namespace: str) -> None:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I update this method to change the namespace to DEFAULT_NAMESPACE and raise a warning if an unregistered namespace is used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant