Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 175 additions & 71 deletions src/maxtext/checkpoint_conversion/utils/load_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def load_sharded_hf_state(path):
"""
t0 = time.time()
context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS)
context.safetensors.ignore_load_sharding = True
with context:
metadata = ocp_v1.pytree_metadata(path)
simple_abstract_state = metadata.metadata
Expand Down Expand Up @@ -196,29 +197,33 @@ def transform_hf_state_to_mt_state(hf_state, target_tree, param_map_mt_to_hf, ho
def tensor_getter(key):
return hf_state.pop(key)

flat_target = flax.traverse_util.flatten_dict(target_tree, sep=".")
flat_target = flax.traverse_util.flatten_dict(target_tree, sep=None)
flat_restored = flat_target.copy()

# Create a lookup mapping from stringified/joined path to the original tuple path
path_str_to_tuple = {".".join(map(str, path)): path for path in flat_target}

mapped_count = 0
keys_missed = []
max_logging.log("Starting fast in-memory Distributed Transformations...")

for mt_key, hf_source in param_map_mt_to_hf.items():
mt_name = mt_key.replace("params-", "").replace("-", ".")

# Determine the correct key in flat_target
# Determine the correct key in path_str_to_tuple
check_name = mt_name
if check_name not in flat_target:
if f"params.{mt_name}" in flat_target:
if check_name not in path_str_to_tuple:
if f"params.{mt_name}" in path_str_to_tuple:
check_name = f"params.{mt_name}"
elif mt_key.replace("-", ".") in flat_target:
elif mt_key.replace("-", ".") in path_str_to_tuple:
check_name = mt_key.replace("-", ".")

if check_name not in flat_target:
if check_name not in path_str_to_tuple:
keys_missed.append(mt_name)
continue

target_leaf = flat_target[check_name]
target_path = path_str_to_tuple[check_name]
target_leaf = flat_target[target_path]
hook_fn = hook_fn_map_mt.get(mt_key)

load_fn = get_hf_loading_function(
Expand All @@ -231,17 +236,17 @@ def tensor_getter(key):

# Execute transformation and assign to flat_restored
t_layer = time.time()
flat_restored[check_name] = load_fn()
flat_restored[target_path] = load_fn()

max_logging.log(f"Transformed {check_name} from {hf_source} in {time.time() - t_layer:.4f}s")
mapped_count += 1

if mapped_count == 0:
max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}")
max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}")
max_logging.log(f"Sample flat_target keys: {list(path_str_to_tuple.keys())[:5]}")

max_logging.log(f"Successfully mapped {mapped_count} parameters.")
restored_params = flax.traverse_util.unflatten_dict(flat_restored, sep=".")
restored_params = flax.traverse_util.unflatten_dict(flat_restored, sep=None)

if "params" in restored_params:
restored_params = restored_params["params"]
Expand All @@ -251,6 +256,130 @@ def tensor_getter(key):
return {"params": restored_params}


def write_gcs_latch(gcs_cache_dir):
"""Host 0 writes the GCS latch file to signal that caching is complete."""
storage_client = storage.Client()
bucket_name = gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[0]
blob_prefix = (
gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[1]
if "/" in gcs_cache_dir.replace("gs://", "")
else ""
)
latch_blob_name = os.path.join(blob_prefix, "download_complete.txt")
latch_blob = storage_client.bucket(bucket_name).blob(latch_blob_name)
latch_blob.upload_from_string("complete")
max_logging.log(f"Host 0 wrote dynamic GCS download latch file: {gcs_cache_dir}/download_complete.txt")


def wait_on_gcs_latch(gcs_cache_dir, host_id):
"""Hosts 1-255 wait for Host 0 in standard CPU sleep loop to prevent JAX collective hang timeout abort."""
storage_client = storage.Client()
bucket_name = gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[0]
blob_prefix = (
gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[1]
if "/" in gcs_cache_dir.replace("gs://", "")
else ""
)
latch_blob_name = os.path.join(blob_prefix, "download_complete.txt")
latch_blob = storage_client.bucket(bucket_name).blob(latch_blob_name)

max_logging.log(f"Host {host_id} polling GCS latch at {gcs_cache_dir}/download_complete.txt...")
t_poll_start = time.time()
last_logged_min = 0
while not latch_blob.exists():
time.sleep(10)
elapsed_min = int(time.time() - t_poll_start) // 60
if elapsed_min > last_logged_min:
last_logged_min = elapsed_min
# only log every minute to avoid spamming logs.
max_logging.log(f"Host {host_id} still waiting for Host 0 download latch...")

max_logging.log(f"Host {host_id} detected GCS download complete latch!")


def jax_devices_barrier(name="dynamic_hf_download_complete"):
"""Synchronizes all hosts/devices using standard JAX multihost sync_global_devices."""
host_id = jax.process_index()
max_logging.log(f"Host {host_id} aligning device clocks via JAX sync_global_devices...")
jax.experimental.multihost_utils.sync_global_devices(name)


def _execute_gcs_download(gcs_cache_dir, files, maxtext_config):
"""Checks GCS cache and executes parallel downloads to shared GCS."""
t_gcs_start = time.time()
max_logging.log("Dynamic HF Hub Fast DL: Host 0 is downloading to shared GCS" f" Cache: {gcs_cache_dir}")

# List existing blobs to avoid spawning processes for already cached
# files
storage_client = storage.Client()
bucket_name = gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[0]
blob_prefix = (
gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[1]
if "/" in gcs_cache_dir.replace("gs://", "")
else ""
)

existing_blobs = {blob.name for blob in storage_client.list_blobs(bucket_name, prefix=blob_prefix)}

files_to_download = []
for fpath in files:
expected_blob_name = os.path.join(blob_prefix, os.path.basename(fpath))
if expected_blob_name not in existing_blobs:
files_to_download.append(fpath)

if files_to_download:
with concurrent.futures.ProcessPoolExecutor(
max_workers=32, mp_context=multiprocessing.get_context("spawn")
) as executor:
futures = [
executor.submit(
build_gcs_cache_worker,
fpath,
gcs_cache_dir,
maxtext_config.hf_access_token,
)
for fpath in files_to_download
]

while futures:
done, futures = concurrent.futures.wait(futures, timeout=10)

# Raise any exceptions if a worker failed
for f in done:
f.result()

t_gcs_end = time.time()
max_logging.log(
f"GCS caching complete in {t_gcs_end - t_gcs_start:.2f}s."
f" Downloaded {len(files_to_download)} missing files."
)


def sync_dynamic_caching(gcs_cache_dir, files, host_id, maxtext_config):
"""Coordinate downloading files on Host 0 and polling status on Hosts 1-255."""
sync_via_jax = getattr(maxtext_config, "safetensors_sync_via_jax", False)

if sync_via_jax:
# Option 1: Baseline JAX barrier (Host 0 downloads, others wait inside JAX barrier)
if host_id == 0:
_execute_gcs_download(gcs_cache_dir, files, maxtext_config)

max_logging.log(f"Host {host_id} waiting for GCS cache at {gcs_cache_dir} to be populated by Host 0...")
jax_devices_barrier()
max_logging.log(f"Host {host_id} detected GCS cache is ready!")
else:
# Option 2: GCS latch file polling (Host 0 downloads and writes latch, others poll via CPU sleep)
if host_id == 0:
_execute_gcs_download(gcs_cache_dir, files, maxtext_config)
write_gcs_latch(gcs_cache_dir)
else:
wait_on_gcs_latch(gcs_cache_dir, host_id)

# Finally, align clocks across all devices via brief JAX barrier
jax_devices_barrier()
max_logging.log(f"Host {host_id} detected GCS cache is ready!")


def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config):
"""Main entry point to dynamically build and load safetensors into MaxText format.

