-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_parallel_seed_initialization.py
More file actions
159 lines (141 loc) · 7.26 KB
/
test_parallel_seed_initialization.py
File metadata and controls
159 lines (141 loc) · 7.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import logging
import multiprocessing as py_mp
import os
import traceback
from pathlib import Path
from typing import Any
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import yaml
from pydantic import BaseModel
from torch.distributed._tensor.placement_types import Replicate
from modalities.__main__ import Main
from modalities.config.config import ProcessGroupBackendType
from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticFSDP2ModuleType
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_parallel_rank
from tests.end2end_tests.custom_components import MultiProcessingCudaEnv
from tests.utility import monitor_child_processes
working_dir = Path(os.path.dirname(__file__))
tmp_folder = working_dir / "../tmp/fsdp2_warmstart_pp_tp"
working_dir = working_dir / "configs"
@pytest.mark.skipif(
torch.cuda.device_count() < 8,
reason="This e2e test requires 8 GPUs.",
)
class TestParallelSeedInitialization:
WORLD_SIZE = 8
RDVZ_PORT = 24574
def test_parameters_follow_parallelism(self, tmp_path: Path):
manager = py_mp.Manager()
error_queue = manager.Queue()
proc_ctx = mp.spawn(
self._seed_distribution_impl_wrapper,
args=(self.WORLD_SIZE, tmp_path, error_queue),
nprocs=self.WORLD_SIZE,
join=False,
)
monitor_child_processes(manager, error_queue, proc_ctx)
def _seed_distribution_impl_wrapper(self, process_id: int, world_size: int, tmp_path: Path, error_queue: Any):
with MultiProcessingCudaEnv(
process_group_backend=ProcessGroupBackendType.nccl,
global_rank=process_id,
local_rank=process_id,
world_size=world_size,
rdvz_port=TestParallelSeedInitialization.RDVZ_PORT,
):
try:
self._seed_distribution_impl(world_size=world_size, tmp_path=tmp_path)
except Exception as exc:
tb = traceback.format_exc()
logging.error(f"Process {process_id} (seed distribution test) encountered an error:\n{exc}")
logging.error(tb)
try:
error_queue.put((process_id, tb))
except Exception:
logging.error("Failed to put exception info into error queue (seed distribution test).")
os._exit(1)
def _seed_distribution_impl(self, world_size: int, tmp_path: Path):
# initialize components
class ComponentsInstantiationModel(BaseModel):
fsdp_model: PydanticFSDP2ModuleType
device_mesh: PydanticDeviceMeshIFType
config_file_path = self._get_tmp_sharding_config_path(dp_degree=2, tp_degree=2, pp_degree=2, tmp_path=tmp_path)
main_obj = Main(config_file_path)
components = main_obj.build_components(components_model_type=ComponentsInstantiationModel)
model = components.fsdp_model
device_mesh = components.device_mesh
# for each pp stage get first transformer block's MLP weight parameter shards and full tensor
block_key = next(iter(model.transformer.h.keys()))
block = model.transformer.h[block_key]
placements = [Replicate()] * len(block.mlp.W.weight.device_mesh.mesh.shape)
full_weight = block.mlp.W.weight.redistribute(placements=placements).to_local().cpu()
payload = {
"tensor_full": full_weight,
"tensor_shard": block.mlp.W.weight.to_local().cpu(),
"tp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.TP),
"pp_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.PP),
"dp_shard_rank": get_parallel_rank(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.DP_SHARD),
"block_key": block_key,
}
gather_list: list[dict[str, Any]] | None = [None] * world_size if dist.get_rank() == 0 else None
dist.gather_object(payload, gather_list, dst=0)
if dist.get_rank() == 0:
assert gather_list is not None
TestParallelSeedInitialization._assert_parameter_distribution(gather_list)
dist.barrier()
@staticmethod
def _assert_parameter_distribution(records: list[dict[str, Any]]):
combos: dict[tuple[int, int], list[dict[str, Any]]] = {}
for record in records:
key = (record["pp_rank"], record["tp_rank"])
combos.setdefault(key, []).append(record)
expected_combo_count = 4
assert (
len(combos) == expected_combo_count
), f"Expected {expected_combo_count} PP/TP combinations, got {len(combos)}"
combo_tensors: dict[tuple[int, int], torch.Tensor] = {}
for (pp_rank, tp_rank), entries in combos.items():
# check that full tensors are the same across data parallel processes
reference = entries[0]["tensor_full"]
seen_dp_ranks: set[int] = set()
for entry in entries:
dp_rank = entry["dp_shard_rank"]
assert dp_rank not in seen_dp_ranks, f"Duplicate DP rank {dp_rank} for combo PP={pp_rank}, TP={tp_rank}"
seen_dp_ranks.add(dp_rank)
assert torch.equal(reference, entry["tensor_full"]), (
"Tensors within the same TP/PP combo must be identical across DP ranks; "
f"mismatch at DP rank {dp_rank} for (PP={pp_rank}, TP={tp_rank})"
)
# concatenate all shards for this pp/tp combo
shards = sorted(entries, key=lambda e: e["dp_shard_rank"])
combo_tensors[(pp_rank, tp_rank)] = torch.cat(
[e["tensor_shard"] for e in shards],
dim=0,
)
# check that tensor shards differ across different pp/tp combos
combo_items = list(combo_tensors.items())
for idx, ((pp_rank, tp_rank), base_tensor) in enumerate(combo_items):
for other_key, other_tensor in combo_items[idx + 1 :]:
tensors_equal = torch.equal(base_tensor, other_tensor)
assert not tensors_equal, (
"Distinct TP/PP combinations should initialize with different weights; "
f"found match between (PP={pp_rank}, TP={tp_rank}) and (PP={other_key[0]}, TP={other_key[1]})"
)
def _get_tmp_sharding_config_path(self, dp_degree: int, tp_degree: int, pp_degree: int, tmp_path: Path) -> Path:
temp_file_path = tmp_path / "pp_tp_sharding_config.yaml"
working_dir = Path(os.path.dirname(__file__))
config_file_path = (
working_dir / "pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml"
)
with open(config_file_path, "r") as file:
config_string = file.read()
config_dict = yaml.safe_load(config_string)
config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = dp_degree
config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree
config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree
# save to temporary file
with open(temp_file_path, "w") as file:
yaml.dump(config_dict, file)
return temp_file_path