@@ -575,46 +575,59 @@ def _check_device_memory(device_index=0):
575575 index: int
576576 The index of the GPU device.
577577 """
578- import pyopencl as cl
579578
580- # Get the device.
581- platforms = cl .get_platforms ()
582- all_devices = []
583- for platform in platforms :
584- try :
585- devices = platform .get_devices (device_type = cl .device_type .GPU )
586- all_devices .extend (devices )
587- except :
588- continue
589-
590- if device_index >= len (all_devices ):
591- msg = f"Device index { device_index } out of range. Found { len (all_devices )} GPU(s)."
592- _logger .error (msg )
593- raise IndexError (msg )
579+ # Try to use pyopencl to detect the GPU vendor.
580+ vendor = None
581+ ocl_device = None
582+ try :
583+ import pyopencl as cl
594584
595- device = all_devices [device_index ]
596- total = device .global_mem_size
585+ platforms = cl .get_platforms ()
586+ all_devices = []
587+ for platform in platforms :
588+ try :
589+ devices = platform .get_devices (device_type = cl .device_type .GPU )
590+ all_devices .extend (devices )
591+ except Exception :
592+ continue
593+
594+ if device_index < len (all_devices ):
595+ ocl_device = all_devices [device_index ]
596+ vendor = ocl_device .vendor
597+ else :
598+ msg = f"Device index { device_index } out of range. Found { len (all_devices )} GPU(s)."
599+ _logger .error (msg )
600+ raise IndexError (msg )
601+ except IndexError :
602+ raise
603+ except Exception :
604+ _logger .warning (
605+ "Could not query GPU platform via OpenCL; falling back to pynvml for NVIDIA detection."
606+ )
597607
598- # NVIDIA: Use pynvml
599- if "NVIDIA" in device . vendor :
608+ # NVIDIA: Use pynvml (also used as fallback when OpenCL is unavailable).
609+ if vendor is None or "NVIDIA" in vendor :
600610 try :
601611 import pynvml
602612
603613 pynvml .nvmlInit ()
604-
605614 handle = pynvml .nvmlDeviceGetHandleByIndex (device_index )
606615 memory = pynvml .nvmlDeviceGetMemoryInfo (handle )
607616 pynvml .nvmlShutdown ()
608617 return (memory .used , memory .free , memory .total )
609618 except Exception as e :
610- msg = f"Could not get NVIDIA GPU memory info for device { device_index } : { e } "
619+ if vendor is None :
620+ msg = f"Could not get GPU memory info for device { device_index } via OpenCL or pynvml: { e } "
621+ else :
622+ msg = f"Could not get NVIDIA GPU memory info for device { device_index } : { e } "
611623 _logger .error (msg )
612624 raise RuntimeError (msg ) from e
613625
614- # AMD: Use OpenCL extension
615- elif "AMD" in device . vendor or "Advanced Micro Devices" in device . vendor :
626+ # AMD: Use OpenCL extension.
627+ elif "AMD" in vendor or "Advanced Micro Devices" in vendor :
616628 try :
617- free_memory_info = device .get_info (0x4038 )
629+ total = ocl_device .global_mem_size
630+ free_memory_info = ocl_device .get_info (0x4038 )
618631 free_kb = (
619632 free_memory_info [0 ]
620633 if isinstance (free_memory_info , list )
0 commit comments