Skip to content

Commit ed3d3bf

Browse files
[mlir][spirv] Add first 3 data layout ops in TOSA Ext Inst Set (llvm#187714)
This patch introduces the following reduction operators: spirv.Tosa.Concat spirv.Tosa.Pad spirv.Tosa.Reshape Also dialect and serialization round-trip tests have been added. Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
1 parent be31cff commit ed3d3bf

5 files changed

Lines changed: 444 additions & 0 deletions

File tree

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,4 +2125,150 @@ def SPIRV_TosaReduceSumOp : SPIRV_TosaReductionOp<"ReduceSum", 53, [NoMemoryEffe
21252125
}
21262126

21272127

2128+
def SPIRV_TosaConcatOp : SPIRV_TosaOpWithResult<"Concat", 54, [Pure,
2129+
VariadicInputWithMinSize<"input1", 1>,
2130+
VariadicInputAllSameElementType<"output", "input1">,
2131+
VariadicInputAllSameRank<"output", "input1">,
2132+
AxisValueLessThanRankOf<"output">]> {
2133+
let summary = "Concatenates tensors along one dimension.";
2134+
2135+
let description = [{
2136+
Concatenates a list of tensors along a given axis.
2137+
No data conversion happens during a concat operation.
2138+
2139+
References:
2140+
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_concat
2141+
* https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_concat
2142+
2143+
#### Example:
2144+
```mlir
2145+
%1 = spirv.Tosa.Concat axis = 2, %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8> -> !spirv.arm.tensor<12x13x12x14xi8>
2146+
%1 = spirv.Tosa.Concat axis = 1, %arg0, %arg1, %arg2 : !spirv.arm.tensor<40x31x19xf32>, !spirv.arm.tensor<40x15x19xf32>, !spirv.arm.tensor<40x16x19xf32> -> !spirv.arm.tensor<40x62x19xf32>
2147+
```
2148+
}];
2149+
2150+
let arguments = (ins
2151+
SPIRV_TensorArmAxisAttr: $axis,
2152+
Variadic<SPIRV_TosaAny_TensorArm>: $input1
2153+
);
2154+
2155+
let results = (outs
2156+
SPIRV_TosaAny_TensorArm: $output
2157+
);
2158+
2159+
let assemblyFormat = [{
2160+
`axis` `=` $axis `,`
2161+
$input1
2162+
attr-dict `:` type(operands) `->` type(results)
2163+
}];
2164+
2165+
let extraClassDeclaration = extraBaseClassDeclaration#[{
2166+
::mlir::TypeRange getInput1Types() {
2167+
return getInput1().getTypes();
2168+
}
2169+
}];
2170+
}
2171+
2172+
2173+
def SPIRV_TosaPadOp : SPIRV_TosaOpWithResult<"Pad", 55, [Pure,
2174+
AllElementTypesMatch<["input1", "pad_const", "output"]>,
2175+
AllRanksMatch<["input1", "output"]>,
2176+
ShapeConstraintFromInputRank<"input1", "padding", 2>]> {
2177+
let summary = "Pads a tensor with value specified.";
2178+
2179+
let description = [{
2180+
Pads a tensor along the borders of each dimension with a supplied value.
2181+
Returns a new tensor with the padding included. The pad_const value includes
2182+
the zero point if the tensor uses a zero point.
2183+
2184+
References:
2185+
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_pad
2186+
* https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_pad
2187+
2188+
#### Example:
2189+
```mlir
2190+
%2 = spirv.Tosa.Pad %arg0, %0, %1 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<21x19xi8>
2191+
%2 = spirv.Tosa.Pad %arg0, %0, %1 : !spirv.arm.tensor<2x9x2x3xf32>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<4x9x4x4xf32>
2192+
```
2193+
}];
2194+
2195+
let arguments = (ins
2196+
SPIRV_TosaAny_TensorArm: $input1,
2197+
SPIRV_Int32_1DTensorArmOfEvenLength2To12: $padding,
2198+
SPIRV_TosaAny_1DTensorArmOfLength1: $pad_const
2199+
);
2200+
2201+
let results = (outs
2202+
SPIRV_TosaAny_TensorArm: $output
2203+
);
2204+
2205+
let assemblyFormat = [{
2206+
$input1 `,`
2207+
$padding `,`
2208+
$pad_const
2209+
attr-dict `:` type(operands) `->` type(results)
2210+
}];
2211+
2212+
let extraClassDeclaration = extraBaseClassDeclaration#[{
2213+
::mlir::spirv::TensorArmType getInput1Type() {
2214+
return cast<::mlir::spirv::TensorArmType>(getInput1().getType());
2215+
}
2216+
::mlir::spirv::TensorArmType getPaddingType() {
2217+
return cast<::mlir::spirv::TensorArmType>(getPadding().getType());
2218+
}
2219+
::mlir::spirv::TensorArmType getPadConstType() {
2220+
return cast<::mlir::spirv::TensorArmType>(getPadConst().getType());
2221+
}
2222+
}];
2223+
}
2224+
2225+
2226+
def SPIRV_TosaReshapeOp : SPIRV_TosaOpWithResult<"Reshape", 56, [Pure,
2227+
AllElementTypesMatch<["input1", "output"]>,
2228+
AllElementCountsMatch<["input1", "output"]>,
2229+
ShapeConstraintFromInputRank<"output", "shape">]> {
2230+
let summary = "Reshape operator.";
2231+
2232+
let description = [{
2233+
Returns a tensor with the same type/values as the input, with a new shape
2234+
specified by the shape argument. Reshape may operate on tensors of any rank.
2235+
No data conversion happens during a reshape operation.
2236+
2237+
References:
2238+
* https://github.khronos.org/SPIRV-Registry/extended/TOSA.001000.1.html#_reshape
2239+
* https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reshape
2240+
2241+
#### Example:
2242+
```mlir
2243+
%1 = spirv.Tosa.Reshape %arg0, %newShape : !spirv.arm.tensor<25x6x29x35xi16>, !spirv.arm.tensor<4xi32> -> !spirv.arm.tensor<125x6x7x29xi16>
2244+
%1 = spirv.Tosa.Reshape %arg0, %newShape : !spirv.arm.tensor<1x2x7x2xf32>, !spirv.arm.tensor<3xi32> -> !spirv.arm.tensor<2x1x14xf32>
2245+
```
2246+
}];
2247+
2248+
let arguments = (ins
2249+
SPIRV_TosaAny_TensorArm: $input1,
2250+
SPIRV_Int32_1DTensorArmOfLength1To6: $shape
2251+
);
2252+
2253+
let results = (outs
2254+
SPIRV_TosaAny_TensorArm: $output
2255+
);
2256+
2257+
let assemblyFormat = [{
2258+
$input1 `,`
2259+
$shape
2260+
attr-dict `:` type(operands) `->` type(results)
2261+
}];
2262+
2263+
let extraClassDeclaration = extraBaseClassDeclaration#[{
2264+
::mlir::spirv::TensorArmType getInput1Type() {
2265+
return cast<::mlir::spirv::TensorArmType>(getInput1().getType());
2266+
}
2267+
::mlir::spirv::TensorArmType getShapeType() {
2268+
return cast<::mlir::spirv::TensorArmType>(getShape().getType());
2269+
}
2270+
}];
2271+
}
2272+
2273+
21282274
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_OPS

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class SPIRV_1DTensorArmOfLengthAndType<list<int> allowedLengths, list<Type> allo
6565
"rank 1 tensorArm of length " # !interleave(allowedLengths, "/"),
6666
"::mlir::spirv::TensorArmType">;
6767

68+
def SPIRV_Int32_1DTensorArmOfLength1To6 : SPIRV_1DTensorArmOfLengthAndType<[1, 2, 3, 4, 5, 6], [SPIRV_Int32]>;
69+
def SPIRV_Int32_1DTensorArmOfEvenLength2To12 : SPIRV_1DTensorArmOfLengthAndType<[2, 4, 6, 8, 10, 12], [SPIRV_Int32]>;
70+
6871
def SPIRV_DenseElementAttrsWithTensorArmType : AttrConstraint<
6972
CPred<"::llvm::isa<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType())">,
7073
"Attr with type = spirv::TensorArmType">;
@@ -77,6 +80,7 @@ def SPIRV_Int32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6
7780

7881
def SPIRV_Int8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_Int8]>;
7982
def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>;
83+
def SPIRV_TosaAny_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaAny]>;
8084

8185
// Struct type
8286

@@ -139,4 +143,29 @@ class TableSizeConstraint<string input, Type type, int size>:
139143
Implies<ElementTypeIsPred<input, type>, [CPred<"::llvm::cast<::mlir::ShapedType>(getTable().getType()).getShape()[0] == " # size>]>
140144
>;
141145

146+
class ShapeConstraintFromInputRank<string input, string other, int mul=1>:
147+
PredOpTrait< "the number of elements of " # other # " must be rank(" # input # ")" # !if(!eq(mul, 1), "", " * " # mul),
148+
Implies<CPred<HasRank<input>.result>,
149+
[CPred<ElementCount<other>.result # " == " # mul # " * " # Rank<input>.result>]>
150+
>;
151+
152+
class VariadicInputWithMinSize<string input, int min_size>:
153+
PredOpTrait<"variadic " # input # " must has at least " # min_size # " elements",
154+
CPred<"static_cast<int64_t>($" # input # ".getTypes().size()) >= " # min_size>>;
155+
156+
class VariadicInputAllSameElementType<string reference, string input>:
157+
PredOpTrait<"all elements of variadic " # input # " must have same element type",
158+
CPred<"::llvm::all_of($" # input # ".getTypes(), "
159+
"[&](::mlir::Type t) { return ::llvm::cast<::mlir::ShapedType>(t).getElementType() == "
160+
# ElementType<reference>.result # "; })">>;
161+
162+
class VariadicInputAllSameRank<string reference, string input>:
163+
PredOpTrait<"all elements of variadic " # input # " must have same element type",
164+
CPred<"::llvm::all_of($" # input # ".getTypes(), "
165+
"[&](::mlir::Type t) { return ::llvm::cast<::mlir::ShapedType>(t).hasRank() && "
166+
# HasRank<reference>.result #
167+
" && ::llvm::cast<::mlir::ShapedType>(t).getRank() == "
168+
# Rank<reference>.result # "; })">>;
169+
170+
142171
#endif // MLIR_DIALECT_SPIRV_IR_TOSA_TYPES

mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,3 +1684,75 @@ spirv.ARM.Graph @reducesum_axis_value_not_in_input_rank_range(%arg0: !spirv.arm.
16841684
%0 = spirv.Tosa.ReduceSum axis = 3, %arg0 : !spirv.arm.tensor<20x24x22xi32> -> !spirv.arm.tensor<20x24x22xi32>
16851685
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<20x24x22xi32>
16861686
}
1687+
1688+
//===----------------------------------------------------------------------===//
1689+
// spirv.TOSA.Concat
1690+
//===----------------------------------------------------------------------===//
1691+
1692+
spirv.ARM.Graph @concat_must_have_at_least_one_input() -> (!spirv.arm.tensor<4x12xi8>) {
1693+
// expected-error @+1 {{op failed to verify that variadic input1 must has at least 1 elements}}
1694+
%0 = "spirv.Tosa.Concat"() <{axis = 0 : i32}> : () -> !spirv.arm.tensor<4x12xi8>
1695+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x12xi8>
1696+
}
1697+
1698+
spirv.ARM.Graph @concat_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<4x5xi8>, %arg1: !spirv.arm.tensor<4x7xi8>) -> (!spirv.arm.tensor<4x12xi16>) {
1699+
// expected-error @+1 {{op failed to verify that all elements of variadic input1 must have same element type}}
1700+
%0 = spirv.Tosa.Concat axis = 1, %arg0, %arg1 : !spirv.arm.tensor<4x5xi8>, !spirv.arm.tensor<4x7xi8> -> !spirv.arm.tensor<4x12xi16>
1701+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x12xi16>
1702+
}
1703+
1704+
spirv.ARM.Graph @concat_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<4x5xi8>, %arg1: !spirv.arm.tensor<4x7xi8>) -> (!spirv.arm.tensor<4x12x1xi8>) {
1705+
// expected-error @+1 {{op failed to verify that all elements of variadic input1 must have same element type}}
1706+
%0 = spirv.Tosa.Concat axis = 1, %arg0, %arg1 : !spirv.arm.tensor<4x5xi8>, !spirv.arm.tensor<4x7xi8> -> !spirv.arm.tensor<4x12x1xi8>
1707+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x12x1xi8>
1708+
}
1709+
1710+
spirv.ARM.Graph @concat_axis_value_not_in_output_rank_range(%arg0: !spirv.arm.tensor<4x5xi8>, %arg1: !spirv.arm.tensor<4x7xi8>) -> (!spirv.arm.tensor<4x12xi8>) {
1711+
// expected-error @+1 {{op failed to verify that axis attribute value should be lower than rank(output)}}
1712+
%0 = spirv.Tosa.Concat axis = 2, %arg0, %arg1 : !spirv.arm.tensor<4x5xi8>, !spirv.arm.tensor<4x7xi8> -> !spirv.arm.tensor<4x12xi8>
1713+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<4x12xi8>
1714+
}
1715+
1716+
//===----------------------------------------------------------------------===//
1717+
// spirv.TOSA.Pad
1718+
//===----------------------------------------------------------------------===//
1719+
1720+
spirv.ARM.Graph @pad_input_pad_const_output_element_types_not_matching(%arg0: !spirv.arm.tensor<4x7xi8>, %arg1: !spirv.arm.tensor<4xi32>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<5x8xi8>) {
1721+
// expected-error @+1 {{op failed to verify that all of {input1, pad_const, output} have same element type}}
1722+
%0 = spirv.Tosa.Pad %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<5x8xi8>
1723+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<5x8xi8>
1724+
}
1725+
1726+
spirv.ARM.Graph @pad_input_output_ranks_not_matching(%arg0: !spirv.arm.tensor<4x7xi8>, %arg1: !spirv.arm.tensor<4xi32>, %arg2: !spirv.arm.tensor<1xi8>) -> (!spirv.arm.tensor<1x5x8xi8>) {
1727+
// expected-error @+1 {{op failed to verify that all of {input1, output} have same rank}}
1728+
%0 = spirv.Tosa.Pad %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x5x8xi8>
1729+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x5x8xi8>
1730+
}
1731+
1732+
spirv.ARM.Graph @pad_padding_element_count_not_twice_input_rank(%arg0: !spirv.arm.tensor<4x7xi8>, %arg1: !spirv.arm.tensor<6xi32>, %arg2: !spirv.arm.tensor<1xi8>) -> (!spirv.arm.tensor<5x8xi8>) {
1733+
// expected-error @+1 {{op failed to verify that the number of elements of padding must be rank(input1) * 2}}
1734+
%0 = spirv.Tosa.Pad %arg0, %arg1, %arg2 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<6xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<5x8xi8>
1735+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<5x8xi8>
1736+
}
1737+
1738+
//===----------------------------------------------------------------------===//
1739+
// spirv.TOSA.Reshape
1740+
//===----------------------------------------------------------------------===//
1741+
1742+
spirv.ARM.Graph @reshape_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi8>, %arg1: !spirv.arm.tensor<2xi32>) -> (!spirv.arm.tensor<6x4xi16>) {
1743+
// expected-error @+1 {{op failed to verify that all of {input1, output} have same element type}}
1744+
%0 = spirv.Tosa.Reshape %arg0, %arg1 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<6x4xi16>
1745+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x4xi16>
1746+
}
1747+
1748+
spirv.ARM.Graph @reshape_input_output_element_counts_not_matching(%arg0: !spirv.arm.tensor<2x3x4xi8>, %arg1: !spirv.arm.tensor<2xi32>) -> (!spirv.arm.tensor<5x4xi8>) {
1749+
// expected-error @+1 {{op failed to verify that all of {input1, output} have same element count}}
1750+
%0 = spirv.Tosa.Reshape %arg0, %arg1 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<2xi32> -> !spirv.arm.tensor<5x4xi8>
1751+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<5x4xi8>
1752+
}
1753+
1754+
spirv.ARM.Graph @reshape_shape_element_count_not_output_rank(%arg0: !spirv.arm.tensor<2x3x4xi8>, %arg1: !spirv.arm.tensor<4xi32>) -> (!spirv.arm.tensor<6x4xi8>) {
1755+
// expected-error @+1 {{op failed to verify that the number of elements of shape must be rank(output)}}
1756+
%0 = spirv.Tosa.Reshape %arg0, %arg1 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<4xi32> -> !spirv.arm.tensor<6x4xi8>
1757+
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x4xi8>
1758+
}

