Skip to content

Commit 0a9bd5d

Browse files
committed
Fix sort mismatches in Theta and Gamma term construction
Theta construction used (Const 1) as condition (IntExpr) but all Theta variants require BoolExpr. Also used untyped Theta with typed body terms and hardcoded (Const 0) init for all types. Fix: use typed Theta (ThetaI/ThetaF/ThetaB/Theta) matching the value's SPIR-V type class, with (BoolConst 1) condition and type-appropriate init value, referencing id{N} instead of raw term. Gamma construction determined type class from only the then-branch value but created Gamma for all pairs of then/else block IDs. When values had different types, the typed Gamma would crash the egraph. Fix: check both branch value type classes and skip mismatched pairs. Move branch_value_pairs.push after the type check to avoid trying to extract gamma variables that were never created.
1 parent b11a706 commit 0a9bd5d

1 file changed

Lines changed: 39 additions & 18 deletions

File tree

  • rust/spirv-tools-opt/src/direct

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -341,29 +341,40 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
341341
None => continue,
342342
};
343343

344-
// Record this pair even if terms differ - egraph handles equivalence
344+
// Add a Gamma term representing the selection between these values
345+
// The egraph will unify equivalent computations via (GammaX c x x) => x
346+
// Use typed Gamma based on branch type class
347+
let then_type_class = ctx
348+
.id_to_type
349+
.get(&then_id)
350+
.and_then(|ty| type_classes.get(ty))
351+
.copied()
352+
.unwrap_or(TypeClass::Other);
353+
let else_type_class = ctx
354+
.id_to_type
355+
.get(&else_id)
356+
.and_then(|ty| type_classes.get(ty))
357+
.copied()
358+
.unwrap_or(TypeClass::Other);
359+
// Skip if branch values have different type classes — creating
360+
// a typed Gamma with mismatched sorts would crash the egraph.
361+
if then_type_class != else_type_class {
362+
continue;
363+
}
364+
345365
branch_value_pairs.push(BranchValuePair {
346366
then_id,
347367
else_id,
348368
condition_id: sel.condition_id,
349369
header_block_label: header_label,
350370
});
351371

352-
// Add a Gamma term representing the selection between these values
353-
// The egraph will unify equivalent computations via (GammaX c x x) => x
354-
// Use typed Gamma based on branch type class
355372
let cond_term = if ctx.id_to_term.contains_key(&sel.condition_id) {
356373
format!("id{}", sel.condition_id)
357374
} else {
358375
format!("(BSym \"id{}\")", sel.condition_id)
359376
};
360-
let branch_type_class = ctx
361-
.id_to_type
362-
.get(&then_id)
363-
.and_then(|ty| type_classes.get(ty))
364-
.copied()
365-
.unwrap_or(TypeClass::Other);
366-
let gamma_ctor = match branch_type_class {
377+
let gamma_ctor = match then_type_class {
367378
TypeClass::Int => "GammaI",
368379
TypeClass::Float => "GammaF",
369380
TypeClass::Bool => "GammaB",
@@ -423,13 +434,23 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
423434
if let Some(&block_label) = id_to_block.get(&id) {
424435
if body_labels.contains(&block_label) {
425436
// This value is defined inside the loop
426-
if let Some(term) = ctx.id_to_term.get(&id) {
427-
// Create a Theta node representing this loop computation
428-
// Theta(cond, body, init) where:
429-
// - cond: (Const 1) for infinite loops
430-
// - body: the expression computed in the loop
431-
// - init: (Const 0) placeholder for loop-carried state
432-
let theta_term = format!("(Theta (Const 1) {} (Const 0))", term);
437+
if ctx.id_to_term.contains_key(&id) {
438+
// Use typed Theta matching the value's sort to avoid
439+
// sort mismatches (IntExpr/FloatExpr/BoolExpr vs Expr)
440+
let value_type_class = ctx
441+
.id_to_type
442+
.get(&id)
443+
.and_then(|ty| type_classes.get(ty))
444+
.copied()
445+
.unwrap_or(TypeClass::Other);
446+
let (theta_ctor, init_val) = match value_type_class {
447+
TypeClass::Int => ("ThetaI", "(Const 0)".to_string()),
448+
TypeClass::Float => ("ThetaF", "(FConst 0.0)".to_string()),
449+
TypeClass::Bool => ("ThetaB", "(BoolConst 0)".to_string()),
450+
TypeClass::Other => ("Theta", format!("(Sym \"theta_init_{}\")", id)),
451+
};
452+
let theta_term =
453+
format!("({} (BoolConst 1) id{} {})", theta_ctor, id, init_val);
433454
let theta_binding = format!("(let theta_{} {})", id, theta_term);
434455
egraph
435456
.parse_and_run_program(None, &theta_binding)

0 commit comments

Comments
 (0)