Skip to content

Commit 128e11a

Browse files
committed
refactor: rename thread_count to files_per_rank
Renames all occurrences of thread_count (and write_thread_count at the user-facing API level) to files_per_rank / write_files_per_rank for transparency. The parameter controls how many files each rank writes to, not a concurrency thread count, which is a separate concept in the codebase. Closes #66
1 parent 6f36c9c commit 128e11a

11 files changed

Lines changed: 124 additions & 124 deletions

File tree

docs/changelog.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ _Release Notes: BEGINNING -> 194b781e75807afaba682f9eef2826464fcc120e_
7777
* (9d1316a) replication_manager: Implement sync_bulk_retrieve method.
7878
* (d0b8770) replication/transfer_service: Save received data to tmp object before finalizing.
7979
* (60dd014) replication_manager: Implement async_replicate of replication_manager.
80-
* (3cb4c38) adapter/pytorch: make writer thread_count and buffer size configurable with defaults
80+
* (3cb4c38) adapter/pytorch: make writer files_per_rank and buffer size configurable with defaults
8181
* (af05afb) replication/transfer_service: Implement async_get method.
8282
* (6c4fe99) replication/transfer_service: Implement async_put method.
8383
* (e62cd71) replication/transfer_service: Implement transfer_service initialize and shutdown.
@@ -106,7 +106,7 @@ _Release Notes: BEGINNING -> 194b781e75807afaba682f9eef2826464fcc120e_
106106
* (6d54996) adapter/nemo: implement MLFlashpointCheckpointCallback; add CheckpointContainerId.from_parent() helper
107107

108108
### :white_check_mark: Bug Fixes
109-
* (b33ddfa) wrapper_util: Expose write_thread_count and initial_write_buffer_size_bytes to user.
109+
* (b33ddfa) wrapper_util: Expose write_files_per_rank and initial_write_buffer_size_bytes to user.
110110
* (21cc19e) adapter/nemo: make CheckpointObjectManager a param to wrapper_util; passthrough kwargs in MLFlashpointAutoResume to parent
111111
* (23daae9) Fix implementation of PairwiseReplicationStrategy and add more tests.
112112
* (4620be7) core/saver: ensure writer can overwrite unfinished checkpoint data after recovery

docs/troubleshooting.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
1. Ensure you have sufficient space on your base container mount.
88
1. If you have enough memory, but are running out of buffer space during writes, you can:
99
1. Increase the default initial buffer size via `initial_write_buffer_size_bytes` in the `wrap` API you are using (the default is 16 GB).
10-
1. Increase the write thread count, so that each rank writes to multiple buffers, effectively cutting the size of each buffer proportionally, via `write_thread_count` in the `wrap` API you are using (the default is 1).
10+
1. Increase the number of files per rank, so that each rank writes to multiple buffers, effectively cutting the size of each buffer proportionally, via `write_files_per_rank` in the `wrap` API you are using (the default is 1).
1111

1212
### How can I clean up ML Flashpoint checkpoints after job completion?
1313

