@@ -42,46 +42,46 @@ namespace detail {
4242// Internal dispatch namespace — use numpy::exp() etc., not numpy::detail::exp().
4343// All math functions are resolved at runtime from numpy's _multiarray_umath.so.
4444//
45- // The shared library handle is auto-discovered on first use by scanning
46- // /proc/self/maps — no explicit bridge_init() call needed.
45+ // 零 static 变量设计: 每次调用 dlsym 直接解析,避免 Python multiprocessing fork
46+ // 后子进程继承父进程的无效 handle 或函数指针。
4747
48- inline void* g_svml_handle = nullptr;
48+ #include <unistd.h> // getpid
4949
50- // Auto-discover numpy's _multiarray_umath shared library path via /proc/self/maps.
51- // Called lazily from resolve_svml() on first use.
52- inline const char* find_umath_path() {
53- static std::string path;
54- static bool tried = false;
55- if (tried) return path.empty() ? nullptr : path.c_str();
56- tried = true;
50+ inline void* g_svml_handle = nullptr;
51+ inline pid_t g_svml_pid = 0; // fork 检测: pid 变化则重新初始化
5752
53+ /// 返回 _multiarray_umath.so 的路径。每次调用重新扫描 /proc/self/maps,
54+ /// 无 static 缓存——fork 安全。
55+ inline std::string find_umath_path() {
5856 std::ifstream maps(" /proc/self/maps" );
5957 std::string line;
6058 while (std::getline(maps, line)) {
6159 if (line.find(" _multiarray_umath" ) != std::string::npos &&
6260 line.find(" .so" ) != std::string::npos) {
6361 auto pos = line.rfind('/');
6462 auto start = line.rfind(' ', pos);
65- if (start != std::string::npos && pos != std::string::npos) {
66- path = line.substr(start + 1);
67- break;
68- }
63+ if (start != std::string::npos && pos != std::string::npos)
64+ return line.substr(start + 1);
6965 }
7066 }
71- return path.empty() ? nullptr : path.c_str() ;
67+ return " " ;
7268}
7369
74- // DEPRECATED — kept for backward compatibility with code that still calls it.
75- // resolve_svml() now auto-discovers the .so path; explicit init is unnecessary.
7670inline void bridge_init(const char* numpy_so_path) {
7771 (void)numpy_so_path;
7872}
7973
74+ /// 解析符号。fork 安全: 如果 pid 变化则重新 dlopen。
8075inline void* resolve_svml(const char* name) {
81- // Lazy init: auto-discover numpy's shared library on first call
76+ pid_t pid = getpid();
77+ if (pid != g_svml_pid) {
78+ g_svml_handle = nullptr; // fork 后父进程 handle 在子进程无效
79+ g_svml_pid = pid;
80+ }
8281 if (!g_svml_handle) {
83- const char* path = find_umath_path();
84- if (path) g_svml_handle = dlopen(path, RTLD_NOLOAD | RTLD_LAZY);
82+ std::string path = find_umath_path();
83+ if (!path.empty())
84+ g_svml_handle = dlopen(path.c_str(), RTLD_NOLOAD | RTLD_LAZY);
8585 }
8686 if (g_svml_handle) return dlsym(g_svml_handle, name);
8787 return nullptr;
@@ -109,15 +109,15 @@ inline bool cpu_has_avx512f() {
109109#define NUMPY_SVML_F64(name, svml_sym, npy_sym) \
110110 __attribute__((target(" avx512f" ))) \
111111 inline double name##_svml_f64(double x) { \
112- static auto fn = (__m512d (*)(__m512d))resolve_svml(svml_sym); \
112+ auto fn = (__m512d (*)(__m512d))resolve_svml(svml_sym); \
113113 if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x))); \
114114 return std::name(x); /* fallback if SVML resolution fails */ \
115115 }
116116
117117#define NUMPY_SVML_F32(name, svml_sym, npy_sym) \
118118 __attribute__((target(" avx512f" ))) \
119119 inline float name##_svml_f32(float x) { \
120- static auto fn = (__m512 (*)(__m512))resolve_svml(svml_sym); \
120+ auto fn = (__m512 (*)(__m512))resolve_svml(svml_sym); \
121121 if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x))); \
122122 return std::name(x); \
123123 }
@@ -153,33 +153,33 @@ NUMPY_SVML_F32(log1p, "__svml_log1pf16","npy_log1pf")
153153// SVML 实现位级一致(npy_pow / npy_atan2 是标量 libm 回退,会差 1 ULP)。
154154__attribute__((target(" avx512f" )))
155155inline double pow_svml_f64(double x, double e) {
156- static auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml(" __svml_pow8" );
156+ auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml(" __svml_pow8" );
157157 if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x), _mm512_set1_pd(e)));
158- static auto scalar_fn = (double (*)(double, double))resolve_svml(" npy_pow" );
158+ auto scalar_fn = (double (*)(double, double))resolve_svml(" npy_pow" );
159159 if (scalar_fn) return scalar_fn(x, e);
160160 return std::pow(x, e);
161161}
162162__attribute__((target(" avx512f" )))
163163inline float pow_svml_f32(float x, float e) {
164- static auto fn = (__m512 (*)(__m512, __m512))resolve_svml(" __svml_powf16" );
164+ auto fn = (__m512 (*)(__m512, __m512))resolve_svml(" __svml_powf16" );
165165 if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x), _mm512_set1_ps(e)));
166- static auto scalar_fn = (float (*)(float, float))resolve_svml(" npy_powf" );
166+ auto scalar_fn = (float (*)(float, float))resolve_svml(" npy_powf" );
167167 if (scalar_fn) return scalar_fn(x, e);
168168 return std::pow(x, e);
169169}
170170__attribute__((target(" avx512f" )))
171171inline double atan2_svml_f64(double y, double x) {
172- static auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml(" __svml_atan28" );
172+ auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml(" __svml_atan28" );
173173 if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(y), _mm512_set1_pd(x)));
174- static auto scalar_fn = (double (*)(double, double))resolve_svml(" npy_atan2" );
174+ auto scalar_fn = (double (*)(double, double))resolve_svml(" npy_atan2" );
175175 if (scalar_fn) return scalar_fn(y, x);
176176 return std::atan2(y, x);
177177}
178178__attribute__((target(" avx512f" )))
179179inline float atan2_svml_f32(float y, float x) {
180- static auto fn = (__m512 (*)(__m512, __m512))resolve_svml(" __svml_atan2f16" );
180+ auto fn = (__m512 (*)(__m512, __m512))resolve_svml(" __svml_atan2f16" );
181181 if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(y), _mm512_set1_ps(x)));
182- static auto scalar_fn = (float (*)(float, float))resolve_svml(" npy_atan2f" );
182+ auto scalar_fn = (float (*)(float, float))resolve_svml(" npy_atan2f" );
183183 if (scalar_fn) return scalar_fn(y, x);
184184 return std::atan2(y, x);
185185}
@@ -196,14 +196,14 @@ inline float atan2_svml_f32(float y, float x) {
196196
197197#define NUMPY_NPY_F64(name, fallback_expr) \
198198 inline double name##_npy_f64(double x) { \
199- static auto fn = (double (*)(double))resolve_svml(" npy_" #name); \
199+ auto fn = (double (*)(double))resolve_svml(" npy_" #name); \
200200 if (fn) return fn(x); \
201201 return (fallback_expr); \
202202 }
203203
204204#define NUMPY_NPY_F32(name, fallback_expr) \
205205 inline float name##_npy_f32(float x) { \
206- static auto fn = (float (*)(float))resolve_svml(" npy_" #name " f" ); \
206+ auto fn = (float (*)(float))resolve_svml(" npy_" #name " f" ); \
207207 if (fn) return fn(x); \
208208 return (fallback_expr); \
209209 }
@@ -247,22 +247,22 @@ inline double hypot_f64(double x, double y) { return std::hypot(x, y); }
247247inline float hypot_f32(float x, float y) { return std::hypot(x, y); }
248248
249249inline double pow_npy_f64(double x, double e) {
250- static auto fn = (double (*)(double, double))resolve_svml(" npy_pow" );
250+ auto fn = (double (*)(double, double))resolve_svml(" npy_pow" );
251251 if (fn) return fn(x, e);
252252 return std::pow(x, e);
253253}
254254inline float pow_npy_f32(float x, float e) {
255- static auto fn = (float (*)(float, float))resolve_svml(" npy_powf" );
255+ auto fn = (float (*)(float, float))resolve_svml(" npy_powf" );
256256 if (fn) return fn(x, e);
257257 return std::pow(x, e);
258258}
259259inline double atan2_npy_f64(double y, double x) {
260- static auto fn = (double (*)(double, double))resolve_svml(" npy_atan2" );
260+ auto fn = (double (*)(double, double))resolve_svml(" npy_atan2" );
261261 if (fn) return fn(y, x);
262262 return std::atan2(y, x);
263263}
264264inline float atan2_npy_f32(float y, float x) {
265- static auto fn = (float (*)(float, float))resolve_svml(" npy_atan2f" );
265+ auto fn = (float (*)(float, float))resolve_svml(" npy_atan2f" );
266266 if (fn) return fn(y, x);
267267 return std::atan2(y, x);
268268}
@@ -347,27 +347,15 @@ inline float cos_f32(float x) { return cos_npy_f32(x); }
347347
348348// pow / atan2 dispatchers
349349inline double pow_f64(double x, double e) {
350- #ifdef __AVX512F__
351- if (cpu_has_avx512f()) return pow_svml_f64(x, e);
352- #endif
353- return pow_npy_f64(x, e);
350+ return pow_npy_f64(x, e); // npy_pow 已验证位级一致
354351}
355352inline float pow_f32(float x, float e) {
356- #ifdef __AVX512F__
357- if (cpu_has_avx512f()) return pow_svml_f32(x, e);
358- #endif
359353 return pow_npy_f32(x, e);
360354}
361355inline double atan2_f64(double y, double x) {
362- #ifdef __AVX512F__
363- if (cpu_has_avx512f()) return atan2_svml_f64(y, x);
364- #endif
365- return atan2_npy_f64(y, x);
356+ return atan2_npy_f64(y, x); // npy_atan2 已验证位级一致,SVML broadcast 有 1 ULP 差
366357}
367358inline float atan2_f32(float y, float x) {
368- #ifdef __AVX512F__
369- if (cpu_has_avx512f()) return atan2_svml_f32(y, x);
370- #endif
371359 return atan2_npy_f32(y, x);
372360}
373361
@@ -379,60 +367,37 @@ inline float sqrt_f32(float x) { return std::sqrt(x); }
379367#undef DISPATCH_F32
380368
381369// ============================================================================
382- // Template dispatchers — svml_impl<T> + free function templates
370+ // 1-arg dispatchers — inline overloads, 零 static, 零 template struct
383371// ============================================================================
384372
385- #define NUMPY_SVML_METHODS(T, suff) \
386- template<> struct svml_impl<T> { \
387- static T exp(T x) { return exp_##suff(x); } \
388- static T log(T x) { return log_##suff(x); } \
389- static T sin(T x) { return sin_##suff(x); } \
390- static T cos(T x) { return cos_##suff(x); } \
391- static T tan(T x) { return tan_##suff(x); } \
392- static T asin(T x) { return asin_##suff(x); } \
393- static T acos(T x) { return acos_##suff(x); } \
394- static T atan(T x) { return atan_##suff(x); } \
395- static T log10(T x){ return log10_##suff(x); } \
396- static T log2(T x) { return log2_##suff(x); } \
397- static T exp2(T x) { return exp2_##suff(x); } \
398- static T cbrt(T x) { return cbrt_##suff(x); } \
399- static T expm1(T x){ return expm1_##suff(x); } \
400- static T log1p(T x){ return log1p_##suff(x); } \
401- static T sqrt(T x) { return sqrt_##suff(x); } \
402- static T pow(T x, T e) { return pow_##suff(x, e); } \
403- static T atan2(T y, T x) { return atan2_##suff(y, x); } \
404- static T hypot(T x, T y) { return hypot_##suff(x, y); } \
405- };
406-
407- template<typename T> struct svml_impl;
408- NUMPY_SVML_METHODS(double, f64)
409- NUMPY_SVML_METHODS(float, f32)
410- #undef NUMPY_SVML_METHODS
411-
412- // 1-arg dispatchers
413- #define NUMPY_SVML_D1(name) \
414- template<typename T> inline T name(T x) { return svml_impl<T>::name(x); }
415- NUMPY_SVML_D1(exp)
416- NUMPY_SVML_D1(log)
417- NUMPY_SVML_D1(sin)
418- NUMPY_SVML_D1(cos)
419- NUMPY_SVML_D1(tan)
420- NUMPY_SVML_D1(asin)
421- NUMPY_SVML_D1(acos)
422- NUMPY_SVML_D1(atan)
423- NUMPY_SVML_D1(log10)
424- NUMPY_SVML_D1(log2)
425- NUMPY_SVML_D1(exp2)
426- NUMPY_SVML_D1(cbrt)
427- NUMPY_SVML_D1(expm1)
428- NUMPY_SVML_D1(log1p)
429- NUMPY_SVML_D1(sqrt)
430- #undef NUMPY_SVML_D1
373+ #define NUMPY_D1(name) \
374+ inline double name(double x) { return name##_f64(x); } \
375+ inline float name(float x) { return name##_f32(x); }
376+
377+ NUMPY_D1(exp)
378+ NUMPY_D1(log)
379+ NUMPY_D1(sin)
380+ NUMPY_D1(cos)
381+ NUMPY_D1(tan)
382+ NUMPY_D1(asin)
383+ NUMPY_D1(acos)
384+ NUMPY_D1(atan)
385+ NUMPY_D1(log10)
386+ NUMPY_D1(log2)
387+ NUMPY_D1(exp2)
388+ NUMPY_D1(cbrt)
389+ NUMPY_D1(expm1)
390+ NUMPY_D1(log1p)
391+ NUMPY_D1(sqrt)
392+ #undef NUMPY_D1
431393
432394// 2-arg dispatchers
433- template<typename T> inline T pow(T x, T e) { return svml_impl<T>::pow(x, e); }
434- template<typename T> inline T atan2(T y, T x) { return svml_impl<T>::atan2(y, x); }
435- template<typename T> inline T hypot(T x, T y) { return svml_impl<T>::hypot(x, y); }
395+ inline double pow(double x, double e) { return pow_f64(x, e); }
396+ inline float pow(float x, float e) { return pow_f32(x, e); }
397+ inline double atan2(double y, double x) { return atan2_f64(y, x); }
398+ inline float atan2(float y, float x) { return atan2_f32(y, x); }
399+ inline double hypot(double x, double y) { return hypot_f64(x, y); }
400+ inline float hypot(float x, float y) { return hypot_f32(x, y); }
436401
437402} // namespace detail
438403} // namespace numpy
0 commit comments