@@ -61,20 +61,28 @@ def setup_scatter_data(init_nvshmem_backend):
6161 torch .manual_seed (0 )
6262
6363 num_features = 8
64+
6465 all_rank_input_data = torch .randn (1 , 8 , num_features )
6566
6667 all_edge_coo = torch .tensor ([[0 , 0 , 0 , 1 , 1 , 2 , 2 , 3 ], [1 , 2 , 3 , 0 , 3 , 0 , 3 , 0 ]])
6768 rank_mappings = torch .tensor ([[0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 ], [0 , 1 , 1 , 0 , 1 , 0 , 1 , 0 ]])
6869
69- all_rank_output = torch .zeros (2 , 4 , num_features )
70+ num_global_output_rows = 4
71+ all_rank_output = torch .zeros (2 , num_global_output_rows , num_features )
7072
7173 for k in range (2 ):
7274 _indices = all_edge_coo [k ].view (1 , - 1 , 1 ).expand (1 , - 1 , num_features )
7375 output_data = torch .zeros_like (all_rank_output [[k ]])
7476 output_data .scatter_add_ (1 , _indices , all_rank_input_data )
7577 all_rank_output [k ] = output_data
7678
77- return all_rank_input_data , all_edge_coo , rank_mappings , all_rank_output
79+ return (
80+ all_rank_input_data ,
81+ all_edge_coo ,
82+ rank_mappings ,
83+ all_rank_output ,
84+ num_global_output_rows ,
85+ )
7886
7987
8088def test_nvshmem_backend_init (init_nvshmem_backend ):
@@ -131,12 +139,13 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):
131139 comm = init_nvshmem_backend
132140 rank = comm .get_rank ()
133141 world_size = comm .get_world_size ()
134- all_rank_input_data , all_edge_coo , rank_mappings , all_rank_output = (
135- setup_scatter_data
136- )
137-
138- all_edge_coo = all_edge_coo .T
139- rank_mappings = rank_mappings .T
142+ (
143+ all_rank_input_data ,
144+ all_edge_coo ,
145+ rank_mappings ,
146+ all_rank_output ,
147+ num_global_output_rows ,
148+ ) = setup_scatter_data
140149
141150 input_slice_start = (all_rank_input_data .shape [1 ] // world_size ) * rank
142151 input_slice_end = (all_rank_input_data .shape [1 ] // world_size ) * (rank + 1 )
@@ -147,7 +156,11 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):
147156 local_input_data_gt = all_rank_input_data [:, input_slice_start :input_slice_end , :]
148157 local_edge_coo = all_edge_coo [:, edge_slice_start :edge_slice_end ]
149158 local_rank_mappings_gt = rank_mappings [:, edge_slice_start :edge_slice_end ]
150- local_output_data_gt = all_rank_output [:, edge_slice_start :edge_slice_end , :]
159+
160+ output_slice_start = (num_global_output_rows // world_size ) * rank
161+ output_slice_end = (num_global_output_rows // world_size ) * (rank + 1 )
162+ local_output_data_gt = all_rank_output [:, output_slice_start :output_slice_end , :]
163+ num_output_rows = local_output_data_gt .shape [1 ]
151164
152165 for i in range (2 ):
153166 local_indices_gt = local_edge_coo [[i ], :]
@@ -161,3 +174,12 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):
161174
162175 local_input_data = comm .get_local_rank_slice (all_rank_input_data , dim = 1 )
163176 assert torch .allclose (local_input_data , local_input_data_gt )
177+
178+ scattered_tensor = comm .scatter (
179+ local_input_data .cuda (),
180+ local_indices .cuda (),
181+ local_rank_mapping .cuda (),
182+ num_output_rows ,
183+ )
184+
185+ assert torch .allclose (scattered_tensor , local_output_data_gt [[i ]].cuda ())
0 commit comments