Skip to content
Open
39 changes: 35 additions & 4 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -963,8 +963,20 @@ __global__ void __launch_bounds__(1024, 1)
int nBlocksForReduce = nRanksPerNode;
int copyReduceRatio = nBlocksForCopy / nBlocksForReduce;
size_t scratchSizePerRank = scratchBufferSize / nRanksPerNode;
size_t sizePerRank = size / nRanksPerNode;
assert(sizePerRank % alignment == 0);

// Pad size to be divisible by (nRanksPerNode * alignment)
size_t paddingNeeded =
(nRanksPerNode * alignment - (size % (nRanksPerNode * alignment))) % (nRanksPerNode * alignment);
size_t paddedSize = size + paddingNeeded;
size_t sizePerRank = paddedSize / nRanksPerNode;

// Calculate actual size this rank should process (without padding)
size_t actualSizeThisRank = sizePerRank;
if (rank == nRanksPerNode - 1) {
// Last rank might have less actual data due to padding
actualSizeThisRank = size - (sizePerRank * (nRanksPerNode - 1));
}

uint32_t sizePerBlock =
((sizePerRank + (nBlocksForCopy - 1)) / nBlocksForCopy + alignment - 1) / alignment * alignment;
uint32_t lastBlockSize = sizePerRank - (nBlocksForCopy - 1) * sizePerBlock;
Expand Down Expand Up @@ -1008,7 +1020,17 @@ __global__ void __launch_bounds__(1024, 1)
uint32_t scratchOffset = scratchIt * unitSize + bid * scratchSizePerBlock + i * scratchSizePerRank;
char* srcData = (char*)src + blockOffset;
char* dstData = (char*)scratch + scratchOffset;
mscclpp::copy(dstData, srcData, iterSize, tid, blockDim.x);
// Calculate actual copy size - don't copy beyond actual data on last rank
size_t actualCopySize = iterSize;
if (i == nRanksPerNode - 1 && blockOffset + iterSize > i * sizePerRank + actualSizeThisRank) {
// On last rank, clamp to actual data size
actualCopySize = (i * sizePerRank + actualSizeThisRank > blockOffset)
? (i * sizePerRank + actualSizeThisRank - blockOffset)
: 0;
}
if (actualCopySize > 0) {
mscclpp::copy(dstData, srcData, actualCopySize, tid, blockDim.x);
}
}
__syncthreads();
if (tid < nPeers) {
Expand Down Expand Up @@ -1067,7 +1089,16 @@ __global__ void __launch_bounds__(1024, 1)
i * scratchSizePerRank;
char* srcData = (char*)scratch + scratchOffset;
char* dstData = (char*)dst + blockOffset;
mscclpp::copy(dstData, srcData, iterSize, tid, blockDim.x);

size_t actualCopySize = iterSize;
if (i == nRanksPerNode - 1 && blockOffset + iterSize > i * sizePerRank + actualSizeThisRank) {
actualCopySize = (i * sizePerRank + actualSizeThisRank > blockOffset)
? (i * sizePerRank + actualSizeThisRank - blockOffset)
: 0;
}
if (actualCopySize > 0) {
mscclpp::copy(dstData, srcData, actualCopySize, tid, blockDim.x);
}
}
__syncthreads();
if (tid == 0) {
Expand Down
1 change: 1 addition & 0 deletions src/include/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <unistd.h>

#include <array>
#include <bitset>
#include <fstream>
#include <iomanip>
Expand Down
4 changes: 3 additions & 1 deletion test/torch/correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def _init_dist():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", os.environ["RANK"]))
dist.init_process_group(backend=backend, rank=rank, world_size=world_size, device_id=local_rank)
dist.init_process_group(
backend=backend, rank=rank, world_size=world_size, device_id=torch.device(f"cuda:{local_rank}")
Comment thread
seagater marked this conversation as resolved.
)
torch.cuda.set_device(local_rank)


Expand Down
Loading