From 1f9514ebddaa2149331fb0672f0920e3c209770a Mon Sep 17 00:00:00 2001 From: Bruce Wu Date: Tue, 5 May 2026 10:40:13 -0700 Subject: [PATCH] Fix misleading comments in commsOverlapBench overlap-pair-pgs (#222) Summary: Fix three comments in commsOverlapBench.py that incorrectly described the --overlap-pair-pgs feature as creating only "two" process groups. The code actually creates 1 + len(pair_collectives_list) PGs, supporting N concurrent collectives on separate streams. Updated comments to accurately reflect this behavior. Reviewed By: dsjohns2 Differential Revision: D103730559 --- et_replay/pyproject.toml | 2 +- train/comms/pt/commsOverlapBench.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/et_replay/pyproject.toml b/et_replay/pyproject.toml index 383f76ca..658e9acb 100644 --- a/et_replay/pyproject.toml +++ b/et_replay/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ ] [tool.setuptools.package-dir] -"et_replay" = "et_replay" +"et_replay" = "." [project.scripts] comm_replay = "et_replay.tools.comm_replay:main" diff --git a/train/comms/pt/commsOverlapBench.py b/train/comms/pt/commsOverlapBench.py index e6765c26..03d5a8b6 100644 --- a/train/comms/pt/commsOverlapBench.py +++ b/train/comms/pt/commsOverlapBench.py @@ -92,8 +92,8 @@ def readArgs(self, parser): "--overlap-pair-pgs", action="store_true", default=False, - help="Toggle to enable overlapping collective pair with two pgs", - ) # overlap collective pair with two pgs + help="Toggle to enable overlapping collectives with separate pgs, one per collective (main + pairs)", + ) # Check arguments that may be custmized per benchmark in a single run # does not depend on data type @@ -225,7 +225,7 @@ def runColl(self, comm_fn=None, comm_fn_pair_list=None, dcheck=False): is_blocking=False, timer=self.collectiveArgs.comm_dev_time, ): - # post another collecitve if on comms pair mode, otherwise it's noop + # post pair collective on a separate stream for overlapping evaluation self.collectiveArgs.group = self.collectiveArgs.groups[ self.collectiveArgs.pairPgId[pairIdx] ] @@ -890,7 +890,7 @@ def genMultiCommGroups( self.backendFuncs.initialize_groups(backend) elif pair and overlap_pair_pgs: - # create two communicators each including all ranks + # create 1 + len(pair_collectives_list) communicators each including all ranks num_pgs = 1 + len(pair_collectives_list) for pgId in range(0, num_pgs): if pgId > 0: