Skip to content

Commit e035630

Browse files
author
peng.li24
committed
refactor(numpycpp): update svml_bridge.h
1 parent b3100f8 commit e035630

1 file changed

Lines changed: 65 additions & 100 deletions

File tree

numpycpp/detail/svml_bridge.h

Lines changed: 65 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -42,46 +42,46 @@ namespace detail {
4242
// Internal dispatch namespace — use numpy::exp() etc., not numpy::detail::exp().
4343
// All math functions are resolved at runtime from numpy's _multiarray_umath.so.
4444
//
45-
// The shared library handle is auto-discovered on first use by scanning
46-
// /proc/self/maps — no explicit bridge_init() call needed.
45+
// 零 static 变量设计: 每次调用 dlsym 直接解析,避免 Python multiprocessing fork
46+
// 后子进程继承父进程的无效 handle 或函数指针。
4747
48-
inline void* g_svml_handle = nullptr;
48+
#include <unistd.h> // getpid
4949
50-
// Auto-discover numpy's _multiarray_umath shared library path via /proc/self/maps.
51-
// Called lazily from resolve_svml() on first use.
52-
inline const char* find_umath_path() {
53-
static std::string path;
54-
static bool tried = false;
55-
if (tried) return path.empty() ? nullptr : path.c_str();
56-
tried = true;
50+
inline void* g_svml_handle = nullptr;
51+
inline pid_t g_svml_pid = 0; // fork 检测: pid 变化则重新初始化
5752
53+
/// 返回 _multiarray_umath.so 的路径。每次调用重新扫描 /proc/self/maps,
54+
/// 无 static 缓存——fork 安全。
55+
inline std::string find_umath_path() {
5856
std::ifstream maps("/proc/self/maps");
5957
std::string line;
6058
while (std::getline(maps, line)) {
6159
if (line.find("_multiarray_umath") != std::string::npos &&
6260
line.find(".so") != std::string::npos) {
6361
auto pos = line.rfind('/');
6462
auto start = line.rfind(' ', pos);
65-
if (start != std::string::npos && pos != std::string::npos) {
66-
path = line.substr(start + 1);
67-
break;
68-
}
63+
if (start != std::string::npos && pos != std::string::npos)
64+
return line.substr(start + 1);
6965
}
7066
}
71-
return path.empty() ? nullptr : path.c_str();
67+
return "";
7268
}
7369
74-
// DEPRECATED — kept for backward compatibility with code that still calls it.
75-
// resolve_svml() now auto-discovers the .so path; explicit init is unnecessary.
7670
inline void bridge_init(const char* numpy_so_path) {
7771
(void)numpy_so_path;
7872
}
7973
74+
/// 解析符号。fork 安全: 如果 pid 变化则重新 dlopen。
8075
inline void* resolve_svml(const char* name) {
81-
// Lazy init: auto-discover numpy's shared library on first call
76+
pid_t pid = getpid();
77+
if (pid != g_svml_pid) {
78+
g_svml_handle = nullptr; // fork 后父进程 handle 在子进程无效
79+
g_svml_pid = pid;
80+
}
8281
if (!g_svml_handle) {
83-
const char* path = find_umath_path();
84-
if (path) g_svml_handle = dlopen(path, RTLD_NOLOAD | RTLD_LAZY);
82+
std::string path = find_umath_path();
83+
if (!path.empty())
84+
g_svml_handle = dlopen(path.c_str(), RTLD_NOLOAD | RTLD_LAZY);
8585
}
8686
if (g_svml_handle) return dlsym(g_svml_handle, name);
8787
return nullptr;
@@ -109,15 +109,15 @@ inline bool cpu_has_avx512f() {
109109
#define NUMPY_SVML_F64(name, svml_sym, npy_sym) \
110110
__attribute__((target("avx512f"))) \
111111
inline double name##_svml_f64(double x) { \
112-
static auto fn = (__m512d (*)(__m512d))resolve_svml(svml_sym); \
112+
auto fn = (__m512d (*)(__m512d))resolve_svml(svml_sym); \
113113
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x))); \
114114
return std::name(x); /* fallback if SVML resolution fails */ \
115115
}
116116
117117
#define NUMPY_SVML_F32(name, svml_sym, npy_sym) \
118118
__attribute__((target("avx512f"))) \
119119
inline float name##_svml_f32(float x) { \
120-
static auto fn = (__m512 (*)(__m512))resolve_svml(svml_sym); \
120+
auto fn = (__m512 (*)(__m512))resolve_svml(svml_sym); \
121121
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x))); \
122122
return std::name(x); \
123123
}
@@ -153,33 +153,33 @@ NUMPY_SVML_F32(log1p, "__svml_log1pf16","npy_log1pf")
153153
// SVML 实现位级一致(npy_pow / npy_atan2 是标量 libm 回退,会差 1 ULP)。
154154
__attribute__((target("avx512f")))
155155
inline double pow_svml_f64(double x, double e) {
156-
static auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml("__svml_pow8");
156+
auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml("__svml_pow8");
157157
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(x), _mm512_set1_pd(e)));
158-
static auto scalar_fn = (double (*)(double, double))resolve_svml("npy_pow");
158+
auto scalar_fn = (double (*)(double, double))resolve_svml("npy_pow");
159159
if (scalar_fn) return scalar_fn(x, e);
160160
return std::pow(x, e);
161161
}
162162
__attribute__((target("avx512f")))
163163
inline float pow_svml_f32(float x, float e) {
164-
static auto fn = (__m512 (*)(__m512, __m512))resolve_svml("__svml_powf16");
164+
auto fn = (__m512 (*)(__m512, __m512))resolve_svml("__svml_powf16");
165165
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(x), _mm512_set1_ps(e)));
166-
static auto scalar_fn = (float (*)(float, float))resolve_svml("npy_powf");
166+
auto scalar_fn = (float (*)(float, float))resolve_svml("npy_powf");
167167
if (scalar_fn) return scalar_fn(x, e);
168168
return std::pow(x, e);
169169
}
170170
__attribute__((target("avx512f")))
171171
inline double atan2_svml_f64(double y, double x) {
172-
static auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml("__svml_atan28");
172+
auto fn = (__m512d (*)(__m512d, __m512d))resolve_svml("__svml_atan28");
173173
if (fn) return _mm512_cvtsd_f64(fn(_mm512_set1_pd(y), _mm512_set1_pd(x)));
174-
static auto scalar_fn = (double (*)(double, double))resolve_svml("npy_atan2");
174+
auto scalar_fn = (double (*)(double, double))resolve_svml("npy_atan2");
175175
if (scalar_fn) return scalar_fn(y, x);
176176
return std::atan2(y, x);
177177
}
178178
__attribute__((target("avx512f")))
179179
inline float atan2_svml_f32(float y, float x) {
180-
static auto fn = (__m512 (*)(__m512, __m512))resolve_svml("__svml_atan2f16");
180+
auto fn = (__m512 (*)(__m512, __m512))resolve_svml("__svml_atan2f16");
181181
if (fn) return _mm512_cvtss_f32(fn(_mm512_set1_ps(y), _mm512_set1_ps(x)));
182-
static auto scalar_fn = (float (*)(float, float))resolve_svml("npy_atan2f");
182+
auto scalar_fn = (float (*)(float, float))resolve_svml("npy_atan2f");
183183
if (scalar_fn) return scalar_fn(y, x);
184184
return std::atan2(y, x);
185185
}
@@ -196,14 +196,14 @@ inline float atan2_svml_f32(float y, float x) {
196196
197197
#define NUMPY_NPY_F64(name, fallback_expr) \
198198
inline double name##_npy_f64(double x) { \
199-
static auto fn = (double (*)(double))resolve_svml("npy_" #name); \
199+
auto fn = (double (*)(double))resolve_svml("npy_" #name); \
200200
if (fn) return fn(x); \
201201
return (fallback_expr); \
202202
}
203203
204204
#define NUMPY_NPY_F32(name, fallback_expr) \
205205
inline float name##_npy_f32(float x) { \
206-
static auto fn = (float (*)(float))resolve_svml("npy_" #name "f"); \
206+
auto fn = (float (*)(float))resolve_svml("npy_" #name "f"); \
207207
if (fn) return fn(x); \
208208
return (fallback_expr); \
209209
}
@@ -247,22 +247,22 @@ inline double hypot_f64(double x, double y) { return std::hypot(x, y); }
247247
inline float hypot_f32(float x, float y) { return std::hypot(x, y); }
248248
249249
inline double pow_npy_f64(double x, double e) {
250-
static auto fn = (double (*)(double, double))resolve_svml("npy_pow");
250+
auto fn = (double (*)(double, double))resolve_svml("npy_pow");
251251
if (fn) return fn(x, e);
252252
return std::pow(x, e);
253253
}
254254
inline float pow_npy_f32(float x, float e) {
255-
static auto fn = (float (*)(float, float))resolve_svml("npy_powf");
255+
auto fn = (float (*)(float, float))resolve_svml("npy_powf");
256256
if (fn) return fn(x, e);
257257
return std::pow(x, e);
258258
}
259259
inline double atan2_npy_f64(double y, double x) {
260-
static auto fn = (double (*)(double, double))resolve_svml("npy_atan2");
260+
auto fn = (double (*)(double, double))resolve_svml("npy_atan2");
261261
if (fn) return fn(y, x);
262262
return std::atan2(y, x);
263263
}
264264
inline float atan2_npy_f32(float y, float x) {
265-
static auto fn = (float (*)(float, float))resolve_svml("npy_atan2f");
265+
auto fn = (float (*)(float, float))resolve_svml("npy_atan2f");
266266
if (fn) return fn(y, x);
267267
return std::atan2(y, x);
268268
}
@@ -347,27 +347,15 @@ inline float cos_f32(float x) { return cos_npy_f32(x); }
347347
348348
// pow / atan2 dispatchers
349349
inline double pow_f64(double x, double e) {
350-
#ifdef __AVX512F__
351-
if (cpu_has_avx512f()) return pow_svml_f64(x, e);
352-
#endif
353-
return pow_npy_f64(x, e);
350+
return pow_npy_f64(x, e); // npy_pow 已验证位级一致
354351
}
355352
inline float pow_f32(float x, float e) {
356-
#ifdef __AVX512F__
357-
if (cpu_has_avx512f()) return pow_svml_f32(x, e);
358-
#endif
359353
return pow_npy_f32(x, e);
360354
}
361355
inline double atan2_f64(double y, double x) {
362-
#ifdef __AVX512F__
363-
if (cpu_has_avx512f()) return atan2_svml_f64(y, x);
364-
#endif
365-
return atan2_npy_f64(y, x);
356+
return atan2_npy_f64(y, x); // npy_atan2 已验证位级一致,SVML broadcast 有 1 ULP 差
366357
}
367358
inline float atan2_f32(float y, float x) {
368-
#ifdef __AVX512F__
369-
if (cpu_has_avx512f()) return atan2_svml_f32(y, x);
370-
#endif
371359
return atan2_npy_f32(y, x);
372360
}
373361
@@ -379,60 +367,37 @@ inline float sqrt_f32(float x) { return std::sqrt(x); }
379367
#undef DISPATCH_F32
380368
381369
// ============================================================================
382-
// Template dispatchers — svml_impl<T> + free function templates
370+
// 1-arg dispatchers — inline overloads, 零 static, 零 template struct
383371
// ============================================================================
384372
385-
#define NUMPY_SVML_METHODS(T, suff) \
386-
template<> struct svml_impl<T> { \
387-
static T exp(T x) { return exp_##suff(x); } \
388-
static T log(T x) { return log_##suff(x); } \
389-
static T sin(T x) { return sin_##suff(x); } \
390-
static T cos(T x) { return cos_##suff(x); } \
391-
static T tan(T x) { return tan_##suff(x); } \
392-
static T asin(T x) { return asin_##suff(x); } \
393-
static T acos(T x) { return acos_##suff(x); } \
394-
static T atan(T x) { return atan_##suff(x); } \
395-
static T log10(T x){ return log10_##suff(x); } \
396-
static T log2(T x) { return log2_##suff(x); } \
397-
static T exp2(T x) { return exp2_##suff(x); } \
398-
static T cbrt(T x) { return cbrt_##suff(x); } \
399-
static T expm1(T x){ return expm1_##suff(x); } \
400-
static T log1p(T x){ return log1p_##suff(x); } \
401-
static T sqrt(T x) { return sqrt_##suff(x); } \
402-
static T pow(T x, T e) { return pow_##suff(x, e); } \
403-
static T atan2(T y, T x) { return atan2_##suff(y, x); } \
404-
static T hypot(T x, T y) { return hypot_##suff(x, y); } \
405-
};
406-
407-
template<typename T> struct svml_impl;
408-
NUMPY_SVML_METHODS(double, f64)
409-
NUMPY_SVML_METHODS(float, f32)
410-
#undef NUMPY_SVML_METHODS
411-
412-
// 1-arg dispatchers
413-
#define NUMPY_SVML_D1(name) \
414-
template<typename T> inline T name(T x) { return svml_impl<T>::name(x); }
415-
NUMPY_SVML_D1(exp)
416-
NUMPY_SVML_D1(log)
417-
NUMPY_SVML_D1(sin)
418-
NUMPY_SVML_D1(cos)
419-
NUMPY_SVML_D1(tan)
420-
NUMPY_SVML_D1(asin)
421-
NUMPY_SVML_D1(acos)
422-
NUMPY_SVML_D1(atan)
423-
NUMPY_SVML_D1(log10)
424-
NUMPY_SVML_D1(log2)
425-
NUMPY_SVML_D1(exp2)
426-
NUMPY_SVML_D1(cbrt)
427-
NUMPY_SVML_D1(expm1)
428-
NUMPY_SVML_D1(log1p)
429-
NUMPY_SVML_D1(sqrt)
430-
#undef NUMPY_SVML_D1
373+
#define NUMPY_D1(name) \
374+
inline double name(double x) { return name##_f64(x); } \
375+
inline float name(float x) { return name##_f32(x); }
376+
377+
NUMPY_D1(exp)
378+
NUMPY_D1(log)
379+
NUMPY_D1(sin)
380+
NUMPY_D1(cos)
381+
NUMPY_D1(tan)
382+
NUMPY_D1(asin)
383+
NUMPY_D1(acos)
384+
NUMPY_D1(atan)
385+
NUMPY_D1(log10)
386+
NUMPY_D1(log2)
387+
NUMPY_D1(exp2)
388+
NUMPY_D1(cbrt)
389+
NUMPY_D1(expm1)
390+
NUMPY_D1(log1p)
391+
NUMPY_D1(sqrt)
392+
#undef NUMPY_D1
431393
432394
// 2-arg dispatchers
433-
template<typename T> inline T pow(T x, T e) { return svml_impl<T>::pow(x, e); }
434-
template<typename T> inline T atan2(T y, T x) { return svml_impl<T>::atan2(y, x); }
435-
template<typename T> inline T hypot(T x, T y) { return svml_impl<T>::hypot(x, y); }
395+
inline double pow(double x, double e) { return pow_f64(x, e); }
396+
inline float pow(float x, float e) { return pow_f32(x, e); }
397+
inline double atan2(double y, double x) { return atan2_f64(y, x); }
398+
inline float atan2(float y, float x) { return atan2_f32(y, x); }
399+
inline double hypot(double x, double y) { return hypot_f64(x, y); }
400+
inline float hypot(float x, float y) { return hypot_f32(x, y); }
436401
437402
} // namespace detail
438403
} // namespace numpy

0 commit comments

Comments
 (0)