Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions export/orbax/export/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ class ExportModelType(enum.Enum):
# name kwarg was not provided in the jax2obm_kwargs.
DEFAULT_WEIGHTS_NAME = 'weights'

# Jax2obm_kwargs key that triggers loading all checkpoint weights
# for the exported functions. By default, only weights used by the function
# are loaded. If this key is set to True, all weights in the checkpoint
# will be loaded. This may result in argument mismatches if the checkpoint
# contains more weights than required by the function.
LOAD_ALL_CHECKPOINT_WEIGHTS = 'load_all_checkpoint_weights'

DEFAULT_PRE_PROCESSOR_NAME = 'pre_processor'

DEFAULT_POST_PROCESSOR_NAME = 'post_processor'
Expand Down
8 changes: 0 additions & 8 deletions export/orbax/export/modules/obm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,6 @@ def __init__(
self._checkpoint_path: str = None
# Set the Orbax checkpoint path if provided in the jax2obm_kwargs.
self._maybe_set_orbax_checkpoint_path()
self._load_all_checkpoint_weights = _get_shared_value(
self._jax2obm_options,
self._apply_fn_keys,
constants.LOAD_ALL_CHECKPOINT_WEIGHTS,
)

def _jax2obm_kwargs_to_options(
self, jax2obm_kwargs: Mapping[str, Any]
Expand All @@ -227,9 +222,6 @@ def _jax2obm_kwargs_to_options(
polymorphic_constraints=jax2obm_kwargs.get(
constants.POLYMORPHIC_CONSTRAINTS
),
load_all_checkpoint_weights=jax2obm_kwargs.get(
constants.LOAD_ALL_CHECKPOINT_WEIGHTS, False
),
xla_flags_per_platform=jax2obm_kwargs.get(
constants.XLA_FLAGS_PER_PLATFORM
),
Expand Down
7 changes: 0 additions & 7 deletions export/orbax/export/obm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,6 @@ class Jax2ObmOptions:
identify a specific set of weights that will be used by the exported JAX
function.
polymorphic_constraints: Input polymorphic constraints.
load_all_checkpoint_weights: If set to True, all weights from the checkpoint
will be loaded, including those not used by the exported function(s).
Defaults to False, which only loads necessary weights to save memory
during serving.
xla_flags_per_platform: XLA flags per platform for the model.
jax_mesh: Mesh for the model.
persist_xla_flags: Whether to persist XLA flags in the exported model. If
Expand All @@ -321,9 +317,6 @@ class Jax2ObmOptions:
polymorphic_constraints: (
Mapping[str, Sequence[str]] | Sequence[str] | None
) = None
# TODO: b/448900820 - Remove this variable, we should always load necessary
# weights only.
load_all_checkpoint_weights: bool = False
xla_flags_per_platform: Mapping[str, Sequence[str]] | None = None
jax_mesh: jax.sharding.Mesh | None = None
persist_xla_flags: bool = True
Expand Down
Loading