From 6623aa0d2d68a9ef3124b44f6c1a62e2ac211f74 Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Mon, 15 Jun 2026 12:34:40 -0700 Subject: [PATCH] Increment the version to 0.2.0. This release includes the following accumulated changes since v0.1.4: - **JAX Explicit Sharding Support**: Added compatibility with the explicit sharding model. - Primitives now propagate sharding annotations during abstract evaluation. - Replaced internal use of `jax.lax.with_sharding_constraints` with `jax.sharding.reshard`. - Added strict placement axis check for explicit sharding meshes, raising a ValueError on mismatch. - **API and Dependency Improvements**: - Replaced `jnp.tile` with `jnp.broadcast_to` in `single_arg_broadcast`. - Removed `jax.jit` from `single_arg_broadcast` to allow callers to control compilation. - Introduced support for JAX abstract meshes (`use_mesh(...)` pattern). - Migrated from deprecated `jax.tree_util.tree_map` to `jax.tree.map`. - Removed JAX version constraints from `pyproject.toml`. PiperOrigin-RevId: 932608109 --- drjax/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/drjax/__init__.py b/drjax/__init__.py index 566ebbd..7aa796d 100644 --- a/drjax/__init__.py +++ b/drjax/__init__.py @@ -19,7 +19,7 @@ from drjax._src import api as _api -__version__ = '0.1.4' +__version__ = '0.2.0' # Import the public API. broadcast = _api.broadcast