docs/user-guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ auto_resume = wrap_trainer_and_auto_resume_with_mlflashpoint(
9090
async_save=not args.sync_save,
9191
default_auto_resume=auto_resume, # Optional
9292
# always_save_context=False, # Optional, defaults to False
93-
# write_thread_count=1, # Optional, defaults to 1
93+
# write_files_per_rank=1, # Optional, defaults to 1
9494
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
9595
)
9696
```

src/ml_flashpoint/adapter/megatron/save_strategies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
106106
self._storage_writer = MemoryStorageWriter(
107107
checkpoint_saver=self._checkpoint_saver,
108108
mp_manager=self._storage_writer._mp_manager,
109-
thread_count=self._storage_writer._thread_count,
109+
files_per_rank=self._storage_writer._files_per_rank,
110110
)
111111
# 1c. Reset the StorageWriter for this checkpoint version.
112112
self._storage_writer.reset(checkpoint_id.data)

src/ml_flashpoint/adapter/nemo/wrapper_util.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
4242
async_save: bool,
4343
default_auto_resume: nl.AutoResume = None,
4444
always_save_context: bool = False,
45-
write_thread_count: int = 1,
45+
write_files_per_rank: int = 1,
4646
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
4747
use_optimized_save: bool = True,
4848
) -> MLFlashpointAutoResume:
@@ -59,7 +59,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
5959
async_save: Whether to enable asynchronous saving for checkpoints.
6060
default_auto_resume: The default AutoResume configuration to inherit from.
6161
always_save_context: Whether to always save the context. Defaults to `False`.
62-
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
62+
write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data. Defaults to 1.
6363
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
6464
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
6565
Returns:
@@ -87,7 +87,7 @@ def wrap_trainer_and_auto_resume_with_mlflashpoint(
8787
async_save=async_save,
8888
checkpoint_loader=ckpt_loader,
8989
always_save_context=always_save_context,
90-
write_thread_count=write_thread_count,
90+
write_files_per_rank=write_files_per_rank,
9191
initial_write_buffer_size_bytes=initial_write_buffer_size_bytes,
9292
use_optimized_save=use_optimized_save,
9393
)
@@ -108,7 +108,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
108108
async_save: bool,
109109
checkpoint_loader: DefaultMLFlashpointCheckpointLoader,
110110
always_save_context: bool = False,
111-
write_thread_count: int = 1,
111+
write_files_per_rank: int = 1,
112112
initial_write_buffer_size_bytes: int = DEFAULT_INITIAL_BUFFER_SIZE_BYTES,
113113
use_optimized_save: bool = True,
114114
):
@@ -135,7 +135,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
135135
async_save: Whether to enable asynchronous saving.
136136
checkpoint_loader: The checkpoint loader to use.
137137
always_save_context: Whether to always save the context. Defaults to `False`.
138-
write_thread_count: Optional. The number of threads to use for writing checkpoint data. Defaults to 1.
138+
write_files_per_rank: Optional. The number of files each rank writes to for checkpoint data. Defaults to 1.
139139
initial_write_buffer_size_bytes: Optional. The initial size of the buffer for writing checkpoint data
140140
in bytes. Defaults to `DEFAULT_INITIAL_BUFFER_SIZE_BYTES`.
141141
@@ -152,8 +152,8 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
152152
raise ValueError("The 'ckpt_obj_manager' argument cannot be None.")
153153
if replication_manager is None:
154154
raise ValueError("The 'replication_manager' argument cannot be None.")
155-
if write_thread_count < 1:
156-
raise ValueError(f"write_thread_count must be >= 1, got {write_thread_count}.")
155+
if write_files_per_rank < 1:
156+
raise ValueError(f"write_files_per_rank must be >= 1, got {write_files_per_rank}.")
157157
if initial_write_buffer_size_bytes <= 0:
158158
raise ValueError(f"initial_write_buffer_size_bytes must be > 0, got {initial_write_buffer_size_bytes}.")
159159

@@ -209,7 +209,7 @@ def wrap_trainer_checkpoint_io_with_mlflashpoint(
209209
use_optimized_save=use_optimized_save,
210210
),
211211
mp_manager=torch_mp.Manager(),
212-
thread_count=write_thread_count,
212+
files_per_rank=write_files_per_rank,
213213
)
214214
)
215215
load_strategy = MLFlashpointMegatronLoadStrategy(

src/ml_flashpoint/adapter/pytorch/memory_storage_writer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
self,
8989
checkpoint_saver: MLFlashpointCheckpointSaver,
9090
mp_manager: torch_mp.Manager,
91-
thread_count: int = 1,
91+
files_per_rank: int = 1,
9292
):
9393
"""Initializes the MemoryStorageWriter.
9494
@@ -97,18 +97,18 @@ def __init__(
9797
handling the actual checkpoint saving logic.
9898
mp_manager: A `torch.multiprocessing.Manager` instance for managing
9999
shared state across processes, particularly for write results and events.
100-
thread_count: Optional. The number of threads to use for writing checkpoint data.
100+
files_per_rank: Optional. The number of files each rank writes to for checkpoint data.
101101
Defaults to 1. If a value less than 1 is provided, it will be reset to 1,
102102
and a warning will be logged.
103103
"""
104104
super().__init__()
105105
self._current_checkpoint_id: CheckpointContainerId | None = None
106106
self._current_save_id: str | None = None
107107
self._checkpoint_saver: MLFlashpointCheckpointSaver = checkpoint_saver
108-
if thread_count < 1:
109-
_LOGGER.warning("thread_count must be >= 1, but was %d. Setting to 1.", thread_count)
110-
thread_count = 1
111-
self._thread_count = thread_count
108+
if files_per_rank < 1:
109+
_LOGGER.warning("files_per_rank must be >= 1, but was %d. Setting to 1.", files_per_rank)
110+
files_per_rank = 1
111+
self._files_per_rank = files_per_rank
112112
self._mp_manager = mp_manager
113113
self._write_events_per_checkpoint_id: dict[CheckpointContainerId, torch_mp.Event] = mp_manager.dict()
114114
self._write_results_per_checkpoint_id: dict[CheckpointContainerId, list[WriteResult]] = mp_manager.dict()
@@ -180,7 +180,7 @@ def prepare_write_data_buckets(
180180
self._write_events_per_checkpoint_id[checkpoint_id] = self._mp_manager.Event()
181181

182182
write_buckets = self.checkpoint_saver.prepare_write_data(
183-
checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._thread_count
183+
checkpoint_id, plan.items, planner, plan.storage_data.prefix, bucket_count=self._files_per_rank
184184
)
185185
return write_buckets
186186
# self._write_buckets_per_checkpoint_id[checkpoint_id] = write_buckets
@@ -220,7 +220,7 @@ def write_staged_data_buckets(
220220
write_results = self._checkpoint_saver.write_data(
221221
checkpoint_id,
222222
write_buckets=staged_write_buckets,
223-
thread_count=self._thread_count,
223+
files_per_rank=self._files_per_rank,
224224
replicate_after_write=replicate_after_write,
225225
)
226226
end_time = time.perf_counter()

src/ml_flashpoint/core/checkpoint_saver.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def write_data(
210210
self,
211211
checkpoint_id: CheckpointContainerId,
212212
write_buckets: list[ObjectWriteBucket],
213-
thread_count: int,
213+
files_per_rank: int,
214214
replicate_after_write: bool,
215215
) -> list[WriteResult]:
216216
"""Performs the core write logic for the given write items and checkpoint_id.
@@ -225,7 +225,7 @@ def write_data(
225225
checkpoint_id: Unique hierarchical ID representing this checkpoint container.
226226
This typically follows a directory path structure.
227227
write_buckets: A list of `ObjectWriteBucket` objects, each containing resolved data ready for writing.
228-
thread_count: The number of threads to use for writing data.
228+
files_per_rank: The number of files each rank writes to.
229229
replicate_after_write: Whether to trigger async replication of each object after it is written.
230230
231231
Returns:
@@ -357,7 +357,7 @@ def prepare_write_data(
357357
) -> list[ObjectWriteBucket]:
358358
bucket_count = max(bucket_count, 1)
359359
_LOGGER.debug(
360-
"%s prepare_write_data with prefix: '%s', thread_count: %d",
360+
"%s prepare_write_data with prefix: '%s', files_per_rank: %d",
361361
self.__class__.__name__,
362362
object_name_prefix,
363363
bucket_count,
@@ -389,7 +389,7 @@ def _clone_if_needed(tensor: torch.Tensor):
389389
# NOTE: There is support for multiple threads, to simplify modifying that setting, but we typically
390390
# only use 1 thread.
391391

392-
# Group items into buckets, one bucket per file, up to thread_count files
392+
# Group items into buckets, one bucket per file, up to files_per_rank files
393393
buckets = _split_by_size_and_type(bucket_count, write_items)
394394
for bucket in buckets:
395395
if not bucket:
@@ -423,22 +423,22 @@ def write_data(
423423
checkpoint_id: CheckpointContainerId,
424424
write_buckets: list[ObjectWriteBucket],
425425
replicate_after_write: bool,
426-
thread_count: int = 1,
426+
files_per_rank: int = 1,
427427
) -> list[WriteResult]:
428-
thread_count = max(thread_count, 1)
428+
files_per_rank = max(files_per_rank, 1)
429429
num_cpus = os.cpu_count() or 1
430430
num_ranks = max(get_accelerator_count(), 1)
431431
# Use 50% of available CPU cores for PyTorch intra-op threads and evenly distribute them across ranks.
432-
torch_thread_count = max(1, num_cpus // 2 // num_ranks // thread_count)
432+
torch_thread_count = max(1, num_cpus // 2 // num_ranks // files_per_rank)
433433
original_num_threads = torch.get_num_threads()
434434
# Explicitly set PyTorch intra-op threads to optimize for performance.
435435
# This also avoids potential runtime errors in tensor.copy_() with concurrent writers
436436
torch.set_num_threads(torch_thread_count)
437437
_LOGGER.debug(
438-
"%s starting multi-threaded write_data. thread_count: %d, original_num_threads: %d, "
438+
"%s starting multi-threaded write_data. files_per_rank: %d, original_num_threads: %d, "
439439
"num_cpus: %d, num_ranks: %d, torch_thread_count: %d",
440440
self.__class__.__name__,
441-
thread_count,
441+
files_per_rank,
442442
original_num_threads,
443443
num_cpus,
444444
num_ranks,
@@ -457,8 +457,8 @@ def write_data(
457457
threads = []
458458

459459
# Kick off additional threads to main thread, if any.
460-
_LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", thread_count - 1)
461-
for i in range(1, thread_count):
460+
_LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", files_per_rank - 1)
461+
for i in range(1, files_per_rank):
462462
thread = threading.Thread(
463463
target=self._write_to_buffer_from_queue_worker,
464464
args=(object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save),

tests/adapter/megatron/test_save_strategies.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,18 @@ def test_async_save_initialization_calls_success(
190190
mock_memory_storage_writer_cls.assert_called_once_with(
191191
checkpoint_saver=checkpoint_saver,
192192
mp_manager=storage_writer._mp_manager,
193-
thread_count=storage_writer._thread_count,
193+
files_per_rank=storage_writer._files_per_rank,
194194
)
195195
mock_new_storage_writer_instance.reset.assert_called_once_with(checkpoint_id.data)
196196
mock_new_storage_writer_instance.stage_write_data_buckets.assert_called_once_with(
197197
checkpoint_id, dummy_write_buckets, non_blocking=True
198198
)
199199

200-
@pytest.mark.parametrize("expected_thread_count", [1, 2, 3, 5])
201-
def test_async_save_reinitializes_storage_writer_with_thread_count(
202-
self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets, expected_thread_count
200+
@pytest.mark.parametrize("expected_files_per_rank", [1, 2, 3, 5])
201+
def test_async_save_reinitializes_storage_writer_with_files_per_rank(
202+
self, mocker, async_save_setup, storage_writer, checkpoint_saver, dummy_write_buckets, expected_files_per_rank
203203
):
204-
"""Tests that the StorageWriter is re-initialized with the correct thread_count."""
204+
"""Tests that the StorageWriter is re-initialized with the correct files_per_rank."""
205205
# Given
206206
mock_statedictsaver = mocker.patch("ml_flashpoint.adapter.megatron.save_strategies.statedictsaver")
207207
(
@@ -216,8 +216,8 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
216216
mocker.MagicMock(),
217217
)
218218

219-
# Set a specific thread_count on the original storage_writer
220-
storage_writer._thread_count = expected_thread_count
219+
# Set a specific files_per_rank on the original storage_writer
220+
storage_writer._files_per_rank = expected_files_per_rank
221221

222222
mock_memory_storage_writer_cls = mocker.patch(
223223
"ml_flashpoint.adapter.megatron.save_strategies.MemoryStorageWriter"
@@ -230,7 +230,7 @@ def test_async_save_reinitializes_storage_writer_with_thread_count(
230230
mock_memory_storage_writer_cls.assert_called_once_with(
231231
checkpoint_saver=checkpoint_saver,
232232
mp_manager=storage_writer._mp_manager,
233-
thread_count=expected_thread_count,
233+
files_per_rank=expected_files_per_rank,
234234
)
235235

236236
def test_initialize_checkpoint_failure(self, mocker, async_save_setup, checkpoint_saver):

0 commit comments

Comments
 (0)