Skip to content

Commit 9456cf1

Browse files
committed
Loop over supported CUDA versions to find installed CUDA on Windows and Mac.
1 parent 2820925 commit 9456cf1

1 file changed

Lines changed: 38 additions & 1 deletion

File tree

src/gpuarray_buffer_cuda.c

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ static int setup_lib(void) {
132132
const char *ver;
133133
CUresult err;
134134
int res, tmp;
135+
int search_version = 0;
135136

136137
if (!setup_done) {
137138
res = load_libcuda(global_err);
@@ -147,13 +148,49 @@ static int setup_lib(void) {
147148
return error_set(global_err, GA_IMPL_ERROR, "cuDriverGetVersion failed");
148149
major = tmp / 1000;
149150
minor = (tmp / 10) % 10;
151+
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) || defined(__APPLE__)
152+
/* We will dynamically search the right CUDA version only on Windows and Macintosh systems,
153+
and only if user has not explicitely specified GPUARRAY_CUDA_VERSION. */
154+
search_version = 1;
155+
#endif
150156
} else {
151157
major = ver[0] - '0';
152158
minor = ver[1] - '0';
153159
}
160+
/* NB: next line will cause problems if a CUDA 10.0 (or 9.11) is released in the future. */
154161
if (major > 9 || major < 0 || minor > 9 || minor < 0)
155162
return error_fmt(global_err, GA_VALUE_ERROR, "Invalid cuda version: %d.%d", major, minor);
156-
res = load_libnvrtc(major, minor, global_err);
163+
if (!search_version) {
164+
res = load_libnvrtc(major, minor, global_err);
165+
} else {
166+
/* First case in next array is reserved to eventually receive the version returned by cuDriverGetVersion(). */
167+
int versions[] = {-1, 80, 75};
168+
int versions_length = sizeof(versions) / sizeof(int);
169+
int current_version = major * 10 + minor;
170+
int i = 0;
171+
for (i = 1; i < versions_length && versions[i] != current_version; ++i);
172+
if (i == versions_length) {
173+
/* Current version not found in the list of versions. We add it at top of the list. */
174+
versions[0] = current_version;
175+
/* We will iterate on versions from the first. */
176+
i = 0;
177+
} else {
178+
/* Current version found in the list of known versions. No need to add it to the list. */
179+
i = 1;
180+
};
181+
do {
182+
major = versions[i] / 10;
183+
minor = versions[i] % 10;
184+
res = load_libnvrtc(major, minor, global_err);
185+
++i;
186+
} while(res != GA_NO_ERROR && i < versions_length);
187+
#ifdef DEBUG
188+
if (res == GA_NO_ERROR)
189+
fprintf(stderr, "Detected CUDA %d.%d.\n", major, minor);
190+
else
191+
fprintf(stderr, "Unable to detect a CUDA version.\n");
192+
#endif
193+
}
157194
if (res != GA_NO_ERROR)
158195
return res;
159196
setup_done = 1;

0 commit comments

Comments
 (0)