Skip to content

Refactor drjax primitives to use jnp.broadcast_to and improve explicit sharding support.#37

Merged
copybara-service[bot] merged 1 commit into
mainfrom
cl/820412444
Apr 3, 2026
Merged

Refactor drjax primitives to use jnp.broadcast_to and improve explicit sharding support.#37
copybara-service[bot] merged 1 commit into
mainfrom
cl/820412444

Conversation

@copybara-service

Copy link
Copy Markdown

Refactor drjax primitives to use jnp.broadcast_to and improve explicit sharding support.

  • Replace usage of jnp.tile with jnp.broadcast_to in single_arg_broadcast for more standard broadcasting.
  • Remove jax.jit from single_arg_broadcast to allow it to be trace-time executed or jitted by the caller.
  • Migrate jax.tree_util.tree_map to jax.tree.map (deprecated alias cleanup).
  • Update map_to_placement when using explicit axes to remove axis_name from vmap (allowing nested shard_map) and use jax.sharding.reshard to ensure output sharding constraints.

…cit sharding support.

-   Replace usage of `jnp.tile` with `jnp.broadcast_to` in `single_arg_broadcast` for more standard broadcasting.
-   Remove `jax.jit` from `single_arg_broadcast` to allow it to be trace-time executed or jitted by the caller.
-   Migrate `jax.tree_util.tree_map` to `jax.tree.map` (deprecated alias cleanup).
-   Update `map_to_placement` when using explicit axes to remove `axis_name` from `vmap` (allowing nested `shard_map`) and use `jax.sharding.reshard` to ensure output sharding constraints.

PiperOrigin-RevId: 894053779
@copybara-service copybara-service Bot merged commit 8c609af into main Apr 3, 2026
@copybara-service copybara-service Bot deleted the cl/820412444 branch April 3, 2026 14:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant