@@ -98,6 +98,90 @@ GPUARRAY_PUBLIC void gpuarray_elemwise_collapse(unsigned int n,
9898 unsigned int * nd ,
9999 size_t * dim , ssize_t * * strs );
100100
101+
102+ typedef uint16_t ga_half_t ;
103+
104+ /* code strongly inspired from
105+ https://github.com/numpy/numpy/blob/master/numpy/core/src/npymath/halffloat.c#L246 */
106+
107+ static inline ga_half_t ga_float2half (float f ) {
108+ union {
109+ float f ;
110+ uint32_t bits ;
111+ } bf ;
112+ union {
113+ ga_half_t h ;
114+ uint16_t bits ;
115+ } bh ;
116+
117+ uint32_t f_exp , f_sig ;
118+ uint16_t h_sgn , h_exp , h_sig ;
119+
120+ bf .f = f ;
121+
122+ h_sgn = (bf .bits & 0x80000000u ) >> 16 ;
123+ f_exp = (bf .bits & 0x7f800000u );
124+
125+ /* Exponent overflow/NaN converts to signed inf/NaN */
126+ if (f_exp >= 0x47800000u ) {
127+ if (f_exp == 0x7f800000u ) {
128+ /* Inf or NaN */
129+ f_sig = (bf .bits & 0x007fffffu );
130+ if (f_sig != 0 ) {
131+ /* NaN - propagate the flag in the significand... */
132+ bh .bits = (uint16_t ) (0x7c00u + (f_sig >> 13 ));
133+ /* ...but make sure it stays a NaN */
134+ if (bh .bits == 0x7c00u ) {
135+ bh .bits ++ ;
136+ }
137+ bh .bits += h_sgn ;
138+ return bh .h ;
139+ } else {
140+ /* signed inf */
141+ bh .bits = h_sgn + 0x7c00u ;
142+ return bh .h ;
143+ }
144+ } else {
145+ bh .bits = h_sgn + 0x7c00u ;
146+ return bh .h ;
147+ }
148+ }
149+
150+ if (f_exp <= 0x38000000u ) {
151+ /*
152+ * Signed zeros, subnormal floats, and floats with small
153+ * exponents all convert to signed zero halfs.
154+ */
155+ if (f_exp < 0x33000000u ) {
156+ bh .bits = h_sgn ;
157+ return bh .h ;
158+ }
159+ /* Make the subnormal significand */
160+ f_exp >>= 23 ;
161+ f_sig = (0x00800000u + (bf .bits & 0x007fffffu ));
162+ f_sig >>= (113 - f_exp );
163+ /* Handle rounding by adding 1 to the bit beyond half precision */
164+ f_sig += 0x00001000u ;
165+ h_sig = (uint16_t ) (f_sig >> 13 );
166+ /*
167+ * If the rounding causes a bit to spill into h_exp, it will
168+ * increment h_exp from zero to one and h_sig will be zero.
169+ * This is the correct result.
170+ */
171+ bh .bits = h_sgn + h_sig ;
172+ return bh .h ;
173+ }
174+
175+ /* Regular case with no overflow or underflow */
176+ h_exp = (uint16_t ) ((f_exp - 0x38000000u ) >> 13 );
177+ /* Handle rounding by adding 1 to the bit beyond half precision */
178+ f_sig = (bf .bits & 0x007fffffu );
179+ f_sig += 0x00001000u ;
180+ h_sig = (uint16_t ) (f_sig >> 13 );
181+ bh .bits = h_sgn + h_exp + h_sig ;
182+ return bh .h ;
183+ }
184+
101185#ifdef __cplusplus
102186}
103187#endif
0 commit comments