Skip to content

Commit e5e6cfd

Browse files
authored
Merge pull request #125 from OpenBioSim/fix_missing_ocl_icd
Handle missing OpenCL ICD loader during GPU platform detection
2 parents 8e44677 + aecfa95 commit e5e6cfd

1 file changed

Lines changed: 37 additions & 24 deletions

File tree

src/somd2/runner/_repex.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -575,46 +575,59 @@ def _check_device_memory(device_index=0):
575575
index: int
576576
The index of the GPU device.
577577
"""
578-
import pyopencl as cl
579578

580-
# Get the device.
581-
platforms = cl.get_platforms()
582-
all_devices = []
583-
for platform in platforms:
584-
try:
585-
devices = platform.get_devices(device_type=cl.device_type.GPU)
586-
all_devices.extend(devices)
587-
except:
588-
continue
589-
590-
if device_index >= len(all_devices):
591-
msg = f"Device index {device_index} out of range. Found {len(all_devices)} GPU(s)."
592-
_logger.error(msg)
593-
raise IndexError(msg)
579+
# Try to use pyopencl to detect the GPU vendor.
580+
vendor = None
581+
ocl_device = None
582+
try:
583+
import pyopencl as cl
594584

595-
device = all_devices[device_index]
596-
total = device.global_mem_size
585+
platforms = cl.get_platforms()
586+
all_devices = []
587+
for platform in platforms:
588+
try:
589+
devices = platform.get_devices(device_type=cl.device_type.GPU)
590+
all_devices.extend(devices)
591+
except Exception:
592+
continue
593+
594+
if device_index < len(all_devices):
595+
ocl_device = all_devices[device_index]
596+
vendor = ocl_device.vendor
597+
else:
598+
msg = f"Device index {device_index} out of range. Found {len(all_devices)} GPU(s)."
599+
_logger.error(msg)
600+
raise IndexError(msg)
601+
except IndexError:
602+
raise
603+
except Exception:
604+
_logger.warning(
605+
"Could not query GPU platform via OpenCL; falling back to pynvml for NVIDIA detection."
606+
)
597607

598-
# NVIDIA: Use pynvml
599-
if "NVIDIA" in device.vendor:
608+
# NVIDIA: Use pynvml (also used as fallback when OpenCL is unavailable).
609+
if vendor is None or "NVIDIA" in vendor:
600610
try:
601611
import pynvml
602612

603613
pynvml.nvmlInit()
604-
605614
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
606615
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
607616
pynvml.nvmlShutdown()
608617
return (memory.used, memory.free, memory.total)
609618
except Exception as e:
610-
msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}"
619+
if vendor is None:
620+
msg = f"Could not get GPU memory info for device {device_index} via OpenCL or pynvml: {e}"
621+
else:
622+
msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}"
611623
_logger.error(msg)
612624
raise RuntimeError(msg) from e
613625

614-
# AMD: Use OpenCL extension
615-
elif "AMD" in device.vendor or "Advanced Micro Devices" in device.vendor:
626+
# AMD: Use OpenCL extension.
627+
elif "AMD" in vendor or "Advanced Micro Devices" in vendor:
616628
try:
617-
free_memory_info = device.get_info(0x4038)
629+
total = ocl_device.global_mem_size
630+
free_memory_info = ocl_device.get_info(0x4038)
618631
free_kb = (
619632
free_memory_info[0]
620633
if isinstance(free_memory_info, list)

0 commit comments

Comments
 (0)