@@ -149,13 +149,16 @@ def compile(self, kernel_instance):
149149 # Format kernel string
150150 kernel_string = kernel_instance .kernel_string
151151 kernel_name = kernel_instance .name
152- if 'extern "C"' not in kernel_string :
153- kernel_string = 'extern "C" {\n ' + kernel_string + "\n }"
152+ expression_name = kernel_name .encode ()
154153
155154 # Create program
156155 prog = hip_check (hiprtc .hiprtcCreateProgram (kernel_string .encode (), kernel_name .encode (), 0 , [], []))
157156
158157 try :
158+ # Add the kernel as an expression. This forces hiprtc to instantiate the kernel if it
159+ # is templated or if it is in a namespace.
160+ hip_check (hiprtc .hiprtcAddNameExpression (prog , expression_name ))
161+
159162 # Get device properties
160163 props = hip .hipDeviceProp_t ()
161164 hip_check (hip .hipGetDeviceProperties (props , 0 ))
@@ -174,6 +177,10 @@ def compile(self, kernel_instance):
174177 hip_check (hiprtc .hiprtcGetProgramLog (prog , log ))
175178 raise RuntimeError (log .decode ())
176179
180+ # Get the lowered name. This is the name that can be used in hipModuleGetFunction to
181+ # get the kernel. For templated kernels, this differs from the original kernel name.
182+ lowered_name = hip_check (hiprtc .hiprtcGetLoweredName (prog , expression_name ))
183+
177184 # Get compiled code
178185 code_size = hip_check (hiprtc .hiprtcGetCodeSize (prog ))
179186 code = bytearray (code_size )
@@ -182,7 +189,7 @@ def compile(self, kernel_instance):
182189 # Load module and get function
183190 module = hip_check (hip .hipModuleLoadData (code ))
184191 self .current_module = module
185- kernel = hip_check (hip .hipModuleGetFunction (module , kernel_name . encode () ))
192+ kernel = hip_check (hip .hipModuleGetFunction (module , lowered_name ))
186193
187194 except Exception as e :
188195 # Cleanup
0 commit comments