Skip to content

Commit eb06849

Browse files
committed
Per-operation signed int type resolution replaces global preference
The previous signed int preference in find_spirv_type caused regressions for unsigned operations (ULessThan, ConvertUToF, etc.) by making the int32_type fallback always signed. This produced type mismatches when one operand kept its original unsigned type and another used the signed fallback. Replace with targeted approach: - Track signed_int32_type separately from int32_type - For explicitly signed ops (ConvertSToF, SLessThan, etc.), use the signed int type - For other cross-class ops, infer operand type from Sym atoms to ensure consistency between existing and synthesized operands
1 parent 78aee01 commit eb06849

2 files changed

Lines changed: 97 additions & 20 deletions

File tree

rust/spirv-tools-opt/src/direct/emit.rs

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ pub struct EmitCtx<'a> {
122122
pub next_id: &'a mut Word,
123123
pub int32_type: Option<Word>,
124124
pub int64_type: Option<Word>,
125+
/// Signed int32 type (OpTypeInt 32 1) for signed operations like ConvertSToF.
126+
/// Separate from int32_type which may be unsigned depending on module order.
127+
pub signed_int32_type: Option<Word>,
125128
pub float32_type: Option<Word>,
126129
pub float64_type: Option<Word>,
127130
pub bool_type: Option<Word>,
@@ -1107,6 +1110,39 @@ fn resolve_operand_type(class: TypeClass, result_type: Word, ctx: &EmitCtx) -> W
11071110
}
11081111
}
11091112

