Skip to content

Commit ae363d5

Browse files
Icohedronfarzonl
andauthored
[HLSL][Matrix] Make Matrix InitListExprs and AST row-major order, and respect /Zpr and /Zpc in codegen (llvm#182904)
Fixes llvm#166410 and llvm#181902 This PR makes matrix initializer lists be kept in row-major order in InitListExpr and the AST for HLSL by not reordering the element indices in `InitListChecker::CheckMatrixType` in `clang/lib/Sema/SemaInit.cpp`. This PR also makes the codegen respect /Zpr and /Zpc during codegen for matrix initializer lists by adding a vector shuffle to `VisitInitListExpr` in `clang/lib/CodeGen/CGExprScalar.cpp`. Assisted-by: claude-opus-4.6 --------- Co-authored-by: Farzon Lotfi <farzonl@gmail.com>
1 parent 28d294e commit ae363d5

10 files changed

Lines changed: 291 additions & 129 deletions

clang/include/clang/AST/TypeBase.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4418,6 +4418,45 @@ class ConstantMatrixType final : public MatrixType {
44184418
return getNumRows() * getNumColumns();
44194419
}
44204420

4421+
/// Returns the row-major flattened index of a matrix element located at row
4422+
/// \p Row, and column \p Column
4423+
unsigned getRowMajorFlattenedIndex(unsigned Row, unsigned Column) const {
4424+
return Row * NumColumns + Column;
4425+
}
4426+
4427+
/// Returns the column-major flattened index of a matrix element located at
4428+
/// row \p Row, and column \p Column
4429+
unsigned getColumnMajorFlattenedIndex(unsigned Row, unsigned Column) const {
4430+
return Column * NumRows + Row;
4431+
}
4432+
4433+
/// Returns the flattened index of a matrix element located at
4434+
/// row \p Row, and column \p Column. If \p IsRowMajor is true, returns the
4435+
/// row-major order flattened index. Otherwise, returns the column-major order
4436+
/// flattened index.
4437+
unsigned getFlattenedIndex(unsigned Row, unsigned Column,
4438+
bool IsRowMajor = false) {
4439+
return IsRowMajor ? getRowMajorFlattenedIndex(Row, Column)
4440+
: getColumnMajorFlattenedIndex(Row, Column);
4441+
}
4442+
4443+
/// Given a column-major flattened index \p ColumnMajorIdx, return the
4444+
/// equivalent row-major flattened index.
4445+
unsigned
4446+
mapColumnMajorToRowMajorFlattenedIndex(unsigned ColumnMajorIdx) const {
4447+
unsigned Column = ColumnMajorIdx / NumRows;
4448+
unsigned Row = ColumnMajorIdx % NumRows;
4449+
return Row * NumColumns + Column;
4450+
}
4451+
4452+
/// Given a row-major flattened index \p RowMajorIdx, return the equivalent
4453+
/// column-major flattened index.
4454+
unsigned mapRowMajorToColumnMajorFlattenedIndex(unsigned RowMajorIdx) const {
4455+
unsigned Row = RowMajorIdx / NumColumns;
4456+
unsigned Column = RowMajorIdx % NumColumns;
4457+
return Column * NumRows + Row;
4458+
}
4459+
44214460
void Profile(llvm::FoldingSetNodeID &ID) {
44224461
Profile(ID, getElementType(), getNumRows(), getNumColumns(),
44234462
getTypeClass());

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,6 +2450,20 @@ Value *ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
24502450
llvm::Value *Init = llvm::Constant::getNullValue(EltTy);
24512451
V = Builder.CreateInsertElement(V, Init, Idx, "vecinit");
24522452
}
2453+
2454+
// Matrix initializer lists are in row-major order but the memory layout for
2455+
// codegen is determined by the -fmatrix-memory-layout flag (default:
2456+
// column-major). When the memory layout is column-major, we need to shuffle
2457+
// the elements from row-major to column-major order.
2458+
if (const auto *MT = E->getType()->getAs<ConstantMatrixType>();
2459+
MT && CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
2460+
LangOptions::MatrixMemoryLayout::MatrixColMajor) {
2461+
SmallVector<int, 16> Mask;
2462+
for (unsigned I = 0, N = MT->getNumElementsFlattened(); I < N; ++I)
2463+
Mask.push_back(MT->mapColumnMajorToRowMajorFlattenedIndex(I));
2464+
V = Builder.CreateShuffleVector(V, Mask, "matrix.rowmajor2colmajor");
2465+
}
2466+
24532467
return V;
24542468
}
24552469

