Skip to content

Commit 62a1fd8

Browse files
committed
switch to production 2d block io functions
1 parent ace054e commit 62a1fd8

6 files changed

Lines changed: 103 additions & 392 deletions

File tree

samples/99_matrixexperiments/matrix_helpers.cl

Lines changed: 0 additions & 307 deletions
Large diffs are not rendered by default.

samples/99_matrixexperiments/matrix_kernel_tiled.cl

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float
405405
}
406406
}
407407

408-
#ifdef cl_intel_subgroup_extended_block_read
408+
#ifdef cl_intel_subgroup_2d_block_io
409409

410410
void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM])
411411
{
@@ -415,49 +415,52 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in
415415
//if (get_sub_group_local_id() == 0) {
416416
// printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM);
417417
//}
418-
ushort8 tmp[2][4];
419-
intel_sub_group_block_read_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp);
418+
short8 aTemp[2][4];
419+
intel_sub_group_2d_block_read_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp);
420420
for (int tkk = 0; tkk < 2; tkk++) {
421421
for (int tmm = 0; tmm < 4; tmm++) {
422-
aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]);
422+
aData[kk + tkk][mm + tmm] = aTemp[tkk][tmm];
423423
}
424424
}
425425
}
426426
}
427427
} else if (KK % 2 == 0 & MM % 2 == 0) {
428428
for (int kk = 0; kk < KK; kk+=2) {
429429
for (int mm = 0; mm < MM; mm+=2) {
430-
ushort8 tmp[2][2];
431-
intel_sub_group_block_read_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp);
430+
short8 aTemp[2][2];
431+
intel_sub_group_2d_block_read_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp);
432432
for (int tkk = 0; tkk < 2; tkk++) {
433433
for (int tmm = 0; tmm < 2; tmm++) {
434-
aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]);
434+
aData[kk + tkk][mm + tmm] = aTemp[tkk][tmm];
435435
}
436436
}
437437
}
438438
}
439439
} else if (KK % 2 == 0) {
440440
for (int kk = 0; kk < KK; kk+=2) {
441441
for (int mm = 0; mm < MM; mm++) {
442-
short16 aTemp = as_short16(intel_sub_group_block_read_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
443-
aData[kk + 0][mm] = aTemp.lo;
444-
aData[kk + 1][mm] = aTemp.hi;
442+
short8 aTemp[2];
443+
intel_sub_group_2d_block_read_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp);
444+
aData[kk + 0][mm] = aTemp[0];
445+
aData[kk + 1][mm] = aTemp[1];
445446
}
446447
}
447448
} else if (MM % 4 == 0) {
448449
for (int kk = 0; kk < KK; kk++) {
449450
for (int mm = 0; mm < MM; mm+=4) {
450-
ushort8 tmp[4];
451-
intel_sub_group_block_read_16b_32r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp);
451+
short8 aTemp[4];
452+
intel_sub_group_2d_block_read_16b_32r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp);
452453
for (int tmm = 0; tmm < 4; tmm++) {
453-
aData[kk][mm + tmm] = as_short8(tmp[tmm]);
454+
aData[kk][mm + tmm] = aTemp[tmm];
454455
}
455456
}
456457
}
457458
} else {
458459
for (int kk = 0; kk < KK; kk++) {
459460
for (int mm = 0; mm < MM; mm++) {
460-
aData[kk][mm] = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)));
461+
short8 aTemp[1];
462+
intel_sub_group_2d_block_read_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp);
463+
aData[kk][mm] = aTemp[0];
461464
}
462465
}
463466
}
@@ -471,35 +474,39 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in
471474
//if (get_sub_group_local_id() == 0) {
472475
// printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK);
473476
//}
474-
int8 tmp[2][2];
475-
intel_sub_group_block_read_transform_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp);
477+
int8 bTemp[2][2];
478+
intel_sub_group_2d_block_read_transform_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp);
476479
for (int tnn = 0; tnn < 2; tnn++) {
477480
for (int tkk = 0; tkk < 2; tkk++) {
478-
bData[nn + tnn][kk + tkk] = tmp[tnn][tkk];
481+
bData[nn + tnn][kk + tkk] = bTemp[tnn][tkk];
479482
}
480483
}
481484
}
482485
}
483486
} else if (NN % 2 == 0) {
484487
for (int kk = 0; kk < KK; kk++) {
485488
for (int nn = 0; nn < NN; nn+=2) {
486-
int16 bTemp = intel_sub_group_block_read_transform_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
487-
bData[nn + 0][kk] = bTemp.lo;
488-
bData[nn + 1][kk] = bTemp.hi;
489+
int8 bTemp[2];
490+
intel_sub_group_2d_block_read_transform_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp);
491+
bData[nn + 0][kk] = bTemp[0];
492+
bData[nn + 1][kk] = bTemp[1];
489493
}
490494
}
491495
} else if (KK % 2 == 0) {
492496
for (int kk = 0; kk < KK; kk+=2) {
493497
for (int nn = 0; nn < NN; nn++) {
494-
int16 bTemp = intel_sub_group_block_read_transform_16b_32r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
495-
bData[nn][kk + 0] = bTemp.lo;
496-
bData[nn][kk + 1] = bTemp.hi;
498+
int8 bTemp[2];
499+
intel_sub_group_2d_block_read_transform_16b_32r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp);
500+
bData[nn][kk + 0] = bTemp[0];
501+
bData[nn][kk + 1] = bTemp[1];
497502
}
498503
}
499504
} else {
500505
for (int kk = 0; kk < KK; kk++) {
501506
for (int nn = 0; nn < NN; nn++) {
502-
bData[nn][kk] = intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
507+
int8 bTemp[1];
508+
intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp);
509+
bData[nn][kk] = bTemp[0];
503510
}
504511
}
505512
}
@@ -510,15 +517,18 @@ void HELPER_NAME(btile_block_load_packed, MM, NN)(global ushort* B, int tN, int
510517
if (KK % 2 == 0) {
511518
for (int kk = 0; kk < KK; kk+=2) {
512519
for (int nn = 0; nn < NN; nn++) {
513-
int16 bTemp = as_int16(intel_sub_group_block_read_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)));
514-
bData[nn][kk + 0] = bTemp.lo;
515-
bData[nn][kk + 1] = bTemp.hi;
520+
int8 bTemp[2];
521+
intel_sub_group_2d_block_read_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2), (uint*)bTemp);
522+
bData[nn][kk + 0] = bTemp[0];
523+
bData[nn][kk + 1] = bTemp[1];
516524
}
517525
}
518526
} else {
519527
for (int kk = 0; kk < KK; kk++) {
520528
for (int nn = 0; nn < NN; nn++) {
521-
bData[nn][kk] = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)));
529+
int8 bTemp[1];
530+
intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2), (uint*)bTemp);
531+
bData[nn][kk] = bTemp[0];
522532
}
523533
}
524534
}
@@ -533,39 +543,35 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM
533543
//if (get_sub_group_local_id() == 0) {
534544
// printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM);
535545
//}
536-
#ifdef USE_32C
537-
intel_sub_group_block_prefetch_16b_8r32c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
538-
#else
539-
intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
540-
#endif
546+
intel_sub_group_2d_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
541547
} else if (KK % 2 == 0 & MM % 4 == 0) {
542548
for (int kk = 0; kk < KK; kk+=2) {
543549
for (int mm = 0; mm < MM; mm+=4) {
544-
intel_sub_group_block_prefetch_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
550+
intel_sub_group_2d_block_prefetch_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
545551
}
546552
}
547553
} else if (KK % 2 == 0 & MM % 2 == 0) {
548554
for (int kk = 0; kk < KK; kk+=2) {
549555
for (int mm = 0; mm < MM; mm+=2) {
550-
intel_sub_group_block_prefetch_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
556+
intel_sub_group_2d_block_prefetch_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
551557
}
552558
}
553559
} else if (KK % 2 == 0) {
554560
for (int kk = 0; kk < KK; kk+=2) {
555561
for (int mm = 0; mm < MM; mm++) {
556-
intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
562+
intel_sub_group_2d_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
557563
}
558564
}
559565
} else if (MM % 4 == 0) {
560566
for (int kk = 0; kk < KK; kk++) {
561567
for (int mm = 0; mm < MM; mm+=4) {
562-
intel_sub_group_block_prefetch_16b_32r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
568+
intel_sub_group_2d_block_prefetch_16b_32r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
563569
}
564570
}
565571
} else {
566572
for (int kk = 0; kk < KK; kk++) {
567573
for (int mm = 0; mm < MM; mm++) {
568-
intel_sub_group_block_prefetch_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
574+
intel_sub_group_2d_block_prefetch_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
569575
}
570576
}
571577
}
@@ -580,33 +586,29 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN
580586
//if (get_sub_group_local_id() == 0) {
581587
// printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK);
582588
//}
583-
#ifdef USE_32C
584-
intel_sub_group_block_prefetch_16b_16r32c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
585-
#else
586-
intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
587-
#endif
589+
intel_sub_group_2d_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
588590
} else if (KK % 2 == 0 & NN % 2 == 0) {
589591
for (int kk = 0; kk < KK; kk+=2) {
590592
for (int nn = 0; nn < NN; nn += 2) {
591-
intel_sub_group_block_prefetch_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
593+
intel_sub_group_2d_block_prefetch_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
592594
}
593595
}
594596
} else if (NN % 2 == 0) {
595597
for (int kk = 0; kk < KK; kk++) {
596598
for (int nn = 0; nn < NN; nn+=2) {
597-
intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
599+
intel_sub_group_2d_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
598600
}
599601
}
600602
} else if (KK % 2 == 0) {
601603
for (int kk = 0; kk < KK; kk+=2) {
602604
for (int nn = 0; nn < NN; nn++) {
603-
intel_sub_group_block_prefetch_16b_32r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
605+
intel_sub_group_2d_block_prefetch_16b_32r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
604606
}
605607
}
606608
} else {
607609
for (int kk = 0; kk < KK; kk++) {
608610
for (int nn = 0; nn < NN; nn++) {
609-
intel_sub_group_block_prefetch_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
611+
intel_sub_group_2d_block_prefetch_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK));
610612
}
611613
}
612614
}
@@ -618,17 +620,17 @@ void HELPER_NAME(btile_block_prefetch_packed, MM, NN)(global ushort* B, int tN,
618620
const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y)
619621
const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3
620622
const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0
621-
intel_sub_group_block_prefetch_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
623+
intel_sub_group_2d_block_prefetch_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
622624
} else if (KK % 2 == 0) {
623625
for (int kk = 0; kk < KK; kk+=2) {
624626
for (int nn = 0; nn < NN; nn++) {
625-
intel_sub_group_block_prefetch_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
627+
intel_sub_group_2d_block_prefetch_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
626628
}
627629
}
628630
} else {
629631
for (int kk = 0; kk < KK; kk++) {
630632
for (int nn = 0; nn < NN; nn++) {
631-
intel_sub_group_block_prefetch_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
633+
intel_sub_group_2d_block_prefetch_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2));
632634
}
633635
}
634636
}
@@ -689,7 +691,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
689691
for (int mm = 0; mm < MM; mm++) {
690692
for (int nn = 0; nn < NN; nn++) {
691693
sum[nn][mm] = activation(sum[nn][mm]);
692-
intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm]));
694+
intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), (uint*)&sum[nn][mm]);
693695
}
694696
}
695697
}
@@ -750,9 +752,9 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
750752
for (int mm = 0; mm < MM; mm++) {
751753
for (int nn = 0; nn < NN; nn++) {
752754
sum[nn][mm] = activation(sum[nn][mm]);
753-
intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm]));
755+
intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), (uint*)&sum[nn][mm]);
754756
}
755757
}
756758
}
757759

758-
#endif // cl_intel_subgroup_extended_block_read
760+
#endif // cl_intel_subgroup_2d_block_io

0 commit comments

Comments
 (0)