Skip to content

Commit b8a5f5d

Browse files
committed
Add vector dispatch tests and fix formatting
1 parent a9b807b commit b8a5f5d

3 files changed

Lines changed: 120 additions & 5 deletions

File tree

rust/spirv-tools-opt/src/direct/parse/vector.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,8 @@ const VEC_ARITHMETIC_BINARY_OPS: &[(&str, Op)] = &[
4747
];
4848

4949
/// Vector arithmetic unary operations (component-wise).
50-
const VEC_ARITHMETIC_UNARY_OPS: &[(&str, Op)] = &[
51-
("VecNeg", Op::SNegate),
52-
("VecFNeg", Op::FNegate),
53-
];
50+
const VEC_ARITHMETIC_UNARY_OPS: &[(&str, Op)] =
51+
&[("VecNeg", Op::SNegate), ("VecFNeg", Op::FNegate)];
5452

5553
/// Try to parse a vector or composite operation.
5654
pub fn try_parse_vector(

rust/spirv-tools-opt/src/egglog_opt/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,6 @@ fn float_neg_bits(a: i64) -> i64 {
495495
}
496496
}
497497

498-
499498
/// SConvert constant fold: sign-extend or truncate to target width.
500499
fn sconvert_fold(v: i64, dst_width: i64) -> i64 {
501500
let dw = dst_width as u32;

rust/spirv-tools-opt/src/egglog_opt/tests.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7488,3 +7488,121 @@ fn test_cross_type_bitcast_does_not_add_int_identity() {
74887488
"Cross-type bitcast must not be simplified to identity"
74897489
);
74907490
}
7491+
7492+
#[test]
7493+
fn test_vec_fadd_accepts_expr_operands() {
7494+
// VecFAdd uses Expr sort, so operands from Sym (Expr) are accepted.
7495+
// This is the vector dispatch path for OpFAdd on vector types.
7496+
let mut egraph = create_spirv_egraph().unwrap();
7497+
7498+
egraph
7499+
.parse_and_run_program(
7500+
None,
7501+
r#"
7502+
(let a (Sym "vec_a"))
7503+
(let b (Sym "vec_b"))
7504+
(let result (VecFAdd a b))
7505+
"#,
7506+
)
7507+
.unwrap();
7508+
egraph
7509+
.parse_and_run_program(None, "(run-schedule (repeat 10 (run)))")
7510+
.unwrap();
7511+
7512+
// Verify the term was created (not rejected by sort mismatch)
7513+
let check = egraph.parse_and_run_program(None, "(check (= result (VecFAdd a b)))");
7514+
assert!(
7515+
check.is_ok(),
7516+
"VecFAdd should accept Expr operands without sort mismatch"
7517+
);
7518+
}
7519+
7520+
#[test]
7521+
fn test_scalar_fadd_rejects_expr_operands() {
7522+
// FAdd uses FloatExpr sort, so Sym (Expr) operands must be rejected.
7523+
// This ensures scalar operations enforce type safety.
7524+
let mut egraph = create_spirv_egraph().unwrap();
7525+
7526+
let result = egraph.parse_and_run_program(
7527+
None,
7528+
r#"
7529+
(let a (Sym "val_a"))
7530+
(let b (Sym "val_b"))
7531+
(let result (FAdd a b))
7532+
"#,
7533+
);
7534+
assert!(
7535+
result.is_err(),
7536+
"FAdd must reject Expr operands — it requires FloatExpr"
7537+
);
7538+
}
7539+
7540+
#[test]
7541+
fn test_vec_add_accepts_expr_operands() {
7542+
// VecAdd uses Expr sort for integer vector operations.
7543+
let mut egraph = create_spirv_egraph().unwrap();
7544+
7545+
egraph
7546+
.parse_and_run_program(
7547+
None,
7548+
r#"
7549+
(let a (Sym "ivec_a"))
7550+
(let b (Sym "ivec_b"))
7551+
(let result (VecAdd a b))
7552+
"#,
7553+
)
7554+
.unwrap();
7555+
egraph
7556+
.parse_and_run_program(None, "(run-schedule (repeat 10 (run)))")
7557+
.unwrap();
7558+
7559+
let check = egraph.parse_and_run_program(None, "(check (= result (VecAdd a b)))");
7560+
assert!(
7561+
check.is_ok(),
7562+
"VecAdd should accept Expr operands without sort mismatch"
7563+
);
7564+
}
7565+
7566+
#[test]
7567+
fn test_vec_fneg_accepts_expr_operand() {
7568+
// VecFNeg uses Expr sort for vector float negation.
7569+
let mut egraph = create_spirv_egraph().unwrap();
7570+
7571+
egraph
7572+
.parse_and_run_program(
7573+
None,
7574+
r#"
7575+
(let a (Sym "vec_a"))
7576+
(let result (VecFNeg a))
7577+
"#,
7578+
)
7579+
.unwrap();
7580+
egraph
7581+
.parse_and_run_program(None, "(run-schedule (repeat 10 (run)))")
7582+
.unwrap();
7583+
7584+
let check = egraph.parse_and_run_program(None, "(check (= result (VecFNeg a)))");
7585+
assert!(
7586+
check.is_ok(),
7587+
"VecFNeg should accept Expr operand without sort mismatch"
7588+
);
7589+
}
7590+
7591+
#[test]
7592+
fn test_scalar_add_rejects_expr_operands() {
7593+
// Add uses IntExpr sort, so Sym (Expr) operands must be rejected.
7594+
let mut egraph = create_spirv_egraph().unwrap();
7595+
7596+
let result = egraph.parse_and_run_program(
7597+
None,
7598+
r#"
7599+
(let a (Sym "val_a"))
7600+
(let b (Sym "val_b"))
7601+
(let result (Add a b))
7602+
"#,
7603+
);
7604+
assert!(
7605+
result.is_err(),
7606+
"Add must reject Expr operands — it requires IntExpr"
7607+
);
7608+
}

0 commit comments

Comments
 (0)