diff --git a/et_replay/pyproject.toml b/et_replay/pyproject.toml index 383f76ca..658e9acb 100644 --- a/et_replay/pyproject.toml +++ b/et_replay/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ ] [tool.setuptools.package-dir] -"et_replay" = "et_replay" +"et_replay" = "." [project.scripts] comm_replay = "et_replay.tools.comm_replay:main" diff --git a/train/comms/pt/commsOverlapBench.py b/train/comms/pt/commsOverlapBench.py index e6765c26..03d5a8b6 100644 --- a/train/comms/pt/commsOverlapBench.py +++ b/train/comms/pt/commsOverlapBench.py @@ -92,8 +92,8 @@ def readArgs(self, parser): "--overlap-pair-pgs", action="store_true", default=False, - help="Toggle to enable overlapping collective pair with two pgs", - ) # overlap collective pair with two pgs + help="Toggle to enable overlapping collectives with separate pgs, one per collective (main + pairs)", + ) # Check arguments that may be custmized per benchmark in a single run # does not depend on data type @@ -225,7 +225,7 @@ def runColl(self, comm_fn=None, comm_fn_pair_list=None, dcheck=False): is_blocking=False, timer=self.collectiveArgs.comm_dev_time, ): - # post another collecitve if on comms pair mode, otherwise it's noop + # post pair collective on a separate stream for overlapping evaluation self.collectiveArgs.group = self.collectiveArgs.groups[ self.collectiveArgs.pairPgId[pairIdx] ] @@ -890,7 +890,7 @@ def genMultiCommGroups( self.backendFuncs.initialize_groups(backend) elif pair and overlap_pair_pgs: - # create two communicators each including all ranks + # create 1 + len(pair_collectives_list) communicators each including all ranks num_pgs = 1 + len(pair_collectives_list) for pgId in range(0, num_pgs): if pgId > 0: