Skip to content

Commit 53a25f6

Browse files
committed
Add conversion function from float to half on the host.
1 parent 9c2e317 commit 53a25f6

1 file changed

Lines changed: 84 additions & 0 deletions

File tree

src/gpuarray/util.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,90 @@ GPUARRAY_PUBLIC void gpuarray_elemwise_collapse(unsigned int n,
9898
unsigned int *nd,
9999
size_t *dim, ssize_t **strs);
100100

101+
102+
typedef uint16_t ga_half_t;
103+
104+
/* code strongly inspired from
105+
https://github.com/numpy/numpy/blob/master/numpy/core/src/npymath/halffloat.c#L246 */
106+
107+
static inline ga_half_t ga_float2half(float f) {
108+
union {
109+
float f;
110+
uint32_t bits;
111+
} bf;
112+
union {
113+
ga_half_t h;
114+
uint16_t bits;
115+
} bh;
116+
117+
uint32_t f_exp, f_sig;
118+
uint16_t h_sgn, h_exp, h_sig;
119+
120+
bf.f = f;
121+
122+
h_sgn = (bf.bits&0x80000000u) >> 16;
123+
f_exp = (bf.bits&0x7f800000u);
124+
125+
/* Exponent overflow/NaN converts to signed inf/NaN */
126+
if (f_exp >= 0x47800000u) {
127+
if (f_exp == 0x7f800000u) {
128+
/* Inf or NaN */
129+
f_sig = (bf.bits&0x007fffffu);
130+
if (f_sig != 0) {
131+
/* NaN - propagate the flag in the significand... */
132+
bh.bits = (uint16_t) (0x7c00u + (f_sig >> 13));
133+
/* ...but make sure it stays a NaN */
134+
if (bh.bits == 0x7c00u) {
135+
bh.bits++;
136+
}
137+
bh.bits += h_sgn;
138+
return bh.h;
139+
} else {
140+
/* signed inf */
141+
bh.bits = h_sgn + 0x7c00u;
142+
return bh.h;
143+
}
144+
} else {
145+
bh.bits = h_sgn + 0x7c00u;
146+
return bh.h;
147+
}
148+
}
149+
150+
if (f_exp <= 0x38000000u) {
151+
/*
152+
* Signed zeros, subnormal floats, and floats with small
153+
* exponents all convert to signed zero halfs.
154+
*/
155+
if (f_exp < 0x33000000u) {
156+
bh.bits = h_sgn;
157+
return bh.h;
158+
}
159+
/* Make the subnormal significand */
160+
f_exp >>= 23;
161+
f_sig = (0x00800000u + (bf.bits&0x007fffffu));
162+
f_sig >>= (113 - f_exp);
163+
/* Handle rounding by adding 1 to the bit beyond half precision */
164+
f_sig += 0x00001000u;
165+
h_sig = (uint16_t) (f_sig >> 13);
166+
/*
167+
* If the rounding causes a bit to spill into h_exp, it will
168+
* increment h_exp from zero to one and h_sig will be zero.
169+
* This is the correct result.
170+
*/
171+
bh.bits = h_sgn + h_sig;
172+
return bh.h;
173+
}
174+
175+
/* Regular case with no overflow or underflow */
176+
h_exp = (uint16_t) ((f_exp - 0x38000000u) >> 13);
177+
/* Handle rounding by adding 1 to the bit beyond half precision */
178+
f_sig = (bf.bits&0x007fffffu);
179+
f_sig += 0x00001000u;
180+
h_sig = (uint16_t) (f_sig >> 13);
181+
bh.bits = h_sgn + h_exp + h_sig;
182+
return bh.h;
183+
}
184+
101185
#ifdef __cplusplus
102186
}
103187
#endif

0 commit comments

Comments
 (0)