Skip to content

Commit d6211b1

Browse files
[mlir][spirv] Add SPV_EXT_float8 support (llvm#179246)
Reference: https://github.khronos.org/SPIRV-Registry/extensions/EXT/SPV_EXT_float8.html --------- Signed-off-by: Davide Grohmann <davide.grohmann@arm.com> Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
1 parent 0d5e58d commit d6211b1

13 files changed

Lines changed: 155 additions & 33 deletions

File tree

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def SPV_EXT_shader_image_int64 : I32EnumAttrCase<"SPV_EXT_shader_image
360360
def SPV_EXT_shader_atomic_float16_add : I32EnumAttrCase<"SPV_EXT_shader_atomic_float16_add", 1011>;
361361
def SPV_EXT_mesh_shader : I32EnumAttrCase<"SPV_EXT_mesh_shader", 1012>;
362362
def SPV_EXT_replicated_composites : I32EnumAttrCase<"SPV_EXT_replicated_composites", 1013>;
363+
def SPV_EXT_float8 : I32EnumAttrCase<"SPV_EXT_float8", 1014>;
363364

364365
def SPV_AMD_gpu_shader_half_float_fetch : I32EnumAttrCase<"SPV_AMD_gpu_shader_half_float_fetch", 2000>;
365366
def SPV_AMD_shader_ballot : I32EnumAttrCase<"SPV_AMD_shader_ballot", 2001>;
@@ -449,7 +450,7 @@ def SPIRV_ExtensionAttr :
449450
SPV_EXT_shader_stencil_export, SPV_EXT_shader_viewport_index_layer,
450451
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
451452
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
452-
SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
453+
SPV_EXT_mesh_shader, SPV_EXT_replicated_composites, SPV_EXT_float8,
453454
SPV_ARM_tensors, SPV_ARM_graph,
454455
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
455456
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
@@ -1486,6 +1487,12 @@ def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
14861487
];
14871488
}
14881489

1490+
def SPIRV_C_Float8EXT : I32EnumAttrCase<"Float8EXT", 4212> {
1491+
list<Availability> availability = [
1492+
Extension<[SPV_EXT_float8]>
1493+
];
1494+
}
1495+
14891496
def SPIRV_CapabilityAttr :
14901497
SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
14911498
SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16,
@@ -1583,7 +1590,7 @@ def SPIRV_CapabilityAttr :
15831590
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
15841591
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
15851592
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
1586-
SPIRV_C_TensorFloat32RoundingINTEL
1593+
SPIRV_C_TensorFloat32RoundingINTEL, SPIRV_C_Float8EXT
15871594
]>;
15881595

15891596
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -3287,9 +3294,24 @@ def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
32873294
Capability<[SPIRV_C_BFloat16TypeKHR]>
32883295
];
32893296
}
3297+
3298+
def SPIRV_FPE_Float8E4M3EXT : I32EnumAttrCase<"Float8E4M3EXT", 4214> {
3299+
list<Availability> availability = [
3300+
Capability<[SPIRV_C_Float8EXT]>
3301+
];
3302+
}
3303+
3304+
def SPIRV_FPE_Float8E5M2EXT : I32EnumAttrCase<"Float8E5M2EXT", 4215> {
3305+
list<Availability> availability = [
3306+
Capability<[SPIRV_C_Float8EXT]>
3307+
];
3308+
}
3309+
32903310
def SPIRV_FPEncodingAttr :
32913311
SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
3292-
SPIRV_FPE_BFloat16KHR
3312+
SPIRV_FPE_BFloat16KHR,
3313+
SPIRV_FPE_Float8E4M3EXT,
3314+
SPIRV_FPE_Float8E5M2EXT,
32933315
]>;
32943316

32953317
def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
@@ -4248,9 +4270,11 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
42484270
def SPIRV_Float16 : TypeAlias<F16, "Float16">;
42494271
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
42504272
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
4273+
def SPIRV_Float8E4M3EXT : TypeAlias<F8E4M3FN, "Float8E4M3">;
4274+
def SPIRV_Float8E5M2EXT : TypeAlias<F8E5M2, "Float8E5M2">;
42514275
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
42524276
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
4253-
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
4277+
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Float8E4M3EXT, SPIRV_Float8E5M2EXT]>;
42544278
def SPIRV_Vector : VectorOfRankAndLengthAndType<[1], [2, 3, 4, 8, 16],
42554279
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
42564280
// Component type check is done in the type parser for the following SPIR-V

