Skip to content

Commit 27ec53d

Browse files
author
SamoraHunter
committed
improved gpu detection for cupy
1 parent 49c3444 commit 27ec53d

1 file changed

Lines changed: 33 additions & 17 deletions

File tree

ml_grid/pipeline/data_correlation_matrix.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,42 @@ def handle_correlation_matrix(
4141
# Convert data to float32
4242
data = df_numeric.values.astype(np.float32)
4343

44-
# --- GPU DETECTION & SAFETY ---
44+
# --- IMPROVED GPU DETECTION & SAFETY ---
4545
use_gpu = False
4646
try:
4747
import cupy as cp
48-
49-
if cp.cuda.is_available():
50-
free_mem = cp.cuda.Device().mem_info[0]
51-
req_mem = (data.shape[1] ** 2) * 4 # 4 bytes per float32
52-
53-
if free_mem > req_mem * 1.2:
54-
use_gpu = True
55-
logger.info(
56-
f"GPU Detected: {cp.cuda.Device().name}. Free VRAM: {free_mem/1e9:.2f} GB."
57-
)
58-
else:
59-
logger.warning(
60-
"GPU detected but insufficient VRAM. Falling back to CPU."
61-
)
48+
49+
# Check if CUDA is available first (before trying to access device)
50+
if not cp.cuda.is_available():
51+
logger.info("No CUDA-capable GPU detected. Using CPU.")
52+
else:
53+
# Now safe to access device properties
54+
try:
55+
device = cp.cuda.Device()
56+
free_mem, total_mem = device.mem_info
57+
req_mem = (data.shape[1] ** 2) * 4 # 4 bytes per float32
58+
59+
if free_mem > req_mem * 1.2:
60+
use_gpu = True
61+
logger.info(
62+
f"GPU Enabled: {device.compute_capability}. "
63+
f"Free VRAM: {free_mem/1e9:.2f} GB / {total_mem/1e9:.2f} GB"
64+
)
65+
else:
66+
logger.warning(
67+
f"GPU detected but insufficient VRAM. "
68+
f"Required: {req_mem/1e9:.2f} GB, Available: {free_mem/1e9:.2f} GB. "
69+
f"Falling back to CPU."
70+
)
71+
except cp.cuda.runtime.CUDARuntimeError as cuda_err:
72+
logger.info(f"CUDA runtime error (using CPU): {cuda_err}")
73+
except Exception as device_err:
74+
logger.info(f"Could not access GPU device (using CPU): {device_err}")
75+
76+
except ImportError:
77+
logger.info("CuPy not installed. Using CPU-only mode.")
6278
except Exception as e:
63-
logger.warning(f"GPU acceleration unavailable (falling back to CPU): {e}")
79+
logger.info(f"GPU initialization failed (using CPU): {e}")
6480
use_gpu = False
6581
# -----------------------------
6682

@@ -197,4 +213,4 @@ def _process_on_cpu(
197213
final_drop_set = existing_drops.union(newly_identified_drops)
198214

199215
logger.info(f"CPU complete. Total columns to drop: {len(final_drop_set)}")
200-
return sorted(list(final_drop_set))
216+
return sorted(list(final_drop_set))

0 commit comments

Comments
 (0)