|
35 | 35 | from .._kernels import code as _kernel_code |
36 | 36 | from ._base import PlatformBackend as _PlatformBackend |
37 | 37 |
|
| 38 | +# Module-level kernel compilation cache. Keyed on |
| 39 | +# (device_index, compiler_optimisations). Stores compiled program binaries. |
| 40 | +_kernel_cache = {} |
| 41 | + |
38 | 42 |
|
39 | 43 | class OpenCLPlatform(_PlatformBackend): |
40 | 44 | """ |
@@ -122,39 +126,75 @@ def compile_kernels(self) -> _Dict[str, _Callable]: |
122 | 126 | """ |
123 | 127 | Compile OpenCL kernels and return callable functions. |
124 | 128 |
|
| 129 | + Uses a module-level cache so that only the first sampler on a given |
| 130 | + device pays the compilation cost. |
| 131 | +
|
125 | 132 | Returns |
126 | 133 | ------- |
127 | 134 | dict |
128 | 135 | Dictionary mapping kernel names to callable kernel functions. |
129 | 136 | """ |
| 137 | + cache_key = (self._device_index, self._compiler_optimisations) |
| 138 | + |
130 | 139 | # Build compiler options |
131 | 140 | build_options = [] |
132 | 141 | if self._compiler_optimisations: |
133 | 142 | build_options.extend(["-cl-mad-enable", "-cl-no-signed-zeros"]) |
134 | 143 |
|
135 | | - # Compile program from source, suppressing stderr and warnings. |
136 | | - stderr_capture = _io.StringIO() |
137 | | - old_stderr = _sys.stderr |
138 | | - try: |
139 | | - _sys.stderr = stderr_capture |
140 | | - with _warnings.catch_warnings(): |
141 | | - _warnings.simplefilter("ignore") |
142 | | - program = _cl.Program(self._context, _kernel_code).build( |
143 | | - options=build_options |
144 | | - ) |
145 | | - except _cl.RuntimeError as e: |
146 | | - stderr_output = stderr_capture.getvalue().strip() |
147 | | - error_msg = f"OpenCL kernel compilation failed: {e}" |
148 | | - if stderr_output: |
149 | | - error_msg += f"\n{stderr_output}" |
150 | | - raise RuntimeError(error_msg) |
151 | | - finally: |
152 | | - _sys.stderr = old_stderr |
153 | | - |
154 | | - # Capture the compiler log (including any warnings). |
155 | | - self._compiler_log = program.get_build_info( |
156 | | - self._device, _cl.program_build_info.LOG |
157 | | - ).strip() |
| 144 | + if cache_key in _kernel_cache: |
| 145 | + cached_binary = _kernel_cache[cache_key] |
| 146 | + |
| 147 | + # Create program from cached binary. |
| 148 | + stderr_capture = _io.StringIO() |
| 149 | + old_stderr = _sys.stderr |
| 150 | + try: |
| 151 | + _sys.stderr = stderr_capture |
| 152 | + with _warnings.catch_warnings(): |
| 153 | + _warnings.simplefilter("ignore") |
| 154 | + program = _cl.Program( |
| 155 | + self._context, [self._device], [cached_binary] |
| 156 | + ) |
| 157 | + program.build(options=build_options) |
| 158 | + except _cl.RuntimeError as e: |
| 159 | + stderr_output = stderr_capture.getvalue().strip() |
| 160 | + error_msg = f"OpenCL kernel build from cached binary failed: {e}" |
| 161 | + if stderr_output: |
| 162 | + error_msg += f"\n{stderr_output}" |
| 163 | + raise RuntimeError(error_msg) |
| 164 | + finally: |
| 165 | + _sys.stderr = old_stderr |
| 166 | + |
| 167 | + self._compiler_log = "" |
| 168 | + self._cache_hit = True |
| 169 | + else: |
| 170 | + # Compile program from source, suppressing stderr and warnings. |
| 171 | + stderr_capture = _io.StringIO() |
| 172 | + old_stderr = _sys.stderr |
| 173 | + try: |
| 174 | + _sys.stderr = stderr_capture |
| 175 | + with _warnings.catch_warnings(): |
| 176 | + _warnings.simplefilter("ignore") |
| 177 | + program = _cl.Program(self._context, _kernel_code).build( |
| 178 | + options=build_options |
| 179 | + ) |
| 180 | + except _cl.RuntimeError as e: |
| 181 | + stderr_output = stderr_capture.getvalue().strip() |
| 182 | + error_msg = f"OpenCL kernel compilation failed: {e}" |
| 183 | + if stderr_output: |
| 184 | + error_msg += f"\n{stderr_output}" |
| 185 | + raise RuntimeError(error_msg) |
| 186 | + finally: |
| 187 | + _sys.stderr = old_stderr |
| 188 | + |
| 189 | + # Capture the compiler log (including any warnings). |
| 190 | + self._compiler_log = program.get_build_info( |
| 191 | + self._device, _cl.program_build_info.LOG |
| 192 | + ).strip() |
| 193 | + |
| 194 | + self._cache_hit = False |
| 195 | + |
| 196 | + # Cache the compiled binary. |
| 197 | + _kernel_cache[cache_key] = program.get_info(_cl.program_info.BINARIES)[0] |
158 | 198 |
|
159 | 199 | # Create kernel wrappers that match PyCUDA calling convention. |
160 | 200 | # OpenCL kernels need (queue, global_size, local_size, *args) |
@@ -189,6 +229,11 @@ def wrapper(*args, **kwargs): |
189 | 229 |
|
190 | 230 | return kernels |
191 | 231 |
|
| 232 | + @staticmethod |
| 233 | + def clear_cache(): |
| 234 | + """Clear the kernel compilation cache.""" |
| 235 | + _kernel_cache.clear() |
| 236 | + |
192 | 237 | def to_gpu(self, array: _np.ndarray) -> _Any: |
193 | 238 | """ |
194 | 239 | Transfer a NumPy array to GPU memory. |
|
0 commit comments