@@ -132,6 +132,7 @@ static int setup_lib(void) {
132132 const char * ver ;
133133 CUresult err ;
134134 int res , tmp ;
135+ int search_version = 0 ;
135136
136137 if (!setup_done ) {
137138 res = load_libcuda (global_err );
@@ -147,13 +148,49 @@ static int setup_lib(void) {
147148 return error_set (global_err , GA_IMPL_ERROR , "cuDriverGetVersion failed" );
148149 major = tmp / 1000 ;
149150 minor = (tmp / 10 ) % 10 ;
151+ #if defined(WIN32 ) || defined(_WIN32 ) || defined(WIN64 ) || defined(_WIN64 ) || defined(__APPLE__ )
152+ /* We will dynamically search the right CUDA version only on Windows and Macintosh systems,
153+ and only if user has not explicitely specified GPUARRAY_CUDA_VERSION. */
154+ search_version = 1 ;
155+ #endif
150156 } else {
151157 major = ver [0 ] - '0' ;
152158 minor = ver [1 ] - '0' ;
153159 }
160+ /* NB: next line will cause problems if a CUDA 10.0 (or 9.11) is released in the future. */
154161 if (major > 9 || major < 0 || minor > 9 || minor < 0 )
155162 return error_fmt (global_err , GA_VALUE_ERROR , "Invalid cuda version: %d.%d" , major , minor );
156- res = load_libnvrtc (major , minor , global_err );
163+ if (!search_version ) {
164+ res = load_libnvrtc (major , minor , global_err );
165+ } else {
166+ /* First case in next array is reserved to eventually receive the version returned by cuDriverGetVersion(). */
167+ int versions [] = {-1 , 80 , 75 };
168+ int versions_length = sizeof (versions ) / sizeof (int );
169+ int current_version = major * 10 + minor ;
170+ int i = 0 ;
171+ for (i = 1 ; i < versions_length && versions [i ] != current_version ; ++ i );
172+ if (i == versions_length ) {
173+ /* Current version not found in the list of versions. We add it at top of the list. */
174+ versions [0 ] = current_version ;
175+ /* We will iterate on versions from the first. */
176+ i = 0 ;
177+ } else {
178+ /* Current version found in the list of known versions. No need to add it to the list. */
179+ i = 1 ;
180+ };
181+ do {
182+ major = versions [i ] / 10 ;
183+ minor = versions [i ] % 10 ;
184+ res = load_libnvrtc (major , minor , global_err );
185+ ++ i ;
186+ } while (res != GA_NO_ERROR && i < versions_length );
187+ #ifdef DEBUG
188+ if (res == GA_NO_ERROR )
189+ fprintf (stderr , "Detected CUDA %d.%d.\n" , major , minor );
190+ else
191+ fprintf (stderr , "Unable to detect a CUDA version.\n" );
192+ #endif
193+ }
157194 if (res != GA_NO_ERROR )
158195 return res ;
159196 setup_done = 1 ;
0 commit comments