@@ -225,65 +225,76 @@ impl EgglogContext {
225225 Op :: ConstantTrue => "(BoolConst 1)" . to_string ( ) ,
226226 Op :: ConstantFalse => "(BoolConst 0)" . to_string ( ) ,
227227
228- Op :: IAdd => self . typed_binary_op ( "Add" , "VecAdd" , inst) ?,
229- Op :: ISub => self . typed_binary_op ( "Sub" , "VecSub" , inst) ?,
230- Op :: IMul => self . typed_binary_op ( "Mul" , "VecMul" , inst) ?,
231- Op :: SDiv => self . typed_binary_op ( "SDiv" , "VecSDiv" , inst) ?,
232- Op :: UDiv => self . typed_binary_op ( "UDiv" , "VecUDiv" , inst) ?,
233- Op :: SRem => self . typed_binary_op ( "SRem" , "VecSRem" , inst) ?,
234- Op :: SMod => self . typed_binary_op ( "SMod" , "VecSMod" , inst) ?,
235- Op :: UMod => self . typed_binary_op ( "UMod" , "VecUMod" , inst) ?,
236- Op :: SNegate => self . typed_unary_op ( "Neg" , "VecNeg" , inst) ?,
237- Op :: ShiftLeftLogical => self . binary_op ( "Shl" , inst) ?,
238- Op :: ShiftRightLogical => self . binary_op ( "ShrU" , inst) ?,
239- Op :: ShiftRightArithmetic => self . binary_op ( "ShrS" , inst) ?,
240- Op :: BitwiseAnd => self . binary_op ( "BitAnd" , inst) ?,
241- Op :: BitwiseOr => self . binary_op ( "BitOr" , inst) ?,
242- Op :: BitwiseXor => self . binary_op ( "BitXor" , inst) ?,
243- Op :: Not => self . unary_op ( "BitNot" , inst) ?,
244- Op :: BitReverse => self . unary_op ( "BitReverse" , inst) ?,
245- Op :: IEqual => self . binary_op ( "Eq" , inst) ?,
246- Op :: INotEqual => self . binary_op ( "Ne" , inst) ?,
247- Op :: SLessThan => self . binary_op ( "SLt" , inst) ?,
248- Op :: SLessThanEqual => self . binary_op ( "SLe" , inst) ?,
249- Op :: SGreaterThan => self . binary_op ( "SGt" , inst) ?,
250- Op :: SGreaterThanEqual => self . binary_op ( "SGe" , inst) ?,
251- Op :: ULessThan => self . binary_op ( "ULt" , inst) ?,
252- Op :: ULessThanEqual => self . binary_op ( "ULe" , inst) ?,
253- Op :: UGreaterThan => self . binary_op ( "UGt" , inst) ?,
254- Op :: UGreaterThanEqual => self . binary_op ( "UGe" , inst) ?,
255- Op :: LogicalNot => self . unary_op ( "LogNot" , inst) ?,
256- Op :: LogicalAnd => self . binary_op ( "LogAnd" , inst) ?,
257- Op :: LogicalOr => self . binary_op ( "LogOr" , inst) ?,
258- Op :: LogicalEqual => self . binary_op ( "LogEq" , inst) ?,
259- Op :: LogicalNotEqual => self . binary_op ( "LogNe" , inst) ?,
228+ // Integer arithmetic (operands: IntExpr, result: IntExpr)
229+ Op :: IAdd => self . typed_binary_op ( "Add" , "VecAdd" , TypeClass :: Int , inst) ?,
230+ Op :: ISub => self . typed_binary_op ( "Sub" , "VecSub" , TypeClass :: Int , inst) ?,
231+ Op :: IMul => self . typed_binary_op ( "Mul" , "VecMul" , TypeClass :: Int , inst) ?,
232+ Op :: SDiv => self . typed_binary_op ( "SDiv" , "VecSDiv" , TypeClass :: Int , inst) ?,
233+ Op :: UDiv => self . typed_binary_op ( "UDiv" , "VecUDiv" , TypeClass :: Int , inst) ?,
234+ Op :: SRem => self . typed_binary_op ( "SRem" , "VecSRem" , TypeClass :: Int , inst) ?,
235+ Op :: SMod => self . typed_binary_op ( "SMod" , "VecSMod" , TypeClass :: Int , inst) ?,
236+ Op :: UMod => self . typed_binary_op ( "UMod" , "VecUMod" , TypeClass :: Int , inst) ?,
237+ Op :: SNegate => self . typed_unary_op ( "Neg" , "VecNeg" , TypeClass :: Int , inst) ?,
238+ // Bitwise/shift (operands: IntExpr, result: IntExpr)
239+ Op :: ShiftLeftLogical => self . checked_binary_op ( "Shl" , TypeClass :: Int , inst) ?,
240+ Op :: ShiftRightLogical => self . checked_binary_op ( "ShrU" , TypeClass :: Int , inst) ?,
241+ Op :: ShiftRightArithmetic => self . checked_binary_op ( "ShrS" , TypeClass :: Int , inst) ?,
242+ Op :: BitwiseAnd => self . checked_binary_op ( "BitAnd" , TypeClass :: Int , inst) ?,
243+ Op :: BitwiseOr => self . checked_binary_op ( "BitOr" , TypeClass :: Int , inst) ?,
244+ Op :: BitwiseXor => self . checked_binary_op ( "BitXor" , TypeClass :: Int , inst) ?,
245+ Op :: Not => self . checked_unary_op ( "BitNot" , TypeClass :: Int , inst) ?,
246+ Op :: BitReverse => self . checked_unary_op ( "BitReverse" , TypeClass :: Int , inst) ?,
247+ // Integer comparisons (operands: IntExpr, result: BoolExpr)
248+ // Some compilers use OpIEqual/OpINotEqual on boolean values instead of
249+ // OpLogicalEqual/OpLogicalNotEqual. Redirect to logical equivalents
250+ // to prevent sort mismatches (BoolExpr operand in IntExpr position).
251+ Op :: IEqual => self . int_comparison_op ( "Eq" , inst) ?,
252+ Op :: INotEqual => self . int_comparison_op ( "Ne" , inst) ?,
253+ Op :: SLessThan => self . checked_binary_op ( "SLt" , TypeClass :: Int , inst) ?,
254+ Op :: SLessThanEqual => self . checked_binary_op ( "SLe" , TypeClass :: Int , inst) ?,
255+ Op :: SGreaterThan => self . checked_binary_op ( "SGt" , TypeClass :: Int , inst) ?,
256+ Op :: SGreaterThanEqual => self . checked_binary_op ( "SGe" , TypeClass :: Int , inst) ?,
257+ Op :: ULessThan => self . checked_binary_op ( "ULt" , TypeClass :: Int , inst) ?,
258+ Op :: ULessThanEqual => self . checked_binary_op ( "ULe" , TypeClass :: Int , inst) ?,
259+ Op :: UGreaterThan => self . checked_binary_op ( "UGt" , TypeClass :: Int , inst) ?,
260+ Op :: UGreaterThanEqual => self . checked_binary_op ( "UGe" , TypeClass :: Int , inst) ?,
261+ // Logical (operands: BoolExpr, result: BoolExpr)
262+ Op :: LogicalNot => self . checked_unary_op ( "LogNot" , TypeClass :: Bool , inst) ?,
263+ Op :: LogicalAnd => self . checked_binary_op ( "LogAnd" , TypeClass :: Bool , inst) ?,
264+ Op :: LogicalOr => self . checked_binary_op ( "LogOr" , TypeClass :: Bool , inst) ?,
265+ Op :: LogicalEqual => self . checked_binary_op ( "LogEq" , TypeClass :: Bool , inst) ?,
266+ Op :: LogicalNotEqual => self . checked_binary_op ( "LogNe" , TypeClass :: Bool , inst) ?,
260267 // Floating-point operations (scalar or vector)
261- Op :: FAdd => self . typed_binary_op ( "FAdd" , "VecFAdd" , inst) ?,
262- Op :: FSub => self . typed_binary_op ( "FSub" , "VecFSub" , inst) ?,
263- Op :: FMul => self . typed_binary_op ( "FMul" , "VecFMul" , inst) ?,
264- Op :: FDiv => self . typed_binary_op ( "FDiv" , "VecFDiv" , inst) ?,
265- Op :: FRem => self . typed_binary_op ( "FRem" , "VecFRem" , inst) ?,
266- Op :: FMod => self . typed_binary_op ( "FMod" , "VecFMod" , inst) ?,
267- Op :: FNegate => self . typed_unary_op ( "FNeg" , "VecFNeg" , inst) ?,
268- // Floating-point comparisons (ordered )
269- Op :: FOrdEqual => self . binary_op ( "FOrdEq" , inst) ?,
270- Op :: FOrdNotEqual => self . binary_op ( "FOrdNe" , inst) ?,
271- Op :: FOrdLessThan => self . binary_op ( "FOrdLt" , inst) ?,
272- Op :: FOrdLessThanEqual => self . binary_op ( "FOrdLe" , inst) ?,
273- Op :: FOrdGreaterThan => self . binary_op ( "FOrdGt" , inst) ?,
274- Op :: FOrdGreaterThanEqual => self . binary_op ( "FOrdGe" , inst) ?,
268+ Op :: FAdd => self . typed_binary_op ( "FAdd" , "VecFAdd" , TypeClass :: Float , inst) ?,
269+ Op :: FSub => self . typed_binary_op ( "FSub" , "VecFSub" , TypeClass :: Float , inst) ?,
270+ Op :: FMul => self . typed_binary_op ( "FMul" , "VecFMul" , TypeClass :: Float , inst) ?,
271+ Op :: FDiv => self . typed_binary_op ( "FDiv" , "VecFDiv" , TypeClass :: Float , inst) ?,
272+ Op :: FRem => self . typed_binary_op ( "FRem" , "VecFRem" , TypeClass :: Float , inst) ?,
273+ Op :: FMod => self . typed_binary_op ( "FMod" , "VecFMod" , TypeClass :: Float , inst) ?,
274+ Op :: FNegate => self . typed_unary_op ( "FNeg" , "VecFNeg" , TypeClass :: Float , inst) ?,
275+ // Floating-point comparisons (operands: FloatExpr, result: BoolExpr )
276+ Op :: FOrdEqual => self . checked_binary_op ( "FOrdEq" , TypeClass :: Float , inst) ?,
277+ Op :: FOrdNotEqual => self . checked_binary_op ( "FOrdNe" , TypeClass :: Float , inst) ?,
278+ Op :: FOrdLessThan => self . checked_binary_op ( "FOrdLt" , TypeClass :: Float , inst) ?,
279+ Op :: FOrdLessThanEqual => self . checked_binary_op ( "FOrdLe" , TypeClass :: Float , inst) ?,
280+ Op :: FOrdGreaterThan => self . checked_binary_op ( "FOrdGt" , TypeClass :: Float , inst) ?,
281+ Op :: FOrdGreaterThanEqual => self . checked_binary_op ( "FOrdGe" , TypeClass :: Float , inst) ?,
275282 // Floating-point comparisons (unordered)
276- Op :: FUnordEqual => self . binary_op ( "FUnordEq" , inst) ?,
277- Op :: FUnordNotEqual => self . binary_op ( "FUnordNe" , inst) ?,
278- Op :: FUnordLessThan => self . binary_op ( "FUnordLt" , inst) ?,
279- Op :: FUnordLessThanEqual => self . binary_op ( "FUnordLe" , inst) ?,
280- Op :: FUnordGreaterThan => self . binary_op ( "FUnordGt" , inst) ?,
281- Op :: FUnordGreaterThanEqual => self . binary_op ( "FUnordGe" , inst) ?,
282- // Conversion operations
283- Op :: ConvertFToU => self . unary_op ( "ConvertFToU" , inst) ?,
284- Op :: ConvertFToS => self . unary_op ( "ConvertFToS" , inst) ?,
285- Op :: ConvertSToF => self . unary_op ( "ConvertSToF" , inst) ?,
286- Op :: ConvertUToF => self . unary_op ( "ConvertUToF" , inst) ?,
283+ Op :: FUnordEqual => self . checked_binary_op ( "FUnordEq" , TypeClass :: Float , inst) ?,
284+ Op :: FUnordNotEqual => self . checked_binary_op ( "FUnordNe" , TypeClass :: Float , inst) ?,
285+ Op :: FUnordLessThan => self . checked_binary_op ( "FUnordLt" , TypeClass :: Float , inst) ?,
286+ Op :: FUnordLessThanEqual => {
287+ self . checked_binary_op ( "FUnordLe" , TypeClass :: Float , inst) ?
288+ }
289+ Op :: FUnordGreaterThan => self . checked_binary_op ( "FUnordGt" , TypeClass :: Float , inst) ?,
290+ Op :: FUnordGreaterThanEqual => {
291+ self . checked_binary_op ( "FUnordGe" , TypeClass :: Float , inst) ?
292+ }
293+ // Conversion operations (validated operand sorts)
294+ Op :: ConvertFToU => self . checked_unary_op ( "ConvertFToU" , TypeClass :: Float , inst) ?,
295+ Op :: ConvertFToS => self . checked_unary_op ( "ConvertFToS" , TypeClass :: Float , inst) ?,
296+ Op :: ConvertSToF => self . checked_unary_op ( "ConvertSToF" , TypeClass :: Int , inst) ?,
297+ Op :: ConvertUToF => self . checked_unary_op ( "ConvertUToF" , TypeClass :: Int , inst) ?,
287298 Op :: Select => {
288299 let ops: Vec < Word > = inst
289300 . operands
@@ -917,6 +928,34 @@ impl EgglogContext {
917928 Some ( term)
918929 }
919930
931+ /// Handle OpIEqual/OpINotEqual which may have boolean operands.
932+ /// Some compilers use integer equality on bools instead of LogicalEqual.
933+ /// Redirect to the logical equivalent when operands are BoolExpr.
934+ fn int_comparison_op ( & mut self , op : & str , inst : & Instruction ) -> Option < String > {
935+ let ops: Vec < Word > = inst
936+ . operands
937+ . iter ( )
938+ . filter_map ( |op| op. id_ref_any ( ) )
939+ . collect ( ) ;
940+ if ops. len ( ) < 2 {
941+ return None ;
942+ }
943+ let first_class = self . type_class_of ( ops[ 0 ] ) ;
944+ if first_class == TypeClass :: Bool {
945+ // Redirect to logical equivalents for boolean operands
946+ let logical_op = match op {
947+ "Eq" => "LogEq" ,
948+ "Ne" => "LogNe" ,
949+ _ => return None ,
950+ } ;
951+ let lhs = self . get_or_create_term ( ops[ 0 ] ) ;
952+ let rhs = self . get_or_create_term ( ops[ 1 ] ) ;
953+ Some ( format ! ( "({} {} {})" , logical_op, lhs, rhs) )
954+ } else {
955+ self . checked_binary_op ( op, TypeClass :: Int , inst)
956+ }
957+ }
958+
920959 fn binary_op ( & mut self , op : & str , inst : & Instruction ) -> Option < String > {
921960 let ops: Vec < Word > = inst
922961 . operands
@@ -938,41 +977,92 @@ impl EgglogContext {
938977 Some ( format ! ( "({} {})" , op, operand) )
939978 }
940979
980+ /// Sort-validated binary op: checks that operand SPIR-V types are compatible
981+ /// with the expected egraph sort. Returns None on cross-sort scalar mismatches
982+ /// so the instruction stays as-is rather than crashing the egraph.
983+ fn checked_binary_op (
984+ & mut self ,
985+ op : & str ,
986+ expected_operand : TypeClass ,
987+ inst : & Instruction ,
988+ ) -> Option < String > {
989+ let ops: Vec < Word > = inst
990+ . operands
991+ . iter ( )
992+ . filter_map ( |op| op. id_ref_any ( ) )
993+ . collect ( ) ;
994+ if ops. len ( ) < 2 {
995+ return None ;
996+ }
997+ // Verify operand types are compatible with the expected sort.
998+ // Other (vector/composite) is always accepted — vectors in Expr sort
999+ // won't reach scalar constructors because typed_binary_op dispatches
1000+ // to Vec* variants for vector results.
1001+ for & id in & ops[ ..2 ] {
1002+ let actual = self . type_class_of ( id) ;
1003+ if actual != expected_operand && actual != TypeClass :: Other {
1004+ return None ;
1005+ }
1006+ }
1007+ let lhs = self . get_or_create_term ( ops[ 0 ] ) ;
1008+ let rhs = self . get_or_create_term ( ops[ 1 ] ) ;
1009+ Some ( format ! ( "({} {} {})" , op, lhs, rhs) )
1010+ }
1011+
1012+ /// Sort-validated unary op: checks that the operand's SPIR-V type is
1013+ /// compatible with the expected egraph sort.
1014+ fn checked_unary_op (
1015+ & mut self ,
1016+ op : & str ,
1017+ expected_operand : TypeClass ,
1018+ inst : & Instruction ,
1019+ ) -> Option < String > {
1020+ let operand_id = inst. operands . iter ( ) . find_map ( |op| op. id_ref_any ( ) ) ?;
1021+ let actual = self . type_class_of ( operand_id) ;
1022+ if actual != expected_operand && actual != TypeClass :: Other {
1023+ return None ;
1024+ }
1025+ let operand = self . get_or_create_term ( operand_id) ;
1026+ Some ( format ! ( "({} {})" , op, operand) )
1027+ }
1028+
9411029 /// Type-dispatched binary op: uses `scalar_op` (typed sort) when the result
9421030 /// type is a scalar, and `vec_op` (Expr sort) when it is a vector/other type.
943- /// This prevents sort mismatches when SPIR-V opcodes like OpFAdd/OpIAdd
944- /// operate on vectors whose operands are in the generic Expr sort.
1031+ /// Validates operand sorts for scalar ops to prevent egraph type errors.
9451032 fn typed_binary_op (
9461033 & mut self ,
9471034 scalar_op : & str ,
9481035 vec_op : & str ,
1036+ expected_operand : TypeClass ,
9491037 inst : & Instruction ,
9501038 ) -> Option < String > {
9511039 let is_scalar = inst
9521040 . result_type
9531041 . map ( |ty| self . type_class_of_type ( ty) != TypeClass :: Other )
9541042 . unwrap_or ( false ) ;
9551043 if is_scalar {
956- self . binary_op ( scalar_op, inst)
1044+ self . checked_binary_op ( scalar_op, expected_operand , inst)
9571045 } else {
9581046 self . binary_op ( vec_op, inst)
9591047 }
9601048 }
9611049
9621050 /// Type-dispatched unary op: uses `scalar_op` (typed sort) when the result
9631051 /// type is a scalar, and `vec_op` (Expr sort) when it is a vector/other type.
1052+ /// Validates operand sorts for scalar ops to prevent egraph type errors.
9641053 fn typed_unary_op (
9651054 & mut self ,
9661055 scalar_op : & str ,
9671056 vec_op : & str ,
1057+ expected_operand : TypeClass ,
9681058 inst : & Instruction ,
9691059 ) -> Option < String > {
9701060 let is_scalar = inst
9711061 . result_type
9721062 . map ( |ty| self . type_class_of_type ( ty) != TypeClass :: Other )
9731063 . unwrap_or ( false ) ;
9741064 if is_scalar {
975- self . unary_op ( scalar_op, inst)
1065+ self . checked_unary_op ( scalar_op, expected_operand , inst)
9761066 } else {
9771067 self . unary_op ( vec_op, inst)
9781068 }
0 commit comments