@@ -7488,3 +7488,121 @@ fn test_cross_type_bitcast_does_not_add_int_identity() {
74887488 "Cross-type bitcast must not be simplified to identity"
74897489 ) ;
74907490}
7491+
7492+ #[ test]
7493+ fn test_vec_fadd_accepts_expr_operands ( ) {
7494+ // VecFAdd uses Expr sort, so operands from Sym (Expr) are accepted.
7495+ // This is the vector dispatch path for OpFAdd on vector types.
7496+ let mut egraph = create_spirv_egraph ( ) . unwrap ( ) ;
7497+
7498+ egraph
7499+ . parse_and_run_program (
7500+ None ,
7501+ r#"
7502+ (let a (Sym "vec_a"))
7503+ (let b (Sym "vec_b"))
7504+ (let result (VecFAdd a b))
7505+ "# ,
7506+ )
7507+ . unwrap ( ) ;
7508+ egraph
7509+ . parse_and_run_program ( None , "(run-schedule (repeat 10 (run)))" )
7510+ . unwrap ( ) ;
7511+
7512+ // Verify the term was created (not rejected by sort mismatch)
7513+ let check = egraph. parse_and_run_program ( None , "(check (= result (VecFAdd a b)))" ) ;
7514+ assert ! (
7515+ check. is_ok( ) ,
7516+ "VecFAdd should accept Expr operands without sort mismatch"
7517+ ) ;
7518+ }
7519+
7520+ #[ test]
7521+ fn test_scalar_fadd_rejects_expr_operands ( ) {
7522+ // FAdd uses FloatExpr sort, so Sym (Expr) operands must be rejected.
7523+ // This ensures scalar operations enforce type safety.
7524+ let mut egraph = create_spirv_egraph ( ) . unwrap ( ) ;
7525+
7526+ let result = egraph. parse_and_run_program (
7527+ None ,
7528+ r#"
7529+ (let a (Sym "val_a"))
7530+ (let b (Sym "val_b"))
7531+ (let result (FAdd a b))
7532+ "# ,
7533+ ) ;
7534+ assert ! (
7535+ result. is_err( ) ,
7536+ "FAdd must reject Expr operands — it requires FloatExpr"
7537+ ) ;
7538+ }
7539+
7540+ #[ test]
7541+ fn test_vec_add_accepts_expr_operands ( ) {
7542+ // VecAdd uses Expr sort for integer vector operations.
7543+ let mut egraph = create_spirv_egraph ( ) . unwrap ( ) ;
7544+
7545+ egraph
7546+ . parse_and_run_program (
7547+ None ,
7548+ r#"
7549+ (let a (Sym "ivec_a"))
7550+ (let b (Sym "ivec_b"))
7551+ (let result (VecAdd a b))
7552+ "# ,
7553+ )
7554+ . unwrap ( ) ;
7555+ egraph
7556+ . parse_and_run_program ( None , "(run-schedule (repeat 10 (run)))" )
7557+ . unwrap ( ) ;
7558+
7559+ let check = egraph. parse_and_run_program ( None , "(check (= result (VecAdd a b)))" ) ;
7560+ assert ! (
7561+ check. is_ok( ) ,
7562+ "VecAdd should accept Expr operands without sort mismatch"
7563+ ) ;
7564+ }
7565+
7566+ #[ test]
7567+ fn test_vec_fneg_accepts_expr_operand ( ) {
7568+ // VecFNeg uses Expr sort for vector float negation.
7569+ let mut egraph = create_spirv_egraph ( ) . unwrap ( ) ;
7570+
7571+ egraph
7572+ . parse_and_run_program (
7573+ None ,
7574+ r#"
7575+ (let a (Sym "vec_a"))
7576+ (let result (VecFNeg a))
7577+ "# ,
7578+ )
7579+ . unwrap ( ) ;
7580+ egraph
7581+ . parse_and_run_program ( None , "(run-schedule (repeat 10 (run)))" )
7582+ . unwrap ( ) ;
7583+
7584+ let check = egraph. parse_and_run_program ( None , "(check (= result (VecFNeg a)))" ) ;
7585+ assert ! (
7586+ check. is_ok( ) ,
7587+ "VecFNeg should accept Expr operand without sort mismatch"
7588+ ) ;
7589+ }
7590+
7591+ #[ test]
7592+ fn test_scalar_add_rejects_expr_operands ( ) {
7593+ // Add uses IntExpr sort, so Sym (Expr) operands must be rejected.
7594+ let mut egraph = create_spirv_egraph ( ) . unwrap ( ) ;
7595+
7596+ let result = egraph. parse_and_run_program (
7597+ None ,
7598+ r#"
7599+ (let a (Sym "val_a"))
7600+ (let b (Sym "val_b"))
7601+ (let result (Add a b))
7602+ "# ,
7603+ ) ;
7604+ assert ! (
7605+ result. is_err( ) ,
7606+ "Add must reject Expr operands — it requires IntExpr"
7607+ ) ;
7608+ }
0 commit comments