Skip to content

Commit 0b935e8

Browse files
committed
Adjust things so we can load cuda libraries on macOS.
1 parent ee12e30 commit 0b935e8

4 files changed

Lines changed: 37 additions & 6 deletions

File tree

src/loaders/libcublas.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ int load_libcublas(int major, int minor) {
4141

4242
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
4343
{
44+
static const char DIGITS[] = "0123456789";
4445
char libname[] = "cublas64_??.dll";
4546

4647
libname[9] = DIGITS[major];
@@ -49,7 +50,17 @@ int load_libcublas(int major, int minor) {
4950
lib = ga_load_library(libname);
5051
}
5152
#else /* Unix */
53+
#ifdef __APPLE__
54+
{
55+
static const char DIGITS[] = "0123456789";
56+
char libname[] = "/Developer/NVIDIA/CUDA-?.?/lib/libcublas.dylib";
57+
libname[23] = DIGITS[major];
58+
libname[25] = DIGITS[minor];
59+
lib = ga_load_library(libname);
60+
}
61+
#else
5262
lib = ga_load_library("libcublas.so");
63+
#endif
5364
#endif
5465
if (lib == NULL)
5566
return GA_LOAD_ERROR;

src/loaders/libcuda.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
99
static char libname[] = "nvcuda.dll";
1010
#else /* Unix */
11+
#ifdef __APPLE__
12+
static char libname[] = "CUDA.framework/CUDA";
13+
#else
1114
static char libname[] = "libcuda.so";
1215
#endif
16+
#endif
1317

1418
#define DEF_PROC(name, args) t##name *name
1519
#define DEF_PROC_V2(name, args) DEF_PROC(name, args)

src/loaders/libnccl.c

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44
#include "dyn_load.h"
55
#include "gpuarray/error.h"
66

7-
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
8-
static const char libname[] = "nccl.dll";
9-
#else /* Unix */
10-
static const char libname[] = "libnccl.so";
11-
#endif
12-
137
#define DEF_PROC(ret, name, args) t##name *name
148

159
#include "libnccl.fn"
1610

1711
#undef DEF_PROC
1812

13+
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) || defined(__APPLE__)
14+
/* As far as we know, nccl is not available or buildable on platforms
15+
other than linux */
16+
int load_libnccl(void) {
17+
return GA_UNSUPPORTED_ERROR;
18+
}
19+
#else /* Unix */
20+
static const char libname[] = "libnccl.so";
21+
1922
#define DEF_PROC(ret, name, args) \
2023
name = (t##name *)ga_func_ptr(lib, #name); \
2124
if (name == NULL) { \
@@ -39,3 +42,4 @@ int load_libnccl(void) {
3942
loaded = 1;
4043
return GA_NO_ERROR;
4144
}
45+
#endif

src/loaders/libnvrtc.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ int load_libnvrtc(int major, int minor) {
2727

2828
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
2929
{
30+
static const char DIGITS[] = "0123456789";
3031
char libname[] = "nvrtc64_??.dll";
3132

3233
libname[8] = DIGITS[major];
@@ -35,7 +36,18 @@ int load_libnvrtc(int major, int minor) {
3536
lib = ga_load_library(libname);
3637
}
3738
#else /* Unix */
39+
#ifdef __APPLE__
40+
{
41+
static const char DIGITS[] = "0123456789";
42+
/* Try the usual fullpath first */
43+
char libname[] = "/Developer/NVIDIA/CUDA-?.?/lib/libnvrtc.dylib";
44+
libname[23] = DIGITS[major];
45+
libname[25] = DIGITS[minor];
46+
lib = ga_load_library(libname);
47+
}
48+
#else
3849
lib = ga_load_library("libnvrtc.so");
50+
#endif
3951
#endif
4052
if (lib == NULL)
4153
return GA_LOAD_ERROR;

0 commit comments

Comments
 (0)