Skip to content

Commit d4975e9

Browse files
committed
Re-add simplified cache for compiled kernel code.
1 parent 40ee8ef commit d4975e9

6 files changed

Lines changed: 234 additions & 50 deletions

File tree

src/loch/_platforms/_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,18 @@ def platform_name(self) -> str:
190190
"""
191191
pass
192192

193+
@property
194+
def cache_hit(self) -> bool:
195+
"""
196+
Whether the last compile_kernels() call was a cache hit.
197+
198+
Returns
199+
-------
200+
bool
201+
True if kernels were loaded from cache, False if freshly compiled.
202+
"""
203+
return getattr(self, "_cache_hit", False)
204+
193205
@property
194206
def compiler_log(self) -> str:
195207
"""

src/loch/_platforms/_cuda.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
from .._kernels import code as _kernel_code
3636
from ._base import PlatformBackend as _PlatformBackend
3737

38+
# Module-level kernel compilation cache. Keyed on
39+
# (device_index, compiler_optimisations). Since the kernel source no longer
40+
# depends on system-specific parameters, the same compiled binary can be
41+
# reused across all samplers on a given device.
42+
_kernel_cache = {}
43+
3844

3945
class CUDAPlatform(_PlatformBackend):
4046
"""
@@ -123,38 +129,51 @@ def compile_kernels(self) -> _Dict[str, _Callable]:
123129
"""
124130
Compile CUDA kernels and return callable functions.
125131
132+
Uses a module-level cache so that only the first sampler on a given
133+
device pays the nvcc compilation cost.
134+
126135
Returns
127136
-------
128137
dict
129138
Dictionary mapping kernel names to callable kernel functions.
130139
"""
131-
# Compile kernel source.
132-
# Suppress stderr but capture it for error reporting.
133-
stderr_capture = _io.StringIO()
134-
old_stderr = _sys.stderr
135-
136-
options = []
137-
if self._compiler_optimisations:
138-
options.append("--use_fast_math")
139-
140-
try:
141-
_sys.stderr = stderr_capture
142-
cubin = _compile(
143-
_kernel_code,
144-
no_extern_c=True,
145-
nvcc=self._nvcc,
146-
options=options,
147-
)
148-
except Exception as e:
149-
stderr_output = stderr_capture.getvalue().strip()
150-
error_msg = f"CUDA kernel compilation failed: {e}"
151-
if stderr_output:
152-
error_msg += f"\n{stderr_output}"
153-
raise RuntimeError(error_msg)
154-
finally:
155-
_sys.stderr = old_stderr
156-
157-
self._compiler_log = stderr_capture.getvalue().strip()
140+
cache_key = (self._device_index, self._compiler_optimisations)
141+
142+
if cache_key in _kernel_cache:
143+
cubin = _kernel_cache[cache_key]
144+
self._compiler_log = ""
145+
self._cache_hit = True
146+
else:
147+
# Compile kernel source.
148+
# Suppress stderr but capture it for error reporting.
149+
stderr_capture = _io.StringIO()
150+
old_stderr = _sys.stderr
151+
152+
options = []
153+
if self._compiler_optimisations:
154+
options.append("--use_fast_math")
155+
156+
try:
157+
_sys.stderr = stderr_capture
158+
cubin = _compile(
159+
_kernel_code,
160+
no_extern_c=True,
161+
nvcc=self._nvcc,
162+
options=options,
163+
)
164+
except Exception as e:
165+
stderr_output = stderr_capture.getvalue().strip()
166+
error_msg = f"CUDA kernel compilation failed: {e}"
167+
if stderr_output:
168+
error_msg += f"\n{stderr_output}"
169+
raise RuntimeError(error_msg)
170+
finally:
171+
_sys.stderr = old_stderr
172+
173+
self._compiler_log = stderr_capture.getvalue().strip()
174+
self._cache_hit = False
175+
_kernel_cache[cache_key] = cubin
176+
158177
mod = _cuda.module_from_buffer(cubin)
159178

160179
# Extract kernel functions
@@ -168,6 +187,11 @@ def compile_kernels(self) -> _Dict[str, _Callable]:
168187

169188
return kernels
170189

190+
@staticmethod
191+
def clear_cache():
192+
"""Clear the kernel compilation cache."""
193+
_kernel_cache.clear()
194+
171195
def to_gpu(self, array: _np.ndarray) -> _Any:
172196
"""
173197
Transfer a NumPy array to GPU memory.

src/loch/_platforms/_opencl.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
from .._kernels import code as _kernel_code
3636
from ._base import PlatformBackend as _PlatformBackend
3737

38+
# Module-level kernel compilation cache. Keyed on
39+
# (device_index, compiler_optimisations). Stores compiled program binaries.
40+
_kernel_cache = {}
41+
3842

3943
class OpenCLPlatform(_PlatformBackend):
4044
"""
@@ -122,39 +126,75 @@ def compile_kernels(self) -> _Dict[str, _Callable]:
122126
"""
123127
Compile OpenCL kernels and return callable functions.
124128
129+
Uses a module-level cache so that only the first sampler on a given
130+
device pays the compilation cost.
131+
125132
Returns
126133
-------
127134
dict
128135
Dictionary mapping kernel names to callable kernel functions.
129136
"""
137+
cache_key = (self._device_index, self._compiler_optimisations)
138+
130139
# Build compiler options
131140
build_options = []
132141
if self._compiler_optimisations:
133142
build_options.extend(["-cl-mad-enable", "-cl-no-signed-zeros"])
134143

135-
# Compile program from source, suppressing stderr and warnings.
136-
stderr_capture = _io.StringIO()
137-
old_stderr = _sys.stderr
138-
try:
139-
_sys.stderr = stderr_capture
140-
with _warnings.catch_warnings():
141-
_warnings.simplefilter("ignore")
142-
program = _cl.Program(self._context, _kernel_code).build(
143-
options=build_options
144-
)
145-
except _cl.RuntimeError as e:
146-
stderr_output = stderr_capture.getvalue().strip()
147-
error_msg = f"OpenCL kernel compilation failed: {e}"
148-
if stderr_output:
149-
error_msg += f"\n{stderr_output}"
150-
raise RuntimeError(error_msg)
151-
finally:
152-
_sys.stderr = old_stderr
153-
154-
# Capture the compiler log (including any warnings).
155-
self._compiler_log = program.get_build_info(
156-
self._device, _cl.program_build_info.LOG
157-
).strip()
144+
if cache_key in _kernel_cache:
145+
cached_binary = _kernel_cache[cache_key]
146+
147+
# Create program from cached binary.
148+
stderr_capture = _io.StringIO()
149+
old_stderr = _sys.stderr
150+
try:
151+
_sys.stderr = stderr_capture
152+
with _warnings.catch_warnings():
153+
_warnings.simplefilter("ignore")
154+
program = _cl.Program(
155+
self._context, [self._device], [cached_binary]
156+
)
157+
program.build(options=build_options)
158+
except _cl.RuntimeError as e:
159+
stderr_output = stderr_capture.getvalue().strip()
160+
error_msg = f"OpenCL kernel build from cached binary failed: {e}"
161+
if stderr_output:
162+
error_msg += f"\n{stderr_output}"
163+
raise RuntimeError(error_msg)
164+
finally:
165+
_sys.stderr = old_stderr
166+
167+
self._compiler_log = ""
168+
self._cache_hit = True
169+
else:
170+
# Compile program from source, suppressing stderr and warnings.
171+
stderr_capture = _io.StringIO()
172+
old_stderr = _sys.stderr
173+
try:
174+
_sys.stderr = stderr_capture
175+
with _warnings.catch_warnings():
176+
_warnings.simplefilter("ignore")
177+
program = _cl.Program(self._context, _kernel_code).build(
178+
options=build_options
179+
)
180+
except _cl.RuntimeError as e:
181+
stderr_output = stderr_capture.getvalue().strip()
182+
error_msg = f"OpenCL kernel compilation failed: {e}"
183+
if stderr_output:
184+
error_msg += f"\n{stderr_output}"
185+
raise RuntimeError(error_msg)
186+
finally:
187+
_sys.stderr = old_stderr
188+
189+
# Capture the compiler log (including any warnings).
190+
self._compiler_log = program.get_build_info(
191+
self._device, _cl.program_build_info.LOG
192+
).strip()
193+
194+
self._cache_hit = False
195+
196+
# Cache the compiled binary.
197+
_kernel_cache[cache_key] = program.get_info(_cl.program_info.BINARIES)[0]
158198

159199
# Create kernel wrappers that match PyCUDA calling convention.
160200
# OpenCL kernels need (queue, global_size, local_size, *args)
@@ -189,6 +229,11 @@ def wrapper(*args, **kwargs):
189229

190230
return kernels
191231

232+
@staticmethod
233+
def clear_cache():
234+
"""Clear the kernel compilation cache."""
235+
_kernel_cache.clear()
236+
192237
def to_gpu(self, array: _np.ndarray) -> _Any:
193238
"""
194239
Transfer a NumPy array to GPU memory.

src/loch/_sampler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,19 @@ def pop(self) -> None:
759759
"""Pop the GPU context from the calling thread's context stack."""
760760
self._backend.pop_context()
761761

762+
@property
763+
def kernel_cache_hit(self) -> bool:
764+
"""
765+
Whether kernel compilation was satisfied from cache.
766+
767+
Returns
768+
-------
769+
770+
cache_hit: bool
771+
True if kernels were loaded from cache, False if freshly compiled.
772+
"""
773+
return self._backend.cache_hit
774+
762775
def system(self) -> _Any:
763776
"""
764777
Return the GCMC system.

tests/test_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def test_compilation_error_raises_exception(self):
113113
nvcc=_get_nvcc(),
114114
)
115115

116+
# Clear the cache so the patched code is actually compiled.
117+
CUDAPlatform.clear_cache()
118+
116119
# Patch kernel code directly in the cuda module (not the kernels module,
117120
# since it's already imported as _kernel_code at module load time).
118121
original_code = cuda_module._kernel_code

tests/test_energy.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,90 @@ def test_energy_regression(fixture, platform, request):
357357
assert math.isclose(
358358
energy_lj, ref["energy_lj"], abs_tol=1e-4
359359
), f"LJ energy changed: {energy_lj!r} != {ref['energy_lj']!r}"
360+
361+
362+
@pytest.mark.skipif(
363+
"CUDA_VISIBLE_DEVICES" not in os.environ,
364+
reason="Requires CUDA enabled GPU.",
365+
)
366+
@pytest.mark.parametrize("platform", ["cuda", "opencl"])
367+
def test_cached_kernel_correctness(platform, water_box):
368+
"""
369+
A second sampler using cached kernels must produce the same energies
370+
as the first.
371+
"""
372+
373+
mols, reference = water_box
374+
375+
schedule = sr.cas.LambdaSchedule.standard_morph()
376+
377+
def _create_and_run(seed):
378+
sampler = GCMCSampler(
379+
mols,
380+
cutoff_type="rf",
381+
cutoff="10 A",
382+
reference=reference,
383+
lambda_schedule=schedule,
384+
lambda_value=0.5,
385+
log_level="debug",
386+
ghost_file=None,
387+
log_file=None,
388+
test=True,
389+
platform=platform,
390+
seed=seed,
391+
)
392+
393+
d = sampler.system().dynamics(
394+
cutoff_type="rf",
395+
cutoff="10 A",
396+
temperature="298 K",
397+
pressure=None,
398+
constraint="h_bonds",
399+
timestep="2 fs",
400+
schedule=schedule,
401+
lambda_value=0.5,
402+
coulomb_power=sampler._coulomb_power,
403+
shift_coulomb=str(sampler._shift_coulomb),
404+
shift_delta=str(sampler._shift_delta),
405+
platform=platform,
406+
)
407+
408+
is_accepted = False
409+
while not is_accepted:
410+
moves = sampler.move(d.context())
411+
if len(moves) > 0 and moves[0] == 0:
412+
is_accepted = True
413+
414+
return sampler
415+
416+
# Clear the cache so the first sampler compiles from source.
417+
if platform == "cuda":
418+
from loch._platforms._cuda import CUDAPlatform
419+
420+
CUDAPlatform.clear_cache()
421+
else:
422+
from loch._platforms._opencl import OpenCLPlatform
423+
424+
OpenCLPlatform.clear_cache()
425+
426+
# First sampler compiles kernels, second uses the cache.
427+
# Both use the same seed so random water positions are identical.
428+
sampler1 = _create_and_run(seed=42)
429+
sampler2 = _create_and_run(seed=42)
430+
431+
# Verify cache behaviour.
432+
assert not sampler1.kernel_cache_hit, "First sampler should compile from source"
433+
assert sampler2.kernel_cache_hit, "Second sampler should use cached kernels"
434+
435+
# Verify energy consistency.
436+
energy1_coul = sampler1._debug["energy_coul"]
437+
energy1_lj = sampler1._debug["energy_lj"]
438+
energy2_coul = sampler2._debug["energy_coul"]
439+
energy2_lj = sampler2._debug["energy_lj"]
440+
441+
assert math.isclose(
442+
energy1_coul, energy2_coul, abs_tol=1e-4
443+
), f"Coulomb energy mismatch: {energy1_coul!r} vs {energy2_coul!r}"
444+
assert math.isclose(
445+
energy1_lj, energy2_lj, abs_tol=1e-4
446+
), f"LJ energy mismatch: {energy1_lj!r} vs {energy2_lj!r}"

0 commit comments

Comments
 (0)