Skip to content

Commit af8b5e1

Browse files
committed
update tf32 function names to be closer to the final versions
1 parent eff9d19 commit af8b5e1

4 files changed

Lines changed: 83 additions & 85 deletions

File tree

samples/99_matrixexperiments/matrix_kernels.cl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho
396396
float sum = 0;
397397
for (int k = 0; k < K; k += tK) {
398398
short aData = as_short(intel_sub_group_block_read_16b_1r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
399-
int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
399+
int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
400400
sum = mat_mul_sg16(aData, bData, sum);
401401
}
402402

@@ -418,7 +418,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho
418418
float2 sum = 0;
419419
for (int k = 0; k < K; k += tK) {
420420
short2 aData = as_short2(intel_sub_group_block_read_16b_2r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
421-
int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
421+
int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
422422
sum = mat_mul_sg16(aData, bData, sum);
423423
}
424424

@@ -440,7 +440,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho
440440
float4 sum = 0;
441441
for (int k = 0; k < K; k += tK) {
442442
short4 aData = as_short4(intel_sub_group_block_read_16b_4r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
443-
int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
443+
int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
444444
sum = mat_mul_sg16(aData, bData, sum);
445445
}
446446

@@ -462,7 +462,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho
462462
float8 sum = 0;
463463
for (int k = 0; k < K; k += tK) {
464464
short8 aData = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m)));
465-
int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
465+
int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k)));
466466
sum = mat_mul_sg16(aData, bData, sum);
467467
}
468468

samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ float8 emu_sub_group_tf32_tf32_matrix_mad_k8(float4 a, float8 b, float8 acc)
134134
}
135135