mlir/include/mlir/IR/Builders.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class Builder {
6262

6363
// Types.
6464
FloatType getF8E8M0Type();
65+
FloatType getF8E4M3FNType();
66+
FloatType getF8E5M2Type();
6567
FloatType getBF16Type();
6668
FloatType getF16Type();
6769
FloatType getTF32Type();

mlir/include/mlir/IR/Types.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ class Type {
116116
bool isF64() const;
117117
bool isF80() const;
118118
bool isF128() const;
119+
bool isF8E4M3FN() const;
120+
bool isF8E5M2() const;
121+
119122
/// Return true if this is an float type (with the specified width).
120123
bool isFloat() const;
121124
bool isFloat(unsigned width) const;

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,11 @@ void TypeExtensionVisitor::addConcrete(ScalarType type) {
551551
extensions.push_back(ext);
552552
}
553553

554+
if (isa<Float8E4M3FNType, Float8E5M2Type>(type)) {
555+
static constexpr auto ext = Extension::SPV_EXT_float8;
556+
extensions.push_back(ext);
557+
}
558+
554559
// 8- or 16-bit integer/floating-point numbers will require extra extensions
555560
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
556561
// SPV_KHR_8bit_storage for more details.
@@ -648,6 +653,15 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
648653
} else {
649654
assert(isa<FloatType>(type));
650655
switch (bitwidth) {
656+
case 8: {
657+
if (isa<Float8E4M3FNType, Float8E5M2Type>(type)) {
658+
static constexpr auto cap = Capability::Float8EXT;
659+
capabilities.push_back(cap);
660+
} else {
661+
llvm_unreachable("invalid 8-bit float type to getCapabilities");
662+
}
663+
break;
664+
}
651665
case 16: {
652666
if (isa<BFloat16Type>(type)) {
653667
static constexpr auto cap = Capability::BFloat16TypeKHR;

mlir/lib/IR/Builders.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
3434

3535
FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
3636

37+
FloatType Builder::getF8E4M3FNType() { return Float8E4M3FNType::get(context); }
38+
39+
FloatType Builder::getF8E5M2Type() { return Float8E5M2Type::get(context); }
40+
3741
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
3842

3943
FloatType Builder::getF16Type() { return Float16Type::get(context); }

mlir/lib/IR/Types.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ bool Type::isF32() const { return llvm::isa<Float32Type>(*this); }
4141
bool Type::isF64() const { return llvm::isa<Float64Type>(*this); }
4242
bool Type::isF80() const { return llvm::isa<Float80Type>(*this); }
4343
bool Type::isF128() const { return llvm::isa<Float128Type>(*this); }
44+
bool Type::isF8E4M3FN() const { return llvm::isa<Float8E4M3FNType>(*this); }
45+
bool Type::isF8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
4446

4547
bool Type::isFloat() const { return llvm::isa<FloatType>(*this); }
4648

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,30 +1094,38 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
10941094
uint32_t bitWidth = operands[1];
10951095

10961096
Type floatTy;
1097-
switch (bitWidth) {
1098-
case 16:
1099-
floatTy = opBuilder.getF16Type();
1100-
break;
1101-
case 32:
1102-
floatTy = opBuilder.getF32Type();
1103-
break;
1104-
case 64:
1105-
floatTy = opBuilder.getF64Type();
1106-
break;
1107-
default:
1108-
return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
1109-
<< bitWidth;
1097+
if (operands.size() == 2) {
1098+
switch (bitWidth) {
1099+
case 16:
1100+
floatTy = opBuilder.getF16Type();
1101+
break;
1102+
case 32:
1103+
floatTy = opBuilder.getF32Type();
1104+
break;
1105+
case 64:
1106+
floatTy = opBuilder.getF64Type();
1107+
break;
1108+
default:
1109+
return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
1110+
<< bitWidth;
1111+
}
11101112
}
11111113

11121114
if (operands.size() == 3) {
1113-
if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
1115+
if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
1116+
bitWidth == 16)
1117+
floatTy = opBuilder.getBF16Type();
1118+
else if (spirv::FPEncoding(operands[2]) ==
1119+
spirv::FPEncoding::Float8E4M3EXT &&
1120+
bitWidth == 8)
1121+
floatTy = opBuilder.getF8E4M3FNType();
1122+
else if (spirv::FPEncoding(operands[2]) ==
1123+
spirv::FPEncoding::Float8E5M2EXT &&
1124+
bitWidth == 8)
1125+
floatTy = opBuilder.getF8E5M2Type();
1126+
else
11141127
return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
1115-
<< operands[2];
1116-
if (bitWidth != 16)
1117-
return emitError(unknownLoc,
1118-
"invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
1119-
<< bitWidth << " (expected 16)";
1120-
floatTy = opBuilder.getBF16Type();
1128+
<< operands[2] << " and bitWidth " << bitWidth;
11211129
}
11221130

11231131
typeMap[operands[0]] = floatTy;
@@ -1734,6 +1742,12 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
17341742
} else if (floatType.isBF16()) {
17351743
APInt data(16, operands[2]);
17361744
value = APFloat(APFloat::BFloat(), data);
1745+
} else if (floatType.isF8E4M3FN()) {
1746+
APInt data(8, operands[2]);
1747+
value = APFloat(APFloat::Float8E4M3FN(), data);
1748+
} else if (floatType.isF8E5M2()) {
1749+
APInt data(8, operands[2]);
1750+
value = APFloat(APFloat::Float8E5M2(), data);
17371751
}
17381752

17391753
auto attr = opBuilder.getFloatAttr(floatType, value);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,15 @@ LogicalResult Serializer::prepareBasicType(
599599
if (floatType.isBF16()) {
600600
operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
601601
}
602+
if (floatType.isF8E4M3FN()) {
603+
operands.push_back(
604+
static_cast<uint32_t>(spirv::FPEncoding::Float8E4M3EXT));
605+
}
606+
if (floatType.isF8E5M2()) {
607+
operands.push_back(
608+
static_cast<uint32_t>(spirv::FPEncoding::Float8E5M2EXT));
609+
}
610+
602611
return success();
603612
}
604613

@@ -1253,8 +1262,10 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
12531262
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
12541263
encodeInstructionInto(typesGlobalValues, opcode,
12551264
{typeID, resultID, words.word1, words.word2});
1256-
} else if (semantics == &APFloat::IEEEhalf() ||
1257-
semantics == &APFloat::BFloat()) {
1265+
} else if (llvm::is_contained({&APFloat::IEEEhalf(), &APFloat::BFloat(),
1266+
&APFloat::Float8E4M3FN(),
1267+
&APFloat::Float8E5M2()},
1268+
semantics)) {
12581269
uint32_t word =
12591270
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
12601271
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});

mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
348348
// -----
349349

350350
func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
351-
// expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16}}
351+
// expected-error @+1 {{'spirv.Dot' op operand #0 must be fixed-length vector of 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16}}
352352
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
353353
return %0 : i32
354354
}

mlir/test/Dialect/SPIRV/IR/composite-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2
100100
// -----
101101

102102
func.func @composite_construct_vector_rank_two(%arg0: vector<2x2xi1>, %arg1: vector<2x2xi1>) -> vector<4x2xi1> {
103-
// expected-error @+1 {{op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
103+
// expected-error @+1 {{ op operand #0 must be variadic of void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float or BFloat16 or Float8E4M3 or Float8E5M2 values of length 2/3/4/8/16 of ranks 1 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V matrix type or any SPIR-V sampled image type or any SPIR-V image type or any SPIR-V tensorArm type, but got 'vector<2x2xi1>'}}
104104
%0 = spirv.CompositeConstruct %arg0, %arg1 : (vector<2x2xi1>, vector<2x2xi1>) -> vector<4x2xi1>
105105
return %0: vector<4x2xi1>
106106
}

0 commit comments

Comments
 (0)