Expand Down Expand Up @@ -286,61 +415,8 @@ def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_con
gcs_cache_dir = f"{maxtext_config.base_output_directory}/hf_cache/{repo_id.replace('/', '_')}"
path = gcs_cache_dir

# Only Host 0 downloads to the shared GCS cache
if host_id == 0:
max_logging.log("Dynamic HF Hub Fast DL: Host 0 is downloading to shared GCS" f" Cache: {gcs_cache_dir}")
t_gcs_start = time.time()

# List existing blobs to avoid spawning processes for already cached
# files
storage_client = storage.Client()
bucket_name = gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[0]
blob_prefix = (
gcs_cache_dir.replace("gs://", "").split("/", maxsplit=1)[1]
if "/" in gcs_cache_dir.replace("gs://", "")
else ""
)

existing_blobs = {blob.name for blob in storage_client.list_blobs(bucket_name, prefix=blob_prefix)}

files_to_download = []
for fpath in files:
expected_blob_name = os.path.join(blob_prefix, os.path.basename(fpath))
if expected_blob_name not in existing_blobs:
files_to_download.append(fpath)

if files_to_download:
with concurrent.futures.ProcessPoolExecutor(
max_workers=32, mp_context=multiprocessing.get_context("spawn")
) as executor:
futures = [
executor.submit(
build_gcs_cache_worker,
fpath,
gcs_cache_dir,
maxtext_config.hf_access_token,
)
for fpath in files_to_download
]

while futures:
done, futures = concurrent.futures.wait(futures, timeout=10)

# Raise any exceptions if a worker failed
for f in done:
f.result()

t_gcs_end = time.time()
max_logging.log(
f"GCS caching complete in {t_gcs_end - t_gcs_start:.2f}s."
f" Downloaded {len(files_to_download)} missing files."
)

# Global barrier: all hosts wait for Host 0 to finish downloading to the
# shared GCS bucket
max_logging.log(f"Host {host_id} waiting for GCS cache at {gcs_cache_dir} to be" " populated by Host 0...")
jax.experimental.multihost_utils.sync_global_devices("dynamic_hf_download_complete")
max_logging.log(f"Host {host_id} detected GCS cache is ready!")
# Only Host 0 downloads while Hosts 1-255 wait using HTTP polling on coordinator
sync_dynamic_caching(gcs_cache_dir, files, host_id, maxtext_config)

