@@ -417,6 +417,7 @@ def __init__(
417417 from shutil import which
418418
419419 nvcc = _os .environ .get ("PYCUDA_NVCC" , which ("nvcc" ))
420+ self ._nvcc = nvcc
420421
421422 # Set the tolerance.
422423 try :
@@ -531,17 +532,41 @@ def __init__(
531532 except Exception as e :
532533 raise ValueError (f"Could not prepare the system for GCMC sampling: { e } " )
533534
535+ # Validate the platform parameter.
536+ valid_platforms = {"auto" , "cuda" , "opencl" }
537+
538+ if not isinstance (platform , str ):
539+ raise TypeError ("'platform' mut be of type 'str'." )
540+
541+ # Convert to lower case and strip whitespace.
542+ platform = platform .lower ().replace (" " , "" )
543+
544+ if platform not in valid_platforms :
545+ raise ValueError (
546+ f"Invalid platform '{ platform } '. Must be one of { valid_platforms } ."
547+ )
548+ self ._platform = platform
549+
550+ if device is not None :
551+ if not isinstance (device , int ):
552+ raise ValueError ("'device' must be of type 'int'" )
553+ self ._device = device
554+
555+ if not isinstance (compiler_optimisations , bool ):
556+ raise ValueError ("'compiler_optimisations' must be of type 'bool'" )
557+ self ._compiler_optimisations = compiler_optimisations
558+
534559 # Create platform backend
535560 self ._backend = _create_backend (
536- platform = platform ,
537- device = device if device is not None else 0 ,
561+ platform = self . _platform ,
562+ device = self . _device if self . _device is not None else 0 ,
538563 num_points = self ._num_points ,
539564 num_batch = self ._batch_size ,
540565 num_waters = self ._num_waters ,
541566 num_atoms = self ._num_atoms ,
542567 num_threads = self ._num_threads ,
543- nvcc = nvcc ,
544- compiler_optimisations = compiler_optimisations ,
568+ nvcc = self . _nvcc ,
569+ compiler_optimisations = self . _compiler_optimisations ,
545570 )
546571
547572 # Compile kernels
@@ -657,7 +682,7 @@ def __str__(self) -> str:
657682 """
658683
659684 return (
660- f"GCMCSampler(system= { self ._system } , "
685+ f"GCMCSampler(self._system, "
661686 f"reference={ self ._reference } , "
662687 f"radius={ self ._radius } , "
663688 f"cutoff_type={ self ._cutoff_type } , "
@@ -672,10 +697,20 @@ def __str__(self) -> str:
672697 f"num_threads={ self ._num_threads } ), "
673698 f"bulk_sampling_probability={ self ._bulk_sampling_probability } , "
674699 f"water_template={ self ._water_template } , "
700+ f"platform={ self ._platform } , "
675701 f"device={ self ._device } , "
676702 f"tolerance={ self ._tolerance } , "
703+ f"nvcc={ self ._nvcc } , "
704+ f"compiler_optimisations={ self ._compiler_optimisations } , "
677705 f"lambda_schedule={ self ._lambda_schedule } , "
678706 f"lambda_value={ self ._lambda_value } , "
707+ f"swap_end_states={ self ._swap_end_states } , "
708+ f"restart={ self ._restart } , "
709+ f"rest2_scale={ self ._rest2_scale } , "
710+ f"rest2_selection={ self ._rest2_selection } , "
711+ f"coulomb_power={ self ._coulomb_power } , "
712+ f"shift_coulomb={ self ._shift_coulomb } , "
713+ f"shift_delta={ self ._shift_delta } , "
679714 f"overwrite={ self ._overwrite } , "
680715 f"ghost_file={ self ._ghost_file } , "
681716 f"log_file={ self ._log_file } , "
0 commit comments