Skip to content

Commit 95d01db

Browse files
committed
Add support for template kernels in HIP
1 parent 2f3b397 commit 95d01db

2 files changed

Lines changed: 11 additions & 4 deletions

File tree

kernel_tuner/backends/hip/hip.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,16 @@ def compile(self, kernel_instance):
149149
# Format kernel string
150150
kernel_string = kernel_instance.kernel_string
151151
kernel_name = kernel_instance.name
152-
if 'extern "C"' not in kernel_string:
153-
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"
152+
expression_name = kernel_name.encode()
154153

155154
# Create program
156155
prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), kernel_name.encode(), 0, [], []))
157156

158157
try:
158+
# Add the kernel as an expression. This forces hiprtc to instantiate the kernel if it
159+
# is templated or if it is in a namespace.
160+
hip_check(hiprtc.hiprtcAddNameExpression(prog, expression_name))
161+
159162
# Get device properties
160163
props = hip.hipDeviceProp_t()
161164
hip_check(hip.hipGetDeviceProperties(props, 0))
@@ -174,6 +177,10 @@ def compile(self, kernel_instance):
174177
hip_check(hiprtc.hiprtcGetProgramLog(prog, log))
175178
raise RuntimeError(log.decode())
176179

180+
# Get the lowered name. This is the name that can be used in hipModuleGetFunction to
181+
# get the kernel. For templated kernels, this differs from the original kernel name.
182+
lowered_name = hip_check(hiprtc.hiprtcGetLoweredName(prog, expression_name))
183+
177184
# Get compiled code
178185
code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog))
179186
code = bytearray(code_size)
@@ -182,7 +189,7 @@ def compile(self, kernel_instance):
182189
# Load module and get function
183190
module = hip_check(hip.hipModuleLoadData(code))
184191
self.current_module = module
185-
kernel = hip_check(hip.hipModuleGetFunction(module, kernel_name.encode()))
192+
kernel = hip_check(hip.hipModuleGetFunction(module, lowered_name))
186193

187194
except Exception as e:
188195
# Cleanup

kernel_tuner/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose)
707707
)
708708

709709
# check for templated kernel
710-
if kernel_source.lang in ["CUDA", "NVCUDA", "HIP"] and "<" in name and ">" in name:
710+
if kernel_source.lang in ["CUDA", "NVCUDA"] and "<" in name and ">" in name:
711711
kernel_string, name = wrap_templated_kernel(kernel_string, name)
712712

713713
# Preprocess GPU arguments. Require for handling `Tunable` arguments

0 commit comments

Comments
 (0)