1113+
/// For cross-class operations, try to infer the operand type from Sym atoms.
1114+
///
1115+
/// When one operand is a Sym (existing ID) and another is a new expression,
1116+
/// the Sym's actual type (from id_to_type) should be used for both operands
1117+
/// to avoid signed/unsigned or width mismatches.
1118+
fn infer_operand_type_from_args(
1119+
args: &[Term],
1120+
arity: usize,
1121+
operand_class: TypeClass,
1122+
fallback: Word,
1123+
ctx: &EmitCtx,
1124+
) -> Word {
1125+
for arg in &args[..arity.min(args.len())] {
1126+
if let Term::Atom(a) = arg {
1127+
if let Some(stripped) = a.strip_prefix("id") {
1128+
if let Ok(id) = stripped.parse::<Word>() {
1129+
if let Some(&ty) = ctx.id_to_type.get(&id) {
1130+
let tc = ctx
1131+
.type_classes
1132+
.get(&ty)
1133+
.copied()
1134+
.unwrap_or(TypeClass::Other);
1135+
if tc == operand_class {
1136+
return ty;
1137+
}
1138+
}
1139+
}
1140+
}
1141+
}
1142+
}
1143+
fallback
1144+
}
1145+
11101146
// ---------------------------------------------------------------------------
11111147
// Core emission
11121148
// ---------------------------------------------------------------------------
@@ -1260,6 +1296,35 @@ fn emit_pattern(
12601296
}
12611297
let op_result_type = resolve_result_type(result_class, result_type, ctx);
12621298
let operand_type = resolve_operand_type(operand_class, result_type, ctx);
1299+
// For cross-class operations where the operand class differs from
1300+
// result class, the operand_type is a fallback (e.g. int32_type).
1301+
// Refine it to avoid signed/unsigned mismatches:
1302+
// 1. Signed int ops (ConvertSToF, SLessThan, etc.) → signed int type
1303+
// 2. Other cross-class ops → infer from Sym operand's actual type
1304+
let operand_type = if operand_class == TypeClass::Int
1305+
&& result_class != operand_class
1306+
{
1307+
match opcode {
1308+
Op::ConvertSToF
1309+
| Op::ConvertFToS
1310+
| Op::SLessThan
1311+
| Op::SLessThanEqual
1312+
| Op::SGreaterThan
1313+
| Op::SGreaterThanEqual
1314+
| Op::SDiv
1315+
| Op::SRem
1316+
| Op::SMod
1317+
| Op::SConvert => ctx.signed_int32_type.unwrap_or(operand_type),
1318+
_ => infer_operand_type_from_args(
1319+
args, arity, operand_class, operand_type, ctx,
1320+
),
1321+
}
1322+
} else if result_class != operand_class {
1323+
// Non-int cross-class: infer from Sym args for consistency
1324+
infer_operand_type_from_args(args, arity, operand_class, operand_type, ctx)
1325+
} else {
1326+
operand_type
1327+
};
12631328
let mut synth = Vec::new();
12641329
let mut operand_ids = Vec::with_capacity(arity);
12651330
for arg in &args[..arity] {
@@ -1886,6 +1951,7 @@ mod tests {
18861951
next_id: &mut next_id,
18871952
int32_type: Some(10),
18881953
int64_type: None,
1954+
signed_int32_type: None,
18891955
float32_type: Some(11),
18901956
float64_type: None,
18911957
bool_type: Some(12),
@@ -1912,6 +1978,7 @@ mod tests {
19121978
next_id: &mut next_id,
19131979
int32_type: Some(10),
19141980
int64_type: None,
1981+
signed_int32_type: None,
19151982
float32_type: Some(11),
19161983
float64_type: None,
19171984
bool_type: Some(12),
@@ -1940,6 +2007,7 @@ mod tests {
19402007
next_id: &mut next_id,
19412008
int32_type: Some(10),
19422009
int64_type: None,
2010+
signed_int32_type: None,
19432011
float32_type: Some(11),
19442012
float64_type: None,
19452013
bool_type: Some(12),
@@ -1972,6 +2040,7 @@ mod tests {
19722040
next_id: &mut next_id,
19732041
int32_type: Some(10),
19742042
int64_type: None,
2043+
signed_int32_type: None,
19752044
float32_type: Some(11),
19762045
float64_type: None,
19772046
bool_type: Some(12),
@@ -2001,6 +2070,7 @@ mod tests {
20012070
next_id,
20022071
int32_type: Some(10),
20032072
int64_type: None,
2073+
signed_int32_type: None,
20042074
float32_type: Some(11),
20052075
float64_type: None,
20062076
bool_type: Some(12),

rust/spirv-tools-opt/src/direct/mod.rs

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,8 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
954954
let bool_type = find_spirv_type(module, Op::TypeBool, None);
955955
let float32_type = find_spirv_type(module, Op::TypeFloat, Some(32));
956956
let float64_type = find_spirv_type(module, Op::TypeFloat, Some(64));
957+
// Signed int types for operations like ConvertSToF where signedness matters
958+
let signed_int32_type = find_signed_int_type(module, 32);
957959

958960
// Build composite → element type mapping for CompositeConstruct emission
959961
let mut composite_element_types: HashMap<Word, Word> = HashMap::new();
@@ -1056,6 +1058,7 @@ pub fn optimize_module_direct(module: &Module) -> Result<Module, EgglogOptError>
10561058
next_id: &mut next_id,
10571059
int32_type,
10581060
int64_type,
1061+
signed_int32_type,
10591062
float32_type,
10601063
float64_type,
10611064
bool_type,
@@ -2261,29 +2264,33 @@ impl TypeClass {
22612264

22622265
/// Find a SPIR-V type declaration's result ID by opcode and optional bit-width.
22632266
fn find_spirv_type(module: &Module, opcode: Op, width: Option<u32>) -> Option<Word> {
2264-
let matches_width = |inst: &Instruction| {
2265-
inst.class.opcode == opcode
2266-
&& match width {
2267-
Some(w) => inst.operands.first() == Some(&rspirv::dr::Operand::LiteralBit32(w)),
2268-
None => true,
2269-
}
2270-
};
2271-
// For OpTypeInt, prefer signed (signedness=1). The int32_type/int64_type
2272-
// fallback is used in resolve_operand_type for cross-class operations like
2273-
// ConvertSToF. If the fallback is unsigned, cross-compilers (SPIRV-Cross)
2274-
// may generate unsigned conversion code, turning negative values to zero.
2275-
if opcode == Op::TypeInt {
2276-
if let Some(inst) = module.types_global_values.iter().find(|inst| {
2277-
matches_width(inst)
2278-
&& inst.operands.get(1) == Some(&rspirv::dr::Operand::LiteralBit32(1))
2279-
}) {
2280-
return inst.result_id;
2281-
}
2282-
}
22832267
module
22842268
.types_global_values
22852269
.iter()
2286-
.find(|inst| matches_width(inst))
2270+
.find(|inst| {
2271+
inst.class.opcode == opcode
2272+
&& match width {
2273+
Some(w) => {
2274+
inst.operands.first() == Some(&rspirv::dr::Operand::LiteralBit32(w))
2275+
}
2276+
None => true,
2277+
}
2278+
})
2279+
.and_then(|inst| inst.result_id)
2280+
}
2281+
2282+
/// Find the signed variant of an integer type (signedness=1).
2283+
/// Used for signed operations like ConvertSToF where cross-compilers
2284+
/// may generate unsigned conversion code if the type is unsigned.
2285+
fn find_signed_int_type(module: &Module, width: u32) -> Option<Word> {
2286+
module
2287+
.types_global_values
2288+
.iter()
2289+
.find(|inst| {
2290+
inst.class.opcode == Op::TypeInt
2291+
&& inst.operands.first() == Some(&rspirv::dr::Operand::LiteralBit32(width))
2292+
&& inst.operands.get(1) == Some(&rspirv::dr::Operand::LiteralBit32(1))
2293+
})
22872294
.and_then(|inst| inst.result_id)
22882295
}
22892296

0 commit comments

Comments
 (0)