mlir/test/Dialect/SPIRV/IR/tosa-ops.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,75 @@ spirv.ARM.Graph @reducesum_fp(%arg0: !spirv.arm.tensor<32x32x33xf32>) -> (!spirv
885885
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<32x1x33xf32>
886886
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<32x1x33xf32>
887887
}
888+
889+
//===----------------------------------------------------------------------===//
890+
// spirv.TOSA.Concat - PRO-INT
891+
//===----------------------------------------------------------------------===//
892+
893+
spirv.ARM.Graph @concat_int(%arg0: !spirv.arm.tensor<12x13x3x14xi8>, %arg1: !spirv.arm.tensor<12x13x3x14xi8>, %arg2: !spirv.arm.tensor<12x13x3x14xi8>, %arg3: !spirv.arm.tensor<12x13x3x14xi8>) -> (!spirv.arm.tensor<12x13x12x14xi8>) {
894+
// CHECK: {{%.*}} = spirv.Tosa.Concat axis = 2, %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8> -> !spirv.arm.tensor<12x13x12x14xi8>
895+
%1 = spirv.Tosa.Concat axis = 2, %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8>, !spirv.arm.tensor<12x13x3x14xi8> -> !spirv.arm.tensor<12x13x12x14xi8>
896+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<12x13x12x14xi8>
897+
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<12x13x12x14xi8>
898+
}
899+
900+
//===----------------------------------------------------------------------===//
901+
// spirv.TOSA.Concat - PRO-FP
902+
//===----------------------------------------------------------------------===//
903+
904+
spirv.ARM.Graph @concat_fp(%arg0: !spirv.arm.tensor<40x31x19xf32>, %arg1: !spirv.arm.tensor<40x15x19xf32>, %arg2: !spirv.arm.tensor<40x16x19xf32>) -> (!spirv.arm.tensor<40x62x19xf32>) {
905+
// CHECK: {{%.*}} = spirv.Tosa.Concat axis = 1, %arg0, %arg1, %arg2 : !spirv.arm.tensor<40x31x19xf32>, !spirv.arm.tensor<40x15x19xf32>, !spirv.arm.tensor<40x16x19xf32> -> !spirv.arm.tensor<40x62x19xf32>
906+
%1 = spirv.Tosa.Concat axis = 1, %arg0, %arg1, %arg2 : !spirv.arm.tensor<40x31x19xf32>, !spirv.arm.tensor<40x15x19xf32>, !spirv.arm.tensor<40x16x19xf32> -> !spirv.arm.tensor<40x62x19xf32>
907+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<40x62x19xf32>
908+
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<40x62x19xf32>
909+
}
910+
911+
//===----------------------------------------------------------------------===//
912+
// spirv.TOSA.Pad - PRO-INT
913+
//===----------------------------------------------------------------------===//
914+
915+
spirv.ARM.Graph @pad_int(%arg0: !spirv.arm.tensor<4x7xi8>) -> (!spirv.arm.tensor<21x19xi8>) {
916+
%0 = spirv.Constant dense<[10, 7, 6, 6]> : !spirv.arm.tensor<4xi32>
917+
%1 = spirv.Constant dense<-76> : !spirv.arm.tensor<1xi8>
918+
// CHECK: {{%.*}} = spirv.Tosa.Pad %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<21x19xi8>
919+
%2 = spirv.Tosa.Pad %arg0, %0, %1 : !spirv.arm.tensor<4x7xi8>, !spirv.arm.tensor<4xi32>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<21x19xi8>
920+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<21x19xi8>
921+
spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<21x19xi8>
922+
}
923+
924+
//===----------------------------------------------------------------------===//
925+
// spirv.TOSA.Pad - PRO-FP
926+
//===----------------------------------------------------------------------===//
927+
928+
spirv.ARM.Graph @pad_fp(%arg0: !spirv.arm.tensor<2x9x2x3xf32>) -> (!spirv.arm.tensor<4x9x4x4xf32>) {
929+
%0 = spirv.Constant dense<[1, 1, 0, 0, 1, 1, 0, 1]> : !spirv.arm.tensor<8xi32>
930+
%1 = spirv.Constant dense<1.21630913E+38> : !spirv.arm.tensor<1xf32>
931+
// CHECK: {{%.*}} = spirv.Tosa.Pad %arg0, {{%.*}}, {{%.*}} : !spirv.arm.tensor<2x9x2x3xf32>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<4x9x4x4xf32>
932+
%2 = spirv.Tosa.Pad %arg0, %0, %1 : !spirv.arm.tensor<2x9x2x3xf32>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xf32> -> !spirv.arm.tensor<4x9x4x4xf32>
933+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<4x9x4x4xf32>
934+
spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<4x9x4x4xf32>
935+
}
936+
937+
//===----------------------------------------------------------------------===//
938+
// spirv.TOSA.Reshape - PRO-INT
939+
//===----------------------------------------------------------------------===//
940+
941+
spirv.ARM.Graph @reshape_int(%arg0: !spirv.arm.tensor<25x6x29x35xi16>) -> (!spirv.arm.tensor<125x6x7x29xi16>) {
942+
%0 = spirv.Constant dense<[125, 6, 7, 29]> : !spirv.arm.tensor<4xi32>
943+
// CHECK: {{%.*}} = spirv.Tosa.Reshape %arg0, {{%.*}} : !spirv.arm.tensor<25x6x29x35xi16>, !spirv.arm.tensor<4xi32> -> !spirv.arm.tensor<125x6x7x29xi16>
944+
%1 = spirv.Tosa.Reshape %arg0, %0 : !spirv.arm.tensor<25x6x29x35xi16>, !spirv.arm.tensor<4xi32> -> !spirv.arm.tensor<125x6x7x29xi16>
945+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<125x6x7x29xi16>
946+
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<125x6x7x29xi16>
947+
}
948+
949+
//===----------------------------------------------------------------------===//
950+
// spirv.TOSA.Reshape - PRO-FP
951+
//===----------------------------------------------------------------------===//
952+
953+
spirv.ARM.Graph @reshape_fp(%arg0: !spirv.arm.tensor<1x2x7x2xf32>) -> (!spirv.arm.tensor<2x1x14xf32>) {
954+
%0 = spirv.Constant dense<[2, 1, 14]> : !spirv.arm.tensor<3xi32>
955+
// CHECK: {{%.*}} = spirv.Tosa.Reshape %arg0, {{%.*}} : !spirv.arm.tensor<1x2x7x2xf32>, !spirv.arm.tensor<3xi32> -> !spirv.arm.tensor<2x1x14xf32>
956+
%1 = spirv.Tosa.Reshape %arg0, %0 : !spirv.arm.tensor<1x2x7x2xf32>, !spirv.arm.tensor<3xi32> -> !spirv.arm.tensor<2x1x14xf32>
957+
// CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<2x1x14xf32>
958+
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x1x14xf32>
959+
}

0 commit comments

Comments
 (0)