Skip to content

Commit 41fefb4

Browse files
authored
add intel subgroup matrix extensions (#907)
1 parent ea6e7d5 commit 41fefb4

2 files changed

Lines changed: 532 additions & 0 deletions

File tree

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
:data-uri:
2+
:sectanchors:
3+
:icons: font
4+
:source-highlighter: coderay
5+
// TODO: try rouge?
6+
7+
= cl_intel_subgroup_matrix_multiply_accumulate
8+
9+
== Name Strings
10+
11+
`cl_intel_subgroup_matrix_multiply_accumulate`
12+
13+
== Contact
14+
15+
Ben Ashbaugh, Intel (ben 'dot' ashbaugh 'at' intel 'dot' com)
16+
17+
== Contributors
18+
19+
// spell-checker: disable
20+
Ben Ashbaugh, Intel +
21+
Eugene Chereshnev, Intel +
22+
Junjie Gu, Intel +
23+
Bartosz Koscielak, Intel +
24+
Mike MacPherson, Intel +
25+
Ritesh Patel, Intel +
26+
Lukasz Towarek, Intel
27+
// spell-checker: enable
28+
29+
== Notice
30+
31+
Copyright (c) 2022-2023 Intel Corporation. All rights reserved.
32+
33+
== Status
34+
35+
Complete
36+
37+
== Version
38+
39+
Built On: {docdate} +
40+
Revision: 1.0.0
41+
42+
== Dependencies
43+
44+
This extension is written against the OpenCL 3.0 C Language specification, V3.0.10.
45+
46+
This extension requires support for subgroups.
47+
48+
This extension depends on `cl_intel_required_subgroup_size` to query the subgroup sizes supported by a device or to require a subgroup size for a kernel.
49+
50+
== Overview
51+
52+
The goal of this extension is to allow programmers to access specialized hardware to compute the product of an M x K matrix with a K x N matrix and then add an M x N matrix accumulation value.
53+
This is a commonly used building block to compute the product of two large matrices.
54+
When used in an OpenCL kernel, all work items in the subgroup cooperate to perform this operation.
55+
56+
This is a low-level extension for expert programmers seeking to access this functionality directly in custom kernels.
57+
Most users will access this functionality via high-level libraries or frameworks.
58+
59+
== New API Functions
60+
61+
None.
62+
63+
== New API Enums
64+
65+
None.
66+
67+
== New OpenCL C Functions
68+
69+
[source]
70+
----
71+
// These functions are available to devices where the minimum subgroup
72+
// size is 8. For these devices, the subgroup size must be 8 (the
73+
// minimum supported subgroup size). Calling these functions on other
74+
// devices or from kernels with a different subgroup size is undefined
75+
// behavior:
76+
77+
// 8-bit matrices:
78+
int intel_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc); // M = 1
79+
int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc); // M = 2
80+
int4 intel_sub_group_i8_i8_matrix_mad_k32(int4 a, int8 b, int4 acc); // M = 4
81+
int8 intel_sub_group_i8_i8_matrix_mad_k32(int8 a, int8 b, int8 acc); // M = 8
82+
83+
int intel_sub_group_i8_u8_matrix_mad_k32(int a, uint8 b, int acc); // ...
84+
int2 intel_sub_group_i8_u8_matrix_mad_k32(int2 a, uint8 b, int2 acc);
85+
int4 intel_sub_group_i8_u8_matrix_mad_k32(int4 a, uint8 b, int4 acc);
86+
int8 intel_sub_group_i8_u8_matrix_mad_k32(int8 a, uint8 b, int8 acc);
87+
88+
int intel_sub_group_u8_i8_matrix_mad_k32(uint a, int8 b, int acc);
89+
int2 intel_sub_group_u8_i8_matrix_mad_k32(uint2 a, int8 b, int2 acc);
90+
int4 intel_sub_group_u8_i8_matrix_mad_k32(uint4 a, int8 b, int4 acc);
91+
int8 intel_sub_group_u8_i8_matrix_mad_k32(uint8 a, int8 b, int8 acc);
92+
93+
int intel_sub_group_u8_u8_matrix_mad_k32(uint a, uint8 b, int acc);
94+
int2 intel_sub_group_u8_u8_matrix_mad_k32(uint2 a, uint8 b, int2 acc);
95+
int4 intel_sub_group_u8_u8_matrix_mad_k32(uint4 a, uint8 b, int4 acc);
96+
int8 intel_sub_group_u8_u8_matrix_mad_k32(uint8 a, uint8 b, int8 acc);
97+
98+
// bfloat16 matrices:
99+
float intel_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc);
100+
float2 intel_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc);
101+
float4 intel_sub_group_bf16_bf16_matrix_mad_k16(int4 a, int8 b, float4 acc);
102+
float8 intel_sub_group_bf16_bf16_matrix_mad_k16(int8 a, int8 b, float8 acc);
103+
104+
// fp16 matrices:
105+
float intel_sub_group_f16_f16_matrix_mad_k16(int a, int8 b, float acc);
106+
float2 intel_sub_group_f16_f16_matrix_mad_k16(int2 a, int8 b, float2 acc);
107+
float4 intel_sub_group_f16_f16_matrix_mad_k16(int4 a, int8 b, float4 acc);
108+
float8 intel_sub_group_f16_f16_matrix_mad_k16(int8 a, int8 b, float8 acc);
109+
110+
// These functions are available to devices where the minimum subgroup
111+
// size is 16. For these devices, the subgroup size must be 16 (the
112+
// minimum supported subgroup size). Calling these functions on other
113+
// devices or from kernels with a different subgroup size is undefined
114+
// behavior:
115+
116+
// 8-bit matrices:
117+
int intel_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc); // M = 1
118+
int2 intel_sub_group_i8_i8_matrix_mad_k32(short2 a, int8 b, int2 acc); // M = 2
119+
int4 intel_sub_group_i8_i8_matrix_mad_k32(short4 a, int8 b, int4 acc); // M = 4
120+
int8 intel_sub_group_i8_i8_matrix_mad_k32(short8 a, int8 b, int8 acc); // M = 8
121+
122+
int intel_sub_group_i8_u8_matrix_mad_k32(short a, uint8 b, int acc); // ...
123+
int2 intel_sub_group_i8_u8_matrix_mad_k32(short2 a, uint8 b, int2 acc);
124+
int4 intel_sub_group_i8_u8_matrix_mad_k32(short4 a, uint8 b, int4 acc);
125+
int8 intel_sub_group_i8_u8_matrix_mad_k32(short8 a, uint8 b, int8 acc);
126+
127+
int intel_sub_group_u8_i8_matrix_mad_k32(ushort a, int8 b, int acc);
128+
int2 intel_sub_group_u8_i8_matrix_mad_k32(ushort2 a, int8 b, int2 acc);
129+
int4 intel_sub_group_u8_i8_matrix_mad_k32(ushort4 a, int8 b, int4 acc);
130+
int8 intel_sub_group_u8_i8_matrix_mad_k32(ushort8 a, int8 b, int8 acc);
131+
132+
int intel_sub_group_u8_u8_matrix_mad_k32(ushort a, uint8 b, int acc);
133+
int2 intel_sub_group_u8_u8_matrix_mad_k32(ushort2 a, uint8 b, int2 acc);
134+
int4 intel_sub_group_u8_u8_matrix_mad_k32(ushort4 a, uint8 b, int4 acc);
135+
int8 intel_sub_group_u8_u8_matrix_mad_k32(ushort8 a, uint8 b, int8 acc);
136+
137+
// bfloat16 matrices:
138+
float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc);
139+
float2 intel_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc);
140+
float4 intel_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc);
141+
float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc);
142+
143+
// fp16 matrices:
144+
float intel_sub_group_f16_f16_matrix_mad_k16(short a, int8 b, float acc);
145+
float2 intel_sub_group_f16_f16_matrix_mad_k16(short2 a, int8 b, float2 acc);
146+
float4 intel_sub_group_f16_f16_matrix_mad_k16(short4 a, int8 b, float4 acc);
147+
float8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, float8 acc);
148+
----
149+
150+
== Modifications to the OpenCL C Specification
151+
152+
=== Add a new Section 6.13.X - Subgroup Matrix Multiply Accumulate Instructions
153+
154+
This section describes a family of built-in functions that multiply two matrix sources `a` and `b` and then add a matrix accumulation value to produce a matrix result value.
155+
`a` is the first matrix operand and has M rows and K columns.
156+
`b` is the second matrix operand and has K rows and N columns.
157+
`acc` is the matrix accumulation value and has M rows and N columns.
158+
The result value also has M rows and N columns.
159+
All work items in the subgroup cooperate to perform this operation.
160+
These functions must be encountered by all work items in the subgroup executing the kernel.
161+
162+
The dimensions of the two source matrices and the elements of each source matrix are described by the built-in function name and its arguments.
163+
164+
As an example, given the function:
165+
166+
[source]
167+
----
168+
int2 intel_sub_group_u8_i8_matrix_mad_k32(uint2 a, int8 b, int2 acc);
169+
----
170+
171+
* `a` is the first source matrix operand and has `M` rows and `K` columns.
172+
** The value for `M` is determined by the number of vector components in the source operand `a`.
173+
In the example above, `a` is a `uint2` argument, therefore the matrix `a` operand has `M` equal to 2 rows.
174+
** The value of `K` is described by the function name.
175+
In this case, the value of `K` is 32, therefore the matrix `a` operand has `K` equal to 32 columns.
176+
** The matrix component data type is also described by the function name.
177+
In this case, the matrix `a` component data type is `u8`, indicating that the elements of the matrix `a` operand are unsigned 8-bit integers.
178+
** Each work item contributes part of this matrix.
179+
In this case, since the elements of the matrix `a` are 8-bit integers, and since each work item is contributing 32 bits (the size of a `uint`) of data per row of this matrix, each work item is contributing four 8-bit integer values per row.
180+
** Since `K` is 32, and each work item is contributing four 8-bit values per row, the number of work items in the subgroup must be equal to 8.
181+
182+
* `b` is the second source matrix operand and has `K` rows and `N` columns.
183+
** Each work item contributes one column of this matrix.
184+
Therefore, the number of columns `N` is equivalent to the subgroup size.
185+
** As above, the value of `K` is described by the function name.
186+
In this case, the value of `K` is 32, therefore the matrix `b` operand has `K` equal to 32 rows.
187+
** As above, the matrix component data type is described by the function name.
188+
In this case, the matrix `b` component data type is `i8`, indicating that the elements of the matrix `b` operand are signed 8-bit integers.
189+
** Since `K` is 32 and the elements of the matrix `b` are 8-bit integers, each work item must contribute 256 bits of source data to contribute `K` values.
190+
The 256 bits of source data are packed and passed as the `int8` argument `b`.
191+
192+
* `acc` specifies the accumulation value and has `M` rows and `N` columns.
193+
** As above, the value of `M` is determined by the number of components in the source operand `acc`.
194+
In the example above, `acc` is an `int2` argument, therefore the accumulation value operand has `M` equal to 2 rows.
195+
** Since both `a` and `acc` specify operands with `M` rows, and since the value of `M` is determined by the number of components in the source operand, both the `a` and `acc` operands will be vector operands with the same number of components.
196+
** As above, each work item contributes one column of accumulation values.
197+
Therefore, the number of columns `N` is equivalent to the subgroup size.
198+
** The `acc` operand is a "full precision" accumulation value.
199+
In the example above, the matrices contain integer data, therefore the `acc` operand is a vector of `int` data.
200+
201+
* The result value returned by the function also has `M` rows and `N` columns.
202+
** As above, the value of `M` is determined by the number of components in the return type.
203+
In the example above, the return type is `int2`, therefore the result value has `M` equal to 2 rows.
204+
** Since the result value, `a`, and `acc` all specify values with `M` rows, and since the value of `M` is determined by the number of components in the source operand or return type, the return tye, `a`, and `acc` will all be vectors with the same number of components.
205+
** As above, each work item will receive one column of result values.
206+
Therefore, the number of columns `N` is equivalent to the subgroup size.
207+
** Similar to the `acc` operand, the return value is a "full precision" result value.
208+
In the example above, the matrices contain integer data, therefore the return type is a vector of `int` data.
209+
210+
The full list of supported functions is described in the overview, above.
211+
For this list of functions:
212+
213+
* `M` may be equal to 1, 2, 4, or 8.
214+
* `N` must be equal to 8 for some devices or 16 for other devices.
215+
In other words, the only supported subgroup sizes are 8 and 16.
216+
* Supported integer matrix types for `a` and `b` are any combination of signed or unsigned 8-bit integers.
217+
For these integer matrix types, the accumulation value `acc` and result value are signed 32-bit integers, and `K` must be equal to 32.
218+
* The supported floating-point matrix types for `a` and `b` are fp16 (half) or bfloat16.
219+
For these floating-point matrix type, the accumulation value `acc` and result value are 32-bit floating-point values, and `K` must be equal to 16.
220+
221+
== Coding Sample
222+
223+
[source]
224+
----
225+
// The code below shows a functional implementation of one of the
226+
// built-in functions added by this extension. For this built-in
227+
// function:
228+
// * M = 2, since the result value, a operand, and acc operand
229+
// are all vectors with two components.
230+
// * N = 8, and is equal to the subgroup size.
231+
// * K = 32, as described by the function name.
232+
// * The elements of both matrix a and matrix b are signed 8-bit
233+
// integers.
234+
235+
// This is a helper function that performs the dot product of
236+
// two vectors of four components of 8-bit integer data, and then
237+
// adds a 32-bit integer accumulation value.
238+
static int __intel_dot_product_accumulate( char4 a, char4 b, int acc )
239+
{
240+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w + acc;
241+
}
242+
243+
// This is a helper function that computes the product of a
244+
// 1 x 32 row vector value shared across the subgroup and a 32 x 1
245+
// column vector, that is added to a full precision accumulation
246+
// value.
247+
static int __intel_vector_matrix_multiply_accumulate_k32( int v, int8 b, int acc )
248+
{
249+
// Note: 8 is the size of the subgroup.
250+
// As K is 32, and the size of the subgroup is 8, each
251+
// work item contributes 4 elements of the 1 x K vector.
252+
// as_char4() is used to reinterpret 32-bits of data
253+
// as four components of 8-bit data.
254+
255+
int result = acc;
256+
257+
result = __intel_dot_product_accumulate(
258+
as_char4( sub_group_broadcast( v, 0 ) ), as_char4( b.s0 ), result );
259+
result = __intel_dot_product_accumulate(
260+
as_char4( sub_group_broadcast( v, 1 ) ), as_char4( b.s1 ), result );
261+
result = __intel_dot_product_accumulate(
262+
as_char4( sub_group_broadcast( v, 2 ) ), as_char4( b.s2 ), result );
263+
result = __intel_dot_product_accumulate(
264+
as_char4( sub_group_broadcast( v, 3 ) ), as_char4( b.s3 ), result );
265+
266+
result = __intel_dot_product_accumulate(
267+
as_char4( sub_group_broadcast( v, 4 ) ), as_char4( b.s4 ), result );
268+
result = __intel_dot_product_accumulate(
269+
as_char4( sub_group_broadcast( v, 5 ) ), as_char4( b.s5 ), result );
270+
result = __intel_dot_product_accumulate(
271+
as_char4( sub_group_broadcast( v, 6 ) ), as_char4( b.s6 ), result );
272+
result = __intel_dot_product_accumulate(
273+
as_char4( sub_group_broadcast( v, 7 ) ), as_char4( b.s7 ), result );
274+
275+
return result;
276+
}
277+
278+
int2 intel_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc)
279+
{
280+
int2 result;
281+
282+
result.x = __intel_vector_matrix_multiply_accumulate_k32( a.x, b, acc.x );
283+
result.y = __intel_vector_matrix_multiply_accumulate_k32( a.y, b, acc.y );
284+
285+
return result;
286+
}
287+
----
288+
289+
== Issues
290+
291+
None.
292+
293+
. Should this extension use signed or unsigned types to represent fp16 and bf16 data?
294+
+
295+
--
296+
`RESOLVED`: This extension will use signed types to represent fp16 and bf16 data even though this is inconsistent with other extensions such as cl_intel_bfloat16 conversions.
297+
This inconsistency may be addressed in a future extension or in a future version of this extension.
298+
Applications are encouraged to use `as_type` to reinterpret unsigned data as signed data as needed to use the functions added by this extension.
299+
--
300+
301+
== Revision History
302+
303+
[cols="5,15,15,70"]
304+
[grid="rows"]
305+
[options="header"]
306+
|========================================
307+
|Rev|Date|Author|Changes
308+
|1.0.0|2022-05-18|Ben Ashbaugh|*Initial public revision*
309+
|========================================
310+
311+
//************************************************************************
312+
//Other formatting suggestions:
313+
//
314+
//* Use *bold* text for host APIs, or [source] syntax highlighting.
315+
//* Use `mono` text for device APIs, or [source] syntax highlighting.
316+
//* Use `mono` text for extension names, types, or enum values.
317+
//* Use _italics_ for parameters.
318+
//************************************************************************

0 commit comments

Comments
 (0)