clang/lib/Sema/SemaInit.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,17 +1910,16 @@ void InitListChecker::CheckMatrixType(const InitializedEntity &Entity,
19101910
QualType ElemTy = MT->getElementType();
19111911

19121912
Index = 0;
1913-
InitializedEntity ElemEnt =
1913+
InitializedEntity Element =
19141914
InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity);
19151915

19161916
while (Index < IList->getNumInits()) {
19171917
// Not a sublist: just consume directly.
1918-
unsigned ColMajorIndex = (Index % MT->getNumRows()) * MT->getNumColumns() +
1919-
(Index / MT->getNumRows());
1920-
ElemEnt.setElementIndex(ColMajorIndex);
1921-
CheckSubElementType(ElemEnt, IList, ElemTy, ColMajorIndex, StructuredList,
1918+
// Note: In HLSL, elements of the InitListExpr are in row-major order, so no
1919+
// change is needed to the Index.
1920+
Element.setElementIndex(Index);
1921+
CheckSubElementType(Element, IList, ElemTy, Index, StructuredList,
19221922
StructuredIndex);
1923-
++Index;
19241923
}
19251924
}
19261925

clang/test/AST/HLSL/matrix-constructors.hlsl

Lines changed: 46 additions & 45 deletions
Large diffs are not rendered by default.

clang/test/AST/HLSL/matrix-general-initializer.hlsl

Lines changed: 45 additions & 45 deletions
Large diffs are not rendered by default.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -finclude-default-header -o - %s | FileCheck %s
2+
3+
// This test verifies that matrix initializer lists in HLSL use row-major
4+
// element ordering. The elements in the AST InitListExpr remain in
5+
// row-major order as written in the source code.
6+
7+
// The AST InitListExpr preserves this row-major source order.
8+
// CHECK: VarDecl {{.*}} m2x2 'float2x2':'matrix<float, 2, 2>' cinit
9+
// CHECK-NEXT: InitListExpr {{.*}} 'float2x2':'matrix<float, 2, 2>'
10+
// CHECK-NEXT: FloatingLiteral {{.*}} 'float' 1.000000e+00
11+
// CHECK-NEXT: FloatingLiteral {{.*}} 'float' 2.000000e+00
12+
// CHECK-NEXT: FloatingLiteral {{.*}} 'float' 3.000000e+00
13+
// CHECK-NEXT: FloatingLiteral {{.*}} 'float' 4.000000e+00
14+
export void test_2x2() {
15+
float2x2 m2x2 = {1.0, 2.0, 3.0, 4.0};
16+
}
17+
18+
// CHECK: VarDecl {{.*}} m2x3 'int2x3':'matrix<int, 2, 3>' cinit
19+
// CHECK-NEXT: InitListExpr {{.*}} 'int2x3':'matrix<int, 2, 3>'
20+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 1
21+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 2
22+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 3
23+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4
24+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 5
25+
// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 6
26+
export void test_2x3() {
27+
int2x3 m2x3 = {1, 2, 3, 4, 5, 6};
28+
}
29+
30+
// CHECK: VarDecl {{.*}} m3x2 'bool3x2':'matrix<bool, 3, 2>' cinit
31+
// CHECK-NEXT: InitListExpr {{.*}} 'bool3x2':'matrix<bool, 3, 2>'
32+
// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' true
33+
// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' false
34+
// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' false
35+
// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' true
36+
// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' true
37+
// CHECK-NEXT: CXXBoolLiteralExpr {{.*}} 'bool' true
38+
export void test_3x2() {
39+
bool3x2 m3x2 = {true, false, false, true, true, true};
40+
}

