From 89f2f09b042a32bcdda99de17b54705b7c525a9c Mon Sep 17 00:00:00 2001 From: Zhikaiiii <1658973216@qq.com> Date: Mon, 18 May 2026 23:21:26 +0800 Subject: [PATCH] [fix] Fix loss missing from logs when context parallelism is enabled The reporting_loss all-reduce only used get_data_parallel_group() which excludes CP ranks. When a CP rank's sequence chunk has no valid tokens (all labels=-100, common in SFT with long prompts), loss_mask.sum()=0 causing the loss key to be skipped entirely in _aggregated_metrics. Fix by using get_data_parallel_group(with_context_parallel=True) so loss and token counts are aggregated across both DP and CP dimensions. --- swift/megatron/trainers/trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 01106467a9..4cc46900d1 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -66,7 +66,7 @@ def loss_func(self, # Reduce loss for logging. reporting_loss = loss.detach().clone() - torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group(with_context_parallel=True)) lm_loss = loss[0] lm_loss = lm_loss.clone() @@ -96,12 +96,13 @@ def _compute_channel_loss(self, losses, loss_mask, channels, packed_seq_params=N metrics[f'loss_{channel}'][1] += c_loss.shape[0] # Synchronize keys to avoid getting stuck. - all_keys = [None] * mpu.get_data_parallel_world_size() - dist.all_gather_object(all_keys, list(metrics.keys()), group=mpu.get_data_parallel_group()) + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + all_keys = [None] * torch.distributed.get_world_size(group=dp_cp_group) + dist.all_gather_object(all_keys, list(metrics.keys()), group=dp_cp_group) new_metrics = {} for key in sorted(set().union(*all_keys)): new_metrics[key] = metrics[key] - new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM) + new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM, group=dp_cp_group) return new_metrics def forward_step(self, data_iterator, model):