Skip to content

Commit eff9d19

Browse files
committed
enable support for the native tf32 dpas
1 parent d3d42f0 commit eff9d19

2 files changed

Lines changed: 2 additions & 12 deletions

File tree

samples/99_matrixexperimentstf32/main.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -517,17 +517,14 @@ int main(int argc, char** argv)
517517
auto minSubGroupSize = findMinSubGroupSize(device);
518518

519519
bool emulate_tN16 = true;
520-
if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) {
521-
printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize);
520+
if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate_tf32")) {
521+
printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate_tf32, min sub-group size is: %zu\n", minSubGroupSize);
522522
switch(minSubGroupSize) {
523523
case 16: emulate_tN16 = false; break;
524524
default: break;
525525
}
526526
}
527527

528-
printf("NOTE: dpas is unconditionally emulated, currently!\n");
529-
emulate_tN16 = true;
530-
531528
buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16);
532529

533530
printf("Config:\n");

samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc)
7171
{
7272
float res = acc;
7373

74-
#if 1
7574
res = fma(sub_group_broadcast(a, 0), b.s0, res);
7675
res = fma(sub_group_broadcast(a, 1), b.s1, res);
7776
res = fma(sub_group_broadcast(a, 2), b.s2, res);
@@ -80,12 +79,6 @@ float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc)
8079
res = fma(sub_group_broadcast(a, 5), b.s5, res);
8180
res = fma(sub_group_broadcast(a, 6), b.s6, res);
8281
res = fma(sub_group_broadcast(a, 7), b.s7, res);
83-
#else
84-
float __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(short a, int8 b, float acc);
85-
uint a_ui = as_uint(sub_group_shuffle(a, get_sub_group_local_id() / 2));
86-
short aData = get_sub_group_local_id() % 2 ? as_short2(a_ui).hi : as_short2(a_ui).lo;
87-
res = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(aData, as_int8(b), res);
88-
#endif
8982

9083
return res;
9184
}

0 commit comments

Comments
 (0)