Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for pipeline parallelism (PP) in the SGLang inference backend. Key changes include updating the Megatron engine to handle per-PP-rank NCCL groups for weight synchronization, implementing monkey-patches for SGLang to support pp_rank in weight update requests, and adding comprehensive documentation and tests. Feedback focuses on improving the robustness of the weight update initialization by ensuring distributed locks are released in finally blocks and optimizing port allocation in the PP==1 case.
| self.engine_lock.acquire() | ||
|
|
||
| gen_world_size = meta.gen_allocation.parallel.world_size | ||
| init_method = f"tcp://{format_host_for_url(meta.nccl_master_address)}:{meta.nccl_master_port}" | ||
| self.logger.info( | ||
| f"Initializing weight update group: type={meta.type} " | ||
| f"init_method={init_method} " | ||
| f"group={self.weight_update_group_name}" | ||
| ) | ||
| self.weight_update_group = init_custom_process_group( | ||
| backend=current_platform.communication_backend, | ||
| world_size=gen_world_size + 1, | ||
| init_method=init_method, | ||
| rank=0, | ||
| group_name=self.weight_update_group_name, | ||
| timeout=DIST_GROUP_DEFAULT_TIMEOUT, | ||
| ) | ||
| # Assign address/port *after* acquiring the lock so that two | ||
| # PP-head ranks on the same node cannot race on port selection | ||
| meta.nccl_master_address = self.weight_update_master_addr = gethostip() | ||
| meta.nccl_master_port = self.weight_update_master_port = find_free_ports(1)[0] | ||
| meta.nccl_group_name = self.weight_update_group_name | ||
|
|
||
| fut.result() | ||
| fut = self.rollout_engine.init_weights_update_group(meta) | ||
|
|
||
| per_pp_world_size = ( | ||
| meta.gen_allocation.parallel.world_size // gen_pp_size | ||
| ) | ||
| init_method = ( | ||
| f"tcp://{format_host_for_url(meta.nccl_master_address)}" | ||
| f":{meta.nccl_master_port}" | ||
| ) | ||
| self.logger.info( | ||
| f"Initializing per-PP-rank weight update group: " | ||
| f"type={meta.type} init_method={init_method} " | ||
| f"group={self.weight_update_group_name} " | ||
| f"per_pp_world_size={per_pp_world_size}" | ||
| ) | ||
| self.weight_update_group = init_custom_process_group( | ||
| backend=current_platform.communication_backend, | ||
| world_size=per_pp_world_size + 1, | ||
| init_method=init_method, | ||
| rank=0, | ||
| group_name=self.weight_update_group_name, | ||
| timeout=DIST_GROUP_DEFAULT_TIMEOUT, | ||
| ) | ||
|
|
||
| fut.result() | ||
|
|
||
| self.engine_lock.release() |
There was a problem hiding this comment.
The DistributedLock is acquired but not released in a finally block. If an exception occurs during group initialization (e.g., network timeout or port allocation failure), the lock will be leaked, potentially hanging the entire experiment. Please wrap the logic in a try...finally block or use a context manager if supported.
| self.engine_lock.acquire() | ||
|
|
||
| fut = self.rollout_engine.init_weights_update_group(meta) | ||
|
|
||
| gen_world_size = meta.gen_allocation.parallel.world_size | ||
| init_method = ( | ||
| f"tcp://{format_host_for_url(meta.nccl_master_address)}" | ||
| f":{meta.nccl_master_port}" | ||
| ) | ||
| self.logger.info( | ||
| f"Initializing weight update group: type={meta.type} " | ||
| f"init_method={init_method} " | ||
| f"group={self.weight_update_group_name}" | ||
| ) | ||
| self.weight_update_group = init_custom_process_group( | ||
| backend=current_platform.communication_backend, | ||
| world_size=gen_world_size + 1, | ||
| init_method=init_method, | ||
| rank=0, | ||
| group_name=self.weight_update_group_name, | ||
| timeout=DIST_GROUP_DEFAULT_TIMEOUT, | ||
| ) | ||
|
|
||
| fut.result() | ||
|
|
||
| self.engine_lock.release() | ||
| self.engine_lock.release() |
| meta.nccl_master_address = self.weight_update_master_addr = gethostip() | ||
| meta.nccl_master_port = self.weight_update_master_port = find_free_ports(1)[0] | ||
| meta.nccl_group_name = self.weight_update_group_name |
There was a problem hiding this comment.
In the PP==1 case, nccl_master_address and nccl_master_port are assigned by all ranks, causing every rank to call find_free_ports(1). This is wasteful as only the head rank actually uses these values to initialize the group. These assignments should be moved inside the if self.is_pipeline_parallel_head(): block, similar to the PP>1 implementation.
Description
Related Issue
Fixes #(issue)
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!