Skip to content

[Bugfix] Fix loss missing from logs when context parallelism is enabled#9380

Merged
Jintao-Huang merged 1 commit into
modelscope:mainfrom
Zhikaiiii:fix/loss_missing_with_cp
May 19, 2026
Merged

[Bugfix] Fix loss missing from logs when context parallelism is enabled#9380
Jintao-Huang merged 1 commit into
modelscope:mainfrom
Zhikaiiii:fix/loss_missing_with_cp

Conversation

@Zhikaiiii
Copy link
Copy Markdown
Collaborator

@Zhikaiiii Zhikaiiii commented May 18, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

Summary

  • Fix loss key missing from training logs when context parallelism (CP) is enabled
  • The root cause is that reporting_loss was all-reduced only over get_data_parallel_group() (excludes CP ranks), so CP ranks whose sequence chunk has no valid tokens (all labels=-100) would report loss_mask.sum()=0, causing the loss key to be skipped in _aggregated_metrics
  • Also fix the same issue in _compute_channel_loss which used the DP-only group for key synchronization and metric reduction

Root Cause

This issue is caused by how Swift splits sequence data across CP ranks. Swift uses a zigzag pattern to distribute sequence chunks (e.g., with CP=2: GPU0 gets chunk_0+chunk_3, GPU1 gets chunk_1+chunk_2). In SFT scenarios where prompts are long and responses are short, the loss-bearing tokens (response part) tend to concentrate at one end of the sequence. As a result, certain CP ranks may receive chunks where all labels are -100 (prompt-only), making loss_mask.sum()=0 on those ranks. Since reporting_loss was only all-reduced within the DP group (not across CP ranks), the aggregated token count remains zero on these ranks, and the loss key gets skipped entirely during metric aggregation.

Changes

  • trainer.py:69: Use get_data_parallel_group(with_context_parallel=True) for reporting_loss all-reduce
  • trainer.py:98-104: Use DP+CP group for channel loss key synchronization and metric reduction

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the Megatron trainer to include context parallel groups in data parallel operations for loss reduction and metric synchronization. Specifically, it modifies loss_func and _compute_channel_loss to use with_context_parallel=True when retrieving the data parallel group. A potential RuntimeError was identified in _compute_channel_loss where _all_reduce_metric could be called with an empty dictionary, and a code suggestion was provided to add a safety check.

for key in sorted(set().union(*all_keys)):
new_metrics[key] = metrics[key]
new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM)
new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM, group=dp_cp_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _all_reduce_metric method in the base class uses torch.stack on the dictionary values. If new_metrics is empty (which can happen if no ranks in the group have any channel-specific tokens in the current micro-batch), torch.stack will raise a RuntimeError. It is safer to check if the dictionary is non-empty before attempting the reduction.

Suggested change
new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM, group=dp_cp_group)
if new_metrics:
new_metrics = self._all_reduce_metric(new_metrics, torch.distributed.ReduceOp.SUM, group=dp_cp_group)

The reporting_loss all-reduce only used get_data_parallel_group() which
excludes CP ranks. When a CP rank's sequence chunk has no valid tokens
(all labels=-100, common in SFT with long prompts), loss_mask.sum()=0
causing the loss key to be skipped entirely in _aggregated_metrics.

Fix by using get_data_parallel_group(with_context_parallel=True) so
loss and token counts are aggregated across both DP and CP dimensions.
@Zhikaiiii Zhikaiiii force-pushed the fix/loss_missing_with_cp branch from f1c9256 to 89f2f09 Compare May 18, 2026 15:31
@Zhikaiiii Zhikaiiii changed the title [Draft][fix] Fix loss missing from logs when context parallelism is enabled [Bugfix] Fix loss missing from logs when context parallelism is enabled May 19, 2026
@Zhikaiiii Zhikaiiii marked this pull request as ready for review May 19, 2026 03:00
@Jintao-Huang
Copy link
Copy Markdown
Collaborator

thanks!

@Jintao-Huang Jintao-Huang merged commit 9e95ea1 into modelscope:main May 19, 2026
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants