Skip to content

Commit 7635452

Browse files
committed
Validate new parameters and store as class attributes.
1 parent 9408100 commit 7635452

2 files changed

Lines changed: 41 additions & 5 deletions

File tree

src/loch/_platforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def create_backend(
126126
PlatformBackend
127127
Initialized platform backend (CUDAPlatform or OpenCLPlatform).
128128
"""
129+
129130
# Detect platform if auto
130131
if platform == "auto":
131132
platform = detect_platform()

src/loch/_sampler.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ def __init__(
417417
from shutil import which
418418

419419
nvcc = _os.environ.get("PYCUDA_NVCC", which("nvcc"))
420+
self._nvcc = nvcc
420421

421422
# Set the tolerance.
422423
try:
@@ -531,17 +532,41 @@ def __init__(
531532
except Exception as e:
532533
raise ValueError(f"Could not prepare the system for GCMC sampling: {e}")
533534

535+
# Validate the platform parameter.
536+
valid_platforms = {"auto", "cuda", "opencl"}
537+
538+
if not isinstance(platform, str):
539+
raise TypeError("'platform' mut be of type 'str'.")
540+
541+
# Convert to lower case and strip whitespace.
542+
platform = platform.lower().replace(" ", "")
543+
544+
if platform not in valid_platforms:
545+
raise ValueError(
546+
f"Invalid platform '{platform}'. Must be one of {valid_platforms}."
547+
)
548+
self._platform = platform
549+
550+
if device is not None:
551+
if not isinstance(device, int):
552+
raise ValueError("'device' must be of type 'int'")
553+
self._device = device
554+
555+
if not isinstance(compiler_optimisations, bool):
556+
raise ValueError("'compiler_optimisations' must be of type 'bool'")
557+
self._compiler_optimisations = compiler_optimisations
558+
534559
# Create platform backend
535560
self._backend = _create_backend(
536-
platform=platform,
537-
device=device if device is not None else 0,
561+
platform=self._platform,
562+
device=self._device if self._device is not None else 0,
538563
num_points=self._num_points,
539564
num_batch=self._batch_size,
540565
num_waters=self._num_waters,
541566
num_atoms=self._num_atoms,
542567
num_threads=self._num_threads,
543-
nvcc=nvcc,
544-
compiler_optimisations=compiler_optimisations,
568+
nvcc=self._nvcc,
569+
compiler_optimisations=self._compiler_optimisations,
545570
)
546571

547572
# Compile kernels
@@ -657,7 +682,7 @@ def __str__(self) -> str:
657682
"""
658683

659684
return (
660-
f"GCMCSampler(system={self._system}, "
685+
f"GCMCSampler(self._system, "
661686
f"reference={self._reference}, "
662687
f"radius={self._radius}, "
663688
f"cutoff_type={self._cutoff_type}, "
@@ -672,10 +697,20 @@ def __str__(self) -> str:
672697
f"num_threads={self._num_threads}), "
673698
f"bulk_sampling_probability={self._bulk_sampling_probability}, "
674699
f"water_template={self._water_template}, "
700+
f"platform={self._platform}, "
675701
f"device={self._device}, "
676702
f"tolerance={self._tolerance}, "
703+
f"nvcc={self._nvcc}, "
704+
f"compiler_optimisations={self._compiler_optimisations}, "
677705
f"lambda_schedule={self._lambda_schedule}, "
678706
f"lambda_value={self._lambda_value}, "
707+
f"swap_end_states={self._swap_end_states}, "
708+
f"restart={self._restart}, "
709+
f"rest2_scale={self._rest2_scale}, "
710+
f"rest2_selection={self._rest2_selection}, "
711+
f"coulomb_power={self._coulomb_power}, "
712+
f"shift_coulomb={self._shift_coulomb}, "
713+
f"shift_delta={self._shift_delta}, "
679714
f"overwrite={self._overwrite}, "
680715
f"ghost_file={self._ghost_file}, "
681716
f"log_file={self._log_file}, "

0 commit comments

Comments
 (0)