TorchComms Integration for MSCCL++#771
Conversation
- python/mscclpp_torchcomm/: TorchComms integration for MSCCL++
- CMakeLists.txt: FetchContent torchcomms, links mscclpp + PyTorch
- TorchCommMSCCLPP: backend class with init/finalize lifecycle,
algorithm selection via AlgorithmCollection, GPU event-based
async work tracking
- TorchCommMSCCLPPBootstrap: rank discovery via c10d::Store
- TorchWorkMSCCLPP: GPU event pool + async completion handles
- TorchCommMSCCLPPPy: pybind11 module + dynamic loader interface
- CMakeLists.txt: add MSCCLPP_BUILD_EXT_TORCHCOMMS option (OFF default)
- Supported: allreduce (10 native algorithms), allgather (2 algorithms)
- Uses same algorithm selector as NCCL extension
- Links mscclpp shared lib (not static) to avoid dual-singleton crashes
- test_correctness.py: allreduce/allgather with --sweep mode for multi-size/dtype coverage, in-place and repeated variants - test_sizes.py: message size sweep from 1 element to 32MB - test_error_handling.py: unsupported ops, invalid reduce ops, metadata - test_training_loop.py: simulated multi-iteration training loop - test_multicomm.py: multiple communicators (known limitation) - test_user_algorithms.py: DSL algorithm registration via builder
- bench_torchcomms.py: allreduce/allgather benchmark with CUDA event timing, curated sizes per native algorithm, JSON output - bench_report.py: generates report + latency/bandwidth figures with algorithm region annotations - run_benchmarks.sh: orchestrator script
- docs/quickstart.md: build instructions, usage example, supported collectives table, environment variables, test/benchmark commands - Consistent with existing doc style (dollar prompts, MSCCLPP_BUILD var)
|
@microsoft-github-policy-service agree company="Microsoft" |
There was a problem hiding this comment.
Pull request overview
Adds a TorchComms backend module (_comms_mscclpp*.so) that adapts TorchComms collective calls onto MSCCL++’s AlgorithmCollection (native + DSL), enabling PyTorch users to select MSCCL++ as a TorchComms backend at runtime.
Changes:
- Introduces a new C++ TorchComms backend implementation (bootstrap via
c10d::Store, algorithm selection, executor + scratch, CUDA-event-based work tracking). - Adds TorchComms-focused tests/benchmarks and a quickstart section documenting build/run steps.
- Adds a build option
MSCCLPP_BUILD_EXT_TORCHCOMMSand a dedicated CMake target for building the backend module.
Reviewed changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
CMakeLists.txt |
Adds MSCCLPP_BUILD_EXT_TORCHCOMMS option and conditionally builds the TorchComms backend. |
docs/quickstart.md |
Documents how to build, use, test, and benchmark TorchComms support. |
python/mscclpp_torchcomm/CMakeLists.txt |
Fetches TorchComms sources/headers, builds _comms_mscclpp pybind module, links MSCCL++ + Torch. |
python/mscclpp_torchcomm/__init__.py |
Package stub for TorchComms backend directory. |
python/mscclpp_torchcomm/requirements_cuda12.txt |
Optional pip requirements for TorchComms backend environment. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.hpp |
Declares the TorchComms backend class and supported/unsupported collectives. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPP.cpp |
Implements init/finalize, algorithm selection wiring, and collective dispatch. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPBootstrap.hpp |
Declares bootstrap helper for rank/size + UniqueId exchange via store. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPBootstrap.cpp |
Implements UniqueId exchange and MSCCL++ communicator creation via TcpBootstrap. |
python/mscclpp_torchcomm/csrc/TorchCommMSCCLPPPy.cpp |
Exposes minimal pybind module + TorchComms dynamic loader entrypoint. |
python/mscclpp_torchcomm/csrc/TorchWorkMSCCLPP.hpp |
Declares CUDA event pool + TorchWork implementation. |
python/mscclpp_torchcomm/csrc/TorchWorkMSCCLPP.cpp |
Implements pooled CUDA events and async work completion tracking. |
test/torchcomms/test_correctness.py |
Correctness coverage for allreduce/allgather/reducescatter via TorchComms. |
test/torchcomms/test_error_handling.py |
Verifies clear runtime errors for unsupported ops and invalid usage patterns. |
test/torchcomms/test_sizes.py |
Sweeps message sizes to exercise selection boundaries and correctness. |
test/torchcomms/test_training_loop.py |
Simulates multi-iteration training-loop allreduce pattern. |
test/torchcomms/test_user_algorithms.py |
Validates user algorithm/selector registration via AlgorithmCollectionBuilder. |
test/torchcomms/test_multicomm.py |
Documents current multi-communicator limitation (expected skip). |
test/torchcomms/bench_torchcomms.py |
TorchComms benchmark driver for allreduce/allgather. |
test/torchcomms/bench_report.py |
Generates a report/figures from benchmark JSON output. |
test/torchcomms/run_benchmarks.sh |
Convenience runner to produce benchmark JSON + report + plots. |
| store_ = createPrefixStore("mscclpp", timeout_); | ||
| } | ||
|
|
||
| std::string key = "mscclpp_uniqueid_" + name + std::to_string(counter_++); |
There was a problem hiding this comment.
The store key concatenates name and counter_ without a delimiter ("mscclpp_uniqueid_" + name + std::to_string(counter_++)), which can collide for different (name, counter) pairs (e.g., name="foo1"/counter=23 vs name="foo12"/counter=3). Add a clear separator (e.g., ... + name + "_" + ...) or use a structured/hashed key. Also consider making counter_ atomic if multiple communicators can be created concurrently in one process.
| std::string key = "mscclpp_uniqueid_" + name + std::to_string(counter_++); | |
| std::string key = "mscclpp_uniqueid_" + name + "_" + std::to_string(counter_++); |
| checkInitialized(); | ||
| auto mscclppOp = torchReduceOpToMscclpp(op, "all_reduce"); | ||
| tensor = tensor.contiguous(); | ||
|
|
||
| return executeCollective("allreduce", tensor.data_ptr(), tensor.data_ptr(), tensor.nbytes(), tensor.nbytes(), | ||
| torchDtypeToMscclpp(tensor.scalar_type()), mscclppOp, async_op, options.timeout); |
There was a problem hiding this comment.
Reassigning tensor = tensor.contiguous() can silently drop results for non-contiguous inputs: the collective runs on a new contiguous tensor, but the caller’s original tensor/view won’t be updated. Prefer either (1) enforcing contiguity with an explicit error, or (2) running on a contiguous temporary and copying the result back into the original tensor/view.
| checkInitialized(); | ||
| auto input_contig = input.contiguous(); | ||
| output = output.contiguous(); | ||
|
|
||
| const size_t chunk_bytes = static_cast<size_t>(input_contig.nbytes()); | ||
|
|
||
| return executeCollective("allgather", input_contig.data_ptr(), output.data_ptr(), chunk_bytes, | ||
| static_cast<size_t>(output.nbytes()), torchDtypeToMscclpp(input_contig.scalar_type()), | ||
| mscclpp::NOP, async_op, options.timeout); |
There was a problem hiding this comment.
output = output.contiguous() can break semantics for non-contiguous outputs: the collective writes into a new contiguous tensor, but the caller’s original output tensor isn’t updated. Either require output.is_contiguous() (and raise a clear error) or write into a contiguous temporary and copy back to output after completion.
| checkInitialized(); | ||
| auto mscclppOp = torchReduceOpToMscclpp(op, "reduce_scatter_single"); | ||
| auto input_contig = input.contiguous(); | ||
| output = output.contiguous(); | ||
|
|
||
| return executeCollective("reducescatter", input_contig.data_ptr(), output.data_ptr(), | ||
| static_cast<size_t>(input_contig.nbytes()), static_cast<size_t>(output.nbytes()), | ||
| torchDtypeToMscclpp(input_contig.scalar_type()), mscclppOp, async_op, options.timeout); |
There was a problem hiding this comment.
output = output.contiguous() / input = input.contiguous() style reassignment can lose results for non-contiguous tensors/views because it changes only the local C++ handle. Prefer enforcing contiguity or copying results back into the caller-provided output tensor.
| checkInitialized(); | ||
| auto input_contig = input.contiguous(); | ||
| output = output.contiguous(); | ||
|
|
||
| return executeCollective("alltoall", input_contig.data_ptr(), output.data_ptr(), | ||
| static_cast<size_t>(input_contig.nbytes()), static_cast<size_t>(output.nbytes()), | ||
| torchDtypeToMscclpp(input_contig.scalar_type()), mscclpp::NOP, async_op, options.timeout); | ||
| } |
There was a problem hiding this comment.
output = output.contiguous() can silently redirect writes to a new tensor if the caller passes a non-contiguous output view. Either enforce contiguity or copy results back into the original output tensor after execution.
| cudaStreamSynchronize(internal_stream_); | ||
| } | ||
| cudaStreamSynchronize(at::cuda::getCurrentCUDAStream(device_.index()).stream()); |
There was a problem hiding this comment.
CUDA API calls here ignore return codes (cudaStreamSynchronize). For consistency with the rest of the file (which uses MSCCLPP_CUDATHROW), wrap these calls so failures surface as exceptions instead of being silently ignored.
| cudaStreamSynchronize(internal_stream_); | |
| } | |
| cudaStreamSynchronize(at::cuda::getCurrentCUDAStream(device_.index()).stream()); | |
| MSCCLPP_CUDATHROW(cudaStreamSynchronize(internal_stream_)); | |
| } | |
| MSCCLPP_CUDATHROW(cudaStreamSynchronize(at::cuda::getCurrentCUDAStream(device_.index()).stream())); |
| event_pool_.reset(); | ||
|
|
||
| if (internal_stream_) { | ||
| cudaStreamDestroy(internal_stream_); |
There was a problem hiding this comment.
cudaStreamDestroy is called without error checking. Consider wrapping it with MSCCLPP_CUDATHROW (or at least checking the return value) so resource teardown failures don’t get silently ignored.
| cudaStreamDestroy(internal_stream_); | |
| MSCCLPP_CUDATHROW(cudaStreamDestroy(internal_stream_)); |
| # Install with: pip install -r python/mscclpp_torchcomm/requirements_cuda12.txt | ||
|
|
||
| torch>=2.0.0 | ||
| pybind11 |
There was a problem hiding this comment.
This requirements file documents dependencies for the TorchComms backend, but it doesn’t include torchcomms even though the docs/tests require it. Consider adding an explicit torchcomms>=0.2.0 (or the exact version you’re targeting) so a pip install -r ... environment is actually runnable.
| pybind11 | |
| pybind11 | |
| torchcomms>=0.2.0 |
| # Find glog (required by torchcomms framework sources via Logging.hpp). | ||
| # Derive the conda env prefix from the Python executable path so we can | ||
| # locate glog headers and libraries installed in the same environment. | ||
| get_filename_component(CONDA_PREFIX "${Python_EXECUTABLE}" DIRECTORY) | ||
| get_filename_component(CONDA_PREFIX "${CONDA_PREFIX}" DIRECTORY) | ||
| find_library(GLOG_LIBRARY glog HINTS "${CONDA_PREFIX}/lib") | ||
| find_path(GLOG_INCLUDE_DIR glog/logging.h HINTS "${CONDA_PREFIX}/include") | ||
|
|
||
| target_include_directories(_comms_mscclpp SYSTEM PRIVATE | ||
| # torchcomms headers: resolves #include <comms/torchcomms/...> | ||
| ${torchcomms_SOURCE_DIR} | ||
| ${GPU_INCLUDE_DIRS} | ||
| ) | ||
| # MSCCL++ internal headers (for algorithm_selector.hpp and debug.h) | ||
| target_include_directories(_comms_mscclpp PRIVATE | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ext/nccl | ||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/core/include | ||
| ) | ||
| if(GLOG_INCLUDE_DIR) | ||
| target_include_directories(_comms_mscclpp SYSTEM PRIVATE ${GLOG_INCLUDE_DIR}) | ||
| endif() |
There was a problem hiding this comment.
The build attempts to locate glog but doesn’t fail fast if it can’t be found. Since torchcomms framework sources include glog headers, this will likely fail later with a confusing compile error. Prefer using find_package(glog REQUIRED) (if available) or add an explicit if(NOT GLOG_LIBRARY OR NOT GLOG_INCLUDE_DIR) message(FATAL_ERROR ...) with installation guidance.
|
|
||
| #### Building | ||
|
|
||
| Prerequisites: PyTorch, pybind11, and [torchcomms](https://github.com/meta-pytorch/torchcomms) (`pip install --pre torchcomms`). |
There was a problem hiding this comment.
Please add the tested version info here
| ```bash | ||
| $ mkdir -p build && cd build | ||
| $ cmake -DCMAKE_BUILD_TYPE=Release \ | ||
| -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON \ | ||
| .. | ||
| $ make -j$(nproc) | ||
| $ cd .. | ||
| ``` |
There was a problem hiding this comment.
Can we support pip installation? We can build this extension automatically when we install mscclpp package and put the .so in the mscclpp's Python installation directory. Then we can default TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP to the installation path.
| throw std::runtime_error("[TorchCommMSCCLPP] Unsupported tensor dtype: " + std::string(at::toString(dtype)) + | ||
| ". Supported: float32, float16, bfloat16, int32, uint32."); |
There was a problem hiding this comment.
Please check if this complies with the torch/torchcomm guidelines for throwing an error from custom operators.
| comm_ = bootstrap->createCommunicator(name, options); | ||
|
|
||
| // 2. Select GPU device | ||
| MSCCLPP_CUDATHROW(cudaSetDevice(device_.index())); |
There was a problem hiding this comment.
We'd better use mscclpp::CudaDeviceGuard to set the device back to the original one when the function returns
| // Detect hardware capabilities for algorithm selection | ||
| static const bool isNvlsSupported = mscclpp::isNvlsSupported(); | ||
| int cudaDevice; | ||
| MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice)); |
There was a problem hiding this comment.
Don't we already know cudaDevice at this point (same as device_.index())?
What This PR Does
This PR adds TorchComms support to MSCCL++, allowing PyTorch users to use MSCCL++ collectives through the TorchComms API with a single line:
This is valuable because it gives PyTorch training frameworks (torchtitan, FSDP2, etc.) a clean way to use MSCCL++ for high-performance collectives without LD_PRELOAD hacks or custom CUDA kernel code. Users can run MSCCL++ for the hot-path collectives (allreduce, allgather) and NCCL for everything else — mixed-backend training with no code changes.
Architecture
Communicator Lifecycle
When a user calls
torchcomms.new_comm("mscclpp", device), TorchComms dlopen's our_comms_mscclpp.*.somodule and callsinit(), which:UniqueIdthrough c10d::Store (rank 0 generates, others read), creates the MSCCL++Communicatorwith aTcpBootstrapGpuBuffer(cuMemMap) for native algorithms that need intermediate storageAlgorithmCollectionBuilder::buildDefaultAlgorithms()which registers 12 native algorithms + 2 DSL plans, then wires up the topology-aware algorithm selectorWhat Happens When You Call a Collective
Component Diagram
The backend is a thin adapter. It does not implement any collective algorithms — it delegates entirely to MSCCL++'s
AlgorithmCollection, which selects the optimal native algorithm based on message size, topology, NVLS support, and compute capability.Files to Review
Core Backend (focus here)
TorchCommMSCCLPP.hppTorchCommMSCCLPP.cppinit()bootstraps and builds the AlgorithmCollection.executeCollective()is the central dispatch — builds aCollectiveRequest, callsselectAlgorithm(), executes. Unsupported ops throw with NCCL/RCCL guidance.TorchCommMSCCLPPBootstrap.hpp/cppUniqueId, writes to c10d::Store, other ranks read it. Same pattern as TorchComms' NCCL backend.TorchWorkMSCCLPP.hpp/cppwait()usescudaStreamWaitEventfor GPU-side sync — no CPU blocking.TorchCommMSCCLPPPy.cppDynamicLoaderInterfacefor TorchComms' dlopen discovery.CMakeLists.txt(torchcomm)Build System
CMakeLists.txt(root)MSCCLPP_BUILD_EXT_TORCHCOMMSoption (OFF by default) andadd_subdirectory()Tests, Benchmarks, Docs
Tests (6 files, ~960 lines), benchmarks (3 files, ~500 lines), and docs (quickstart.md) are straightforward and lower review priority.
Supported Collectives
Key Design Decisions
1. Thin adapter, not a reimplementation.
The backend calls
AlgorithmCollection::selectAlgorithm()andAlgorithm::execute(). It does not contain any collective kernel code. Algorithm registration, selection logic, and kernel implementations all live in MSCCL++ core.2. Same algorithm selector as the NCCL extension.
We reuse
algorithm_selector.hppfromsrc/ext/nccl/so the TorchComms path selects the same algorithms as the LD_PRELOAD NCCL shim. This avoids divergence and ensures consistent behavior.3. Shared library linking (not static).
The module links against
libmscclpp.so(notmscclpp_static.a) to avoid dual-singleton crashes.mscclpp_collectives.solinks against the shared lib, so if we statically linked, there would be two copies of singletons likeUnixSocketServer::instance().4. GpuBuffer for scratch allocation.
Scratch memory is allocated via
mscclpp::GpuBuffer(cuMemMap) instead of plaincudaMalloc. This registers POSIX file descriptors in the unix socket server, which is required for cross-rank IPC sharing. Plain cudaMalloc causes "Requested fd not found" crashes.5. Build-gated behind
MSCCLPP_BUILD_EXT_TORCHCOMMS=OFF.No impact on existing builds. TorchComms headers are fetched on-demand via CMake FetchContent only when the option is enabled.
6. GPU event pooling.
Every collective call needs 2 CUDA events (start + end) for async tracking. Creating/destroying events costs ~5-10μs each. The pool amortizes this across thousands of collective calls in a training loop.
7. User-defined algorithms via AlgorithmCollectionBuilder singleton.
Custom algorithms (DSL or native) are configured on the builder before creating the TorchComms communicator. The backend picks them up during
init(). No algorithm registration API lives on the backend itself.Limitations
Algorithm::execute()operates on contiguous buffers (one input pointer, one output pointer), so the backend implementsall_gather_singleandreduce_scatter_singlebut not the tensor-list variants. The tensor-list variants throw with guidance to use the single-tensor variant instead.RuntimeErrorwith an explicit message naming the operation and suggesting the caller use a separate NCCL/RCCL communicator. This is the expected pattern for mixed-backend training.algorithm_selector.hppfromsrc/ext/nccl/rather than sharing it through a common path. A TODO in the code notes this should be consolidated intoAlgorithmCollectionBuilderso all consumers get a default selector automatically.How to Build and Test