clang/test/CodeGenHLSL/BasicFeatures/MatrixConstructor.hlsl

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ float3x2 case1() {
1313
// vec[3] = 1
1414
// vec[4] = 3
1515
// vec[5] = 5
16-
return float3x2(0, 1,
16+
return float3x2(0, 1,
1717
2, 3,
1818
4, 5);
1919
}
@@ -24,25 +24,26 @@ RWStructuredBuffer<float> In;
2424
// CHECK-LABEL: define hidden noundef nofpclass(nan inf) <6 x float> @_Z5case2v(
2525
// CHECK-SAME: ) #[[ATTR0]] {
2626
// CHECK-NEXT: [[ENTRY:.*:]]
27-
// CHECK-NEXT: [[CALL:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 0) #[[ATTR3:[0-9]+]]
28-
// CHECK-NEXT: [[CALL1:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 1) #[[ATTR3]]
29-
// CHECK-NEXT: [[CALL2:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 2) #[[ATTR3]]
30-
// CHECK-NEXT: [[CALL3:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 3) #[[ATTR3]]
31-
// CHECK-NEXT: [[CALL4:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 4) #[[ATTR3]]
32-
// CHECK-NEXT: [[CALL5:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 5) #[[ATTR3]]
27+
// CHECK-NEXT: [[CALL:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 0) #[[ATTR4:[0-9]+]]
28+
// CHECK-NEXT: [[CALL1:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 1) #[[ATTR4]]
29+
// CHECK-NEXT: [[CALL2:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 2) #[[ATTR4]]
30+
// CHECK-NEXT: [[CALL3:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 3) #[[ATTR4]]
31+
// CHECK-NEXT: [[CALL4:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 4) #[[ATTR4]]
32+
// CHECK-NEXT: [[CALL5:%.*]] = call noundef nonnull align 4 dereferenceable(4) ptr @_ZN4hlsl18RWStructuredBufferIfEixEj(ptr noundef nonnull align 4 dereferenceable(8) @_ZL2In, i32 noundef 5) #[[ATTR4]]
3333
// CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[CALL]], align 4
3434
// CHECK-NEXT: [[VECINIT:%.*]] = insertelement <6 x float> poison, float [[TMP0]], i32 0
35-
// CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[CALL2]], align 4
35+
// CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[CALL1]], align 4
3636
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <6 x float> [[VECINIT]], float [[TMP1]], i32 1
37-
// CHECK-NEXT: [[TMP2:%.*]] = load float, ptr [[CALL4]], align 4
37+
// CHECK-NEXT: [[TMP2:%.*]] = load float, ptr [[CALL2]], align 4
3838
// CHECK-NEXT: [[VECINIT7:%.*]] = insertelement <6 x float> [[VECINIT6]], float [[TMP2]], i32 2
39-
// CHECK-NEXT: [[TMP3:%.*]] = load float, ptr [[CALL1]], align 4
39+
// CHECK-NEXT: [[TMP3:%.*]] = load float, ptr [[CALL3]], align 4
4040
// CHECK-NEXT: [[VECINIT8:%.*]] = insertelement <6 x float> [[VECINIT7]], float [[TMP3]], i32 3
41-
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[CALL3]], align 4
41+
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[CALL4]], align 4
4242
// CHECK-NEXT: [[VECINIT9:%.*]] = insertelement <6 x float> [[VECINIT8]], float [[TMP4]], i32 4
4343
// CHECK-NEXT: [[TMP5:%.*]] = load float, ptr [[CALL5]], align 4
4444
// CHECK-NEXT: [[VECINIT10:%.*]] = insertelement <6 x float> [[VECINIT9]], float [[TMP5]], i32 5
45-
// CHECK-NEXT: ret <6 x float> [[VECINIT10]]
45+
// CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <6 x float> [[VECINIT10]], <6 x float> poison, <6 x i32> <i32 0, i32 2, i32 4, i32 1, i32 3, i32 5>
46+
// CHECK-NEXT: ret <6 x float> [[MATRIX_ROWMAJOR2COLMAJOR]]
4647
//
4748
float3x2 case2() {
4849
// vec[0] = Call
@@ -51,7 +52,7 @@ float3x2 case2() {
5152
// vec[3] = Call1
5253
// vec[4] = Call3
5354
// vec[5] = Call5
54-
return float3x2(In[0], In[1],
55+
return float3x2(In[0], In[1],
5556
In[2], In[3],
5657
In[4], In[5]);
5758
}
@@ -68,28 +69,29 @@ float3x2 case2() {
6869
// CHECK-NEXT: [[VECEXT:%.*]] = extractelement <3 x float> [[TMP0]], i64 0
6970
// CHECK-NEXT: [[VECINIT:%.*]] = insertelement <6 x float> poison, float [[VECEXT]], i32 0
7071
// CHECK-NEXT: [[TMP1:%.*]] = load <3 x float>, ptr [[A_ADDR]], align 16
71-
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <3 x float> [[TMP1]], i64 2
72+
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <3 x float> [[TMP1]], i64 1
7273
// CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <6 x float> [[VECINIT]], float [[VECEXT1]], i32 1
73-
// CHECK-NEXT: [[TMP2:%.*]] = load <3 x float>, ptr [[B_ADDR]], align 16
74-
// CHECK-NEXT: [[VECEXT3:%.*]] = extractelement <3 x float> [[TMP2]], i64 1
74+
// CHECK-NEXT: [[TMP2:%.*]] = load <3 x float>, ptr [[A_ADDR]], align 16
75+
// CHECK-NEXT: [[VECEXT3:%.*]] = extractelement <3 x float> [[TMP2]], i64 2
7576
// CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <6 x float> [[VECINIT2]], float [[VECEXT3]], i32 2
76-
// CHECK-NEXT: [[TMP3:%.*]] = load <3 x float>, ptr [[A_ADDR]], align 16
77-
// CHECK-NEXT: [[VECEXT5:%.*]] = extractelement <3 x float> [[TMP3]], i64 1
77+
// CHECK-NEXT: [[TMP3:%.*]] = load <3 x float>, ptr [[B_ADDR]], align 16
78+
// CHECK-NEXT: [[VECEXT5:%.*]] = extractelement <3 x float> [[TMP3]], i64 0
7879
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <6 x float> [[VECINIT4]], float [[VECEXT5]], i32 3
7980
// CHECK-NEXT: [[TMP4:%.*]] = load <3 x float>, ptr [[B_ADDR]], align 16
80-
// CHECK-NEXT: [[VECEXT7:%.*]] = extractelement <3 x float> [[TMP4]], i64 0
81+
// CHECK-NEXT: [[VECEXT7:%.*]] = extractelement <3 x float> [[TMP4]], i64 1
8182
// CHECK-NEXT: [[VECINIT8:%.*]] = insertelement <6 x float> [[VECINIT6]], float [[VECEXT7]], i32 4
8283
// CHECK-NEXT: [[TMP5:%.*]] = load <3 x float>, ptr [[B_ADDR]], align 16
8384
// CHECK-NEXT: [[VECEXT9:%.*]] = extractelement <3 x float> [[TMP5]], i64 2
8485
// CHECK-NEXT: [[VECINIT10:%.*]] = insertelement <6 x float> [[VECINIT8]], float [[VECEXT9]], i32 5
85-
// CHECK-NEXT: ret <6 x float> [[VECINIT10]]
86+
// CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <6 x float> [[VECINIT10]], <6 x float> poison, <6 x i32> <i32 0, i32 2, i32 4, i32 1, i32 3, i32 5>
87+
// CHECK-NEXT: ret <6 x float> [[MATRIX_ROWMAJOR2COLMAJOR]]
8688
//
8789
float3x2 case3(float3 a, float3 b) {
8890
// vec[0] = A[0]
89-
// vec[1] = A[2]
90-
// vec[2] = B[1]
91-
// vec[3] = A[1]
92-
// vec[4] = B[0]
91+
// vec[1] = A[1]
92+
// vec[2] = A[2]
93+
// vec[3] = B[0]
94+
// vec[4] = B[1]
9395
// vec[5] = B[2]
9496
return float3x2(a,b);
9597
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -disable-llvm-passes \
2+
// RUN: -emit-llvm -finclude-default-header -o - %s | FileCheck %s --check-prefix=CHECK,COL-CHECK
3+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -disable-llvm-passes \
4+
// RUN: -emit-llvm -finclude-default-header -fmatrix-memory-layout=row-major -o - %s \
5+
// RUN: | FileCheck %s --check-prefix=CHECK,ROW-CHECK
6+
7+
// Verify that matrix initializer lists store elements in the correct memory
8+
// layout. The initializer list {1,2,3,4,5,6} for a float2x3 (2 rows, 3 cols)
9+
// is in row-major order: row0=[1,2,3], row1=[4,5,6].
10+
//
11+
// With column-major (default) memory layout, the stored vector should be
12+
// reordered to: col0=[1,4], col1=[2,5], col2=[3,6] = <1,4,2,5,3,6>.
13+
//
14+
// With row-major memory layout, the stored vector stays as-is: <1,2,3,4,5,6>.
15+
16+
export float test_row0_col2() {
17+
// CHECK-LABEL: define {{.*}} float @_Z14test_row0_col2v
18+
// COL-CHECK: store <6 x float> <float 1.000000e+00, float 4.000000e+00, float 2.000000e+00, float 5.000000e+00, float 3.000000e+00, float 6.000000e+00>
19+
// COL-CHECK: extractelement <6 x float> %{{.*}}, i32 4
20+
// ROW-CHECK: store <6 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00, float 5.000000e+00, float 6.000000e+00>
21+
// ROW-CHECK: extractelement <6 x float> %{{.*}}, i32 2
22+
float2x3 M = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
23+
// Row 0, Col 2 in row-major is the 3rd element = 3.0
24+
return M[0][2];
25+
}
26+
27+
export float test_row1_col0() {
28+
// CHECK-LABEL: define {{.*}} float @_Z14test_row1_col0v
29+
// COL-CHECK: store <6 x float> <float 1.000000e+00, float 4.000000e+00, float 2.000000e+00, float 5.000000e+00, float 3.000000e+00, float 6.000000e+00>
30+
// COL-CHECK: extractelement <6 x float> %{{.*}}, i32 1
31+
// ROW-CHECK: store <6 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00, float 5.000000e+00, float 6.000000e+00>
32+
// ROW-CHECK: extractelement <6 x float> %{{.*}}, i32 3
33+
float2x3 M = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
34+
// Row 1, Col 0 in row-major is the 4th element = 4.0
35+
return M[1][0];
36+
}
37+
38+
// Verify the shuffle is emitted for non-constant init lists when the memory
39+
// layout is column-major, and not emitted when it is row-major.
40+
41+
export float2x3 test_dynamic(float a, float b, float c,
42+
float d, float e, float f) {
43+
// CHECK-LABEL: define {{.*}} <6 x float> @_Z12test_dynamicffffff
44+
// CHECK: [[A:%.*]] = load float, ptr %a.addr
45+
// CHECK: [[VECINIT0:%.*]] = insertelement <6 x float> poison, float [[A]], i32 0
46+
// CHECK: [[B:%.*]] = load float, ptr %b.addr
47+
// CHECK: [[VECINIT1:%.*]] = insertelement <6 x float> [[VECINIT0]], float [[B]], i32 1
48+
// CHECK: [[C:%.*]] = load float, ptr %c.addr
49+
// CHECK: [[VECINIT2:%.*]] = insertelement <6 x float> [[VECINIT1]], float [[C]], i32 2
50+
// CHECK: [[D:%.*]] = load float, ptr %d.addr
51+
// CHECK: [[VECINIT3:%.*]] = insertelement <6 x float> [[VECINIT2]], float [[D]], i32 3
52+
// CHECK: [[E:%.*]] = load float, ptr %e.addr
53+
// CHECK: [[VECINIT4:%.*]] = insertelement <6 x float> [[VECINIT3]], float [[E]], i32 4
54+
// CHECK: [[F:%.*]] = load float, ptr %f.addr
55+
// CHECK: [[VECINIT5:%.*]] = insertelement <6 x float> [[VECINIT4]], float [[F]], i32 5
56+
// COL-CHECK: shufflevector <6 x float> [[VECINIT5]], <6 x float> poison, <6 x i32> <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
57+
// ROW-CHECK-NOT: shufflevector
58+
// ROW-CHECK: store <6 x float> [[VECINIT5]], ptr
59+
return (float2x3){a, b, c, d, e, f};
60+
}

clang/test/CodeGenHLSL/BasicFeatures/MatrixToAndFromVectorConstructors.hlsl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,17 @@ float4 fn(float2x2 m) {
4040
// CHECK-NEXT: [[VECEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i64 0
4141
// CHECK-NEXT: [[VECINIT:%.*]] = insertelement <4 x i32> poison, i32 [[VECEXT]], i32 0
4242
// CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[V_ADDR]], align 16
43-
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <4 x i32> [[TMP1]], i64 2
43+
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <4 x i32> [[TMP1]], i64 1
4444
// CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <4 x i32> [[VECINIT]], i32 [[VECEXT1]], i32 1
4545
// CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, ptr [[V_ADDR]], align 16
46-
// CHECK-NEXT: [[VECEXT3:%.*]] = extractelement <4 x i32> [[TMP2]], i64 1
46+
// CHECK-NEXT: [[VECEXT3:%.*]] = extractelement <4 x i32> [[TMP2]], i64 2
4747
// CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <4 x i32> [[VECINIT2]], i32 [[VECEXT3]], i32 2
4848
// CHECK-NEXT: [[TMP3:%.*]] = load <4 x i32>, ptr [[V_ADDR]], align 16
4949
// CHECK-NEXT: [[VECEXT5:%.*]] = extractelement <4 x i32> [[TMP3]], i64 3
5050
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <4 x i32> [[VECINIT4]], i32 [[VECEXT5]], i32 3
51-
// CHECK-NEXT: store <4 x i32> [[VECINIT6]], ptr [[M]], align 4
51+
// COL-CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <4 x i32> [[VECINIT6]], <4 x i32> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
52+
// COL-CHECK-NEXT: store <4 x i32> [[MATRIX_ROWMAJOR2COLMAJOR]], ptr [[M]], align 4
53+
// ROW-CHECK-NEXT: store <4 x i32> [[VECINIT6]], ptr [[M]], align 4
5254
// CHECK-NEXT: [[TMP4:%.*]] = load <4 x i32>, ptr [[M]], align 4
5355
// CHECK-NEXT: ret <4 x i32> [[TMP4]]
5456
//
@@ -68,7 +70,9 @@ int2x2 fn(int4 v) {
6870
// CHECK-NEXT: [[TMP1:%.*]] = load <2 x i32>, ptr [[V_ADDR]], align 8
6971
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <2 x i32> [[TMP1]], i64 1
7072
// CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <2 x i32> [[VECINIT]], i32 [[VECEXT1]], i32 1
71-
// CHECK-NEXT: ret <2 x i32> [[VECINIT2]]
73+
// COL-CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <2 x i32> [[VECINIT2]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
74+
// COL-CHECK-NEXT: ret <2 x i32> [[MATRIX_ROWMAJOR2COLMAJOR]]
75+
// ROW-CHECK-NEXT: ret <2 x i32> [[VECINIT2]]
7276
//
7377
int1x2 fn1(int2 v) {
7478
return v;
@@ -92,7 +96,9 @@ int1x2 fn1(int2 v) {
9296
// CHECK-NEXT: [[LOADEDV4:%.*]] = trunc <3 x i32> [[TMP3]] to <3 x i1>
9397
// CHECK-NEXT: [[VECEXT5:%.*]] = extractelement <3 x i1> [[LOADEDV4]], i64 2
9498
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <3 x i1> [[VECINIT3]], i1 [[VECEXT5]], i32 2
95-
// CHECK-NEXT: ret <3 x i1> [[VECINIT6]]
99+
// COL-CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <3 x i1> [[VECINIT6]], <3 x i1> poison, <3 x i32> <i32 0, i32 1, i32 2>
100+
// COL-CHECK-NEXT: ret <3 x i1> [[MATRIX_ROWMAJOR2COLMAJOR]]
101+
// ROW-CHECK-NEXT: ret <3 x i1> [[VECINIT6]]
96102
//
97103
bool3x1 fn2(bool3 b) {
98104
return b;

0 commit comments

Comments
 (0)