77
88namespace op ::acos::cuda {
99
10- // ----------------------
11- // Fast acos approximation
12- // ----------------------
13- __device__ __forceinline__ float fast_acosf (float x) {
14- // 高性能多项式近似 acos(x)
15- float ax = fabsf (x);
16- float t = sqrtf (1 .0f - ax);
17- float r = ((-0 .0187293f * ax + 0 .0742610f ) * ax - 0 .2121144f ) * ax + 1 .5707288f ;
18- return (x >= 0 .0f ? t * r : 3 .14159265358979323846f - t * r);
19- }
20-
2110// ----------------------
2211// float kernel (F32)
2312// ----------------------
@@ -26,39 +15,27 @@ __device__ __forceinline__ T acos_impl(T val);
2615
2716template <>
2817__device__ __forceinline__ float acos_impl<float >(float val) {
29- return fast_acosf (val);
18+ return :: acosf (val);
3019}
3120
3221// ----------------------
3322// half kernel (F16)
3423// ----------------------
3524template <>
3625__device__ __forceinline__ half acos_impl<half>(half val) {
37- #if (__CUDA_ARCH__ >= 530)
38- float f = __half2float (val);
39- return __float2half (fast_acosf (f));
40- #else
4126 float f = __half2float (val);
42- return __float2half (fast_acosf (f));
43- #endif
27+ return __float2half (::acosf (f));
4428}
4529
4630// ----------------------
4731// half2 kernel (F16x2 vectorized)
4832// ----------------------
4933template <>
5034__device__ __forceinline__ half2 acos_impl<half2>(half2 val) {
51- #if (__CUDA_ARCH__ >= 530)
5235 float2 f = __half22float2 (val);
53- f.x = fast_acosf (f.x );
54- f.y = fast_acosf (f.y );
36+ f.x = :: acosf (f.x );
37+ f.y = :: acosf (f.y );
5538 return __float22half2_rn (f);
56- #else
57- float2 f = __half22float2 (val);
58- f.x = fast_acosf (f.x );
59- f.y = fast_acosf (f.y );
60- return __float22half2_rn (f);
61- #endif
6239}
6340
6441// ----------------------
@@ -67,15 +44,20 @@ __device__ __forceinline__ half2 acos_impl<half2>(half2 val) {
6744template <>
6845__device__ __forceinline__ cuda_bfloat16 acos_impl<cuda_bfloat16>(cuda_bfloat16 val) {
6946 float f = __bfloat162float (val);
70- return __float2bfloat16 (fast_acosf (f));
47+ return __float2bfloat16 (::acosf (f));
48+ }
49+
50+ template <>
51+ __device__ __forceinline__ double acos_impl<double >(double val) {
52+ return ::acos (val);
7153}
7254
7355// ----------------------
7456// Fallback kernel
7557// ----------------------
7658template <typename T>
7759__device__ __forceinline__ T acos_impl (T val) {
78- return static_cast <T>(fast_acosf (static_cast <float >(val)));
60+ return static_cast <T>(:: acos (static_cast <double >(val)));
7961}
8062
8163// ----------------------
0 commit comments