@@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
9898 dist .destroy_process_group ()
9999
100100
101+ def _context_parallel_backward_worker (
102+ rank , world_size , master_port , model_class , init_dict , cp_dict , inputs_dict , return_dict
103+ ):
104+ """Worker function for context parallel backward pass testing."""
105+ try :
106+ # Set up distributed environment
107+ os .environ ["MASTER_ADDR" ] = "localhost"
108+ os .environ ["MASTER_PORT" ] = str (master_port )
109+ os .environ ["RANK" ] = str (rank )
110+ os .environ ["WORLD_SIZE" ] = str (world_size )
111+
112+ # Get device configuration
113+ device_config = DEVICE_CONFIG .get (torch_device , DEVICE_CONFIG ["cuda" ])
114+ backend = device_config ["backend" ]
115+ device_module = device_config ["module" ]
116+
117+ # Initialize process group
118+ dist .init_process_group (backend = backend , rank = rank , world_size = world_size )
119+
120+ # Set device for this process
121+ device_module .set_device (rank )
122+ device = torch .device (f"{ torch_device } :{ rank } " )
123+
124+ # Create model in training mode
125+ model = model_class (** init_dict )
126+ model .to (device )
127+ model .train ()
128+
129+ # Move inputs to device
130+ inputs_on_device = {k : v .to (device ) if isinstance (v , torch .Tensor ) else v for k , v in inputs_dict .items ()}
131+
132+ # Enable context parallelism
133+ cp_config = ContextParallelConfig (** cp_dict )
134+ model .enable_parallelism (config = cp_config )
135+
136+ # Run forward and backward pass
137+ output = model (** inputs_on_device , return_dict = False )[0 ]
138+ loss = output .sum ()
139+ loss .backward ()
140+
141+ # Check that backward actually produced at least one valid gradient
142+ grads = [p .grad for p in model .parameters () if p .requires_grad and p .grad is not None ]
143+ has_valid_grads = len (grads ) > 0 and all (torch .isfinite (g ).all () for g in grads )
144+
145+ # Only rank 0 reports results
146+ if rank == 0 :
147+ return_dict ["status" ] = "success"
148+ return_dict ["has_valid_grads" ] = bool (has_valid_grads )
149+
150+ except Exception as e :
151+ if rank == 0 :
152+ return_dict ["status" ] = "error"
153+ return_dict ["error" ] = str (e )
154+ finally :
155+ if dist .is_initialized ():
156+ dist .destroy_process_group ()
157+
158+
101159def _custom_mesh_worker (
102160 rank ,
103161 world_size ,
@@ -204,6 +262,51 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
204262 def test_context_parallel_batch_inputs (self , cp_type ):
205263 self .test_context_parallel_inference (cp_type , batch_size = 2 )
206264
265+ @pytest .mark .parametrize ("cp_type" , ["ulysses_degree" , "ring_degree" ], ids = ["ulysses" , "ring" ])
266+ def test_context_parallel_backward (self , cp_type , batch_size : int = 1 ):
267+ if not torch .distributed .is_available ():
268+ pytest .skip ("torch.distributed is not available." )
269+
270+ if not hasattr (self .model_class , "_cp_plan" ) or self .model_class ._cp_plan is None :
271+ pytest .skip ("Model does not have a _cp_plan defined for context parallel inference." )
272+
273+ if cp_type == "ring_degree" :
274+ active_backend , _ = _AttentionBackendRegistry .get_active_backend ()
275+ if active_backend == AttentionBackendName .NATIVE :
276+ pytest .skip ("Ring attention is not supported with the native attention backend." )
277+
278+ world_size = 2
279+ init_dict = self .get_init_dict ()
280+ inputs_dict = self .get_dummy_inputs (batch_size = batch_size )
281+
282+ # Move all tensors to CPU for multiprocessing
283+ inputs_dict = {k : v .cpu () if isinstance (v , torch .Tensor ) else v for k , v in inputs_dict .items ()}
284+ cp_dict = {cp_type : world_size }
285+
286+ # Find a free port for distributed communication
287+ master_port = _find_free_port ()
288+
289+ # Use multiprocessing manager for cross-process communication
290+ manager = mp .Manager ()
291+ return_dict = manager .dict ()
292+
293+ # Spawn worker processes
294+ mp .spawn (
295+ _context_parallel_backward_worker ,
296+ args = (world_size , master_port , self .model_class , init_dict , cp_dict , inputs_dict , return_dict ),
297+ nprocs = world_size ,
298+ join = True ,
299+ )
300+
301+ assert return_dict .get ("status" ) == "success" , (
302+ f"Context parallel backward pass failed: { return_dict .get ('error' , 'Unknown error' )} "
303+ )
304+ assert return_dict .get ("has_valid_grads" ), "Context parallel backward pass did not produce valid gradients."
305+
306+ @pytest .mark .parametrize ("cp_type" , ["ulysses_degree" , "ring_degree" ], ids = ["ulysses" , "ring" ])
307+ def test_context_parallel_backward_batch_inputs (self , cp_type ):
308+ self .test_context_parallel_backward (cp_type , batch_size = 2 )
309+
207310 @pytest .mark .parametrize (
208311 "cp_type,mesh_shape,mesh_dim_names" ,
209312 [
0 commit comments