Skip to content

Commit b8a1ce6

Browse files
authored
Add optimized masked local Gather kernel (#12)
* Setting up individual benchmarks for Graphcast data and layers * Updated local kernel API * Add local kernel impl * Add missing header file and unt test * Add python side bindings * Add unit test for opt gather * Updated build script to add pyproject (intermediate fix before moving to scikit build) - Change Macro to be MPI free - Passing tests * Add support for older compute arches * Patch to allow half prcision oeprators in extensions
1 parent 38ee37b commit b8a1ce6

16 files changed

Lines changed: 822 additions & 84 deletions

DGraph/distributed/RankLocalOps.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717

1818
import torch
1919

20+
try:
21+
from torch_local import local_masked_gather, local_masked_scatter
22+
23+
_LOCAL_OPT_KERNELS_AVAILABLE = True
24+
except ImportError:
25+
_LOCAL_OPT_KERNELS_AVAILABLE = False
26+
import warnings
27+
2028

2129
def RankLocalMaskedGather(
2230
_src: torch.Tensor, indices: torch.Tensor, rank_mapping: torch.Tensor, rank: int
@@ -31,6 +39,48 @@ def RankLocalMaskedGather(
3139
return local_gathered_data
3240

3341

42+
def __Local_Gather_impl(_src_tensor, local_indices):
43+
num_features = _src_tensor.shape[-1]
44+
bs = _src_tensor.shape[0]
45+
local_indices = local_indices.view(bs, -1, 1).expand(bs, -1, num_features)
46+
local_gathered_data = torch.gather(_src_tensor, 1, local_indices)
47+
return local_gathered_data
48+
49+
50+
def OptimizedRankLocalMaskedGather(
51+
src: torch.Tensor,
52+
indices: torch.Tensor,
53+
rank_mapping: torch.Tensor,
54+
output: torch.Tensor,
55+
rank: int,
56+
) -> torch.Tensor:
57+
"""
58+
This function gathers the indices from the source rank to the destination rank.
59+
"""
60+
if not _LOCAL_OPT_KERNELS_AVAILABLE:
61+
warnings.warn(
62+
"Optimized local kernels are not available. Falling back to the default implementation."
63+
)
64+
return RankLocalMaskedGather(src, indices, rank_mapping, rank)
65+
bs = src.shape[0]
66+
indices = indices.view(bs, -1, 1)
67+
num_output_rows = indices.shape[1]
68+
num_src_rows = src.shape[1]
69+
num_features = src.shape[-1]
70+
local_masked_gather(
71+
src,
72+
indices,
73+
rank_mapping,
74+
output,
75+
bs,
76+
num_src_rows,
77+
num_features,
78+
num_output_rows,
79+
rank,
80+
)
81+
return output
82+
83+
3484
def OutOfPlaceRankLocalMaskedGather(
3585
_src: torch.Tensor, indices: torch.Tensor, rank_mapping: torch.Tensor, rank: int
3686
) -> torch.Tensor:

DGraph/distributed/csrc/local_data_kernels.cuh

Lines changed: 102 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515
*/
1616
#pragma once
1717
#include <cuda.h>
18-
#include <thrust/pair.h>
19-
#include <cub/cub.cuh>
20-
2118

2219
/**
23-
*
20+
*
2421
* This file houses all the kernels that we use for local data communication.
2522
* Currently all the kernels are in the Local namespace and in the same file, but
2623
* we can split this up in the future if needed for better organization.
27-
*
24+
*
2825
*/
2926
namespace Local
3027
{
@@ -36,7 +33,7 @@ namespace Local
3633

3734
__global__ void Fused_ReLU_Scatter_Kernel(
3835
const float *__restrict__ values,
39-
const float *__restrict__ indices,
36+
const long *__restrict__ indices,
4037
float *__restrict__ output,
4138
const int mini_batch_size,
4239
const int num_values_rows,
@@ -60,7 +57,7 @@ namespace Local
6057

6158
for (size_t row = gidy; row < num_values_rows; row += nthreadsy)
6259
{
63-
const int ind = __float2int_rd(indices[ind_offset + row]);
60+
const int ind = indices[ind_offset + row];
6461

6562
for (size_t i = gidx; i < num_cols; i += nthreadsx)
6663
{
@@ -79,7 +76,7 @@ namespace Local
7976
const float *__restrict__ values_2,
8077
const float *__restrict__ means,
8178
const float *__restrict__ inv_var,
82-
const float *__restrict__ indices,
79+
const long *__restrict__ indices,
8380
float *__restrict__ output,
8481
const int mini_batch_size,
8582
const int num_values_rows,
@@ -103,7 +100,7 @@ namespace Local
103100

104101
for (size_t row = gidy; row < num_values_rows; row += nthreadsy)
105102
{
106-
const int ind = __float2int_rd(indices[ind_offset + row]);
103+
const int ind = indices[ind_offset + row];
107104

108105
for (size_t i = gidx; i < num_cols; i += nthreadsx)
109106
{
@@ -119,7 +116,7 @@ namespace Local
119116

120117
__global__ void Sparse_Scatter_Kernel(
121118
const float *__restrict__ values,
122-
const float *__restrict__ indices,
119+
const long *__restrict__ indices,
123120
float *__restrict__ output,
124121
const int mini_batch_size,
125122
const int num_values_rows,
@@ -143,7 +140,7 @@ namespace Local
143140

144141
for (size_t row = gidy; row < num_values_rows; row += nthreadsy)
145142
{
146-
const int ind = __float2int_rd(indices[ind_offset + row]);
143+
const int ind = indices[ind_offset + row];
147144

148145
for (size_t i = gidx; i < num_cols; i += nthreadsx)
149146
{
@@ -160,4 +157,98 @@ namespace Local
160157
}
161158
}
162159

160+
__global__ void Rank_Local_Gather_Kernel(
161+
const float *__restrict__ values,
162+
const long *__restrict__ indices,
163+
const long *__restrict__ rank_placement,
164+
float *__restrict__ output,
165+
const int mini_batch_size,
166+
const int num_values_rows,
167+
const int num_cols,
168+
const int num_output_rows,
169+
const int local_rank)
170+
{
171+
172+
const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x;
173+
const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y;
174+
const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z;
175+
176+
const size_t nthreadsx = gridDim.x * blockDim.x;
177+
const size_t nthreadsy = gridDim.y * blockDim.y;
178+
const size_t nthreadsz = gridDim.z * blockDim.z;
179+
180+
for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz)
181+
{
182+
const auto values_offset = mb_i * num_cols * num_values_rows;
183+
const auto output_offset = mb_i * num_cols * num_output_rows;
184+
const auto ind_offset = mb_i * num_output_rows;
185+
const auto rank_placement_offset = mb_i * num_output_rows;
186+
187+
for (size_t row = gidy; row < num_output_rows; row += nthreadsy)
188+
{
189+
const int ind = indices[ind_offset + row];
190+
const int row_rank = rank_placement[rank_placement_offset + row];
191+
// Only gather the values if the rank is the same as the local rank
192+
if (row_rank == local_rank)
193+
{
194+
// Probably not needed, but just in case
195+
if (ind > -1 && ind < num_values_rows)
196+
{
197+
for (size_t i = gidx; i < num_cols; i += nthreadsx)
198+
{
199+
const auto val = values[values_offset + ind * num_cols + i];
200+
output[output_offset + row * num_cols + i] = val;
201+
}
202+
}
203+
}
204+
}
205+
}
206+
}
207+
208+
__global__ void Rank_Local_Scatter_Kernel(
209+
const float *__restrict__ values,
210+
const long *__restrict__ indices,
211+
const long *__restrict__ rank_placement,
212+
float *__restrict__ output,
213+
const int mini_batch_size,
214+
const int num_values_rows,
215+
const int num_cols,
216+
const int num_output_rows,
217+
const int local_rank)
218+
{
219+
const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x;
220+
const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y;
221+
const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z;
222+
223+
const size_t nthreadsx = gridDim.x * blockDim.x;
224+
const size_t nthreadsy = gridDim.y * blockDim.y;
225+
const size_t nthreadsz = gridDim.z * blockDim.z;
226+
227+
for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz)
228+
{
229+
const auto values_offset = mb_i * num_cols * num_values_rows;
230+
const auto output_offset = mb_i * num_cols * num_output_rows;
231+
const auto ind_offset = mb_i * num_values_rows;
232+
const auto rank_placement_offset = mb_i * num_output_rows;
233+
234+
for (size_t row = gidy; row < num_values_rows; row += nthreadsy)
235+
{
236+
const int ind = indices[ind_offset + row];
237+
const int row_rank = rank_placement[rank_placement_offset + row];
238+
// Only gather the values if the rank is the same as the local rank
239+
if (row_rank == local_rank)
240+
{
241+
// Probably not needed, but just in case
242+
if (ind > -1 && ind < num_output_rows)
243+
{
244+
for (size_t i = gidx; i < num_cols; i += nthreadsx)
245+
{
246+
const auto val = values[values_offset + row * num_cols + i];
247+
atomicAdd(&output[output_offset + ind * num_cols + i], Max(val, 0.0));
248+
}
249+
}
250+
}
251+
}
252+
}
253+
}
163254
} // namespace Local
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/**
2+
* Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
3+
* Produced at the Lawrence Livermore National Laboratory.
4+
* Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5+
* the CONTRIBUTORS file. See the top-level LICENSE file for details.
6+
*
7+
* LLNL-CODE-697807.
8+
* All rights reserved.
9+
*
10+
* This file is part of LBANN: Livermore Big Artificial Neural Network
11+
* Toolkit. For details, see http://software.llnl.gov/LBANN or
12+
* https://github.com/LBANN and https://github.com/LLNL/LBANN.
13+
*
14+
* SPDX-License-Identifier: (Apache-2.0)
15+
*/
16+
17+
#include <torch/extension.h>
18+
#include "torch_local.hpp"
19+
20+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
21+
{
22+
m.def("local_masked_gather", &local_masked_gather, "Masked Gather");
23+
m.def("local_masked_scatter", &local_masked_scatter, "Masked Scatter");
24+
}
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/**
2+
* Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC.
3+
* Produced at the Lawrence Livermore National Laboratory.
4+
* Written by the LBANN Research Team (B. Van Essen, et al.) listed in
5+
* the CONTRIBUTORS file. See the top-level LICENSE file for details.
6+
*
7+
* LLNL-CODE-697807.
8+
* All rights reserved.
9+
*
10+
* This file is part of LBANN: Livermore Big Artificial Neural Network
11+
* Toolkit. For details, see http://software.llnl.gov/LBANN or
12+
* https://github.com/LBANN and https://github.com/LLNL/LBANN.
13+
*
14+
* SPDX-License-Identifier: (Apache-2.0)
15+
*/
16+
#include <torch/extension.h>
17+
#include <c10/cuda/CUDAStream.h>
18+
#include "torch_local.hpp"
19+
#include "local_data_kernels.cuh"
20+
#include "macros.hpp"
21+
22+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
23+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
24+
#define CHECK_INPUT(x) \
25+
CHECK_CUDA(x); \
26+
CHECK_CONTIGUOUS(x)
27+
28+
torch::Tensor local_masked_gather(torch::Tensor input,
29+
torch::Tensor indices,
30+
torch::Tensor rank_local_placement,
31+
torch::Tensor output,
32+
const int num_batches,
33+
const int num_values_rows,
34+
const int num_cols,
35+
const int num_output_rows,
36+
const int local_rank)
37+
{
38+
CHECK_INPUT(input);
39+
CHECK_INPUT(indices);
40+
CHECK_INPUT(rank_local_placement);
41+
CHECK_INPUT(output);
42+
43+
const float *input_ptr = input.data_ptr<float>();
44+
const long *indices_ptr = indices.data_ptr<long>();
45+
const long *rank_local_placement_ptr = rank_local_placement.data_ptr<long>();
46+
float *output_ptr = output.data_ptr<float>();
47+
48+
dim3 block_dims, grid_dims;
49+
block_dims.x = 32;
50+
block_dims.y = 32;
51+
block_dims.z = 1;
52+
53+
const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y;
54+
const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x;
55+
grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535;
56+
grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535;
57+
grid_dims.z = 1;
58+
59+
// Get the default stream for the current device
60+
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index());
61+
Local::Rank_Local_Gather_Kernel<<<grid_dims, block_dims>>>(input_ptr,
62+
indices_ptr,
63+
rank_local_placement_ptr,
64+
output_ptr,
65+
num_batches,
66+
num_values_rows,
67+
num_cols,
68+
num_output_rows,
69+
local_rank);
70+
CUDACHECK(cudaGetLastError());
71+
return output;
72+
}
73+
74+
torch::Tensor local_masked_scatter(torch::Tensor input,
75+
torch::Tensor indices,
76+
torch::Tensor rank_local_placement,
77+
torch::Tensor output,
78+
const int num_batches,
79+
const int num_values_rows,
80+
const int num_cols,
81+
const int num_output_rows,
82+
const int rank)
83+
{
84+
CHECK_INPUT(input);
85+
CHECK_INPUT(indices);
86+
CHECK_INPUT(rank_local_placement);
87+
CHECK_INPUT(output);
88+
89+
const float *input_ptr = input.data_ptr<float>();
90+
const long *indices_ptr = indices.data_ptr<long>();
91+
const long *rank_local_placement_ptr = rank_local_placement.data_ptr<long>();
92+
float *output_ptr = output.data_ptr<float>();
93+
94+
dim3 block_dims, grid_dims;
95+
block_dims.x = 32;
96+
block_dims.y = 32;
97+
block_dims.z = 1;
98+
99+
const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y;
100+
const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x;
101+
grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535;
102+
grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535;
103+
grid_dims.z = 1;
104+
// Get the default stream for the current device
105+
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index());
106+
Local::Rank_Local_Scatter_Kernel<<<grid_dims, block_dims>>>(input_ptr,
107+
indices_ptr,
108+
rank_local_placement_ptr,
109+
output_ptr,
110+
num_batches,
111+
num_values_rows,
112+
num_cols,
113+
num_output_rows,
114+
rank);
115+
CUDACHECK(cudaGetLastError());
116+
return output;
117+
}

0 commit comments

Comments
 (0)