1616 get_forward_backward_parallel_group ,
1717 get_forward_backward_parallel_dual_rank ,
1818)
19+ from megatron .training import get_args
20+ from megatron .training .trace import get_tensor_bytes
21+ from megatron .training .global_vars import get_tracer
1922
2023# Types
2124Shape = Union [List [int ], torch .Size ]
@@ -131,28 +134,53 @@ def _batched_p2p_ops(
131134 bypass_controller : bool = False ,
132135):
133136 ops = []
137+ args = get_args ()
138+ if args .trace :
139+ group_rank = set ()
140+ group_rank .add (torch .distributed .get_rank ())
134141 if tensor_send_prev is not None :
135142 send_prev_op = dist .P2POp (
136143 dist .isend , tensor_send_prev , prev_pipeline_rank , group
137144 )
145+ if args .trace :
146+ tracers = get_tracer ()
147+ if tracers .get ("group" ) is not None :
148+ group_rank .add (prev_pipeline_rank )
138149 ops .append (send_prev_op )
139150 if tensor_recv_prev is not None :
140151 # print('#', dist.get_rank(), prev_pipeline_rank, tensor_recv_prev)
141152 recv_prev_op = dist .P2POp (
142153 dist .irecv , tensor_recv_prev , prev_pipeline_rank , group
143154 )
155+ if args .trace :
156+ tracers = get_tracer ()
157+ if tracers .get ("group" ) is not None :
158+ group_rank .add (prev_pipeline_rank )
144159 ops .append (recv_prev_op )
145160 if tensor_send_next is not None :
146161 # print(tensor_send_next)
147162 send_next_op = dist .P2POp (
148163 dist .isend , tensor_send_next , next_pipeline_rank , group
149164 )
165+ if args .trace :
166+ tracers = get_tracer ()
167+ if tracers .get ("group" ) is not None :
168+ group_rank .add (next_pipeline_rank )
150169 ops .append (send_next_op )
151170 if tensor_recv_next is not None :
152171 recv_next_op = dist .P2POp (
153172 dist .irecv , tensor_recv_next , next_pipeline_rank , group
154173 )
174+ if args .trace :
175+ tracers = get_tracer ()
176+ if tracers .get ("group" ) is not None :
177+ group_rank .add (next_pipeline_rank )
155178 ops .append (recv_next_op )
179+
180+ if args .trace :
181+ tracers = get_tracer ()
182+ if tracers .get ("group" ) is not None :
183+ tracers .set_group (list (group_rank ))
156184 if len (ops ) > 0 :
157185 reqs = dist .batch_isend_irecv (ops , bypass_controller )
158186 else :
@@ -315,7 +343,7 @@ def _communicate(
315343 - tensor_recv_next: torch.Tensor if recv_next is True, None otherwise.
316344
317345 """
318- # print(dist.get_rank(), torch.cuda.current_device())
346+
319347 tensor_recv_prev_func = None
320348 tensor_recv_next_func = None
321349
@@ -383,8 +411,6 @@ def _ring_exchange_wrapper(**kwargs):
383411 # several different decoder ranks. We therefore have to receive or send tensors
384412 # from several groups. For convenience, I wrap everything into lists.
385413 pp_group = get_pipeline_model_parallel_group (extracted = True )
386- # pp = get_pipeline_model_parallel_group()
387- # print(dist.get_rank(), ':', pp[1])
388414 next_rank = get_pipeline_model_parallel_next_rank ()
389415 prev_rank = get_pipeline_model_parallel_prev_rank ()
390416 if not isinstance (pp_group , list ):
@@ -413,7 +439,7 @@ def _ring_exchange_wrapper(**kwargs):
413439 tensor_recv_next_list .append (tensor_recv_next )
414440 else :
415441 tensor_recv_next = None
416- # print('cm req:', dist.get_rank())
442+
417443 if p2p_func is _batched_p2p_ops :
418444 p2p_reqs = p2p_func (
419445 tensor_send_prev = tensor_send_prev ,
@@ -439,14 +465,32 @@ def _ring_exchange_wrapper(**kwargs):
439465 reqs .extend (p2p_reqs )
440466 else :
441467 reqs .update (p2p_reqs )
442-
468+
469+ args = get_args ()
470+ if args .trace :
471+ tracers = get_tracer ()
472+ if tracers .get ("data" ) is not None :
473+ data_bytes = 0
474+ for t in [tensor_send_prev , tensor_recv_prev , tensor_send_next , tensor_recv_next ]:
475+ if t is not None :
476+ data_bytes += get_tensor_bytes (t )
477+ tracers .set ("data" , data_bytes )
478+
479+ if args .trace :
480+ tracers = get_tracer ()
481+ trace_p2p_recv = tracers .get ("trace_p2p_recv" )
482+
483+ # For simplicity, we now only support tracing for batched p2p communication.
484+ if trace_p2p_recv is not None :
485+ assert not config .use_ring_exchange_p2p
486+ assert config .batch_p2p_comm
487+ assert wait_on_reqs
488+
443489 if wait_on_reqs and len (reqs ) > 0 :
444490 for req in reqs if isinstance (reqs , list ) else reqs .values ():
445491 req .wait ()
446492 reqs = None
447-
448- import time
449- start_time = time .time ()
493+
450494 if (
451495 (config .batch_p2p_comm and config .batch_p2p_sync )
452496 # The lists below have a size > 1 only when ETP ≠ DTP,
@@ -457,10 +501,6 @@ def _ring_exchange_wrapper(**kwargs):
457501 # To protect against race condition when using batch_isend_irecv().
458502 # User should assert that we have a modern enough PyTorch to not need this
459503 torch .cuda .synchronize ()
460-
461- end_time = time .time ()
462- if dist .get_rank () == 7 :
463- print ('communicate' , end_time - start_time )
464504
465505 def _handle_tensor_list (x ):
466506 """This basically handles all the cases that we expect to see. Either the list None,
@@ -568,7 +608,6 @@ def _forward_backward_communicate(
568608
569609 p2p_func = _forward_backward_p2p_ops
570610
571- # print('fbd req:', dist.get_rank())
572611 reqs = p2p_func (
573612 tensor_recv_prev = tensor_recv_prev ,
574613 tensor_send_next = tensor_send_next ,
0 commit comments