else:
raise ValueError("base_output_directory with gs:// prefix is required for " "huggingface downloads.")
Expand All @@ -349,11 +425,21 @@ def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_con
param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(maxtext_config)
max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s")

target_tree = (
abstract_unboxed_pre_state.to_pure_dict()
if isinstance(abstract_unboxed_pre_state, nnx.State)
else abstract_unboxed_pre_state.params
)
if isinstance(abstract_unboxed_pre_state, nnx.State):
# In NNX, abstract_unboxed_pre_state contains both model and optimizer states.
# We only want to target and transform the model variables.
model_state = getattr(abstract_unboxed_pre_state, "model", None)
if model_state is not None:
target_tree = (
model_state.to_pure_dict()
if hasattr(model_state, "to_pure_dict")
else model_state
)
else:
target_tree = abstract_unboxed_pre_state.to_pure_dict()
else:
# In Linen, params is only the model parameters.
target_tree = abstract_unboxed_pre_state.params

t1 = time.time()
hf_state = load_sharded_hf_state(path)
Expand All @@ -368,4 +454,22 @@ def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_con
max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s")
max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s")

return None, restored_params
if restored_params and "params" in restored_params:
restored_params = restored_params["params"]

def _filter_shape_dtype_structs(d):
if not isinstance(d, dict):
return d
res = {}
for k, v in d.items():
if isinstance(v, dict):
sub = _filter_shape_dtype_structs(v)
if sub:
res[k] = sub
elif not isinstance(v, jax.ShapeDtypeStruct):
res[k] = v
return res

restored_params = _filter_shape_dtype_structs(restored_params)

return None, restored_params
34 changes: 28 additions & 6 deletions src/maxtext/checkpoint_conversion/utils/tensor_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def _binary_chunked_stack(tensors: List[np.ndarray], axis: int) -> np.ndarray:
return np.concatenate([left, right], axis=axis)


def reshard_to_target(array, sharding):
"""Reshards a local SingleDevice array cross-host to the target sharding."""
is_np = isinstance(array, np.ndarray) or not hasattr(array, "is_fully_addressable")
is_owner = is_np or array.is_fully_addressable

# Ensure the local buffer is placed on the local TPU device on all hosts
if is_owner:
local_arr = jax.device_put(array, jax.local_devices()[0])
else:
local_arr = jax.device_put(np.zeros(array.shape, dtype=array.dtype), jax.local_devices()[0])

# Broadcast the array from the owner to all other hosts.
# This avoids the (num_hosts,) transient shape expansion, keeping HBM memory usage to flat tensor size.
global_replicated = jax.experimental.multihost_utils.broadcast_one_to_all(
local_arr, is_source=is_owner
)

# Reshard the globally replicated array to the target sharding
_reshard = jax.jit(lambda x: x, out_shardings=sharding)
return _reshard(global_replicated)


def _build_multi_axis_stacked_tensor(
hf_source_keys: List[List[str]],
tensor_getter_fn: Callable[[str], np.ndarray],
Expand Down Expand Up @@ -89,17 +111,17 @@ def _build_multi_axis_stacked_tensor(
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)

if target_sharding is not None:
processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding)
processed_hf_tensor = reshard_to_target(processed_hf_tensor, slice_sharding)
layer_tensors_for_expert.append(processed_hf_tensor)

expert_tensor = _binary_chunked_stack(layer_tensors_for_expert, axis=0)
if target_sharding is not None:
expert_tensor = jax.device_put(expert_tensor, layer_sharding)
expert_tensor = reshard_to_target(expert_tensor, layer_sharding)
all_expert_tensors.append(expert_tensor)

stacked_array = _binary_chunked_stack(all_expert_tensors, axis=0).astype(target_dtype)
if target_sharding is not None:
stacked_array = jax.device_put(stacked_array, target_sharding)
stacked_array = reshard_to_target(stacked_array, target_sharding)
return stacked_array


Expand Down Expand Up @@ -146,12 +168,12 @@ def _build_single_axis_stacked_tensor(
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)

if target_sharding is not None:
processed_hf_tensor = jax.device_put(processed_hf_tensor, slice_sharding)
processed_hf_tensor = reshard_to_target(processed_hf_tensor, slice_sharding)
tensors_to_stack.append(processed_hf_tensor)

stacked_array = _binary_chunked_stack(tensors_to_stack, axis=axis_to_stack).astype(target_dtype)
if target_sharding is not None:
stacked_array = jax.device_put(stacked_array, target_sharding)
stacked_array = reshard_to_target(stacked_array, target_sharding)
return stacked_array


Expand All @@ -162,7 +184,7 @@ def get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_ta
def _loader(getter, key, leaf, hook):
if hasattr(leaf, "sharding"):
array = apply_hook_fns(getter(key), leaf.shape, hook)
return jax.device_put(array, device=leaf.sharding)
return reshard_to_target(array, leaf.sharding)
else:
return apply_hook_fns(getter(key), leaf, hook)

Expand Down
Loading
Loading