|
| 1 | +:data-uri: |
| 2 | +:sectanchors: |
| 3 | +:icons: font |
| 4 | +:source-highlighter: coderay |
| 5 | +// TODO: try rouge? |
| 6 | + |
| 7 | += cl_intel_subgroup_matrix_multiply_accumulate |
| 8 | + |
| 9 | +== Name Strings |
| 10 | + |
| 11 | +`cl_intel_subgroup_matrix_multiply_accumulate` |
| 12 | + |
| 13 | +== Contact |
| 14 | + |
| 15 | +Ben Ashbaugh, Intel (ben 'dot' ashbaugh 'at' intel 'dot' com) |
| 16 | + |
| 17 | +== Contributors |
| 18 | + |
| 19 | +// spell-checker: disable |
| 20 | +Ben Ashbaugh, Intel + |
| 21 | +Eugene Chereshnev, Intel + |
| 22 | +Junjie Gu, Intel + |
| 23 | +Bartosz Koscielak, Intel + |
| 24 | +Mike MacPherson, Intel + |
| 25 | +Ritesh Patel, Intel + |
| 26 | +Lukasz Towarek, Intel |
| 27 | +// spell-checker: enable |
| 28 | + |
| 29 | +== Notice |
| 30 | + |
| 31 | +Copyright (c) 2022-2023 Intel Corporation. All rights reserved. |
| 32 | + |
| 33 | +== Status |
| 34 | + |
| 35 | +Complete |
| 36 | + |
| 37 | +== Version |
| 38 | + |
| 39 | +Built On: {docdate} + |
| 40 | +Revision: 1.0.0 |
| 41 | + |
| 42 | +== Dependencies |
| 43 | + |
| 44 | +This extension is written against the OpenCL 3.0 C Language specification, V3.0.10. |
| 45 | + |
| 46 | +This extension requires support for subgroups. |
| 47 | + |
| 48 | +This extension depends on `cl_intel_required_subgroup_size` to query the subgroup sizes supported by a device or to require a subgroup size for a kernel. |
| 49 | + |
| 50 | +== Overview |
| 51 | + |
| 52 | +The goal of this extension is to allow programmers to access specialized hardware to compute the product of an M x K matrix with a K x N matrix and then add an M x N matrix accumulation value. |
| 53 | +This is a commonly used building block to compute the product of two large matrices. |
| 54 | +When used in an OpenCL kernel, all work items in the subgroup cooperate to perform this operation. |
| 55 | + |
| 56 | +This is a low-level extension for expert programmers seeking to access this functionality directly in custom kernels. |
| 57 | +Most users will access this functionality via high-level libraries or frameworks. |
| 58 | + |
| 59 | +== New API Functions |
| 60 | + |
| 61 | +None. |
| 62 | + |
| 63 | +== New API Enums |
| 64 | + |
| 65 | +None. |
| 66 | + |
| 67 | +== New OpenCL C Functions |
| 68 | + |
| 69 | +[source] |
| 70 | +---- |
| 71 | +// These functions are available to devices where the minimum subgroup |
| 72 | +// size is 8. For these devices, the subgroup size must be 8 (the |
| 73 | +// minimum supported subgroup size). Calling these functions on other |
| 74 | +// devices or from kernels with a different subgroup size is undefined |
| 75 | +// behavior: |
| 76 | +
|
| 77 | +// 8-bit matrices: |
| 78 | +int intel_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc); // M = 1 |
| 79 | +int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc); // M = 2 |
| 80 | +int4 intel_sub_group_i8_i8_matrix_mad_k32(int4 a, int8 b, int4 acc); // M = 4 |
| 81 | +int8 intel_sub_group_i8_i8_matrix_mad_k32(int8 a, int8 b, int8 acc); // M = 8 |
| 82 | +
|
| 83 | +int intel_sub_group_i8_u8_matrix_mad_k32(int a, uint8 b, int acc); // ... |
| 84 | +int2 intel_sub_group_i8_u8_matrix_mad_k32(int2 a, uint8 b, int2 acc); |
| 85 | +int4 intel_sub_group_i8_u8_matrix_mad_k32(int4 a, uint8 b, int4 acc); |
| 86 | +int8 intel_sub_group_i8_u8_matrix_mad_k32(int8 a, uint8 b, int8 acc); |
| 87 | +
|
| 88 | +int intel_sub_group_u8_i8_matrix_mad_k32(uint a, int8 b, int acc); |
| 89 | +int2 intel_sub_group_u8_i8_matrix_mad_k32(uint2 a, int8 b, int2 acc); |
| 90 | +int4 intel_sub_group_u8_i8_matrix_mad_k32(uint4 a, int8 b, int4 acc); |
| 91 | +int8 intel_sub_group_u8_i8_matrix_mad_k32(uint8 a, int8 b, int8 acc); |
| 92 | +
|
| 93 | +int intel_sub_group_u8_u8_matrix_mad_k32(uint a, uint8 b, int acc); |
| 94 | +int2 intel_sub_group_u8_u8_matrix_mad_k32(uint2 a, uint8 b, int2 acc); |
| 95 | +int4 intel_sub_group_u8_u8_matrix_mad_k32(uint4 a, uint8 b, int4 acc); |
| 96 | +int8 intel_sub_group_u8_u8_matrix_mad_k32(uint8 a, uint8 b, int8 acc); |
| 97 | +
|
| 98 | +// bfloat16 matrices: |
| 99 | +float intel_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc); |
| 100 | +float2 intel_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc); |
| 101 | +float4 intel_sub_group_bf16_bf16_matrix_mad_k16(int4 a, int8 b, float4 acc); |
| 102 | +float8 intel_sub_group_bf16_bf16_matrix_mad_k16(int8 a, int8 b, float8 acc); |
| 103 | +
|
| 104 | +// fp16 matrices: |
| 105 | +float intel_sub_group_f16_f16_matrix_mad_k16(int a, int8 b, float acc); |
| 106 | +float2 intel_sub_group_f16_f16_matrix_mad_k16(int2 a, int8 b, float2 acc); |
| 107 | +float4 intel_sub_group_f16_f16_matrix_mad_k16(int4 a, int8 b, float4 acc); |
| 108 | +float8 intel_sub_group_f16_f16_matrix_mad_k16(int8 a, int8 b, float8 acc); |
| 109 | +
|
| 110 | +// These functions are available to devices where the minimum subgroup |
| 111 | +// size is 16. For these devices, the subgroup size must be 16 (the |
| 112 | +// minimum supported subgroup size). Calling these functions on other |
| 113 | +// devices or from kernels with a different subgroup size is undefined |
| 114 | +// behavior: |
| 115 | +
|
| 116 | +// 8-bit matrices: |
| 117 | +int intel_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc); // M = 1 |
| 118 | +int2 intel_sub_group_i8_i8_matrix_mad_k32(short2 a, int8 b, int2 acc); // M = 2 |
| 119 | +int4 intel_sub_group_i8_i8_matrix_mad_k32(short4 a, int8 b, int4 acc); // M = 4 |
| 120 | +int8 intel_sub_group_i8_i8_matrix_mad_k32(short8 a, int8 b, int8 acc); // M = 8 |
| 121 | +
|
| 122 | +int intel_sub_group_i8_u8_matrix_mad_k32(short a, uint8 b, int acc); // ... |
| 123 | +int2 intel_sub_group_i8_u8_matrix_mad_k32(short2 a, uint8 b, int2 acc); |
| 124 | +int4 intel_sub_group_i8_u8_matrix_mad_k32(short4 a, uint8 b, int4 acc); |
| 125 | +int8 intel_sub_group_i8_u8_matrix_mad_k32(short8 a, uint8 b, int8 acc); |
| 126 | +
|
| 127 | +int intel_sub_group_u8_i8_matrix_mad_k32(ushort a, int8 b, int acc); |
| 128 | +int2 intel_sub_group_u8_i8_matrix_mad_k32(ushort2 a, int8 b, int2 acc); |
| 129 | +int4 intel_sub_group_u8_i8_matrix_mad_k32(ushort4 a, int8 b, int4 acc); |
| 130 | +int8 intel_sub_group_u8_i8_matrix_mad_k32(ushort8 a, int8 b, int8 acc); |
| 131 | +
|
| 132 | +int intel_sub_group_u8_u8_matrix_mad_k32(ushort a, uint8 b, int acc); |
| 133 | +int2 intel_sub_group_u8_u8_matrix_mad_k32(ushort2 a, uint8 b, int2 acc); |
| 134 | +int4 intel_sub_group_u8_u8_matrix_mad_k32(ushort4 a, uint8 b, int4 acc); |
| 135 | +int8 intel_sub_group_u8_u8_matrix_mad_k32(ushort8 a, uint8 b, int8 acc); |
| 136 | +
|
| 137 | +// bfloat16 matrices: |
| 138 | +float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc); |
| 139 | +float2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc); |
| 140 | +float4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc); |
| 141 | +float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc); |
| 142 | +
|
| 143 | +// fp16 matrices: |
| 144 | +float intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, float acc); |
| 145 | +float2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, float2 acc); |
| 146 | +float4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, float4 acc); |
| 147 | +float8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, float8 acc); |
| 148 | +---- |
| 149 | + |
| 150 | +== Modifications to the OpenCL C Specification |
| 151 | + |
| 152 | +=== Add a new Section 6.13.X - Subgroup Matrix Multiply Accumulate Instructions |
| 153 | + |
| 154 | +This section describes a family of built-in functions that multiply two matrix sources `a` and `b` and then add a matrix accumulation value to produce a matrix result value. |
| 155 | +`a` is the first matrix operand and has M rows and K columns. |
| 156 | +`b` is the second matrix operand and has K rows and N columns. |
| 157 | +`acc` is the matrix accumulation value and has M rows and N columns. |
| 158 | +The result value also has M rows and N columns. |
| 159 | +All work items in the subgroup cooperate to perform this operation. |
| 160 | +These functions must be encountered by all work items in the subgroup executing the kernel. |
| 161 | + |
| 162 | +The dimensions of the two source matrices and the elements of each source matrix are described by the built-in function name and its arguments. |
| 163 | + |
| 164 | +As an example, given the function: |
| 165 | + |
| 166 | +[source] |
| 167 | +---- |
| 168 | +int2 intel_sub_group_u8_i8_matrix_mad_k32(uint2 a, int8 b, int2 acc); |
| 169 | +---- |
| 170 | + |
| 171 | +* `a` is the first source matrix operand and has `M` rows and `K` columns. |
| 172 | +** The value for `M` is determined by the number of vector components in the source operand `a`. |
| 173 | +In the example above, `a` is a `uint2` argument, therefore the matrix `a` operand has `M` equal to 2 rows. |
| 174 | +** The value of `K` is described by the function name. |
| 175 | +In this case, the value of `K` is 32, therefore the matrix `a` operand has `K` equal to 32 columns. |
| 176 | +** The matrix component data type is also described by the function name. |
| 177 | +In this case, the matrix `a` component data type is `u8`, indicating that the elements of the matrix `a` operand are unsigned 8-bit integers. |
| 178 | +** Each work item contributes part of this matrix. |
| 179 | +In this case, since the elements of the matrix `a` are 8-bit integers, and since each work item is contributing 32 bits (the size of a `uint`) of data per row of this matrix, each work item is contributing four 8-bit integer values per row. |
| 180 | +** Since `K` is 32, and each work item is contributing four 8-bit values per row, the number of work items in the subgroup must be equal to 8. |
| 181 | + |
| 182 | +* `b` is the second source matrix operand and has `K` rows and `N` columns. |
| 183 | +** Each work item contributes one column of this matrix. |
| 184 | +Therefore, the number of columns `N` is equivalent to the subgroup size. |
| 185 | +** As above, the value of `K` is described by the function name. |
| 186 | +In this case, the value of `K` is 32, therefore the matrix `b` operand has `K` equal to 32 rows. |
| 187 | +** As above, the matrix component data type is described by the function name. |
| 188 | +In this case, the matrix `b` component data type is `i8`, indicating that the elements of the matrix `b` operand are signed 8-bit integers. |
| 189 | +** Since `K` is 32 and the elements of the matrix `b` are 8-bit integers, each work item must contribute 256 bits of source data to contribute `K` values. |
| 190 | +The 256 bits of source data are packed and passed as the `int8` argument `b`. |
| 191 | + |
| 192 | +* `acc` specifies the accumulation value and has `M` rows and `N` columns. |
| 193 | +** As above, the value of `M` is determined by the number of components in the source operand `acc`. |
| 194 | +In the example above, `acc` is an `int2` argument, therefore the accumulation value operand has `M` equal to 2 rows. |
| 195 | +** Since both `a` and `acc` specify operands with `M` rows, and since the value of `M` is determined by the number of components in the source operand, both the `a` and `acc` operands will be vector operands with the same number of components. |
| 196 | +** As above, each work item contributes one column of accumulation values. |
| 197 | +Therefore, the number of columns `N` is equivalent to the subgroup size. |
| 198 | +** The `acc` operand is a "full precision" accumulation value. |
| 199 | +In the example above, the matrices contain integer data, therefore the `acc` operand is a vector of `int` data. |
| 200 | + |
| 201 | +* The result value returned by the function also has `M` rows and `N` columns. |
| 202 | +** As above, the value of `M` is determined by the number of components in the return type. |
| 203 | +In the example above, the return type is `int2`, therefore the result value has `M` equal to 2 rows. |
| 204 | +** Since the result value, `a`, and `acc` all specify values with `M` rows, and since the value of `M` is determined by the number of components in the source operand or return type, the return tye, `a`, and `acc` will all be vectors with the same number of components. |
| 205 | +** As above, each work item will receive one column of result values. |
| 206 | +Therefore, the number of columns `N` is equivalent to the subgroup size. |
| 207 | +** Similar to the `acc` operand, the return value is a "full precision" result value. |
| 208 | +In the example above, the matrices contain integer data, therefore the return type is a vector of `int` data. |
| 209 | + |
| 210 | +The full list of supported functions is described in the overview, above. |
| 211 | +For this list of functions: |
| 212 | + |
| 213 | +* `M` may be equal to 1, 2, 4, or 8. |
| 214 | +* `N` must be equal to 8 for some devices or 16 for other devices. |
| 215 | +In other words, the only supported subgroup sizes are 8 and 16. |
| 216 | +* Supported integer matrix types for `a` and `b` are any combination of signed or unsigned 8-bit integers. |
| 217 | +For these integer matrix types, the accumulation value `acc` and result value are signed 32-bit integers, and `K` must be equal to 32. |
| 218 | +* The supported floating-point matrix types for `a` and `b` are fp16 (half) or bfloat16. |
| 219 | +For these floating-point matrix type, the accumulation value `acc` and result value are 32-bit floating-point values, and `K` must be equal to 16. |
| 220 | + |
| 221 | +== Coding Sample |
| 222 | + |
| 223 | +[source] |
| 224 | +---- |
| 225 | +// The code below shows a functional implementation of one of the |
| 226 | +// built-in functions added by this extension. For this built-in |
| 227 | +// function: |
| 228 | +// * M = 2, since the result value, a operand, and acc operand |
| 229 | +// are all vectors with two components. |
| 230 | +// * N = 8, and is equal to the subgroup size. |
| 231 | +// * K = 32, as described by the function name. |
| 232 | +// * The elements of both matrix a and matrix b are signed 8-bit |
| 233 | +// integers. |
| 234 | +
|
| 235 | +// This is a helper function that performs the dot product of |
| 236 | +// two vectors of four components of 8-bit integer data, and then |
| 237 | +// adds a 32-bit integer accumulation value. |
| 238 | +static int __intel_dot_product_accumulate( char4 a, char4 b, int acc ) |
| 239 | +{ |
| 240 | + return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w + acc; |
| 241 | +} |
| 242 | +
|
| 243 | +// This is a helper function that computes the product of a |
| 244 | +// 1 x 32 row vector value shared across the subgroup and a 32 x 1 |
| 245 | +// column vector, that is added to a full precision accumulation |
| 246 | +// value. |
| 247 | +static int __intel_vector_matrix_multiply_accumulate_k32( int v, int8 b, int acc ) |
| 248 | +{ |
| 249 | + // Note: 8 is the size of the subgroup. |
| 250 | + // As K is 32, and the size of the subgroup is 8, each |
| 251 | + // work item contributes 4 elements of the 1 x K vector. |
| 252 | + // as_char4() is used to reinterpret 32-bits of data |
| 253 | + // as four components of 8-bit data. |
| 254 | +
|
| 255 | + int result = acc; |
| 256 | +
|
| 257 | + result = __intel_dot_product_accumulate( |
| 258 | + as_char4( sub_group_broadcast( v, 0 ) ), as_char4( b.s0 ), result ); |
| 259 | + result = __intel_dot_product_accumulate( |
| 260 | + as_char4( sub_group_broadcast( v, 1 ) ), as_char4( b.s1 ), result ); |
| 261 | + result = __intel_dot_product_accumulate( |
| 262 | + as_char4( sub_group_broadcast( v, 2 ) ), as_char4( b.s2 ), result ); |
| 263 | + result = __intel_dot_product_accumulate( |
| 264 | + as_char4( sub_group_broadcast( v, 3 ) ), as_char4( b.s3 ), result ); |
| 265 | +
|
| 266 | + result = __intel_dot_product_accumulate( |
| 267 | + as_char4( sub_group_broadcast( v, 4 ) ), as_char4( b.s4 ), result ); |
| 268 | + result = __intel_dot_product_accumulate( |
| 269 | + as_char4( sub_group_broadcast( v, 5 ) ), as_char4( b.s5 ), result ); |
| 270 | + result = __intel_dot_product_accumulate( |
| 271 | + as_char4( sub_group_broadcast( v, 6 ) ), as_char4( b.s6 ), result ); |
| 272 | + result = __intel_dot_product_accumulate( |
| 273 | + as_char4( sub_group_broadcast( v, 7 ) ), as_char4( b.s7 ), result ); |
| 274 | +
|
| 275 | + return result; |
| 276 | +} |
| 277 | +
|
| 278 | +int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc) |
| 279 | +{ |
| 280 | + int2 result; |
| 281 | +
|
| 282 | + result.x = __intel_vector_matrix_multiply_accumulate_k32( a.x, b, acc.x ); |
| 283 | + result.y = __intel_vector_matrix_multiply_accumulate_k32( a.y, b, acc.y ); |
| 284 | +
|
| 285 | + return result; |
| 286 | +} |
| 287 | +---- |
| 288 | + |
| 289 | +== Issues |
| 290 | + |
| 291 | +None. |
| 292 | + |
| 293 | +. Should this extension use signed or unsigned types to represent fp16 and bf16 data? |
| 294 | ++ |
| 295 | +-- |
| 296 | +`RESOLVED`: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions such as cl_intel_bfloat16 conversions. |
| 297 | +This inconsistency may be addressed in a future extension or in a future version of this extension. |
| 298 | +Applications are encouraged to use `as_type` to reinterpret unsigned data as signed data as needed to use the functions added by this extension. |
| 299 | +-- |
| 300 | + |
| 301 | +== Revision History |
| 302 | + |
| 303 | +[cols="5,15,15,70"] |
| 304 | +[grid="rows"] |
| 305 | +[options="header"] |
| 306 | +|======================================== |
| 307 | +|Rev|Date|Author|Changes |
| 308 | +|1.0.0|2022-05-18|Ben Ashbaugh|*Initial public revision* |
| 309 | +|======================================== |
| 310 | + |
| 311 | +//************************************************************************ |
| 312 | +//Other formatting suggestions: |
| 313 | +// |
| 314 | +//* Use *bold* text for host APIs, or [source] syntax highlighting. |
| 315 | +//* Use `mono` text for device APIs, or [source] syntax highlighting. |
| 316 | +//* Use `mono` text for extension names, types, or enum values. |
| 317 | +//* Use _italics_ for parameters. |
| 318 | +//************************************************************************ |
0 commit comments