136136
// M rows x K columns
137-
float load_a_rowmajor_d32_m1_k8_sg16(global float* A, int rowStart, int colStart, int stride)
137+
float load_a_rowmajor_32b_1r8c_sg16(global float* A, int rowStart, int colStart, int stride)
138138
{
139139
float ret;
140140

@@ -148,7 +148,7 @@ float load_a_rowmajor_d32_m1_k8_sg16(global float* A, int rowStart, int colStart
148148
}
149149

150150
// M rows x K columns
151-
float load_a_rowmajor_d32_m2_k8_sg16(global float* A, int rowStart, int colStart, int stride)
151+
float load_a_rowmajor_32b_2r8c_sg16(global float* A, int rowStart, int colStart, int stride)
152152
{
153153
float ret;
154154

@@ -162,7 +162,7 @@ float load_a_rowmajor_d32_m2_k8_sg16(global float* A, int rowStart, int colStart
162162
}
163163

164164
// M rows x K columns
165-
float2 load_a_rowmajor_d32_m4_k8_sg16(global float* A, int rowStart, int colStart, int stride)
165+
float2 load_a_rowmajor_32b_4r8c_sg16(global float* A, int rowStart, int colStart, int stride)
166166
{
167167
float2 ret;
168168

@@ -177,7 +177,7 @@ float2 load_a_rowmajor_d32_m4_k8_sg16(global float* A, int rowStart, int colStar
177177
}
178178

179179
// M rows x K columns
180-
float4 load_a_rowmajor_d32_m8_k8_sg16(global float* A, int rowStart, int colStart, int stride)
180+
float4 load_a_rowmajor_32b_8r8c_sg16(global float* A, int rowStart, int colStart, int stride)
181181
{
182182
float4 ret;
183183

@@ -194,7 +194,7 @@ float4 load_a_rowmajor_d32_m8_k8_sg16(global float* A, int rowStart, int colStar
194194
}
195195

196196
// M rows x K columns x V tiles (in the M and K dimensions)
197-
void prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(global float* A, int rowStart, int colStart, int stride)
197+
void prefetch_a_rowmajor_32b_8x2r8x2c_sg16(global float* A, int rowStart, int colStart, int stride)
198198
{
199199
#if defined(PREFETCH_DEFAULT)
200200
uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride;
@@ -205,7 +205,7 @@ void prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(global float* A, int rowStart, int c
205205
// K rows x N columns:
206206
// Each work-item loads K values.
207207
// Stride is in units of elements.
208-
float8 load_b_rowmajor_d32_k8_nx(global float* B, int rowStart, int colStart, int stride)
208+
float8 load_b_rowmajor_32b_8rNc(global float* B, int rowStart, int colStart, int stride)
209209
{
210210
float8 ret;
211211

@@ -224,15 +224,15 @@ float8 load_b_rowmajor_d32_k8_nx(global float* B, int rowStart, int colStart, in
224224
}
225225

226226
// K rows x N columns x V tiles (in the K and N dimensions)
227-
void prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(global float* B, int rowStart, int colStart, int stride)
227+
void prefetch_b_rowmajor_32b_8x2r8x2c_sg16(global float* B, int rowStart, int colStart, int stride)
228228
{
229229
#if defined(PREFETCH_DEFAULT)
230230
uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride;
231231
prefetch(B + offset, 1);
232232
#endif // defined(PREFETCH_DEFAULT)
233233
}
234234

235-
void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int colStart, int stride)
235+
void store_c_rowmajor_fp32_1rNc(global float* C, float v, int rowStart, int colStart, int stride)
236236
{
237237
global uint* C_ui = (global uint*)C;
238238
uint v_ui = as_uint(v);
@@ -242,7 +242,7 @@ void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int col
242242
intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride;
243243
}
244244

245-
void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int colStart, int stride)
245+
void store_c_rowmajor_fp32_2rNc(global float* C, float2 v, int rowStart, int colStart, int stride)
246246
{
247247
global uint* C_ui = (global uint*)C;
248248
uint2 v_ui = as_uint2(v);
@@ -253,7 +253,7 @@ void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int co
253253
intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride;
254254
}
255255

256-
void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int colStart, int stride)
256+
void store_c_rowmajor_fp32_4rNc(global float* C, float4 v, int rowStart, int colStart, int stride)
257257
{
258258
global uint* C_ui = (global uint*)C;
259259
uint4 v_ui = as_uint4(v);
@@ -266,7 +266,7 @@ void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int co
266266
intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride;
267267
}
268268

269-
void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int colStart, int stride)
269+
void store_c_rowmajor_fp32_8rNc(global float* C, float8 v, int rowStart, int colStart, int stride)
270270
{
271271
global uint* C_ui = (global uint*)C;
272272
uint8 v_ui = as_uint8(v);
@@ -295,24 +295,6 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co
295295
// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes.
296296
// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data.
297297

298-
// Built-in functions are:
299-
300-
// #ifdef cl_intel_subgroup_extended_block_read
301-
// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord);
302-
// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord);
303-
// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord);
304-
// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord);
305-
// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord);
306-
// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord);
307-
// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord);
308-
// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord);
309-
// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord);
310-
// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord);
311-
// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord);
312-
// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord);
313-
// #endif //defined(cl_intel_subgroup_extended_block_read)
314-
315-
316298
// For intrinsics, the pattern is:
317299
// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat
318300
// - operation (optional): _transpose or _transform
@@ -332,7 +314,18 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co
332314
// - tile width: subgroup size (16)
333315
// - number of tiles: 1
334316

335-
// Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers:
317+
enum LSC_LDCC {
318+
LSC_LDCC_DEFAULT = 0,
319+
LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached
320+
LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached
321+
LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached
322+
LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached
323+
LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached
324+
LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached
325+
LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached
326+
};
327+
328+
// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers:
336329

337330
uint __builtin_IB_subgroup_block_read_flat_u32_m1k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
338331
uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
@@ -344,65 +337,70 @@ uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(long baseoffset, int wi
344337
uint4 __builtin_IB_subgroup_block_read_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
345338
uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
346339

347-
uint intel_subgroup_block_read_u32_m1k8(const __global void *base_address, int width, int height, int pitch, int2 coord)
340+
uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
341+
342+
void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data);
343+
void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data);
344+
void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data);
345+
void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data);
346+
347+
uint intel_sub_group_block_read_32b_1r8c(const __global void *base_address, int width, int height, int pitch, int2 coord)
348348
{
349349
return __builtin_IB_subgroup_block_read_flat_u32_m1k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
350350
}
351-
uint intel_subgroup_block_read_u32_m2k8(const __global void *base_address, int width, int height, int pitch, int2 coord)
351+
uint intel_sub_group_block_read_32b_2r8c(const __global void *base_address, int width, int height, int pitch, int2 coord)
352352
{
353353
return __builtin_IB_subgroup_block_read_flat_u32_m2k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo;
354354
}
355-
uint2 intel_subgroup_block_read_u32_m4k8(const __global void *base_address, int width, int height, int pitch, int2 coord)
355+
uint2 intel_sub_group_block_read_32b_4r8c(const __global void *base_address, int width, int height, int pitch, int2 coord)
356356
{
357357
return __builtin_IB_subgroup_block_read_flat_u32_m4k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo;
358358
}
359-
uint4 intel_subgroup_block_read_u32_m8k8(const __global void *base_address, int width, int height, int pitch, int2 coord)
359+
uint4 intel_sub_group_block_read_32b_8r8c(const __global void *base_address, int width, int height, int pitch, int2 coord)
360360
{
361361
return __builtin_IB_subgroup_block_read_flat_u32_m8k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo;
362362
}
363363

364-
uint intel_subgroup_block_read_u32_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
364+
uint intel_sub_group_block_read_32b_1r16c(const __global void *base_address, int width, int height, int pitch, int2 coord)
365365
{
366366
return __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
367367
}
368-
uint2 intel_subgroup_block_read_u32_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
368+
uint2 intel_sub_group_block_read_32b_2r16c(const __global void *base_address, int width, int height, int pitch, int2 coord)
369369
{
370370
return __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
371371
}
372-
uint4 intel_subgroup_block_read_u32_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
372+
uint4 intel_sub_group_block_read_32b_4r16c(const __global void *base_address, int width, int height, int pitch, int2 coord)
373373
{
374374
return __builtin_IB_subgroup_block_read_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
375375
}
376-
uint8 intel_subgroup_block_read_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord)
376+
uint8 intel_sub_group_block_read_32b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord)
377377
{
378378
return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
379379
}
380380

381-
uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord);
382-
383-
uint8 intel_subgroup_block_read_u32_m8k8v2(const __global void* base_address, int width, int height, int pitch, int2 coord)
381+
uint8 intel_sub_group_block_read_32b_8r8x2c(const __global void* base_address, int width, int height, int pitch, int2 coord)
384382
{
385383
return __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord);
386384
}
387385

388-
void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data);
389-
void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data);
390-
void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data);
391-
void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data);
392386

393-
void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data)
387+
#if !defined(BLOCK_PREFETCH_CACHE_TYPE)
388+
#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C
389+
#endif
390+
391+
void intel_sub_group_block_write_32b_1r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint data)
394392
{
395393
__builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
396394
}
397-
void intel_subgroup_block_write_u32_m2k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data)
395+
void intel_sub_group_block_write_32b_2r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data)
398396
{
399397
__builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
400398
}
401-
void intel_subgroup_block_write_u32_m4k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data)
399+
void intel_sub_group_block_write_32b_4r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data)
402400
{
403401
__builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
404402
}
405-
void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data)
403+
void intel_sub_group_block_write_32b_8r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data)
406404
{
407405
__builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data);
408406
}

0 commit comments

Comments
 (0)