diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 01106467a9..4cc46900d1 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -66,7 +66,7 @@ def loss_func(self, # Reduce loss for logging. reporting_loss = loss.detach().clone() - torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group(with_context_parallel=True)) lm_loss = loss[0] lm_loss = lm_loss.clone() @@ -96,12 +96,13 @@ def _compute_channel_loss(self, losses, loss_mask, channels, packed_seq_params=N metrics[f'loss_{channel}'][1] += c_loss.shape[0] # Synchronize keys to avoid getting stuck. - all_keys = [None] * mpu.get_data_parallel_world_size() - dist.all_gather_object(all_keys, list(metrics.keys()), group=mpu.get_data_parallel_group()) + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + all_keys = [None] * torch.distributed.get_world_size(group=dp_cp_group) + dist.all_gather_object(all_keys, list(metrics.keys()), group=dp_cp_group) new_metrics = {} for key in sorted(set().union(*all_keys)): new_metrics[key] = metrics[key] - new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM) + new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM, group=dp_cp_group) return new_metrics def forward_step(self, data_iterator, model):