-
Notifications
You must be signed in to change notification settings - Fork 93
Expand file tree
/
Copy pathexecution_kernel.cu
More file actions
74 lines (70 loc) · 2.86 KB
/
execution_kernel.cu
File metadata and controls
74 lines (70 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include "execution_kernel.hpp"
#if defined(MSCCLPP_DEVICE_CUDA)
namespace mscclpp {
template <typename PacketType>
void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag) {
switch (dataType) {
case DataType::INT32:
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::UINT32:
executionKernel<uint32_t><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::FLOAT16:
executionKernel<half><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::FLOAT32:
executionKernel<float><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (float*)src, (float*)dst, (float*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
}
}
template void ExecutionKernel::launchKernel<LL16Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
template void ExecutionKernel::launchKernel<LL8Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
} // namespace mscclpp
#endif