Skip to content

Commit a9b807b

Browse files
committed
Dispatch vector arithmetic to Vec* constructors in the egraph
SPIR-V OpFAdd/OpIAdd/etc. work on both scalars and vectors, but the egglog typed sorts (FloatExpr, IntExpr) are scalar-only. When an operand was vector-typed, the sort mismatch caused egglog failures. Following the existing Select/VecSelect and FMix/VecFMix pattern, dispatch to Vec* constructors (Expr sort) for vector operations while keeping typed sorts for scalars. This keeps everything in the egraph for one big saturation pass. Add VecFRem, VecFMod, VecSDiv, VecUDiv, VecSRem, VecSMod, VecUMod constructors with DCE, VecSize, LICM rules, and reconstruction.
1 parent 22bcfb1 commit a9b807b

5 files changed

Lines changed: 160 additions & 21 deletions

File tree

rust/spirv-tools-opt/src/direct/context.rs

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -217,15 +217,15 @@ impl EgglogContext {
217217
Op::ConstantTrue => "(BoolConst 1)".to_string(),
218218
Op::ConstantFalse => "(BoolConst 0)".to_string(),
219219

220-
Op::IAdd => self.binary_op("Add", inst)?,
221-
Op::ISub => self.binary_op("Sub", inst)?,
222-
Op::IMul => self.binary_op("Mul", inst)?,
223-
Op::SDiv => self.binary_op("SDiv", inst)?,
224-
Op::UDiv => self.binary_op("UDiv", inst)?,
225-
Op::SRem => self.binary_op("SRem", inst)?,
226-
Op::SMod => self.binary_op("SMod", inst)?,
227-
Op::UMod => self.binary_op("UMod", inst)?,
228-
Op::SNegate => self.unary_op("Neg", inst)?,
220+
Op::IAdd => self.typed_binary_op("Add", "VecAdd", inst)?,
221+
Op::ISub => self.typed_binary_op("Sub", "VecSub", inst)?,
222+
Op::IMul => self.typed_binary_op("Mul", "VecMul", inst)?,
223+
Op::SDiv => self.typed_binary_op("SDiv", "VecSDiv", inst)?,
224+
Op::UDiv => self.typed_binary_op("UDiv", "VecUDiv", inst)?,
225+
Op::SRem => self.typed_binary_op("SRem", "VecSRem", inst)?,
226+
Op::SMod => self.typed_binary_op("SMod", "VecSMod", inst)?,
227+
Op::UMod => self.typed_binary_op("UMod", "VecUMod", inst)?,
228+
Op::SNegate => self.typed_unary_op("Neg", "VecNeg", inst)?,
229229
Op::ShiftLeftLogical => self.binary_op("Shl", inst)?,
230230
Op::ShiftRightLogical => self.binary_op("ShrU", inst)?,
231231
Op::ShiftRightArithmetic => self.binary_op("ShrS", inst)?,
@@ -249,14 +249,14 @@ impl EgglogContext {
249249
Op::LogicalOr => self.binary_op("LogOr", inst)?,
250250
Op::LogicalEqual => self.binary_op("LogEq", inst)?,
251251
Op::LogicalNotEqual => self.binary_op("LogNe", inst)?,
252-
// Floating-point operations
253-
Op::FAdd => self.binary_op("FAdd", inst)?,
254-
Op::FSub => self.binary_op("FSub", inst)?,
255-
Op::FMul => self.binary_op("FMul", inst)?,
256-
Op::FDiv => self.binary_op("FDiv", inst)?,
257-
Op::FRem => self.binary_op("FRem", inst)?,
258-
Op::FMod => self.binary_op("FMod", inst)?,
259-
Op::FNegate => self.unary_op("FNeg", inst)?,
252+
// Floating-point operations (scalar or vector)
253+
Op::FAdd => self.typed_binary_op("FAdd", "VecFAdd", inst)?,
254+
Op::FSub => self.typed_binary_op("FSub", "VecFSub", inst)?,
255+
Op::FMul => self.typed_binary_op("FMul", "VecFMul", inst)?,
256+
Op::FDiv => self.typed_binary_op("FDiv", "VecFDiv", inst)?,
257+
Op::FRem => self.typed_binary_op("FRem", "VecFRem", inst)?,
258+
Op::FMod => self.typed_binary_op("FMod", "VecFMod", inst)?,
259+
Op::FNegate => self.typed_unary_op("FNeg", "VecFNeg", inst)?,
260260
// Floating-point comparisons (ordered)
261261
Op::FOrdEqual => self.binary_op("FOrdEq", inst)?,
262262
Op::FOrdNotEqual => self.binary_op("FOrdNe", inst)?,
@@ -930,6 +930,46 @@ impl EgglogContext {
930930
Some(format!("({} {})", op, operand))
931931
}
932932

933+
/// Type-dispatched binary op: uses `scalar_op` (typed sort) when the result
934+
/// type is a scalar, and `vec_op` (Expr sort) when it is a vector/other type.
935+
/// This prevents sort mismatches when SPIR-V opcodes like OpFAdd/OpIAdd
936+
/// operate on vectors whose operands are in the generic Expr sort.
937+
fn typed_binary_op(
938+
&mut self,
939+
scalar_op: &str,
940+
vec_op: &str,
941+
inst: &Instruction,
942+
) -> Option<String> {
943+
let is_scalar = inst
944+
.result_type
945+
.map(|ty| self.type_class_of_type(ty) != TypeClass::Other)
946+
.unwrap_or(false);
947+
if is_scalar {
948+
self.binary_op(scalar_op, inst)
949+
} else {
950+
self.binary_op(vec_op, inst)
951+
}
952+
}
953+
954+
/// Type-dispatched unary op: uses `scalar_op` (typed sort) when the result
955+
/// type is a scalar, and `vec_op` (Expr sort) when it is a vector/other type.
956+
fn typed_unary_op(
957+
&mut self,
958+
scalar_op: &str,
959+
vec_op: &str,
960+
inst: &Instruction,
961+
) -> Option<String> {
962+
let is_scalar = inst
963+
.result_type
964+
.map(|ty| self.type_class_of_type(ty) != TypeClass::Other)
965+
.unwrap_or(false);
966+
if is_scalar {
967+
self.unary_op(scalar_op, inst)
968+
} else {
969+
self.unary_op(vec_op, inst)
970+
}
971+
}
972+
933973
/// Convert GLSL.std.450 extended instruction to egglog term.
934974
fn extended_instruction_to_term(&mut self, inst: &Instruction) -> Option<String> {
935975
// ExtInst operands: %set %instruction operands...

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,19 +2454,24 @@ fn instruction_has_valid_types(
24542454
let op = inst.class.opcode;
24552455

24562456
// Check result type: if the opcode requires a specific type class,
2457-
// the result type MUST match (no TypeClass::Other escape).
2457+
// the result type MUST match OR be TypeClass::Other (vectors/matrices).
2458+
// SPIR-V arithmetic ops (FAdd, IAdd, etc.) work on both scalars and vectors.
2459+
// TypeClass::Other means the type is a vector/matrix/struct which is valid
2460+
// for these component-wise operations.
24582461
if let (Some(required), Some(result_type)) = (required_result_type_class(op), inst.result_type)
24592462
{
24602463
let actual = type_classes
24612464
.get(&result_type)
24622465
.copied()
24632466
.unwrap_or(TypeClass::Other);
2464-
if actual != required {
2467+
if actual != required && actual != TypeClass::Other {
24652468
return false;
24662469
}
24672470
}
24682471

24692472
// Check operand types for comparisons
2473+
// Again, TypeClass::Other (vector operands) is always acceptable since
2474+
// SPIR-V comparison ops work component-wise on vectors.
24702475
if let Some(required_op_class) = required_operand_type_class(op) {
24712476
for operand in &inst.operands {
24722477
if let Some(operand_id) = operand.id_ref_any() {
@@ -2475,7 +2480,7 @@ fn instruction_has_valid_types(
24752480
.get(&operand_type)
24762481
.copied()
24772482
.unwrap_or(TypeClass::Other);
2478-
if actual != required_op_class {
2483+
if actual != required_op_class && actual != TypeClass::Other {
24792484
return false;
24802485
}
24812486
}

rust/spirv-tools-opt/src/direct/parse/vector.rs

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use rspirv::spirv::{Op, Word};
55
use std::collections::HashMap;
66

77
use super::util::{
8-
parse_binary_args, parse_expr_list, parse_ternary_args, resolve_term_to_id, split_terms,
8+
parse_binary_args, parse_expr_list, parse_ternary_args, parse_unary_arg, resolve_term_to_id,
9+
split_terms,
910
};
1011

1112
/// Matrix binary operations.
@@ -21,6 +22,36 @@ const MATRIX_BINARY_OPS: &[(&str, Op)] = &[
2122
/// Matrix unary operations.
2223
const MATRIX_UNARY_OPS: &[(&str, Op)] = &[("Transpose", Op::Transpose)];
2324

25+
/// Vector arithmetic binary operations (component-wise).
26+
/// These are the Expr-sort constructors for ops that work on both scalars and vectors.
27+
const VEC_ARITHMETIC_BINARY_OPS: &[(&str, Op)] = &[
28+
// Integer vector arithmetic
29+
("VecAdd", Op::IAdd),
30+
("VecSub", Op::ISub),
31+
("VecMul", Op::IMul),
32+
("VecDiv", Op::SDiv),
33+
("VecSDiv", Op::SDiv),
34+
("VecUDiv", Op::UDiv),
35+
("VecSRem", Op::SRem),
36+
("VecSMod", Op::SMod),
37+
("VecUMod", Op::UMod),
38+
// Float vector arithmetic
39+
("VecFAdd", Op::FAdd),
40+
("VecFSub", Op::FSub),
41+
("VecFMul", Op::FMul),
42+
("VecFDiv", Op::FDiv),
43+
("VecFRem", Op::FRem),
44+
("VecFMod", Op::FMod),
45+
// Scalar-vector multiply
46+
("VecTimesScalar", Op::VectorTimesScalar),
47+
];
48+
49+
/// Vector arithmetic unary operations (component-wise).
50+
const VEC_ARITHMETIC_UNARY_OPS: &[(&str, Op)] = &[
51+
("VecNeg", Op::SNegate),
52+
("VecFNeg", Op::FNegate),
53+
];
54+
2455
/// Try to parse a vector or composite operation.
2556
pub fn try_parse_vector(
2657
term: &str,
@@ -63,6 +94,39 @@ pub fn try_parse_vector(
6394
}
6495
}
6596

97+
// Vector arithmetic binary operations (component-wise)
98+
for (name, opcode) in VEC_ARITHMETIC_BINARY_OPS {
99+
let prefix = format!("({} ", name);
100+
if let Some(rest) = term.strip_prefix(&prefix) {
101+
if let Some((lhs, rhs)) = parse_binary_args(rest, id_map) {
102+
return Some(Instruction::new(
103+
*opcode,
104+
Some(result_type),
105+
Some(result_id),
106+
vec![
107+
rspirv::dr::Operand::IdRef(lhs),
108+
rspirv::dr::Operand::IdRef(rhs),
109+
],
110+
));
111+
}
112+
}
113+
}
114+
115+
// Vector arithmetic unary operations (component-wise)
116+
for (name, opcode) in VEC_ARITHMETIC_UNARY_OPS {
117+
let prefix = format!("({} ", name);
118+
if let Some(rest) = term.strip_prefix(&prefix) {
119+
if let Some(operand) = parse_unary_arg(rest, id_map) {
120+
return Some(Instruction::new(
121+
*opcode,
122+
Some(result_type),
123+
Some(result_id),
124+
vec![rspirv::dr::Operand::IdRef(operand)],
125+
));
126+
}
127+
}
128+
}
129+
66130
// VectorInsertDynamic (ternary: vector, component, index)
67131
if let Some(rest) = term.strip_prefix("(VectorInsertDynamic ") {
68132
if let Some((vec, component, idx)) = parse_ternary_args(rest, id_map) {

rust/spirv-tools-opt/src/rules/datatypes.egg

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,13 @@
393393
(constructor VecFMul (Expr Expr) Expr)
394394
(constructor VecFDiv (Expr Expr) Expr)
395395
(constructor VecFNeg (Expr) Expr)
396+
(constructor VecFRem (Expr Expr) Expr)
397+
(constructor VecFMod (Expr Expr) Expr)
398+
(constructor VecSDiv (Expr Expr) Expr)
399+
(constructor VecUDiv (Expr Expr) Expr)
400+
(constructor VecSRem (Expr Expr) Expr)
401+
(constructor VecSMod (Expr Expr) Expr)
402+
(constructor VecUMod (Expr Expr) Expr)
396403
(constructor VecTimesScalar (Expr Expr) Expr)
397404

398405
; Vector-level GLSL ops (vector variants of scalar FloatExpr constructors)
@@ -1103,6 +1110,13 @@
11031110
(rule ((Live e) (= e (VecFMul a b))) ((Live a) (Live b)))
11041111
(rule ((Live e) (= e (VecFDiv a b))) ((Live a) (Live b)))
11051112
(rule ((Live e) (= e (VecFNeg a))) ((Live a)))
1113+
(rule ((Live e) (= e (VecFRem a b))) ((Live a) (Live b)))
1114+
(rule ((Live e) (= e (VecFMod a b))) ((Live a) (Live b)))
1115+
(rule ((Live e) (= e (VecSDiv a b))) ((Live a) (Live b)))
1116+
(rule ((Live e) (= e (VecUDiv a b))) ((Live a) (Live b)))
1117+
(rule ((Live e) (= e (VecSRem a b))) ((Live a) (Live b)))
1118+
(rule ((Live e) (= e (VecSMod a b))) ((Live a) (Live b)))
1119+
(rule ((Live e) (= e (VecUMod a b))) ((Live a) (Live b)))
11061120
(rule ((Live e) (= e (VecTimesScalar v s))) ((Live v) (Live s)))
11071121
(rule ((Live e) (= e (VecFMix a b c))) ((Live a) (Live b) (Live c)))
11081122
(rule ((Live e) (= e (VecSelect mask t f))) ((Live mask) (Live t) (Live f)))

rust/spirv-tools-opt/src/rules/vector.egg

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,14 @@
619619
(rule ((= e (VecFSub (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecFSub a b)))))
620620
(rule ((= e (VecFMul (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecFMul a b)))))
621621
(rule ((= e (VecFNeg (LoopInvariant a)))) ((union e (LoopInvariant (VecFNeg a)))))
622+
(rule ((= e (VecFRem (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecFRem a b)))))
623+
(rule ((= e (VecFMod (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecFMod a b)))))
624+
; Note: VecDiv and VecFDiv loop invariant rules are in licm.egg
625+
(rule ((= e (VecSDiv (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecSDiv a b)))))
626+
(rule ((= e (VecUDiv (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecUDiv a b)))))
627+
(rule ((= e (VecSRem (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecSRem a b)))))
628+
(rule ((= e (VecSMod (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecSMod a b)))))
629+
(rule ((= e (VecUMod (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (VecUMod a b)))))
622630
; Note: Dot returns FloatExpr, use LoopInvariantF (FloatExpr -> FloatExpr)
623631
(rule ((= e (Dot (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariantF (Dot a b)))))
624632
(rule ((= e (Cross (LoopInvariant a) (LoopInvariant b)))) ((union e (LoopInvariant (Cross a b)))))
@@ -649,6 +657,14 @@
649657
(rule ((= e (VecFSub a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
650658
(rule ((= e (VecFMul a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
651659
(rule ((= e (VecFNeg a)) (= sz (VecSize a))) ((set (VecSize e) sz)))
660+
(rule ((= e (VecFRem a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
661+
(rule ((= e (VecFMod a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
662+
(rule ((= e (VecDiv a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
663+
(rule ((= e (VecSDiv a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
664+
(rule ((= e (VecUDiv a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
665+
(rule ((= e (VecSRem a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
666+
(rule ((= e (VecSMod a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
667+
(rule ((= e (VecUMod a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
652668
(rule ((= e (VecTimesScalar a _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
653669
(rule ((= e (VecFMix a _ _)) (= sz (VecSize a))) ((set (VecSize e) sz)))
654670

0 commit comments

Comments
 (0)