Skip to content

Commit 6fef24a

Browse files
committed
Skip CopyObject emission when operand and result types differ
When the egraph unifies values of different SPIR-V types (e.g. Vec4 and Vec2 are both Expr), extraction may resolve to an ID with a different type. Previously a CopyObject was emitted regardless, causing "operand type not matching result type" validation errors. Now only emit CopyObject and create an alias when the SPIR-V types actually match. When they don't, the original instruction is preserved.
1 parent eee67f1 commit 6fef24a

1 file changed

Lines changed: 15 additions & 11 deletions

File tree

  • rust/spirv-tools-opt/src/direct

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,22 +1076,26 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
10761076
// emit_term resolved to an existing ID — emit CopyObject
10771077
// if it's a different ID, or skip if same
10781078
if final_id != id {
1079-
// Only alias when SPIR-V types match (see above)
1079+
// Only emit CopyObject when SPIR-V types match.
1080+
// The egraph may unify values of different types
1081+
// (e.g. Vec4 and Vec2 both as Expr). Emitting a
1082+
// CopyObject with mismatched types would cause
1083+
// validation errors.
10801084
let type_matches = ctx.id_to_type.get(&id)
10811085
== ctx.id_to_type.get(&final_id);
10821086
if type_matches {
10831087
id_aliases.insert(id, final_id);
1088+
used_ids.insert(final_id);
1089+
optimized_instructions.insert(
1090+
id,
1091+
Instruction::new(
1092+
Op::CopyObject,
1093+
Some(corrected_type),
1094+
Some(id),
1095+
vec![rspirv::dr::Operand::IdRef(final_id)],
1096+
),
1097+
);
10841098
}
1085-
used_ids.insert(final_id);
1086-
optimized_instructions.insert(
1087-
id,
1088-
Instruction::new(
1089-
Op::CopyObject,
1090-
Some(corrected_type),
1091-
Some(id),
1092-
vec![rspirv::dr::Operand::IdRef(final_id)],
1093-
),
1094-
);
10951099
}
10961100
} else {
10971101
for (i, mut inst) in new_insts.into_iter().enumerate() {

0 commit comments

Comments
 (0)