Skip to content

Commit fcdd2fa

Browse files
committed
Add sort validation for typed operations in SPIR-V parser
Validate operand SPIR-V types against expected egraph sorts before constructing terms. This prevents cross-sort mismatches that crash the egraph when compilers use unexpected type combinations (e.g. OpIEqual on boolean operands, OpBitwiseAnd on non-integer types). - Add checked_binary_op/checked_unary_op that verify operand types match the expected sort (Int, Float, Bool) and return None on mismatch so the instruction is kept as-is - Add int_comparison_op that redirects OpIEqual/OpINotEqual to LogEq/LogNe when operands are boolean-typed - Update typed_binary_op/typed_unary_op to use sort validation for scalar ops while keeping unchecked dispatch for vector ops - Update all callers with explicit expected operand TypeClass
1 parent 78f67c7 commit fcdd2fa

1 file changed

Lines changed: 151 additions & 61 deletions

File tree

rust/spirv-tools-opt/src/direct/context.rs

Lines changed: 151 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -225,65 +225,76 @@ impl EgglogContext {
225225
Op::ConstantTrue => "(BoolConst 1)".to_string(),
226226
Op::ConstantFalse => "(BoolConst 0)".to_string(),
227227

228-
Op::IAdd => self.typed_binary_op("Add", "VecAdd", inst)?,
229-
Op::ISub => self.typed_binary_op("Sub", "VecSub", inst)?,
230-
Op::IMul => self.typed_binary_op("Mul", "VecMul", inst)?,
231-
Op::SDiv => self.typed_binary_op("SDiv", "VecSDiv", inst)?,
232-
Op::UDiv => self.typed_binary_op("UDiv", "VecUDiv", inst)?,
233-
Op::SRem => self.typed_binary_op("SRem", "VecSRem", inst)?,
234-
Op::SMod => self.typed_binary_op("SMod", "VecSMod", inst)?,
235-
Op::UMod => self.typed_binary_op("UMod", "VecUMod", inst)?,
236-
Op::SNegate => self.typed_unary_op("Neg", "VecNeg", inst)?,
237-
Op::ShiftLeftLogical => self.binary_op("Shl", inst)?,
238-
Op::ShiftRightLogical => self.binary_op("ShrU", inst)?,
239-
Op::ShiftRightArithmetic => self.binary_op("ShrS", inst)?,
240-
Op::BitwiseAnd => self.binary_op("BitAnd", inst)?,
241-
Op::BitwiseOr => self.binary_op("BitOr", inst)?,
242-
Op::BitwiseXor => self.binary_op("BitXor", inst)?,
243-
Op::Not => self.unary_op("BitNot", inst)?,
244-
Op::BitReverse => self.unary_op("BitReverse", inst)?,
245-
Op::IEqual => self.binary_op("Eq", inst)?,
246-
Op::INotEqual => self.binary_op("Ne", inst)?,
247-
Op::SLessThan => self.binary_op("SLt", inst)?,
248-
Op::SLessThanEqual => self.binary_op("SLe", inst)?,
249-
Op::SGreaterThan => self.binary_op("SGt", inst)?,
250-
Op::SGreaterThanEqual => self.binary_op("SGe", inst)?,
251-
Op::ULessThan => self.binary_op("ULt", inst)?,
252-
Op::ULessThanEqual => self.binary_op("ULe", inst)?,
253-
Op::UGreaterThan => self.binary_op("UGt", inst)?,
254-
Op::UGreaterThanEqual => self.binary_op("UGe", inst)?,
255-
Op::LogicalNot => self.unary_op("LogNot", inst)?,
256-
Op::LogicalAnd => self.binary_op("LogAnd", inst)?,
257-
Op::LogicalOr => self.binary_op("LogOr", inst)?,
258-
Op::LogicalEqual => self.binary_op("LogEq", inst)?,
259-
Op::LogicalNotEqual => self.binary_op("LogNe", inst)?,
228+
// Integer arithmetic (operands: IntExpr, result: IntExpr)
229+
Op::IAdd => self.typed_binary_op("Add", "VecAdd", TypeClass::Int, inst)?,
230+
Op::ISub => self.typed_binary_op("Sub", "VecSub", TypeClass::Int, inst)?,
231+
Op::IMul => self.typed_binary_op("Mul", "VecMul", TypeClass::Int, inst)?,
232+
Op::SDiv => self.typed_binary_op("SDiv", "VecSDiv", TypeClass::Int, inst)?,
233+
Op::UDiv => self.typed_binary_op("UDiv", "VecUDiv", TypeClass::Int, inst)?,
234+
Op::SRem => self.typed_binary_op("SRem", "VecSRem", TypeClass::Int, inst)?,
235+
Op::SMod => self.typed_binary_op("SMod", "VecSMod", TypeClass::Int, inst)?,
236+
Op::UMod => self.typed_binary_op("UMod", "VecUMod", TypeClass::Int, inst)?,
237+
Op::SNegate => self.typed_unary_op("Neg", "VecNeg", TypeClass::Int, inst)?,
238+
// Bitwise/shift (operands: IntExpr, result: IntExpr)
239+
Op::ShiftLeftLogical => self.checked_binary_op("Shl", TypeClass::Int, inst)?,
240+
Op::ShiftRightLogical => self.checked_binary_op("ShrU", TypeClass::Int, inst)?,
241+
Op::ShiftRightArithmetic => self.checked_binary_op("ShrS", TypeClass::Int, inst)?,
242+
Op::BitwiseAnd => self.checked_binary_op("BitAnd", TypeClass::Int, inst)?,
243+
Op::BitwiseOr => self.checked_binary_op("BitOr", TypeClass::Int, inst)?,
244+
Op::BitwiseXor => self.checked_binary_op("BitXor", TypeClass::Int, inst)?,
245+
Op::Not => self.checked_unary_op("BitNot", TypeClass::Int, inst)?,
246+
Op::BitReverse => self.checked_unary_op("BitReverse", TypeClass::Int, inst)?,
247+
// Integer comparisons (operands: IntExpr, result: BoolExpr)
248+
// Some compilers use OpIEqual/OpINotEqual on boolean values instead of
249+
// OpLogicalEqual/OpLogicalNotEqual. Redirect to logical equivalents
250+
// to prevent sort mismatches (BoolExpr operand in IntExpr position).
251+
Op::IEqual => self.int_comparison_op("Eq", inst)?,
252+
Op::INotEqual => self.int_comparison_op("Ne", inst)?,
253+
Op::SLessThan => self.checked_binary_op("SLt", TypeClass::Int, inst)?,
254+
Op::SLessThanEqual => self.checked_binary_op("SLe", TypeClass::Int, inst)?,
255+
Op::SGreaterThan => self.checked_binary_op("SGt", TypeClass::Int, inst)?,
256+
Op::SGreaterThanEqual => self.checked_binary_op("SGe", TypeClass::Int, inst)?,
257+
Op::ULessThan => self.checked_binary_op("ULt", TypeClass::Int, inst)?,
258+
Op::ULessThanEqual => self.checked_binary_op("ULe", TypeClass::Int, inst)?,
259+
Op::UGreaterThan => self.checked_binary_op("UGt", TypeClass::Int, inst)?,
260+
Op::UGreaterThanEqual => self.checked_binary_op("UGe", TypeClass::Int, inst)?,
261+
// Logical (operands: BoolExpr, result: BoolExpr)
262+
Op::LogicalNot => self.checked_unary_op("LogNot", TypeClass::Bool, inst)?,
263+
Op::LogicalAnd => self.checked_binary_op("LogAnd", TypeClass::Bool, inst)?,
264+
Op::LogicalOr => self.checked_binary_op("LogOr", TypeClass::Bool, inst)?,
265+
Op::LogicalEqual => self.checked_binary_op("LogEq", TypeClass::Bool, inst)?,
266+
Op::LogicalNotEqual => self.checked_binary_op("LogNe", TypeClass::Bool, inst)?,
260267
// Floating-point operations (scalar or vector)
261-
Op::FAdd => self.typed_binary_op("FAdd", "VecFAdd", inst)?,
262-
Op::FSub => self.typed_binary_op("FSub", "VecFSub", inst)?,
263-
Op::FMul => self.typed_binary_op("FMul", "VecFMul", inst)?,
264-
Op::FDiv => self.typed_binary_op("FDiv", "VecFDiv", inst)?,
265-
Op::FRem => self.typed_binary_op("FRem", "VecFRem", inst)?,
266-
Op::FMod => self.typed_binary_op("FMod", "VecFMod", inst)?,
267-
Op::FNegate => self.typed_unary_op("FNeg", "VecFNeg", inst)?,
268-
// Floating-point comparisons (ordered)
269-
Op::FOrdEqual => self.binary_op("FOrdEq", inst)?,
270-
Op::FOrdNotEqual => self.binary_op("FOrdNe", inst)?,
271-
Op::FOrdLessThan => self.binary_op("FOrdLt", inst)?,
272-
Op::FOrdLessThanEqual => self.binary_op("FOrdLe", inst)?,
273-
Op::FOrdGreaterThan => self.binary_op("FOrdGt", inst)?,
274-
Op::FOrdGreaterThanEqual => self.binary_op("FOrdGe", inst)?,
268+
Op::FAdd => self.typed_binary_op("FAdd", "VecFAdd", TypeClass::Float, inst)?,
269+
Op::FSub => self.typed_binary_op("FSub", "VecFSub", TypeClass::Float, inst)?,
270+
Op::FMul => self.typed_binary_op("FMul", "VecFMul", TypeClass::Float, inst)?,
271+
Op::FDiv => self.typed_binary_op("FDiv", "VecFDiv", TypeClass::Float, inst)?,
272+
Op::FRem => self.typed_binary_op("FRem", "VecFRem", TypeClass::Float, inst)?,
273+
Op::FMod => self.typed_binary_op("FMod", "VecFMod", TypeClass::Float, inst)?,
274+
Op::FNegate => self.typed_unary_op("FNeg", "VecFNeg", TypeClass::Float, inst)?,
275+
// Floating-point comparisons (operands: FloatExpr, result: BoolExpr)
276+
Op::FOrdEqual => self.checked_binary_op("FOrdEq", TypeClass::Float, inst)?,
277+
Op::FOrdNotEqual => self.checked_binary_op("FOrdNe", TypeClass::Float, inst)?,
278+
Op::FOrdLessThan => self.checked_binary_op("FOrdLt", TypeClass::Float, inst)?,
279+
Op::FOrdLessThanEqual => self.checked_binary_op("FOrdLe", TypeClass::Float, inst)?,
280+
Op::FOrdGreaterThan => self.checked_binary_op("FOrdGt", TypeClass::Float, inst)?,
281+
Op::FOrdGreaterThanEqual => self.checked_binary_op("FOrdGe", TypeClass::Float, inst)?,
275282
// Floating-point comparisons (unordered)
276-
Op::FUnordEqual => self.binary_op("FUnordEq", inst)?,
277-
Op::FUnordNotEqual => self.binary_op("FUnordNe", inst)?,
278-
Op::FUnordLessThan => self.binary_op("FUnordLt", inst)?,
279-
Op::FUnordLessThanEqual => self.binary_op("FUnordLe", inst)?,
280-
Op::FUnordGreaterThan => self.binary_op("FUnordGt", inst)?,
281-
Op::FUnordGreaterThanEqual => self.binary_op("FUnordGe", inst)?,
282-
// Conversion operations
283-
Op::ConvertFToU => self.unary_op("ConvertFToU", inst)?,
284-
Op::ConvertFToS => self.unary_op("ConvertFToS", inst)?,
285-
Op::ConvertSToF => self.unary_op("ConvertSToF", inst)?,
286-
Op::ConvertUToF => self.unary_op("ConvertUToF", inst)?,
283+
Op::FUnordEqual => self.checked_binary_op("FUnordEq", TypeClass::Float, inst)?,
284+
Op::FUnordNotEqual => self.checked_binary_op("FUnordNe", TypeClass::Float, inst)?,
285+
Op::FUnordLessThan => self.checked_binary_op("FUnordLt", TypeClass::Float, inst)?,
286+
Op::FUnordLessThanEqual => {
287+
self.checked_binary_op("FUnordLe", TypeClass::Float, inst)?
288+
}
289+
Op::FUnordGreaterThan => self.checked_binary_op("FUnordGt", TypeClass::Float, inst)?,
290+
Op::FUnordGreaterThanEqual => {
291+
self.checked_binary_op("FUnordGe", TypeClass::Float, inst)?
292+
}
293+
// Conversion operations (validated operand sorts)
294+
Op::ConvertFToU => self.checked_unary_op("ConvertFToU", TypeClass::Float, inst)?,
295+
Op::ConvertFToS => self.checked_unary_op("ConvertFToS", TypeClass::Float, inst)?,
296+
Op::ConvertSToF => self.checked_unary_op("ConvertSToF", TypeClass::Int, inst)?,
297+
Op::ConvertUToF => self.checked_unary_op("ConvertUToF", TypeClass::Int, inst)?,
287298
Op::Select => {
288299
let ops: Vec<Word> = inst
289300
.operands
@@ -917,6 +928,34 @@ impl EgglogContext {
917928
Some(term)
918929
}
919930

931+
/// Handle OpIEqual/OpINotEqual which may have boolean operands.
932+
/// Some compilers use integer equality on bools instead of LogicalEqual.
933+
/// Redirect to the logical equivalent when operands are BoolExpr.
934+
fn int_comparison_op(&mut self, op: &str, inst: &Instruction) -> Option<String> {
935+
let ops: Vec<Word> = inst
936+
.operands
937+
.iter()
938+
.filter_map(|op| op.id_ref_any())
939+
.collect();
940+
if ops.len() < 2 {
941+
return None;
942+
}
943+
let first_class = self.type_class_of(ops[0]);
944+
if first_class == TypeClass::Bool {
945+
// Redirect to logical equivalents for boolean operands
946+
let logical_op = match op {
947+
"Eq" => "LogEq",
948+
"Ne" => "LogNe",
949+
_ => return None,
950+
};
951+
let lhs = self.get_or_create_term(ops[0]);
952+
let rhs = self.get_or_create_term(ops[1]);
953+
Some(format!("({} {} {})", logical_op, lhs, rhs))
954+
} else {
955+
self.checked_binary_op(op, TypeClass::Int, inst)
956+
}
957+
}
958+
920959
fn binary_op(&mut self, op: &str, inst: &Instruction) -> Option<String> {
921960
let ops: Vec<Word> = inst
922961
.operands
@@ -938,41 +977,92 @@ impl EgglogContext {
938977
Some(format!("({} {})", op, operand))
939978
}
940979

980+
/// Sort-validated binary op: checks that operand SPIR-V types are compatible
981+
/// with the expected egraph sort. Returns None on cross-sort scalar mismatches
982+
/// so the instruction stays as-is rather than crashing the egraph.
983+
fn checked_binary_op(
984+
&mut self,
985+
op: &str,
986+
expected_operand: TypeClass,
987+
inst: &Instruction,
988+
) -> Option<String> {
989+
let ops: Vec<Word> = inst
990+
.operands
991+
.iter()
992+
.filter_map(|op| op.id_ref_any())
993+
.collect();
994+
if ops.len() < 2 {
995+
return None;
996+
}
997+
// Verify operand types are compatible with the expected sort.
998+
// Other (vector/composite) is always accepted — vectors in Expr sort
999+
// won't reach scalar constructors because typed_binary_op dispatches
1000+
// to Vec* variants for vector results.
1001+
for &id in &ops[..2] {
1002+
let actual = self.type_class_of(id);
1003+
if actual != expected_operand && actual != TypeClass::Other {
1004+
return None;
1005+
}
1006+
}
1007+
let lhs = self.get_or_create_term(ops[0]);
1008+
let rhs = self.get_or_create_term(ops[1]);
1009+
Some(format!("({} {} {})", op, lhs, rhs))
1010+
}
1011+
1012+
/// Sort-validated unary op: checks that the operand's SPIR-V type is
1013+
/// compatible with the expected egraph sort.
1014+
fn checked_unary_op(
1015+
&mut self,
1016+
op: &str,
1017+
expected_operand: TypeClass,
1018+
inst: &Instruction,
1019+
) -> Option<String> {
1020+
let operand_id = inst.operands.iter().find_map(|op| op.id_ref_any())?;
1021+
let actual = self.type_class_of(operand_id);
1022+
if actual != expected_operand && actual != TypeClass::Other {
1023+
return None;
1024+
}
1025+
let operand = self.get_or_create_term(operand_id);
1026+
Some(format!("({} {})", op, operand))
1027+
}
1028+
9411029
/// Type-dispatched binary op: uses `scalar_op` (typed sort) when the result
9421030
/// type is a scalar, and `vec_op` (Expr sort) when it is a vector/other type.
943-
/// This prevents sort mismatches when SPIR-V opcodes like OpFAdd/OpIAdd
944-
/// operate on vectors whose operands are in the generic Expr sort.
1031+
/// Validates operand sorts for scalar ops to prevent egraph type errors.
9451032
fn typed_binary_op(
9461033
&mut self,
9471034
scalar_op: &str,
9481035
vec_op: &str,
1036+
expected_operand: TypeClass,
9491037
inst: &Instruction,
9501038
) -> Option<String> {
9511039
let is_scalar = inst
9521040
.result_type
9531041
.map(|ty| self.type_class_of_type(ty) != TypeClass::Other)
9541042
.unwrap_or(false);
9551043
if is_scalar {
956-
self.binary_op(scalar_op, inst)
1044+
self.checked_binary_op(scalar_op, expected_operand, inst)
9571045
} else {
9581046
self.binary_op(vec_op, inst)
9591047
}
9601048
}
9611049

9621050
/// Type-dispatched unary op: uses `scalar_op` (typed sort) when the result
9631051
/// type is a scalar, and `vec_op` (Expr sort) when it is a vector/other type.
1052+
/// Validates operand sorts for scalar ops to prevent egraph type errors.
9641053
fn typed_unary_op(
9651054
&mut self,
9661055
scalar_op: &str,
9671056
vec_op: &str,
1057+
expected_operand: TypeClass,
9681058
inst: &Instruction,
9691059
) -> Option<String> {
9701060
let is_scalar = inst
9711061
.result_type
9721062
.map(|ty| self.type_class_of_type(ty) != TypeClass::Other)
9731063
.unwrap_or(false);
9741064
if is_scalar {
975-
self.unary_op(scalar_op, inst)
1065+
self.checked_unary_op(scalar_op, expected_operand, inst)
9761066
} else {
9771067
self.unary_op(vec_op, inst)
9781068
}

0 commit comments

Comments
 (0)