Skip to content

Commit e41c1ee

Browse files
authored
Merge pull request #24 from OpenSQZ/yg/fix_merge
yg/fix merge
2 parents 47992ad + c67c5de commit e41c1ee

4 files changed

Lines changed: 210 additions & 26 deletions

File tree

megatron/Controller.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2025 Suanzhi Future Co., Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions
13+
# and limitations under the License.
14+
115
import torch
216
import torch.distributed as dist
317
import time

megatron/core/pipeline_parallel/p2p_communication.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
get_forward_backward_parallel_group,
1717
get_forward_backward_parallel_dual_rank,
1818
)
19+
from megatron.training import get_args
20+
from megatron.training.trace import get_tensor_bytes
21+
from megatron.training.global_vars import get_tracer
1922

2023
# Types
2124
Shape = Union[List[int], torch.Size]
@@ -131,28 +134,53 @@ def _batched_p2p_ops(
131134
bypass_controller: bool = False,
132135
):
133136
ops = []
137+
args = get_args()
138+
if args.trace:
139+
group_rank = set()
140+
group_rank.add(torch.distributed.get_rank())
134141
if tensor_send_prev is not None:
135142
send_prev_op = dist.P2POp(
136143
dist.isend, tensor_send_prev, prev_pipeline_rank, group
137144
)
145+
if args.trace:
146+
tracers = get_tracer()
147+
if tracers.get("group") is not None:
148+
group_rank.add(prev_pipeline_rank)
138149
ops.append(send_prev_op)
139150
if tensor_recv_prev is not None:
140151
# print('#', dist.get_rank(), prev_pipeline_rank, tensor_recv_prev)
141152
recv_prev_op = dist.P2POp(
142153
dist.irecv, tensor_recv_prev, prev_pipeline_rank, group
143154
)
155+
if args.trace:
156+
tracers = get_tracer()
157+
if tracers.get("group") is not None:
158+
group_rank.add(prev_pipeline_rank)
144159
ops.append(recv_prev_op)
145160
if tensor_send_next is not None:
146161
# print(tensor_send_next)
147162
send_next_op = dist.P2POp(
148163
dist.isend, tensor_send_next, next_pipeline_rank, group
149164
)
165+
if args.trace:
166+
tracers = get_tracer()
167+
if tracers.get("group") is not None:
168+
group_rank.add(next_pipeline_rank)
150169
ops.append(send_next_op)
151170
if tensor_recv_next is not None:
152171
recv_next_op = dist.P2POp(
153172
dist.irecv, tensor_recv_next, next_pipeline_rank, group
154173
)
174+
if args.trace:
175+
tracers = get_tracer()
176+
if tracers.get("group") is not None:
177+
group_rank.add(next_pipeline_rank)
155178
ops.append(recv_next_op)
179+
180+
if args.trace:
181+
tracers = get_tracer()
182+
if tracers.get("group") is not None:
183+
tracers.set_group(list(group_rank))
156184
if len(ops) > 0:
157185
reqs = dist.batch_isend_irecv(ops, bypass_controller)
158186
else:
@@ -315,7 +343,7 @@ def _communicate(
315343
- tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
316344
317345
"""
318-
# print(dist.get_rank(), torch.cuda.current_device())
346+
319347
tensor_recv_prev_func = None
320348
tensor_recv_next_func = None
321349

@@ -383,8 +411,6 @@ def _ring_exchange_wrapper(**kwargs):
383411
# several different decoder ranks. We therefore have to receive or send tensors
384412
# from several groups. For convenience, I wrap everything into lists.
385413
pp_group = get_pipeline_model_parallel_group(extracted=True)
386-
# pp = get_pipeline_model_parallel_group()
387-
# print(dist.get_rank(), ':', pp[1])
388414
next_rank = get_pipeline_model_parallel_next_rank()
389415
prev_rank = get_pipeline_model_parallel_prev_rank()
390416
if not isinstance(pp_group, list):
@@ -413,7 +439,7 @@ def _ring_exchange_wrapper(**kwargs):
413439
tensor_recv_next_list.append(tensor_recv_next)
414440
else:
415441
tensor_recv_next = None
416-
# print('cm req:', dist.get_rank())
442+
417443
if p2p_func is _batched_p2p_ops:
418444
p2p_reqs = p2p_func(
419445
tensor_send_prev=tensor_send_prev,
@@ -439,14 +465,32 @@ def _ring_exchange_wrapper(**kwargs):
439465
reqs.extend(p2p_reqs)
440466
else:
441467
reqs.update(p2p_reqs)
442-
468+
469+
args = get_args()
470+
if args.trace:
471+
tracers = get_tracer()
472+
if tracers.get("data") is not None:
473+
data_bytes = 0
474+
for t in [tensor_send_prev, tensor_recv_prev, tensor_send_next, tensor_recv_next]:
475+
if t is not None:
476+
data_bytes += get_tensor_bytes(t)
477+
tracers.set("data", data_bytes)
478+
479+
if args.trace:
480+
tracers = get_tracer()
481+
trace_p2p_recv = tracers.get("trace_p2p_recv")
482+
483+
# For simplicity, we now only support tracing for batched p2p communication.
484+
if trace_p2p_recv is not None:
485+
assert not config.use_ring_exchange_p2p
486+
assert config.batch_p2p_comm
487+
assert wait_on_reqs
488+
443489
if wait_on_reqs and len(reqs) > 0:
444490
for req in reqs if isinstance(reqs, list) else reqs.values():
445491
req.wait()
446492
reqs = None
447-
448-
import time
449-
start_time = time.time()
493+
450494
if (
451495
(config.batch_p2p_comm and config.batch_p2p_sync)
452496
# The lists below have a size > 1 only when ETP ≠ DTP,
@@ -457,10 +501,6 @@ def _ring_exchange_wrapper(**kwargs):
457501
# To protect against race condition when using batch_isend_irecv().
458502
# User should assert that we have a modern enough PyTorch to not need this
459503
torch.cuda.synchronize()
460-
461-
end_time = time.time()
462-
if dist.get_rank() == 7:
463-
print('communicate', end_time-start_time)
464504

465505
def _handle_tensor_list(x):
466506
"""This basically handles all the cases that we expect to see. Either the list None,
@@ -568,7 +608,6 @@ def _forward_backward_communicate(
568608

569609
p2p_func = _forward_backward_p2p_ops
570610

571-
# print('fbd req:', dist.get_rank())
572611
reqs = p2p_func(
573612
tensor_recv_prev=tensor_recv_prev,
574613
tensor_send_next=tensor_send_next,

0 commit comments

Comments
 (0)