diff --git a/libs/@local/hashql/compiletest/src/pipeline.rs b/libs/@local/hashql/compiletest/src/pipeline.rs index a0f7d427d39..49fe6b68dac 100644 --- a/libs/@local/hashql/compiletest/src/pipeline.rs +++ b/libs/@local/hashql/compiletest/src/pipeline.rs @@ -2,8 +2,10 @@ //! //! [`Pipeline`] drives the full HashQL compilation sequence: parsing J-Expr //! source into an AST, lowering through HIR and MIR, running optimization and -//! execution analysis passes, and finally compiling to -//! [`PreparedQueries`](hashql_eval::postgres::PreparedQueries) ready for PostgreSQL execution. +//! execution analysis passes, and finally compiling to [`PreparedQuery`] +//! ready for PostgreSQL execution. +//! +//! [`PreparedQuery`]: hashql_eval::postgres::PreparedQuery //! //! Each stage is exposed as a separate method so callers can inspect or test //! intermediate results. Diagnostics (warnings, advisories) accumulate in @@ -237,12 +239,14 @@ impl<'heap> Pipeline<'heap> { Ok(()) } - /// Runs execution analysis and compiles MIR bodies to prepared SQL queries. + /// Runs execution analysis on MIR bodies. + /// + /// Performs size estimation and execution island analysis, determining which + /// parts of each body run on PostgreSQL vs the interpreter. Returns + /// per-body residuals that downstream compilation stages use to produce + /// [`PreparedQuery`] instances. /// - /// Performs size estimation, execution island analysis (determining which - /// parts of each body run on PostgreSQL vs the interpreter), then compiles - /// the PostgreSQL islands into [`PreparedQueries`](hashql_eval::postgres::PreparedQueries) - /// containing the SQL statements, parameter bindings, and column descriptors. + /// [`PreparedQuery`]: hashql_eval::postgres::PreparedQuery /// /// # Errors /// diff --git a/libs/@local/hashql/core/src/id/bit_vec/finite.rs b/libs/@local/hashql/core/src/id/bit_vec/finite.rs index 9ccfddd22ad..1a9c66daac6 100644 --- a/libs/@local/hashql/core/src/id/bit_vec/finite.rs +++ b/libs/@local/hashql/core/src/id/bit_vec/finite.rs @@ -339,6 +339,28 @@ impl FiniteBitSet { Some(I::from_u32(self.store.trailing_zeros())) } + /// Returns `true` if `self` is a superset of `other` (contains all bits set in `other`). + #[inline] + #[must_use] + pub const fn is_superset(&self, other: &Self) -> bool + where + T: [const] FiniteBitSetIntegral, + { + // `other` is a subset of `self` iff `other & self == other` + other.store & self.store == other.store + } + + /// Returns `true` if `self` is a subset of `other` (all bits set in `self` are also set in + /// `other`). + #[inline] + #[must_use] + pub const fn is_subset(&self, other: &Self) -> bool + where + T: [const] FiniteBitSetIntegral, + { + other.is_superset(self) + } + /// Returns an iterator over the indices of set bits. #[inline] pub fn iter(&self) -> FiniteBitIter { @@ -863,6 +885,84 @@ mod tests { assert_eq!(set, original); } + #[test] + fn is_superset_of_subset() { + let mut a: FiniteBitSet = FiniteBitSet::new_empty(8); + a.insert_range(TestId::from_usize(0)..=TestId::from_usize(5), 8); + + let mut b: FiniteBitSet = FiniteBitSet::new_empty(8); + b.insert(TestId::from_usize(1)); + b.insert(TestId::from_usize(3)); + + assert!(a.is_superset(&b)); + assert!(!b.is_superset(&a)); + } + + #[test] + fn is_subset_of_superset() { + let mut a: FiniteBitSet = FiniteBitSet::new_empty(8); + a.insert(TestId::from_usize(2)); + a.insert(TestId::from_usize(4)); + + let mut b: FiniteBitSet = FiniteBitSet::new_empty(8); + b.insert_range(TestId::from_usize(0)..=TestId::from_usize(7), 8); + + assert!(a.is_subset(&b)); + assert!(!b.is_subset(&a)); + } + + #[test] + fn empty_is_subset_of_everything() { + let empty: FiniteBitSet = FiniteBitSet::new_empty(8); + + let mut full: FiniteBitSet = FiniteBitSet::new_empty(8); + full.insert_range(TestId::from_usize(0)..=TestId::from_usize(7), 8); + + assert!(empty.is_subset(&full)); + assert!(empty.is_subset(&empty)); + assert!(full.is_superset(&empty)); + } + + #[test] + fn equal_sets_are_both_subset_and_superset() { + let mut a: FiniteBitSet = FiniteBitSet::new_empty(8); + a.insert(TestId::from_usize(1)); + a.insert(TestId::from_usize(5)); + + let b = a; + + assert!(a.is_subset(&b)); + assert!(a.is_superset(&b)); + } + + #[test] + fn disjoint_sets_are_not_subsets() { + let mut a: FiniteBitSet = FiniteBitSet::new_empty(8); + a.insert(TestId::from_usize(0)); + a.insert(TestId::from_usize(1)); + + let mut b: FiniteBitSet = FiniteBitSet::new_empty(8); + b.insert(TestId::from_usize(6)); + b.insert(TestId::from_usize(7)); + + assert!(!a.is_subset(&b)); + assert!(!a.is_superset(&b)); + assert!(!b.is_subset(&a)); + assert!(!b.is_superset(&a)); + } + + #[test] + fn overlapping_sets_are_not_subsets() { + let mut a: FiniteBitSet = FiniteBitSet::new_empty(8); + a.insert_range(TestId::from_usize(0)..=TestId::from_usize(3), 8); + + let mut b: FiniteBitSet = FiniteBitSet::new_empty(8); + b.insert_range(TestId::from_usize(2)..=TestId::from_usize(5), 8); + + assert!(!a.is_subset(&b)); + assert!(!a.is_superset(&b)); + } + #[test] fn negate_full_width() { let mut set: FiniteBitSet = FiniteBitSet::new_empty(8); diff --git a/libs/@local/hashql/eval/src/orchestrator/codec/decode/mod.rs b/libs/@local/hashql/eval/src/orchestrator/codec/decode/mod.rs index 7a96a27c4d2..0c8fd4dcfd0 100644 --- a/libs/@local/hashql/eval/src/orchestrator/codec/decode/mod.rs +++ b/libs/@local/hashql/eval/src/orchestrator/codec/decode/mod.rs @@ -33,8 +33,8 @@ mod tests; /// variant in order, and opaque types wrap their inner representation. /// /// When the type is unknown ([`Param`], [`Infer`], [`Unknown`]), falls back to -/// `decode_unknown`, which uses JSON structure alone -/// (objects become structs or dicts, arrays become lists, etc.). +/// a structural decoder that uses JSON shape alone: objects become structs or +/// dicts, arrays become lists, etc. /// /// [`Value`]: hashql_mir::interpret::value::Value /// [`Param`]: hashql_core::type::kind::TypeKind::Param diff --git a/libs/@local/hashql/eval/src/orchestrator/codec/mod.rs b/libs/@local/hashql/eval/src/orchestrator/codec/mod.rs index e9ca9aec3ce..0ed9512d2b1 100644 --- a/libs/@local/hashql/eval/src/orchestrator/codec/mod.rs +++ b/libs/@local/hashql/eval/src/orchestrator/codec/mod.rs @@ -1,10 +1,10 @@ //! JSON codec for converting between interpreter [`Value`]s and the PostgreSQL //! wire format. //! -//! - `decode`: deserializes JSON column values (from `tokio_postgres` rows) into typed [`Value`]s, -//! guided by the HashQL type system. -//! - `encode`: serializes runtime [`Value`]s and query parameters into forms that `tokio_postgres` -//! can send to the database (via [`ToSql`]). +//! Decoding deserializes JSON column values (from `tokio_postgres` rows) into typed +//! [`Value`]s, guided by the HashQL type system. Encoding serializes runtime [`Value`]s +//! and query parameters into forms that `tokio_postgres` can send to the database +//! (via [`ToSql`]). //! //! The [`JsonValueRef`] type provides a borrowed view over `serde_json::Value` //! that avoids cloning during decode, while [`JsonValueKind`] is a data-free diff --git a/libs/@local/hashql/eval/src/postgres/continuation.rs b/libs/@local/hashql/eval/src/postgres/continuation.rs index 749d64ecc5e..d30b3f2a445 100644 --- a/libs/@local/hashql/eval/src/postgres/continuation.rs +++ b/libs/@local/hashql/eval/src/postgres/continuation.rs @@ -47,10 +47,10 @@ impl ContinuationAlias { /// Continuation fields returned to the bridge in the `SELECT` list. /// -/// A subset of `ContinuationColumn` that excludes internal-only columns -/// (`Entry` and `Filter`). -/// Each variant corresponds to a column the bridge must decode to reconstruct -/// island exit control flow and live-out locals. +/// Excludes internal-only columns (entry and filter) that are only used +/// within the generated SQL. Each variant corresponds to a column the +/// bridge must decode to reconstruct island exit control flow and live-out +/// locals. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum ContinuationField { /// The target basic block id for island exits. diff --git a/libs/@local/hashql/eval/tests/ui/orchestrator/.spec.toml b/libs/@local/hashql/eval/tests/ui/orchestrator/.spec.toml new file mode 100644 index 00000000000..ef02627aed9 --- /dev/null +++ b/libs/@local/hashql/eval/tests/ui/orchestrator/.spec.toml @@ -0,0 +1,2 @@ +skip = true +suite = "eval/orchestrator" diff --git a/libs/@local/hashql/eval/tests/ui/orchestrator/jsonc/filter-false.stdout b/libs/@local/hashql/eval/tests/ui/orchestrator/jsonc/filter-false.stdout index aac4e2f4a36..0a280baf2e7 100644 --- a/libs/@local/hashql/eval/tests/ui/orchestrator/jsonc/filter-false.stdout +++ b/libs/@local/hashql/eval/tests/ui/orchestrator/jsonc/filter-false.stdout @@ -1,28 +1,3 @@ [] --- query executed: body 2, block bb0 -row received -filter started: body 1 -island entered: body 1, island 0, target interpreter -filter rejected: body 1 -row rejected -row received -filter started: body 1 -island entered: body 1, island 0, target interpreter -filter rejected: body 1 -row rejected -row received -filter started: body 1 -island entered: body 1, island 0, target interpreter -filter rejected: body 1 -row rejected -row received -filter started: body 1 -island entered: body 1, island 0, target interpreter -filter rejected: body 1 -row rejected -row received -filter started: body 1 -island entered: body 1, island 0, target interpreter -filter rejected: body 1 -row rejected diff --git a/libs/@local/hashql/eval/tests/ui/orchestrator/jsonc/simple-read.stdout b/libs/@local/hashql/eval/tests/ui/orchestrator/jsonc/simple-read.stdout index de0a09441b1..410834f54a6 100644 --- a/libs/@local/hashql/eval/tests/ui/orchestrator/jsonc/simple-read.stdout +++ b/libs/@local/hashql/eval/tests/ui/orchestrator/jsonc/simple-read.stdout @@ -9,26 +9,31 @@ query executed: body 2, block bb0 row received filter started: body 1 -island entered: body 1, island 0, target interpreter +island entered: body 1, island 0, target postgres +continuation implicit true: body 1 filter accepted: body 1 row accepted row received filter started: body 1 -island entered: body 1, island 0, target interpreter +island entered: body 1, island 0, target postgres +continuation implicit true: body 1 filter accepted: body 1 row accepted row received filter started: body 1 -island entered: body 1, island 0, target interpreter +island entered: body 1, island 0, target postgres +continuation implicit true: body 1 filter accepted: body 1 row accepted row received filter started: body 1 -island entered: body 1, island 0, target interpreter +island entered: body 1, island 0, target postgres +continuation implicit true: body 1 filter accepted: body 1 row accepted row received filter started: body 1 -island entered: body 1, island 0, target interpreter +island entered: body 1, island 0, target postgres +continuation implicit true: body 1 filter accepted: body 1 row accepted diff --git a/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.aux.mir b/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.aux.mir index 7c924219d5d..86de485090d 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.aux.mir +++ b/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.aux.mir @@ -9,7 +9,7 @@ thunk {thunk#1}() -> ::graph::temporal::PinnedTransactionTimeTemporalAxes | ::gr } fn {graph::read::filter@7}(%0: (), %1: ::graph::types::knowledge::entity::Entity) -> Boolean { - bb0(): { // interpreter + bb0(): { // postgres return true } } diff --git a/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.stdout b/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.stdout index 841d0a8e621..0df73c4dc88 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.stdout +++ b/libs/@local/hashql/eval/tests/ui/postgres/constant-true-filter.stdout @@ -1,8 +1,10 @@ ════ SQL ═══════════════════════════════════════════════════════════════════════ -SELECT 1 AS "placeholder" +SELECT ("continuation_1_0"."row")."block" AS "continuation_1_0_block", ("continuation_1_0"."row")."locals" AS "continuation_1_0_locals", ("continuation_1_0"."row")."values" AS "continuation_1_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) +CROSS JOIN LATERAL (SELECT (ROW(COALESCE(((1)::boolean), FALSE), NULL, NULL, NULL)::continuation) AS "row" +OFFSET 0) AS "continuation_1_0" +WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_1_0"."row")."filter" IS NOT FALSE ════ Parameters ════════════════════════════════════════════════════════════════ diff --git a/libs/@local/hashql/eval/tests/ui/postgres/env-captured-variable.aux.mir b/libs/@local/hashql/eval/tests/ui/postgres/env-captured-variable.aux.mir index 349ef783c3c..c51df4ddc66 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/env-captured-variable.aux.mir +++ b/libs/@local/hashql/eval/tests/ui/postgres/env-captured-variable.aux.mir @@ -9,7 +9,7 @@ fn {graph::read::filter@11}(%0: (::graph::types::knowledge::entity::EntityUuid,) } fn {graph::read::filter@27}(%0: (), %1: ::graph::types::knowledge::entity::Entity) -> Boolean { - bb0(): { // interpreter + bb0(): { // postgres return true } } diff --git a/libs/@local/hashql/eval/tests/ui/postgres/filter/island_exit_switch_int.snap b/libs/@local/hashql/eval/tests/ui/postgres/filter/island_exit_switch_int.snap index 98e460d0204..c1eb6da0711 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/filter/island_exit_switch_int.snap +++ b/libs/@local/hashql/eval/tests/ui/postgres/filter/island_exit_switch_int.snap @@ -36,4 +36,4 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { } ==================== Island (entry: bb0, target: postgres) ===================== -CASE WHEN ((($10::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($10::jsonb))::int) = 0 THEN (ROW(NULL, 2, ARRAY[]::int[], ARRAY[]::jsonb[])::continuation) WHEN ((($10::jsonb))::int) = 1 THEN (ROW(COALESCE(((1)::boolean), FALSE), NULL, NULL, NULL)::continuation) END +CASE WHEN ((($10::jsonb))::int) IS NULL THEN (ROW(COALESCE(((FALSE)::boolean), FALSE), NULL, NULL, NULL)::continuation) WHEN ((($10::jsonb))::int) = 0 THEN (ROW(NULL, 2, ARRAY[]::int[], ARRAY[]::jsonb[])::continuation) WHEN ((($10::jsonb))::int) = 1 THEN (ROW(NULL, 1, ARRAY[]::int[], ARRAY[]::jsonb[])::continuation) END diff --git a/libs/@local/hashql/eval/tests/ui/postgres/filter/temporal_decision_time_interval.snap b/libs/@local/hashql/eval/tests/ui/postgres/filter/temporal_decision_time_interval.snap index c21e275c7bf..de37c3d85d5 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/filter/temporal_decision_time_interval.snap +++ b/libs/@local/hashql/eval/tests/ui/postgres/filter/temporal_decision_time_interval.snap @@ -11,6 +11,11 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { bb0(): { %2 = %1.metadata.temporal_versioning.decision_time + + goto -> bb1() + } + + bb1(): { %3 = ({def@99} as FnPtr) %4 = apply %3 @@ -19,9 +24,11 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { } ===================================== SQL ====================================== -SELECT jsonb_build_object(($3::text), (extract(epoch from lower("entity_temporal_metadata_0_0_0"."decision_time")) * 1000)::int8, ($4::text), CASE WHEN upper_inf("entity_temporal_metadata_0_0_0"."decision_time") THEN NULL ELSE (extract(epoch from upper("entity_temporal_metadata_0_0_0"."decision_time")) * 1000)::int8 END) AS "decision_time" +SELECT ("continuation_0_0"."row")."block" AS "continuation_0_0_block", ("continuation_0_0"."row")."locals" AS "continuation_0_0_locals", ("continuation_0_0"."row")."values" AS "continuation_0_0_values" FROM "entity_temporal_metadata" AS "entity_temporal_metadata_0_0_0" -WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) +CROSS JOIN LATERAL (SELECT (ROW(NULL, 1, ARRAY[]::int[], ARRAY[]::jsonb[])::continuation) AS "row" +OFFSET 0) AS "continuation_0_0" +WHERE "entity_temporal_metadata_0_0_0"."transaction_time" && ($1::tstzrange) AND "entity_temporal_metadata_0_0_0"."decision_time" && ($2::tstzrange) AND ("continuation_0_0"."row")."filter" IS NOT FALSE ================================== Parameters ================================== $1: TemporalAxis(Transaction) diff --git a/libs/@local/hashql/eval/tests/ui/postgres/mixed-sources-filter.aux.mir b/libs/@local/hashql/eval/tests/ui/postgres/mixed-sources-filter.aux.mir index 430a0ff87b5..aa1758bc481 100644 --- a/libs/@local/hashql/eval/tests/ui/postgres/mixed-sources-filter.aux.mir +++ b/libs/@local/hashql/eval/tests/ui/postgres/mixed-sources-filter.aux.mir @@ -19,7 +19,7 @@ fn {graph::read::filter@20}(%0: (::graph::types::knowledge::entity::EntityUuid,) } fn {graph::read::filter@36}(%0: (), %1: ::graph::types::knowledge::entity::Entity) -> Boolean { - bb0(): { // interpreter + bb0(): { // postgres return true } } diff --git a/libs/@local/hashql/mir/src/pass/execution/cost/analysis.rs b/libs/@local/hashql/mir/src/pass/execution/cost/analysis.rs index 3850c247a85..72becc17c3d 100644 --- a/libs/@local/hashql/mir/src/pass/execution/cost/analysis.rs +++ b/libs/@local/hashql/mir/src/pass/execution/cost/analysis.rs @@ -1,6 +1,6 @@ use core::alloc::Allocator; -use super::{ApproxCost, StatementCostVec}; +use super::{ApproxCost, StatementCostVec, TerminatorCostVec}; use crate::{ body::basic_block::{BasicBlock, BasicBlockId, BasicBlockSlice, BasicBlockVec}, pass::{ @@ -98,7 +98,8 @@ impl BasicBlockCostVec { pub(crate) struct BasicBlockCostAnalysis<'ctx, A: Allocator> { pub vertex: VertexType, pub assignments: &'ctx BasicBlockSlice, - pub costs: &'ctx TargetArray>, + pub statement_costs: &'ctx TargetArray>, + pub terminator_costs: &'ctx TargetArray>, } impl BasicBlockCostAnalysis<'_, A> { @@ -109,7 +110,8 @@ impl BasicBlockCostAnalysis<'_, A> { target: TargetId, traversals: TraversalPathBitSet, ) -> BasicBlockTargetCost { - let base = self.costs[target].sum_approx(id); + let mut base = self.statement_costs[target].sum_approx(id); + base += self.terminator_costs[target].approx(id); let mut range = InformationRange::zero(); @@ -204,6 +206,12 @@ mod tests { TransferCostConfig::new(InformationRange::full()) } + fn make_zero_terminator_costs(block_count: usize) -> TargetArray> { + TargetArray::from_fn(|_| { + TerminatorCostVec::from_costs(&vec![Some(cost!(0)); block_count], Global) + }) + } + fn make_targets(body: &crate::body::Body<'_>, domain: TargetBitSet) -> Vec { body.basic_blocks.iter().map(|_| domain).collect() } @@ -231,10 +239,12 @@ mod tests { let costs: TargetArray> = TargetArray::from_fn(|_| StatementCostVec::from_iter([1].into_iter(), Global)); + let terminator_costs = make_zero_terminator_costs(body.basic_blocks.len()); let analysis = BasicBlockCostAnalysis { vertex: VertexType::Entity, assignments: targets, - costs: &costs, + statement_costs: &costs, + terminator_costs: &terminator_costs, }; let result = analysis.analyze_in(&default_config(), &body.basic_blocks, Global); @@ -272,10 +282,12 @@ mod tests { let costs: TargetArray> = TargetArray::from_fn(|_| StatementCostVec::from_iter([1].into_iter(), Global)); + let terminator_costs = make_zero_terminator_costs(body.basic_blocks.len()); let analysis = BasicBlockCostAnalysis { vertex: VertexType::Entity, assignments: targets, - costs: &costs, + statement_costs: &costs, + terminator_costs: &terminator_costs, }; let config = default_config(); @@ -327,10 +339,12 @@ mod tests { let costs: TargetArray> = TargetArray::from_fn(|_| StatementCostVec::from_iter([1].into_iter(), Global)); + let terminator_costs = make_zero_terminator_costs(body.basic_blocks.len()); let analysis = BasicBlockCostAnalysis { vertex: VertexType::Entity, assignments: targets, - costs: &costs, + statement_costs: &costs, + terminator_costs: &terminator_costs, }; let result = analysis.analyze_in(&default_config(), &body.basic_blocks, Global); @@ -384,10 +398,12 @@ mod tests { let costs: TargetArray> = TargetArray::from_fn(|_| StatementCostVec::from_iter([1].into_iter(), Global)); + let terminator_costs = make_zero_terminator_costs(body.basic_blocks.len()); let analysis = BasicBlockCostAnalysis { vertex: VertexType::Entity, assignments: targets, - costs: &costs, + statement_costs: &costs, + terminator_costs: &terminator_costs, }; // Use a bounded properties size so both premiums are finite and comparable. @@ -443,10 +459,12 @@ mod tests { // Use zero properties size so Properties path doesn't contribute noise let config = TransferCostConfig::new(InformationRange::zero()); + let terminator_costs = make_zero_terminator_costs(body.basic_blocks.len()); let analysis = BasicBlockCostAnalysis { vertex: VertexType::Entity, assignments: targets, - costs: &costs, + statement_costs: &costs, + terminator_costs: &terminator_costs, }; let result = analysis.analyze_in(&config, &body.basic_blocks, Global); @@ -493,10 +511,12 @@ mod tests { let costs: TargetArray> = TargetArray::from_fn(|_| StatementCostVec::from_iter([1].into_iter(), Global)); + let terminator_costs = make_zero_terminator_costs(body.basic_blocks.len()); let analysis = BasicBlockCostAnalysis { vertex: VertexType::Entity, assignments: targets, - costs: &costs, + statement_costs: &costs, + terminator_costs: &terminator_costs, }; let result = analysis.analyze_in(&default_config(), &body.basic_blocks, Global); @@ -546,10 +566,12 @@ mod tests { TargetArray::from_fn(|_| StatementCostVec::from_iter([2, 1, 0].into_iter(), Global)); let config = default_config(); + let terminator_costs = make_zero_terminator_costs(body.basic_blocks.len()); let analysis = BasicBlockCostAnalysis { vertex: VertexType::Entity, assignments: targets, - costs: &costs, + statement_costs: &costs, + terminator_costs: &terminator_costs, }; let result = analysis.analyze_in(&config, &body.basic_blocks, Global); @@ -583,4 +605,55 @@ mod tests { assert_eq!(cost, base, "bb2 target {target:?} should have zero load"); } } + + /// Terminator cost is included in the per-block total. + /// + /// Uses a nonzero terminator cost and verifies the total exceeds the statement-only + /// base. Removing the terminator cost addition would cause this test to fail. + #[test] + fn terminator_cost_included_in_total() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; [graph::read::filter]@0/2 -> Int { + decl env: (Int), vertex: [Opaque sym::path::Entity; ?], val: Int; + @proj env_0 = env.0: Int; + + bb0() { + val = load env_0; + return val; + } + }); + + let targets = make_targets(&body, all_targets()); + let targets = BasicBlockSlice::from_raw(&targets); + + let statement_costs: TargetArray> = + TargetArray::from_fn(|_| StatementCostVec::from_iter([1].into_iter(), Global)); + + let terminator_costs: TargetArray> = + TargetArray::from_fn(|_| TerminatorCostVec::from_costs(&[Some(cost!(10))], Global)); + + let analysis = BasicBlockCostAnalysis { + vertex: VertexType::Entity, + assignments: targets, + statement_costs: &statement_costs, + terminator_costs: &terminator_costs, + }; + + let result = analysis.analyze_in(&default_config(), &body.basic_blocks, Global); + let bb0 = BasicBlockId::new(0); + + for target in TargetId::all() { + let total = result.cost(bb0, target); + let statement_base = statement_costs[target].sum_approx(bb0); + + assert!( + total > statement_base, + "target {target:?}: total ({total}) should exceed statement base \ + ({statement_base}) because terminator cost is nonzero" + ); + } + } } diff --git a/libs/@local/hashql/mir/src/pass/execution/cost/mod.rs b/libs/@local/hashql/mir/src/pass/execution/cost/mod.rs index 7c971af0fa7..470f53f4259 100644 --- a/libs/@local/hashql/mir/src/pass/execution/cost/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/cost/mod.rs @@ -1,12 +1,15 @@ //! Cost tracking for execution planning. //! -//! Two levels of cost representation: +//! Three levels of cost representation: //! //! - **Per-statement**: [`StatementCostVec`] records the [`Cost`] of each statement on a given //! target. Produced by the statement placement pass and consumed by [`BasicBlockCostAnalysis`]. //! -//! - **Per-block**: [`BasicBlockCostVec`] aggregates statement costs and adds a path transfer -//! premium for non-origin backends. This is what the placement solver operates on. +//! - **Per-terminator**: [`TerminatorCostVec`] records the [`Cost`] of each block's terminator on a +//! given target. Produced alongside statement costs during placement analysis. +//! +//! - **Per-block**: [`BasicBlockCostVec`] aggregates statement and terminator costs and adds a path +//! transfer premium for non-origin backends. This is what the placement solver operates on. use alloc::alloc::Global; use core::{ @@ -17,10 +20,16 @@ use core::{ }; use std::f32; +use hashql_core::id::Id as _; + pub(crate) use self::analysis::{BasicBlockCostAnalysis, BasicBlockCostVec}; use super::block_partitioned_vec::BlockPartitionedVec; use crate::{ - body::{basic_block::BasicBlockId, basic_blocks::BasicBlocks, location::Location}, + body::{ + basic_block::{BasicBlockId, BasicBlockSlice, BasicBlockVec}, + basic_blocks::BasicBlocks, + location::Location, + }, macros::{forward_ref_binop, forward_ref_op_assign}, pass::analysis::size_estimation::InformationUnit, }; @@ -326,6 +335,93 @@ impl Sum for ApproxCost { } } +/// Per-block cost of executing the terminator on a given target. +/// +/// Each block has exactly one terminator. A `None` cost indicates the target cannot execute that +/// terminator (the terminator's operands are not supported on the target). One instance exists per +/// target inside `TargetArray`. +#[derive(Debug)] +pub(crate) struct TerminatorCostVec(BasicBlockVec, A>); + +impl TerminatorCostVec { + /// Creates an empty cost vector with capacity reserved for one slot per block. + pub(crate) fn new_in(blocks: &BasicBlocks, alloc: A) -> Self { + Self(BasicBlockVec::with_capacity_in(blocks.len(), alloc)) + } + + /// Creates a cost vector from a slice of optional cost values. + #[cfg(test)] + pub(crate) fn from_costs(costs: &[Option], alloc: A) -> Self { + let mut vec = BasicBlockVec::from_elem_in(None, costs.len(), alloc); + for (i, cost) in costs.iter().enumerate() { + vec[BasicBlockId::from_usize(i)] = *cost; + } + Self(vec) + } +} + +impl TerminatorCostVec { + /// Returns `true` if no terminators have assigned costs. + #[cfg(test)] + pub(crate) fn all_unassigned(&self) -> bool { + self.0.iter().all(Option::is_none) + } + + /// Returns the cost for the terminator in `block`, or `None` if the target cannot execute it. + pub(crate) fn of(&self, block: BasicBlockId) -> Option { + self.0.lookup(block).copied() + } + + pub(crate) fn insert(&mut self, block: BasicBlockId, cost: Cost) { + self.0.insert(block, cost); + } + + /// Remaps terminator costs after block splitting. + /// + /// For each original block, the original terminator cost is placed on the last block + /// of its region (which holds the original terminator after splitting). All preceding + /// blocks in the region received synthesized `Goto` terminators and get zero cost. + /// + /// Operates in-place by extending the vec, then shuffling entries from back to front + /// to avoid overwriting unprocessed entries. + pub(crate) fn remap(&mut self, regions: &BasicBlockSlice<(core::num::NonZero, bool)>) { + let mut new_length = BasicBlockId::START; + for (_, (region_len, _)) in regions.iter_enumerated() { + new_length.increment_by(region_len.get()); + } + + // Extend to the new size. New slots are initialized to `None`. + self.0.fill_until(new_length.minus(1), || None); + + // Walk regions in reverse so we never overwrite an unprocessed original entry. + let mut write = new_length; + for (original, (region_len, _)) in regions.iter_enumerated().rev() { + let original_cost = self.0[original]; + + // The last block in the region holds the original terminator. + write.decrement_by(1); + self.0[write] = original_cost; + + // Preceding blocks have synthesized Goto terminators: zero cost. + for _ in 1..region_len.get() { + write.decrement_by(1); + self.0[write] = Some(cost!(0)); + } + } + + debug_assert_eq!(write, BasicBlockId::START); + } + + /// Returns the approximate cost for the terminator in `block`, or zero if unassigned. + pub(crate) fn approx(&self, block: BasicBlockId) -> ApproxCost { + debug_assert!(self.0.contains(block)); + self.0 + .lookup(block) + .copied() + .map_or(ApproxCost::ZERO, ApproxCost::from) + } +} + /// Dense cost map for all statements in a body. /// /// Stores the execution cost for every statement, indexed by [`Location`]. A `None` cost diff --git a/libs/@local/hashql/mir/src/pass/execution/cost/tests.rs b/libs/@local/hashql/mir/src/pass/execution/cost/tests.rs index 6f9588f2f00..ba41a7ffa73 100644 --- a/libs/@local/hashql/mir/src/pass/execution/cost/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/cost/tests.rs @@ -1,7 +1,11 @@ use alloc::alloc::Global; +use core::num::NonZero; -use super::{Cost, StatementCostVec}; -use crate::body::{basic_block::BasicBlockId, location::Location}; +use super::{Cost, StatementCostVec, TerminatorCostVec}; +use crate::body::{ + basic_block::{BasicBlockId, BasicBlockSlice}, + location::Location, +}; /// `Cost::new` succeeds for valid values (0 and 100). #[test] @@ -42,6 +46,76 @@ fn cost_new_unchecked_valid() { assert_eq!(Cost::new(100), Some(hundred)); } +macro_rules! nz { + ($value:expr) => { + const { NonZero::new($value).unwrap() } + }; +} + +fn bb(index: u32) -> BasicBlockId { + BasicBlockId::new(index) +} + +/// No splits: region lengths all 1. Output equals input. +#[test] +fn remap_no_splits() { + let mut costs = TerminatorCostVec::from_costs(&[Some(cost!(7)), Some(cost!(3))], Global); + + let regions = [(nz!(1), false), (nz!(1), false)]; + costs.remap(BasicBlockSlice::from_raw(®ions)); + + assert_eq!(costs.of(bb(0)), Some(cost!(7))); + assert_eq!(costs.of(bb(1)), Some(cost!(3))); +} + +/// Single block splits into 2 regions. +#[test] +fn remap_single_split() { + let mut costs = TerminatorCostVec::from_costs(&[Some(cost!(7))], Global); + + let regions = [(nz!(2), true)]; + costs.remap(BasicBlockSlice::from_raw(®ions)); + + // First block gets synthesized Goto: zero cost + assert_eq!(costs.of(bb(0)), Some(cost!(0))); + // Second block holds original terminator + assert_eq!(costs.of(bb(1)), Some(cost!(7))); +} + +/// Mixed splits with None: original None cost preserved on last block of region. +#[test] +fn remap_mixed_with_none() { + let mut costs = TerminatorCostVec::from_costs(&[Some(cost!(4)), None, Some(cost!(8))], Global); + + let regions = [(nz!(2), true), (nz!(1), false), (nz!(3), true)]; + costs.remap(BasicBlockSlice::from_raw(®ions)); + + // Region 0: split into 2 blocks + assert_eq!(costs.of(bb(0)), Some(cost!(0))); + assert_eq!(costs.of(bb(1)), Some(cost!(4))); + // Region 1: no split + assert_eq!(costs.of(bb(2)), None); + // Region 2: split into 3 blocks + assert_eq!(costs.of(bb(3)), Some(cost!(0))); + assert_eq!(costs.of(bb(4)), Some(cost!(0))); + assert_eq!(costs.of(bb(5)), Some(cost!(8))); +} + +/// All blocks split: every non-last block in each region gets zero cost. +#[test] +fn remap_all_split() { + let mut costs = TerminatorCostVec::from_costs(&[Some(cost!(10)), Some(cost!(20))], Global); + + let regions = [(nz!(3), true), (nz!(2), true)]; + costs.remap(BasicBlockSlice::from_raw(®ions)); + + assert_eq!(costs.of(bb(0)), Some(cost!(0))); + assert_eq!(costs.of(bb(1)), Some(cost!(0))); + assert_eq!(costs.of(bb(2)), Some(cost!(10))); + assert_eq!(costs.of(bb(3)), Some(cost!(0))); + assert_eq!(costs.of(bb(4)), Some(cost!(20))); +} + /// `StatementCostVec` uses 1-based `Location` indexing to address the underlying /// 0-based `BlockPartitionedVec`. #[test] diff --git a/libs/@local/hashql/mir/src/pass/execution/mod.rs b/libs/@local/hashql/mir/src/pass/execution/mod.rs index ff0a03efa7a..c1bd420e07b 100644 --- a/libs/@local/hashql/mir/src/pass/execution/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/mod.rs @@ -78,22 +78,27 @@ impl<'heap, S: BumpAllocator> ExecutionAnalysis<'_, 'heap, S> { }; let mut statement_costs: TargetArray<_> = TargetArray::from_fn(|_| None); + let mut terminator_costs: TargetArray<_> = TargetArray::from_fn(|_| None); for target in TargetId::all() { let mut statement = TargetPlacementStatement::new_in(target, &self.scratch); - let statement_cost = + let (statement_cost, terminator_cost) = statement.statement_placement_in(context, body, vertex, &self.scratch); statement_costs[target] = Some(statement_cost); + terminator_costs[target] = Some(terminator_cost); } let mut statement_costs = statement_costs.map(|cost| cost.unwrap_or_else(|| unreachable!())); + let mut terminator_costs = + terminator_costs.map(|cost| cost.unwrap_or_else(|| unreachable!())); let mut assignments = BasicBlockSplitting::new_in(&self.scratch).split_in( context, body, &mut statement_costs, + &mut terminator_costs, &self.scratch, ); @@ -101,7 +106,7 @@ impl<'heap, S: BumpAllocator> ExecutionAnalysis<'_, 'heap, S> { TransferCostConfig::new(InformationRange::full()), &self.scratch, ); - let mut terminator_costs = terminators.terminator_placement_in( + let mut transition_costs = terminators.terminator_placement_in( body, vertex, &self.footprints[body.id], @@ -111,14 +116,15 @@ impl<'heap, S: BumpAllocator> ExecutionAnalysis<'_, 'heap, S> { ArcConsistency { blocks: &mut assignments, - terminators: &mut terminator_costs, + terminators: &mut transition_costs, } .run_in(body, &self.scratch); let block_costs = BasicBlockCostAnalysis { vertex, assignments: &assignments, - costs: &statement_costs, + statement_costs: &statement_costs, + terminator_costs: &terminator_costs, } .analyze_in( &TransferCostConfig::new(InformationRange::full()), @@ -128,7 +134,7 @@ impl<'heap, S: BumpAllocator> ExecutionAnalysis<'_, 'heap, S> { let mut solver = PlacementSolverContext { blocks: &block_costs, - terminators: &terminator_costs, + terminators: &transition_costs, } .build_in(body, &self.scratch); diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/arc/mod.rs b/libs/@local/hashql/mir/src/pass/execution/placement/arc/mod.rs index 240552db65d..08879dfa65e 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/arc/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/arc/mod.rs @@ -69,7 +69,7 @@ use crate::{ Body, basic_block::{BasicBlockId, BasicBlockSlice}, }, - pass::execution::{target::TargetBitSet, terminator_placement::TerminatorCostVec}, + pass::execution::{target::TargetBitSet, terminator_placement::TerminatorTransitionCostVec}, }; /// Deduplicated worklist of directed arcs `(x, y)`. @@ -111,13 +111,13 @@ impl PairWorkQueue { /// AC-3 arc consistency enforcer over per-block target domains. /// -/// Operates on mutable [`TargetBitSet`] domains and [`TerminatorCostVec`] transition matrices. -/// After [`run_in`](Self::run_in), every surviving target in a block's domain has at least one -/// compatible transition partner across each incident CFG edge, and the matrices are pruned to -/// match the narrowed domains. +/// Operates on mutable [`TargetBitSet`] domains and [`TerminatorTransitionCostVec`] transition +/// matrices. After [`run_in`](Self::run_in), every surviving target in a block's domain has at +/// least one compatible transition partner across each incident CFG edge, and the matrices are +/// pruned to match the narrowed domains. pub(crate) struct ArcConsistency<'ctx, A: Allocator> { pub blocks: &'ctx mut BasicBlockSlice, - pub terminators: &'ctx mut TerminatorCostVec, + pub terminators: &'ctx mut TerminatorTransitionCostVec, } impl ArcConsistency<'_, A> { diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/arc/tests.rs b/libs/@local/hashql/mir/src/pass/execution/placement/arc/tests.rs index 1a9765b3ee0..e540acf2c80 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/arc/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/arc/tests.rs @@ -15,7 +15,7 @@ use crate::{ intern::Interner, pass::execution::{ target::{TargetBitSet, TargetId}, - terminator_placement::{TerminatorCostVec, TransMatrix}, + terminator_placement::{TerminatorTransitionCostVec, TransMatrix}, }, }; @@ -42,7 +42,7 @@ fn bb(index: u32) -> BasicBlockId { fn run_ac3<'heap>( body: &Body<'heap>, domains: &mut [TargetBitSet], - terminators: &mut TerminatorCostVec<&'heap Heap>, + terminators: &mut TerminatorTransitionCostVec<&'heap Heap>, ) { let mut arc = ArcConsistency { blocks: BasicBlockSlice::from_raw_mut(domains), @@ -81,7 +81,7 @@ fn already_consistent_no_pruning() { }); let mut domains = [all_targets(), all_targets()]; - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators.of_mut(bb(0))[0] = full_matrix(); let before = domains; @@ -113,7 +113,7 @@ fn source_side_pruning() { let mut matrix = TransMatrix::new(); matrix.insert(TargetId::Interpreter, TargetId::Interpreter, cost!(0)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators.of_mut(bb(0))[0] = matrix; run_ac3(&body, &mut domains, &mut terminators); @@ -149,7 +149,7 @@ fn target_side_pruning() { matrix.insert(TargetId::Interpreter, TargetId::Interpreter, cost!(0)); matrix.insert(TargetId::Postgres, TargetId::Postgres, cost!(0)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators.of_mut(bb(0))[0] = matrix; run_ac3(&body, &mut domains, &mut terminators); @@ -190,7 +190,7 @@ fn mutual_pruning_both_sides() { let mut matrix = TransMatrix::new(); matrix.insert(TargetId::Interpreter, TargetId::Interpreter, cost!(0)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators.of_mut(bb(0))[0] = matrix; run_ac3(&body, &mut domains, &mut terminators); @@ -230,7 +230,7 @@ fn diamond_cfg_pruning() { let mut m_interp = TransMatrix::new(); m_interp.insert(TargetId::Interpreter, TargetId::Interpreter, cost!(0)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); let matrices = terminators.of_mut(bb(0)); matrices[0] = m_interp; matrices[1] = m_interp; @@ -264,7 +264,7 @@ fn self_loop_pruning() { matrix.insert(TargetId::Interpreter, TargetId::Interpreter, cost!(0)); matrix.insert(TargetId::Postgres, TargetId::Postgres, cost!(0)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators.of_mut(bb(0))[0] = matrix; run_ac3(&body, &mut domains, &mut terminators); @@ -303,7 +303,7 @@ fn bidirectional_edges_require_joint_support() { let mut reverse = TransMatrix::new(); reverse.insert(TargetId::Postgres, TargetId::Postgres, cost!(0)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators.of_mut(bb(0))[0] = forward; terminators.of_mut(bb(1))[0] = reverse; @@ -338,7 +338,7 @@ fn matrix_pruned_after_ac3() { matrix.insert(TargetId::Postgres, TargetId::Interpreter, cost!(10)); matrix.insert(TargetId::Embedding, TargetId::Interpreter, cost!(20)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators.of_mut(bb(0))[0] = matrix; run_ac3(&body, &mut domains, &mut terminators); @@ -368,7 +368,7 @@ fn single_block_no_edges() { }); let mut domains = [all_targets()]; - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); run_ac3(&body, &mut domains, &mut terminators); @@ -403,7 +403,7 @@ fn switchint_multiple_edges_to_same_block() { let mut m1 = TransMatrix::new(); m1.insert(TargetId::Interpreter, TargetId::Interpreter, cost!(0)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); let matrices = terminators.of_mut(bb(0)); matrices[0] = m0; matrices[1] = m1; diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/condensation.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/condensation.rs index f1e12b5c3c0..a0e32f762cb 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/condensation.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/condensation.rs @@ -32,7 +32,7 @@ use super::{ }; use crate::{ body::{Body, basic_block::BasicBlockId}, - pass::execution::terminator_placement::{TerminatorCostVec, TransMatrix}, + pass::execution::terminator_placement::{TerminatorTransitionCostVec, TransMatrix}, }; /// A placement region containing a single basic block. @@ -105,7 +105,7 @@ impl<'alloc, S: BumpAllocator> Condensation<'alloc, S> { /// Builds the condensation from a [`Body`]'s CFG and terminator transition costs. pub(crate) fn new( body: &Body<'_>, - terminators: &TerminatorCostVec, + terminators: &TerminatorTransitionCostVec, alloc: &'alloc S, ) -> Self { let scc = Tarjan::new_in(&body.basic_blocks, alloc).run(); @@ -153,7 +153,7 @@ impl<'alloc, S: BumpAllocator> Condensation<'alloc, S> { } /// Populates the condensation graph with regions and boundary edges. - fn fill(&mut self, body: &Body<'_>, terminators: &TerminatorCostVec) { + fn fill(&mut self, body: &Body<'_>, terminators: &TerminatorTransitionCostVec) { for scc in self.scc.iter_nodes() { let members = self.scc_members.of(scc); diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/tests.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/tests.rs index 57b38f9ae90..e993a2a9474 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/csp/tests.rs @@ -21,7 +21,7 @@ use crate::{ }, }, target::{TargetArray, TargetId}, - terminator_placement::{TerminatorCostVec, TransMatrix}, + terminator_placement::{TerminatorTransitionCostVec, TransMatrix}, }, }; @@ -70,7 +70,7 @@ fn narrow_restricts_successor_domain() { let domains = [all_targets(), all_targets(), all_targets(), all_targets()]; let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [I->I = 0, I->P = 0]; bb(1): [complete(1)]; @@ -120,7 +120,7 @@ fn narrow_restricts_predecessor_domain() { let domains = [all_targets(), all_targets(), all_targets(), all_targets()]; let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [complete(1)]; bb(1): [complete(1)]; @@ -171,7 +171,7 @@ fn narrow_to_empty_domain() { let domains = [target_set(&[I]), target_set(&[P]), all_targets()]; let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ complete(1); @@ -220,7 +220,7 @@ fn narrow_multiple_edges_intersect() { let domains = [all_targets(), all_targets(), all_targets(), all_targets()]; let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ I->I = 0, I->P = 0; @@ -282,7 +282,7 @@ fn replay_narrowing_resets_then_repropagates() { let domains = [all_targets(), all_targets(), all_targets(), all_targets()]; let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [I->I = 0, I->P = 0, P->E = 0]; bb(1): [complete(1)]; @@ -360,7 +360,7 @@ fn lower_bound_min_block_cost_per_block() { bb(2): I = 5, P = 15 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0)]; bb(1): [diagonal(0)]; @@ -415,7 +415,7 @@ fn lower_bound_min_transition_cost_per_edge() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0)]; bb(1): [I->P = 10, P->I = 3]; @@ -467,7 +467,7 @@ fn lower_bound_skips_self_loop_edges() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ diagonal(0); @@ -522,7 +522,7 @@ fn lower_bound_fixed_successor_uses_concrete_target() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0)]; bb(1): [I->P = 10, I->I = 0]; @@ -579,7 +579,7 @@ fn lower_bound_all_fixed_returns_zero() { bb(1): I = 5 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ I->I = 3; @@ -633,7 +633,7 @@ fn mrv_selects_smallest_domain() { ]; let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [complete(1)]; bb(1): [complete(1)]; @@ -682,7 +682,7 @@ fn mrv_tiebreak_by_constraint_degree() { let domains = [ip, ip, ip, all_targets()]; let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ complete(1); @@ -736,7 +736,7 @@ fn mrv_skips_fixed_blocks() { ]; let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [complete(1)]; bb(1): [complete(1)]; @@ -794,7 +794,7 @@ fn greedy_solves_two_block_loop() { bb(1): I = 8, P = 3 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0), I->P = 5, P->I = 5]; bb(1): [ @@ -847,7 +847,7 @@ fn greedy_rollback_finds_alternative() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [complete(1)]; bb(1): [I->P = 0]; @@ -905,7 +905,7 @@ fn greedy_fails_when_infeasible() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ I->I = 0; @@ -960,7 +960,7 @@ fn bnb_finds_optimal() { bb(2): I = 1, P = 50 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ diagonal(0), I->P = 20, P->I = 20; @@ -1017,7 +1017,7 @@ fn bnb_retains_ranked_solutions() { bb(1): I = 5, P = 10, E = 15 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ diagonal(0); @@ -1091,7 +1091,7 @@ fn bnb_pruning_preserves_optimal() { bb(3): I = 1, P = 1 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0), I->P = 100, P->I = 100]; bb(1): [diagonal(0), I->P = 100, P->I = 100]; @@ -1149,7 +1149,7 @@ fn retry_returns_ranked_solutions_in_order() { bb(1): I = 1, P = 2 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ diagonal(0), I->P = 5, P->I = 5; @@ -1219,7 +1219,7 @@ fn retry_exhausts_then_perturbs() { bb(1): I = 1, P = 2 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ diagonal(0); @@ -1283,7 +1283,7 @@ fn greedy_rollback_on_empty_heap() { bb(0): I = 0, P = 5 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); // arm0 (bb0→bb2): complete (exit edge, always feasible) // arm1 (bb0→bb1): swap-only transitions (I→P, P→I) // bb1→bb0: from I go to P or I @@ -1356,7 +1356,7 @@ fn retry_perturbation_after_ranked_exhaustion() { } // All transitions allowed → all combinations feasible - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ complete(0); diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/estimate/tests.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/estimate/tests.rs index f773a30a004..4713a4945ba 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/estimate/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/estimate/tests.rs @@ -16,7 +16,7 @@ use crate::{ pass::execution::{ cost::StatementCostVec, target::{TargetArray, TargetId}, - terminator_placement::{TerminatorCostVec, TransMatrix}, + terminator_placement::{TerminatorTransitionCostVec, TransMatrix}, }, }; @@ -155,7 +155,7 @@ fn self_loop_edges_excluded_from_cost() { stmt_costs! { statements; bb(0): I = 5, P = 5 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ I->I = 0, P->I = 0; @@ -221,7 +221,7 @@ fn boundary_multiplier_applied_to_cross_region_edges() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0), I->P = 20, P->I = 0]; bb(1): [diagonal(0), I->P = 0, P->I = 20] @@ -294,7 +294,7 @@ fn infeasible_transition_returns_none() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [I->I = 0] } @@ -359,7 +359,7 @@ fn unassigned_neighbor_uses_heuristic_minimum() { bb(1): I = 3, P = 7 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0), I->P = 10, P->I = 5] } diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/mod.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/mod.rs index e0b685fd39e..690b2033352 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/mod.rs @@ -29,7 +29,7 @@ use crate::{ context::MirContext, pass::execution::{ ApproxCost, cost::BasicBlockCostVec, target::TargetId, - terminator_placement::TerminatorCostVec, + terminator_placement::TerminatorTransitionCostVec, }, }; @@ -82,7 +82,7 @@ fn back_edge_span(body: &Body<'_>, members: &[BasicBlockId]) -> SpanId { #[derive(Debug, Copy, Clone)] pub(crate) struct PlacementSolverContext<'ctx, A: Allocator> { pub blocks: &'ctx BasicBlockCostVec, - pub terminators: &'ctx TerminatorCostVec, + pub terminators: &'ctx TerminatorTransitionCostVec, } impl<'ctx, A: Allocator> PlacementSolverContext<'ctx, A> { diff --git a/libs/@local/hashql/mir/src/pass/execution/placement/solve/tests.rs b/libs/@local/hashql/mir/src/pass/execution/placement/solve/tests.rs index a4520cd2f82..ea875739fd1 100644 --- a/libs/@local/hashql/mir/src/pass/execution/placement/solve/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/placement/solve/tests.rs @@ -29,10 +29,12 @@ use crate::{ analysis::size_estimation::{InformationRange, InformationUnit}, execution::{ ApproxCost, Cost, VertexType, - cost::{BasicBlockCostAnalysis, BasicBlockCostVec, StatementCostVec}, + cost::{ + BasicBlockCostAnalysis, BasicBlockCostVec, StatementCostVec, TerminatorCostVec, + }, placement::error::PlacementDiagnosticCategory, target::{TargetArray, TargetBitSet, TargetId}, - terminator_placement::{TerminatorCostVec, TransMatrix}, + terminator_placement::{TerminatorTransitionCostVec, TransMatrix}, traversal::TransferCostConfig, }, }, @@ -142,10 +144,14 @@ pub(crate) fn make_block_costs_with_config<'heap>( alloc: &'heap Heap, ) -> BasicBlockCostVec<&'heap Heap> { let assignments = BasicBlockSlice::from_raw(domains); + let terminator_costs: TargetArray> = TargetArray::from_fn(|_| { + TerminatorCostVec::from_costs(&vec![Some(cost!(0)); body.basic_blocks.len()], alloc) + }); BasicBlockCostAnalysis { vertex: VertexType::Entity, assignments, - costs: statements, + statement_costs: statements, + terminator_costs: &terminator_costs, } .analyze_in(config, &body.basic_blocks, alloc) } @@ -160,7 +166,7 @@ pub(crate) fn run_solver<'heap>( interner: &Interner<'heap>, domains: &[TargetBitSet], statements: &TargetArray>, - terminators: &TerminatorCostVec<&'heap Heap>, + terminators: &TerminatorTransitionCostVec<&'heap Heap>, ) -> BasicBlockVec { let mut context = MirContext::new(env, interner); let block_costs = make_block_costs(body, domains, statements, env.heap); @@ -246,7 +252,7 @@ fn forward_pass_assigns_all_blocks() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ complete(1); @@ -322,7 +328,7 @@ fn backward_pass_improves_suboptimal_forward() { // bb0: arm0=bb2(else), arm1=bb1(then). All transitions at cost 0. // bb1→bb3: P→P=0 (cheap), P→I=50 (expensive), I→I=0. // bb2→bb3: same-target only (diagonal). - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ complete(0); @@ -399,7 +405,7 @@ fn rewind_triggers_on_join_with_conflicting_predecessors() { // bb0: arm0=bb2(else), arm1=bb1(then). All transitions allowed. // bb1→bb3: same-target only. bb2→bb3: swap only. - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ complete(0); @@ -484,7 +490,7 @@ fn rewind_skips_exhausted_region() { bb(1): I = 0, P = 10 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [complete(0)]; bb(1): [ @@ -530,7 +536,7 @@ fn single_block_trivial_region() { bb(0): I = 10, P = 5 } - let terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); let result = run_solver(&body, &env, &interner, &domains, &statements, &terminators); @@ -583,7 +589,7 @@ fn cyclic_region_in_forward_backward() { bb(2): I = 3, P = 1 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [I->I = 0, I->P = 5]; bb(1): [diagonal(0), I->P = 5, P->I = 5]; @@ -668,7 +674,7 @@ fn rewind_retries_cyclic_region() { bb(2): I = 0, P = 1 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); // bb2→bb3 (arm0, else): diagonal — forces bb3 to match SCC target. // SCC solver sees this as feasible for both I and P (each has a matching // target in bb3's domain {I,P}). @@ -766,7 +772,7 @@ fn rewind_skips_exhausted_cyclic_region() { bb(0): I = 0, P = 5 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); // bb0→bb3 (arm0, else): swap only — I→P, P→I. // bb0→bb1 (arm1, then): complete — permissive SCC entry. // SCC internal bb1→bb2: diagonal. bb2→bb1 (arm1, then): diagonal. @@ -833,7 +839,7 @@ fn rewind_exhausts_all_regions() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0); diagonal(0)]; bb(1): [diagonal(0)]; @@ -904,7 +910,7 @@ fn forward_pass_rewinds_on_cyclic_failure() { bb(0): I = 0, P = 5 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); // bb0→bb1 diagonal forces bb1==bb0. SCC internals only allow P. terminators! { terminators; bb(0): [diagonal(0)]; @@ -979,7 +985,7 @@ fn backward_pass_keeps_assignment_when_csp_fails() { bb(2): I = 0, P = 10 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); // SCC internal diagonal. Exit bb2→bb3(arm0) only to I. terminators! { terminators; bb(0): [complete(0)]; @@ -1086,7 +1092,7 @@ fn backward_pass_adopts_better_cyclic_solution() { bb(2): I = 10, P = 0 } - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [complete(0)]; bb(1): [diagonal(0)]; @@ -1147,7 +1153,7 @@ fn trivial_failure_emits_diagnostic() { let statements: TargetArray> = IdArray::from_fn(|_: TargetId| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [diagonal(0); diagonal(0)]; bb(1): [diagonal(0)]; @@ -1214,7 +1220,7 @@ fn cyclic_failure_emits_diagnostic() { // bb1→bb0 (arm0, goto): only I→P — forces bb1=I, bb0=P. // Contradiction: bb1 must be both P and I. AC-3 wipes the domain. // bb0→bb2 (arm0, else): permissive, irrelevant to the SCC. - let mut terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let mut terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); terminators! { terminators; bb(0): [ complete(0); @@ -1279,7 +1285,7 @@ fn path_premiums_influence_placement() { // Equal base costs so the path premium is the deciding factor. stmt_costs! { statements; bb(0): I = 1, P = 1, E = 1 } - let terminators = TerminatorCostVec::new(&body.basic_blocks, &heap); + let terminators = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); let config = TransferCostConfig::new(InformationRange::value(InformationUnit::new(100))); let block_costs = make_block_costs_with_config(&body, &domains, &statements, &config, &heap); diff --git a/libs/@local/hashql/mir/src/pass/execution/splitting/mod.rs b/libs/@local/hashql/mir/src/pass/execution/splitting/mod.rs index 7dacfa90feb..c3fa52702a9 100644 --- a/libs/@local/hashql/mir/src/pass/execution/splitting/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/splitting/mod.rs @@ -18,7 +18,7 @@ use hashql_core::{ use super::{ Cost, - cost::StatementCostVec, + cost::{StatementCostVec, TerminatorCostVec}, target::{TargetArray, TargetBitSet, TargetId}, }; use crate::{ @@ -38,9 +38,8 @@ mod tests; /// Returns a [`TargetBitSet`] of execution targets that can cover the statement at `index`. /// /// A target is supported when its [`Cost`] entry is present for that statement. -#[expect(clippy::cast_possible_truncation)] -fn supported(costs: &TargetArray<&[Option]>, index: usize) -> TargetBitSet { - let mut output = FiniteBitSet::new_empty(TargetId::VARIANT_COUNT as u32); +fn supported_statement(costs: &TargetArray<&[Option]>, index: usize) -> TargetBitSet { + let mut output = FiniteBitSet::new_empty(TargetId::VARIANT_COUNT_U32); for (cost_index, cost) in costs.iter_enumerated() { output.set(cost_index, cost[index].is_some()); @@ -49,20 +48,36 @@ fn supported(costs: &TargetArray<&[Option]>, index: usize) -> TargetBitSet output } +fn supported_terminator( + costs: &TargetArray>, + block: BasicBlockId, +) -> TargetBitSet { + let mut output = FiniteBitSet::new_empty(TargetId::VARIANT_COUNT_U32); + + for (cost_index, cost) in costs.iter_enumerated() { + output.set(cost_index, cost.of(block).is_some()); + } + + output +} + /// Counts contiguous target regions per [`BasicBlock`]. /// -/// Returns a non-zero count for each block. Blocks with fewer than two statements -/// always yield one region. -#[expect(unsafe_code, clippy::cast_possible_truncation)] -fn count_regions( +/// Returns a `(region_count, has_separate_terminator_region)` pair for each block. +/// An extra region is added when the terminator's target support is not a superset +/// of the last statement region's support (including incomparable sets, not just +/// strict subsets). +#[expect(unsafe_code)] +fn count_regions( body: &Body<'_>, statement_costs: &TargetArray>, + terminator_costs: &TargetArray>, alloc: B, -) -> BasicBlockVec, B> { +) -> BasicBlockVec<(NonZero, bool), B> { // Start with one region per block and only grow when target support changes. let mut regions = BasicBlockVec::from_elem_in( // SAFETY: 1 is not 0 - unsafe { NonZero::new_unchecked(1) }, + (unsafe { NonZero::new_unchecked(1) }, false), body.basic_blocks.len(), alloc, ); @@ -70,16 +85,16 @@ fn count_regions( for (id, block) in body.basic_blocks.iter_enumerated() { let costs = statement_costs.each_ref().map(|costs| costs.of(id)); - if block.statements.len() < 2 { - // Zero or one statement cannot introduce a target boundary. + if block.statements.is_empty() { + // Zero statements cannot introduce a target boundary. continue; } let mut total = 0; - let mut current: TargetBitSet = FiniteBitSet::new_empty(TargetId::VARIANT_COUNT as u32); + let mut current: TargetBitSet = FiniteBitSet::new_empty(TargetId::VARIANT_COUNT_U32); for stmt_index in 0..block.statements.len() { - let next = supported(&costs, stmt_index); + let next = supported_statement(&costs, stmt_index); // Always count the first statement as a region start. This keeps the count non-zero // even if cost data is missing or malformed. @@ -89,9 +104,23 @@ fn count_regions( } } + let mut has_separate_terminator_region = false; + + // Check if the terminator narrows the target set of the last statement region. + // If the terminator supports a strict subset of backends, it needs its own region + // so that the preceding statements can still be assigned to the wider set. + let terminator_supported = supported_terminator(terminator_costs, id); + if !terminator_supported.is_superset(¤t) { + total += 1; + has_separate_terminator_region = true; + } + // SAFETY: The loop always counts the first statement for blocks with 2+ statements, so // total cannot be zero here. - regions[id] = unsafe { NonZero::new_unchecked(total) }; + regions[id] = ( + unsafe { NonZero::new_unchecked(total) }, + has_separate_terminator_region, + ); } regions @@ -124,12 +153,13 @@ impl<'heap> VisitorMut<'heap> for RemapBasicBlockId<'_> { /// /// Remaps all [`BasicBlockId`] references, connects split blocks with [`Goto`] chains, /// and updates [`StatementCostVec`] to reflect the new layout. -#[expect(clippy::cast_possible_truncation)] +#[expect(clippy::too_many_lines)] fn offset_basic_blocks<'heap, A: Allocator, S: Allocator + Clone>( context: &MirContext<'_, 'heap>, body: &mut Body<'heap>, - regions: &BasicBlockSlice>, + regions: &BasicBlockSlice<(NonZero, bool)>, statement_costs: &mut TargetArray>, + terminator_costs: &mut TargetArray>, scratch: S, alloc: A, ) -> BasicBlockVec { @@ -143,13 +173,13 @@ fn offset_basic_blocks<'heap, A: Allocator, S: Allocator + Clone>( let mut indices = BasicBlockVec::from_elem_in(BasicBlockId::MIN, body.basic_blocks.len(), scratch); - for (id, regions) in regions.iter_enumerated() { + for (id, (regions, _)) in regions.iter_enumerated() { indices[id] = length; length.increment_by(regions.get()); } let mut targets = BasicBlockVec::from_elem_in( - FiniteBitSet::new_empty(TargetId::VARIANT_COUNT as u32), + FiniteBitSet::new_empty(TargetId::VARIANT_COUNT_U32), length.as_usize(), alloc, ); @@ -176,19 +206,25 @@ fn offset_basic_blocks<'heap, A: Allocator, S: Allocator + Clone>( let mut index = BasicBlockId::START; for &[start_id, end_id] in indices.windows() { let region = &mut body.basic_blocks.as_mut()[start_id..end_id]; - debug_assert_eq!(region.len(), regions[index].get()); + let (region_len, has_separate_terminator_region) = regions[index]; + + debug_assert_eq!(region.len(), region_len.get()); let costs = statement_costs.each_ref().map(|cost| cost.of(index)); if region.len() < 2 { debug_assert_eq!(region.len(), 1); - // Unlike other regions, these may be empty. Mark empty blocks as supported everywhere. if costs[TargetId::Interpreter].is_empty() { - targets[start_id] - .insert_range(TargetId::MIN..=TargetId::MAX, TargetId::VARIANT_COUNT); + // No statements: the block's target affinity comes from its terminator. + targets[start_id] = supported_terminator(terminator_costs, index); } else { - targets[start_id] = supported(&costs, 0); + // `count_regions` only produces a single region (no split) when all statements + // share uniform target support AND the terminator's support is a superset of + // that. The terminator can run everywhere the statements can, so the statement + // support is the binding constraint. Index 0 is representative of all statements + // because uniformity is what made this a single region. + targets[start_id] = supported_statement(&costs, 0); } index.increment_by(1); @@ -226,14 +262,32 @@ fn offset_basic_blocks<'heap, A: Allocator, S: Allocator + Clone>( }; let mut rest = rest; + let mut runs = 0; + + // If the terminator narrows the target set, peel off the last block for it. + // That block is already empty (placeholder) and already holds the original terminator + // (from the `mem::swap` above). We just need to record its target affinity and exclude + // it from the statement-peeling loop. + if has_separate_terminator_region { + let [statements @ .., _] = rest else { + unreachable!() + }; + + rest = statements; + + // Write the target before incrementing `runs`, matching the convention in the + // statement-peeling loop below. `terminator_costs` is indexed by original (pre-split) + // block IDs, so we use `index` rather than a post-split ID. + targets[end_id.minus(runs + 1)] = supported_terminator(terminator_costs, index); + runs += 1; + } + // Peel off runs and move them into recipient blocks counted from the end. - let mut current = supported(&costs, start.statements.len() - 1); + let mut current = supported_statement(&costs, start.statements.len() - 1); let mut ptr = start.statements.len() - 1; - let mut runs = 0; - while let [remaining @ .., recipient] = rest { - while supported(&costs, ptr) == current { + while supported_statement(&costs, ptr) == current { ptr -= 1; } @@ -247,13 +301,13 @@ fn offset_basic_blocks<'heap, A: Allocator, S: Allocator + Clone>( "Each run contains at least one statement" ); - current = supported(&costs, ptr); + current = supported_statement(&costs, ptr); recipient.statements = statements; rest = remaining; runs += 1; } - debug_assert_eq!(runs, regions[index].get() - 1); + debug_assert_eq!(runs, region_len.get() - 1); // The first block holds the remaining run. targets[start_id] = current; @@ -265,6 +319,10 @@ fn offset_basic_blocks<'heap, A: Allocator, S: Allocator + Clone>( cost.remap(&body.basic_blocks); } + for cost in terminator_costs.iter_mut() { + cost.remap(regions); + } + targets } @@ -299,36 +357,43 @@ impl BasicBlockSplitting { context: &MirContext<'_, 'heap>, body: &mut Body<'heap>, statement_costs: &mut TargetArray>, + terminator_costs: &mut TargetArray>, ) -> BasicBlockVec where S: Clone, { - self.split_in(context, body, statement_costs, Global) + self.split_in(context, body, statement_costs, terminator_costs, Global) } - /// Splits [`Body`] blocks and returns per-block [`TargetBitSet`] affinities along with - /// the per-block region counts used during splitting. + /// Splits [`Body`] blocks and returns per-block [`TargetBitSet`] affinities. /// - /// The first element is indexed by the new [`BasicBlockId`]s. The second element maps - /// each original block to the number of blocks it was split into, which callers can use - /// to redistribute parallel data structures. + /// Partitions blocks so each resulting block's statements share the same target support, + /// with an additional split when the terminator narrows the target set. Updates both + /// `statement_costs` and `terminator_costs` to reflect the new block layout. pub(crate) fn split_in<'heap, A: Allocator>( &self, context: &MirContext<'_, 'heap>, body: &mut Body<'heap>, statement_costs: &mut TargetArray>, + terminator_costs: &mut TargetArray>, alloc: A, ) -> BasicBlockVec where S: Clone, { - let regions = count_regions(body, statement_costs, self.scratch.clone()); + let regions = count_regions( + body, + statement_costs, + terminator_costs, + self.scratch.clone(), + ); offset_basic_blocks( context, body, ®ions, statement_costs, + terminator_costs, self.scratch.clone(), alloc, ) diff --git a/libs/@local/hashql/mir/src/pass/execution/splitting/tests.rs b/libs/@local/hashql/mir/src/pass/execution/splitting/tests.rs index ec24c5fcfc9..4d033486809 100644 --- a/libs/@local/hashql/mir/src/pass/execution/splitting/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/splitting/tests.rs @@ -17,7 +17,7 @@ use hashql_core::{ use hashql_diagnostics::DiagnosticIssues; use insta::{Settings, assert_snapshot}; -use super::{BasicBlockSplitting, count_regions, offset_basic_blocks, supported}; +use super::{BasicBlockSplitting, count_regions, offset_basic_blocks, supported_statement}; use crate::{ body::{ Body, @@ -30,7 +30,7 @@ use crate::{ context::MirContext, intern::Interner, pass::execution::{ - cost::{Cost, StatementCostVec}, + cost::{Cost, StatementCostVec, TerminatorCostVec}, target::{TargetArray, TargetBitSet, TargetId}, }, pretty::{TextFormatAnnotations, TextFormatOptions}, @@ -92,6 +92,42 @@ fn make_target_costs<'heap, const N: usize>( costs } +/// Creates terminator costs where every target supports every block's terminator. +fn make_all_supported_terminator_costs<'heap>( + body: &Body<'heap>, + heap: &'heap Heap, +) -> TargetArray> { + TargetArray::from_fn(|_| { + let mut costs = TerminatorCostVec::new_in(&body.basic_blocks, heap); + for (id, _) in body.basic_blocks.iter_enumerated() { + costs.insert(id, cost!(1)); + } + costs + }) +} + +/// Creates terminator costs with per-target, per-block support patterns. +/// +/// `patterns[target][block]` is `true` if the target supports the terminator of that block. +fn make_terminator_costs<'heap, const N: usize>( + body: &Body<'heap>, + patterns: TargetArray<[bool; N]>, + heap: &'heap Heap, +) -> TargetArray> { + let mut costs = TargetArray::from_fn(|_| TerminatorCostVec::new_in(&body.basic_blocks, heap)); + + for (target_id, block_patterns) in patterns.iter_enumerated() { + for (block_index, &supported) in block_patterns.iter().enumerate() { + if supported { + let block = BasicBlockId::new(block_index as u32); + costs[target_id].insert(block, cost!(1)); + } + } + } + + costs +} + fn assert_assignment_locals(body: &Body<'_>, block_id: BasicBlockId, expected: &[&str]) { let block = &body.basic_blocks[block_id]; assert_eq!(block.statements.len(), expected.len()); @@ -135,7 +171,7 @@ fn supported_all_targets() { let costs: TargetArray<&[Option]> = TargetArray::from_raw([&[Some(cost!(1))], &[Some(cost!(2))], &[Some(cost!(3))]]); - let result = supported(&costs, 0); + let result = supported_statement(&costs, 0); let expected = Targets { interpreter: true, postgres: true, @@ -150,7 +186,7 @@ fn supported_all_targets() { fn supported_no_targets() { let costs: TargetArray<&[Option]> = TargetArray::from_raw([&[None], &[None], &[None]]); - let result = supported(&costs, 0); + let result = supported_statement(&costs, 0); assert!(result.is_empty()); } @@ -159,7 +195,7 @@ fn supported_no_targets() { fn supported_single_target() { let costs: TargetArray<&[_]> = TargetArray::from_raw([&[Some(cost!(1))], &[None], &[None]]); - let result = supported(&costs, 0); + let result = supported_statement(&costs, 0); assert!(result.contains(TargetId::Interpreter)); assert!(!result.contains(TargetId::Postgres)); @@ -171,7 +207,7 @@ fn supported_mixed_targets() { let costs: TargetArray<&[_]> = TargetArray::from_raw([&[Some(cost!(1))], &[Some(cost!(2))], &[None]]); - let result = supported(&costs, 0); + let result = supported_statement(&costs, 0); let expected = Targets { interpreter: true, postgres: true, @@ -201,9 +237,10 @@ fn count_regions_empty_block() { }); let costs = TargetArray::from_fn(|_| StatementCostVec::new_in(&body.basic_blocks, &heap)); - let regions = count_regions(&body, &costs, Global); + let terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - assert_eq!(regions[BasicBlockId::new(0)].get(), 1); + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 1); } #[test] @@ -223,9 +260,10 @@ fn count_regions_single_statement() { let patterns = TargetArray::from_raw([[[true]], [[true]], [[false]]]); let costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - assert_eq!(regions[BasicBlockId::new(0)].get(), 1); + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 1); } #[test] @@ -251,9 +289,10 @@ fn count_regions_uniform_support() { [[false, false, false]], ]); let costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - assert_eq!(regions[BasicBlockId::new(0)].get(), 1); + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 1); } #[test] @@ -274,9 +313,10 @@ fn count_regions_two_regions() { let patterns = TargetArray::from_raw([[[true, true]], [[true, false]], [[false, false]]]); let costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - assert_eq!(regions[BasicBlockId::new(0)].get(), 2); + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 2); } #[test] @@ -302,9 +342,10 @@ fn count_regions_three_regions() { [[false, false, false]], ]); let costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - assert_eq!(regions[BasicBlockId::new(0)].get(), 3); + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 3); } #[test] @@ -331,9 +372,10 @@ fn count_regions_alternating() { [[false, false, false, false]], ]); let costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - assert_eq!(regions[BasicBlockId::new(0)].get(), 4); + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 4); } // ============================================================================= @@ -364,9 +406,18 @@ fn offset_single_block_no_split() { let patterns = TargetArray::from_raw([[[true]], [[true]], [[false]]]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_eq!(body.basic_blocks.len(), 1); assert_eq!(targets.len(), 1); @@ -408,9 +459,18 @@ fn offset_single_block_splits() { let patterns = TargetArray::from_raw([[[true, true]], [[true, false]], [[false, false]]]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_eq!(body.basic_blocks.len(), 2); assert_eq!(targets.len(), 2); @@ -464,9 +524,18 @@ fn offset_multiple_blocks_no_splits() { let patterns = TargetArray::from_raw([[[true], [true]], [[true], [true]], [[false], [false]]]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_eq!(body.basic_blocks.len(), 2); assert_eq!(targets.len(), 2); @@ -519,9 +588,18 @@ fn offset_multiple_blocks_mixed() { [&[false, false], &[false]], ]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_eq!(body.basic_blocks.len(), 3); assert_eq!(targets.len(), 3); @@ -580,9 +658,18 @@ fn offset_terminator_moves_to_last() { [[false, false]], ]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let _targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let _targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_return_terminator(&body, BasicBlockId::new(1)); } @@ -617,9 +704,18 @@ fn offset_goto_chain_created() { [[false, false, false]], ]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let _targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let _targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_eq!(body.basic_blocks.len(), 3); assert_goto_terminator(&body, BasicBlockId::new(0)); @@ -657,9 +753,18 @@ fn offset_goto_targets_correct() { [[false, false, false]], ]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let _targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let _targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_eq!(body.basic_blocks.len(), 3); assert_goto_target(&body, BasicBlockId::new(0), BasicBlockId::new(1)); @@ -696,9 +801,18 @@ fn offset_statements_split_correctly() { [[false, false]], ]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let _targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let _targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_assignment_locals(&body, BasicBlockId::new(0), &["x"]); assert_assignment_locals(&body, BasicBlockId::new(1), &["y"]); @@ -734,9 +848,18 @@ fn offset_statement_order_preserved() { [[false, false, false]], ]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let _targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let _targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); assert_assignment_locals(&body, BasicBlockId::new(0), &["a"]); assert_assignment_locals(&body, BasicBlockId::new(1), &["b", "c"]); @@ -771,9 +894,18 @@ fn offset_targets_populated() { [[false, false]], ]); let mut costs = make_target_costs(&body, patterns, &heap); - let regions = count_regions(&body, &costs, Global); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); - let targets = offset_basic_blocks(&context, &mut body, ®ions, &mut costs, Global, Global); + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); let expected_first = Targets { interpreter: true, @@ -930,7 +1062,8 @@ fn split_no_changes_needed() { let mut costs = make_target_costs(&body, patterns, &heap); let splitting = BasicBlockSplitting::new(); - let targets = splitting.split(&context, &mut body, &mut costs); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let targets = splitting.split(&context, &mut body, &mut costs, &mut terminator_costs); assert_split("split_no_changes_needed", &context, &body, &costs, &targets); } @@ -966,7 +1099,8 @@ fn split_basic_two_regions() { let mut costs = make_target_costs(&body, patterns, &heap); let splitting = BasicBlockSplitting::new(); - let targets = splitting.split(&context, &mut body, &mut costs); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let targets = splitting.split(&context, &mut body, &mut costs, &mut terminator_costs); assert_split("split_basic_two_regions", &context, &body, &costs, &targets); } @@ -1007,7 +1141,8 @@ fn split_multi_block_complex() { let mut costs = make_target_costs(&body, patterns, &heap); let splitting = BasicBlockSplitting::new(); - let targets = splitting.split(&context, &mut body, &mut costs); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let targets = splitting.split(&context, &mut body, &mut costs, &mut terminator_costs); assert_split( "split_multi_block_complex", @@ -1050,7 +1185,8 @@ fn split_cost_remap() { let mut costs = make_target_costs(&body, patterns, &heap); let splitting = BasicBlockSplitting::new(); - let targets = splitting.split(&context, &mut body, &mut costs); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let targets = splitting.split(&context, &mut body, &mut costs, &mut terminator_costs); assert_split("split_cost_remap", &context, &body, &costs, &targets); } @@ -1093,7 +1229,8 @@ fn split_block_references_updated() { let mut costs = make_target_costs(&body, patterns, &heap); let splitting = BasicBlockSplitting::new(); - let targets = splitting.split(&context, &mut body, &mut costs); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let targets = splitting.split(&context, &mut body, &mut costs, &mut terminator_costs); assert_split( "split_block_references_updated", @@ -1103,3 +1240,486 @@ fn split_block_references_updated() { &targets, ); } + +/// One statement, terminator is superset of statement support: no extra region. +#[test] +fn count_regions_terminator_superset_no_split() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int; + + bb0() { + x = load 42; + return x; + } + }); + + // Statement supported on {I, P} + let patterns = TargetArray::from_raw([[[true]], [[true]], [[false]]]); + let costs = make_target_costs(&body, patterns, &heap); + // Terminator supported on {I, P, E} (superset) + let terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 1); + assert!(!regions[BasicBlockId::new(0)].1); +} + +/// One statement, terminator is strict subset: extra region needed. +#[test] +fn count_regions_terminator_subset_splits() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int; + + bb0() { + x = load 42; + return x; + } + }); + + // Statement supported on {I, P} + let patterns = TargetArray::from_raw([[[true]], [[true]], [[false]]]); + let costs = make_target_costs(&body, patterns, &heap); + // Terminator supported only on {I} + let terminator_costs = make_terminator_costs( + &body, + TargetArray::from_raw([[true], [false], [false]]), + &heap, + ); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 2); + assert!(regions[BasicBlockId::new(0)].1); +} + +/// One statement, terminator is incomparable with statement support: extra region needed. +#[test] +fn count_regions_terminator_incomparable_splits() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int; + + bb0() { + x = load 42; + return x; + } + }); + + // Statement supported on {P, I} + let patterns = TargetArray::from_raw([[[true]], [[true]], [[false]]]); + let costs = make_target_costs(&body, patterns, &heap); + // Terminator supported on {E, I} (incomparable with {P, I}) + let terminator_costs = make_terminator_costs( + &body, + TargetArray::from_raw([[true], [false], [true]]), + &heap, + ); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 2); + assert!(regions[BasicBlockId::new(0)].1); +} + +/// Zero statements: always 1 region, no terminator split (block IS the terminator). +#[test] +fn count_regions_zero_statements_no_split() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int; + + bb0() { + return x; + } + }); + + let costs = TargetArray::from_fn(|_| StatementCostVec::new_in(&body.basic_blocks, &heap)); + // Terminator only on {I} + let terminator_costs = make_terminator_costs( + &body, + TargetArray::from_raw([[true], [false], [false]]), + &heap, + ); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 1); + assert!(!regions[BasicBlockId::new(0)].1); +} + +/// Multiple statement runs, terminator doesn't narrow last run: no extra region. +#[test] +fn count_regions_multiple_runs_terminator_superset() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int, y: Int; + + bb0() { + x = load 1; + y = load 2; + return y; + } + }); + + // First stmt: {I, P}, second stmt: {I} + let patterns = TargetArray::from_raw([[[true, true]], [[true, false]], [[false, false]]]); + let costs = make_target_costs(&body, patterns, &heap); + // Terminator on {I, P, E} (superset of last run {I}) + let terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + // 2 statement runs, no terminator split + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 2); + assert!(!regions[BasicBlockId::new(0)].1); +} + +/// Multiple statement runs, terminator narrows last run: extra region. +#[test] +fn count_regions_multiple_runs_terminator_narrows() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let body = body!(interner, env; fn@0/0 -> Int { + decl x: Int, y: Int; + + bb0() { + x = load 1; + y = load 2; + return y; + } + }); + + // First stmt: {I, P}, second stmt: {I, P} + let patterns = TargetArray::from_raw([[[true, true]], [[true, true]], [[false, false]]]); + let costs = make_target_costs(&body, patterns, &heap); + // Terminator only on {I} + let terminator_costs = make_terminator_costs( + &body, + TargetArray::from_raw([[true], [false], [false]]), + &heap, + ); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + // 1 statement run + 1 terminator region + assert_eq!(regions[BasicBlockId::new(0)].0.get(), 2); + assert!(regions[BasicBlockId::new(0)].1); +} + +/// One statement + narrowed terminator produces 2 blocks with correct affinities. +#[test] +fn offset_terminator_narrowing_creates_split() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let mut body = body!(interner, env; fn@0/0 -> Int { + decl x: Int; + + bb0() { + x = load 42; + return x; + } + }); + + let context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + // Statement supported on {I, P} + let patterns = TargetArray::from_raw([[[true]], [[true]], [[false]]]); + let mut costs = make_target_costs(&body, patterns, &heap); + // Terminator only on {I} + let mut terminator_costs = make_terminator_costs( + &body, + TargetArray::from_raw([[true], [false], [false]]), + &heap, + ); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); + + // Should produce 2 blocks + assert_eq!(body.basic_blocks.len(), 2); + + // First block has the statement, second is empty (terminator only) + assert_assignment_locals(&body, BasicBlockId::new(0), &["x"]); + assert_eq!(body.basic_blocks[BasicBlockId::new(1)].statements.len(), 0); + + // First block has Goto to second, second has the original Return + assert_goto_target(&body, BasicBlockId::new(0), BasicBlockId::new(1)); + assert_return_terminator(&body, BasicBlockId::new(1)); + + // Affinities: first {I, P}, second {I} + let expected_statements = Targets { + interpreter: true, + postgres: true, + embedding: false, + } + .compile(); + let expected_terminator = Targets { + interpreter: true, + postgres: false, + embedding: false, + } + .compile(); + assert_eq!(targets[BasicBlockId::new(0)], expected_statements); + assert_eq!(targets[BasicBlockId::new(1)], expected_terminator); +} + +/// Empty block with restricted terminator gets its affinity from the terminator. +#[test] +fn offset_empty_block_uses_terminator_affinity() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let mut body = body!(interner, env; fn@0/0 -> Int { + decl x: Int; + + bb0() { + return x; + } + }); + + let context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + let costs = TargetArray::from_fn(|_| StatementCostVec::new_in(&body.basic_blocks, &heap)); + // Only interpreter supports the terminator + let mut terminator_costs = make_terminator_costs( + &body, + TargetArray::from_raw([[true], [false], [false]]), + &heap, + ); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + let mut costs = costs; + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); + + // Still 1 block, no split + assert_eq!(body.basic_blocks.len(), 1); + + let expected = Targets { + interpreter: true, + postgres: false, + embedding: false, + } + .compile(); + assert_eq!(targets[BasicBlockId::new(0)], expected); +} + +/// Two statement runs + disjoint terminator produces 3 blocks. +#[test] +fn offset_two_runs_plus_terminator_split() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let mut body = body!(interner, env; fn@0/0 -> Int { + decl x: Int, y: Int; + + bb0() { + x = load 1; + y = load 2; + return y; + } + }); + + let context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + // First stmt: {P} only, second stmt: {E} only + let patterns = TargetArray::from_raw([[[false, false]], [[true, false]], [[false, true]]]); + let mut costs = make_target_costs(&body, patterns, &heap); + // Terminator only on {I} + let mut terminator_costs = make_terminator_costs( + &body, + TargetArray::from_raw([[true], [false], [false]]), + &heap, + ); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); + + assert_eq!(body.basic_blocks.len(), 3); + + // Affinities: {P}, {E}, {I} + let expected_p = Targets { + interpreter: false, + postgres: true, + embedding: false, + } + .compile(); + let expected_e = Targets { + interpreter: false, + postgres: false, + embedding: true, + } + .compile(); + let expected_i = Targets { + interpreter: true, + postgres: false, + embedding: false, + } + .compile(); + assert_eq!(targets[BasicBlockId::new(0)], expected_p); + assert_eq!(targets[BasicBlockId::new(1)], expected_e); + assert_eq!(targets[BasicBlockId::new(2)], expected_i); + + // Last block has the original return, others have gotos + assert_goto_terminator(&body, BasicBlockId::new(0)); + assert_goto_terminator(&body, BasicBlockId::new(1)); + assert_return_terminator(&body, BasicBlockId::new(2)); +} + +/// Terminator cost remap after split: last block gets original cost, Goto blocks get zero. +#[test] +fn terminator_cost_remap_after_split() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let mut body = body!(interner, env; fn@0/0 -> Int { + decl x: Int, y: Int; + + bb0() { + x = load 1; + y = load 2; + return y; + } + }); + + let context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + // Two statement runs: {I,P} then {I} + let patterns = TargetArray::from_raw([[[true, true]], [[true, false]], [[false, false]]]); + let mut costs = make_target_costs(&body, patterns, &heap); + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + + let splitting = BasicBlockSplitting::new(); + let _targets = splitting.split(&context, &mut body, &mut costs, &mut terminator_costs); + + // After split: 2 blocks. First has Goto (zero cost), second has original Return. + assert_eq!(body.basic_blocks.len(), 2); + + for target in TargetId::all() { + // Goto block gets zero cost + assert_eq!( + terminator_costs[target].of(BasicBlockId::new(0)), + Some(cost!(0)) + ); + // Original terminator block keeps the original cost + assert_eq!( + terminator_costs[target].of(BasicBlockId::new(1)), + Some(cost!(1)) + ); + } +} + +/// Terminator is strict superset of statement support: no split occurs. +#[test] +fn offset_terminator_superset_no_split() { + let heap = Heap::new(); + let interner = Interner::new(&heap); + let env = Environment::new(&heap); + + let mut body = body!(interner, env; fn@0/0 -> Int { + decl x: Int; + + bb0() { + x = load 42; + return x; + } + }); + + let context = MirContext { + heap: &heap, + env: &env, + interner: &interner, + diagnostics: DiagnosticIssues::new(), + }; + + // Statement: {P} only + let patterns = TargetArray::from_raw([[[false]], [[true]], [[false]]]); + let mut costs = make_target_costs(&body, patterns, &heap); + // Terminator: {I, P, E} (superset) + let mut terminator_costs = make_all_supported_terminator_costs(&body, &heap); + let regions = count_regions(&body, &costs, &terminator_costs, Global); + + let targets = offset_basic_blocks( + &context, + &mut body, + ®ions, + &mut costs, + &mut terminator_costs, + Global, + Global, + ); + + // No split + assert_eq!(body.basic_blocks.len(), 1); + + // Affinity comes from statements: {P} + let expected = Targets { + interpreter: false, + postgres: true, + embedding: false, + } + .compile(); + assert_eq!(targets[BasicBlockId::new(0)], expected); +} diff --git a/libs/@local/hashql/mir/src/pass/execution/statement_placement/common.rs b/libs/@local/hashql/mir/src/pass/execution/statement_placement/common.rs index aaf098cf66b..2ea7cdda04b 100644 --- a/libs/@local/hashql/mir/src/pass/execution/statement_placement/common.rs +++ b/libs/@local/hashql/mir/src/pass/execution/statement_placement/common.rs @@ -19,7 +19,7 @@ use crate::{ place::Projection, rvalue::RValue, statement::{Assign, Statement, StatementKind}, - terminator::TerminatorKind, + terminator::{Terminator, TerminatorKind}, }, context::MirContext, pass::{ @@ -29,7 +29,7 @@ use crate::{ }, execution::{ Cost, - cost::StatementCostVec, + cost::{StatementCostVec, TerminatorCostVec}, traversal::{Access, EntityPath}, }, }, @@ -73,6 +73,14 @@ pub(crate) trait Supported<'heap> { operand: &Operand<'heap>, ) -> bool; + fn is_supported_terminator( + &self, + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + terminator: &Terminator<'heap>, + ) -> bool; + /// Checks whether a type can be unambiguously deserialized after crossing a backend boundary. /// /// Returns `true` by default. Targets that serialize values to a lossy format (e.g., jsonb) @@ -97,6 +105,16 @@ where T::is_supported_rvalue(self, context, body, domain, rvalue) } + fn is_supported_terminator( + &self, + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + terminator: &Terminator<'heap>, + ) -> bool { + T::is_supported_terminator(self, context, body, domain, terminator) + } + fn is_supported_operand( &self, context: &MirContext<'_, 'heap>, @@ -276,6 +294,7 @@ pub(crate) struct CostVisitor<'ctx, 'env, 'heap, S, A: Allocator> { pub cost: Cost, pub statement_costs: StatementCostVec, + pub terminator_costs: TerminatorCostVec, pub supported: S, } @@ -313,6 +332,25 @@ where Ok(()) } + + fn visit_terminator( + &mut self, + location: Location, + terminator: &Terminator<'heap>, + ) -> Self::Result { + let is_supported = self.supported.is_supported_terminator( + self.context, + self.body, + self.dispatchable, + terminator, + ); + + if is_supported { + self.terminator_costs.insert(location.block, self.cost); + } + + Ok(()) + } } /// Determines which backend can access an entity field projection. diff --git a/libs/@local/hashql/mir/src/pass/execution/statement_placement/embedding/mod.rs b/libs/@local/hashql/mir/src/pass/execution/statement_placement/embedding/mod.rs index 90a4da7c7bf..cc0e6ed0543 100644 --- a/libs/@local/hashql/mir/src/pass/execution/statement_placement/embedding/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/statement_placement/embedding/mod.rs @@ -7,11 +7,20 @@ use super::{ common::{CostVisitor, OnceValue, Supported, SupportedAnalysis}, }; use crate::{ - body::{Body, Source, local::Local, operand::Operand, place::Place, rvalue::RValue}, + body::{ + Body, Source, + local::Local, + operand::Operand, + place::Place, + rvalue::RValue, + terminator::{Goto, Return, SwitchInt, Terminator, TerminatorKind}, + }, context::MirContext, pass::execution::{ - Cost, VertexType, cost::StatementCostVec, - statement_placement::common::entity_projection_access, traversal::Access, + Cost, VertexType, + cost::{StatementCostVec, TerminatorCostVec}, + statement_placement::common::entity_projection_access, + traversal::Access, }, visit::Visitor as _, }; @@ -65,6 +74,38 @@ impl<'heap> Supported<'heap> for EmbeddingSupported { } } + fn is_supported_terminator( + &self, + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + terminator: &Terminator<'heap>, + ) -> bool { + match &terminator.kind { + TerminatorKind::Goto(Goto { target }) => target + .args + .iter() + .all(|arg| self.is_supported_operand(context, body, domain, arg)), + TerminatorKind::SwitchInt(SwitchInt { + discriminant, + targets, + }) => { + self.is_supported_operand(context, body, domain, discriminant) + && targets.targets().iter().all(|target| { + target + .args + .iter() + .all(|arg| self.is_supported_operand(context, body, domain, arg)) + }) + } + TerminatorKind::Return(Return { value }) => { + self.is_supported_operand(context, body, domain, value) + } + TerminatorKind::GraphRead(_) => false, + TerminatorKind::Unreachable => true, + } + } + fn is_supported_operand( &self, _: &MirContext<'_, 'heap>, @@ -106,13 +147,14 @@ impl<'heap, A: Allocator + Clone, S: Allocator> StatementPlacement<'heap, A> body: &Body<'heap>, vertex: VertexType, alloc: A, - ) -> StatementCostVec { - let statement_costs = StatementCostVec::new_in(&body.basic_blocks, alloc); + ) -> (StatementCostVec, TerminatorCostVec) { + let statement_costs = StatementCostVec::new_in(&body.basic_blocks, alloc.clone()); + let terminator_costs = TerminatorCostVec::new_in(&body.basic_blocks, alloc); match body.source { Source::GraphReadFilter(_) => {} Source::Ctor(_) | Source::Closure(..) | Source::Thunk(..) | Source::Intrinsic(_) => { - return statement_costs; + return (statement_costs, terminator_costs); } } @@ -148,11 +190,12 @@ impl<'heap, A: Allocator + Clone, S: Allocator> StatementPlacement<'heap, A> cost: self.statement_cost, statement_costs, + terminator_costs, supported: &EmbeddingSupported { vertex }, }; visitor.visit_body(body); - visitor.statement_costs + (visitor.statement_costs, visitor.terminator_costs) } } diff --git a/libs/@local/hashql/mir/src/pass/execution/statement_placement/interpret/mod.rs b/libs/@local/hashql/mir/src/pass/execution/statement_placement/interpret/mod.rs index 419cccded41..eebbb5dbefb 100644 --- a/libs/@local/hashql/mir/src/pass/execution/statement_placement/interpret/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/statement_placement/interpret/mod.rs @@ -6,11 +6,12 @@ use crate::{ Body, Source, location::Location, statement::{Statement, StatementKind}, + terminator::Terminator, }, context::MirContext, pass::execution::{ VertexType, - cost::{Cost, StatementCostVec}, + cost::{Cost, StatementCostVec, TerminatorCostVec}, }, visit::Visitor, }; @@ -22,6 +23,7 @@ struct CostVisitor { cost: Cost, statement_costs: StatementCostVec, + terminator_costs: TerminatorCostVec, } impl<'heap, A: Allocator> Visitor<'heap> for CostVisitor { @@ -43,6 +45,17 @@ impl<'heap, A: Allocator> Visitor<'heap> for CostVisitor { Ok(()) } + + fn visit_terminator(&mut self, location: Location, _: &Terminator<'heap>) -> Self::Result { + // Because interpreter is our base case, every terminator is supported, via the default base + // cost. + // Because this is done *before* basic block splitting, we assign the same cost to as well, + // splitting, then assigns a cumulative cost of `0` for generated GOTOs to not distort the + // cost distribution. + self.terminator_costs.insert(location.block, self.cost); + + Ok(()) + } } /// Statement placement for the [`Interpreter`](super::super::TargetId::Interpreter) execution @@ -68,22 +81,24 @@ impl<'heap, A: Allocator + Clone> StatementPlacement<'heap, A> for InterpreterSt body: &Body<'heap>, _: VertexType, alloc: A, - ) -> StatementCostVec { - let statement_costs = StatementCostVec::new_in(&body.basic_blocks, alloc); + ) -> (StatementCostVec, TerminatorCostVec) { + let statement_costs = StatementCostVec::new_in(&body.basic_blocks, alloc.clone()); + let terminator_costs = TerminatorCostVec::new_in(&body.basic_blocks, alloc); match body.source { Source::GraphReadFilter(_) => {} Source::Ctor(_) | Source::Closure(..) | Source::Thunk(..) | Source::Intrinsic(_) => { - return statement_costs; + return (statement_costs, terminator_costs); } } let mut visitor = CostVisitor { cost: self.statement_cost, statement_costs, + terminator_costs, }; visitor.visit_body(body); - visitor.statement_costs + (visitor.statement_costs, visitor.terminator_costs) } } diff --git a/libs/@local/hashql/mir/src/pass/execution/statement_placement/mod.rs b/libs/@local/hashql/mir/src/pass/execution/statement_placement/mod.rs index 3b9ca7852be..49a64593990 100644 --- a/libs/@local/hashql/mir/src/pass/execution/statement_placement/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/statement_placement/mod.rs @@ -23,7 +23,7 @@ pub(crate) use self::{ embedding::EmbeddingStatementPlacement, interpret::InterpreterStatementPlacement, postgres::PostgresStatementPlacement, }; -use super::{VertexType, target::TargetId}; +use super::{VertexType, cost::TerminatorCostVec, target::TargetId}; use crate::{body::Body, context::MirContext, pass::execution::cost::StatementCostVec}; /// Computes statement placement costs for a specific execution target. @@ -50,7 +50,7 @@ pub(crate) trait StatementPlacement<'heap, A: Allocator> { body: &Body<'heap>, vertex: VertexType, alloc: A, - ) -> StatementCostVec; + ) -> (StatementCostVec, TerminatorCostVec); } pub(crate) enum TargetPlacementStatement<'heap, S: Allocator> { @@ -80,7 +80,7 @@ impl<'heap, A: Allocator + Clone, S: Allocator> StatementPlacement<'heap, A> body: &Body<'heap>, vertex: VertexType, alloc: A, - ) -> StatementCostVec { + ) -> (StatementCostVec, TerminatorCostVec) { match self { TargetPlacementStatement::Interpreter(placement) => { placement.statement_placement_in(context, body, vertex, alloc) diff --git a/libs/@local/hashql/mir/src/pass/execution/statement_placement/postgres/mod.rs b/libs/@local/hashql/mir/src/pass/execution/statement_placement/postgres/mod.rs index 10be791ff9c..637d6cb3aa9 100644 --- a/libs/@local/hashql/mir/src/pass/execution/statement_placement/postgres/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/statement_placement/postgres/mod.rs @@ -25,11 +25,12 @@ use crate::{ operand::Operand, place::{FieldIndex, Place, ProjectionKind}, rvalue::{Aggregate, AggregateKind, BinOp, Binary, RValue, Unary}, + terminator::{Goto, Return, SwitchInt, Terminator, TerminatorKind}, }, context::MirContext, pass::execution::{ VertexType, - cost::{Cost, StatementCostVec}, + cost::{Cost, StatementCostVec, TerminatorCostVec}, statement_placement::common::entity_projection_access, traversal::Access, }, @@ -446,6 +447,38 @@ impl<'heap, A: Allocator> Supported<'heap> for PostgresSupported<'_, 'heap, A> { } } + fn is_supported_terminator( + &self, + context: &MirContext<'_, 'heap>, + body: &Body<'heap>, + domain: &DenseBitSet, + terminator: &Terminator<'heap>, + ) -> bool { + match &terminator.kind { + TerminatorKind::Goto(Goto { target }) => target + .args + .iter() + .all(|arg| self.is_supported_operand(context, body, domain, arg)), + TerminatorKind::SwitchInt(SwitchInt { + discriminant, + targets, + }) => { + self.is_supported_operand(context, body, domain, discriminant) + && targets.targets().iter().all(|target| { + target + .args + .iter() + .all(|arg| self.is_supported_operand(context, body, domain, arg)) + }) + } + TerminatorKind::Return(Return { value }) => { + self.is_supported_operand(context, body, domain, value) + } + TerminatorKind::GraphRead(_) => false, + TerminatorKind::Unreachable => true, + } + } + fn is_supported_operand( &self, context: &MirContext<'_, 'heap>, @@ -698,13 +731,14 @@ impl<'heap, A: Allocator + Clone, S: Allocator> StatementPlacement<'heap, A> body: &Body<'heap>, vertex: VertexType, alloc: A, - ) -> StatementCostVec { - let statement_costs = StatementCostVec::new_in(&body.basic_blocks, alloc); + ) -> (StatementCostVec, TerminatorCostVec) { + let statement_costs = StatementCostVec::new_in(&body.basic_blocks, alloc.clone()); + let terminator_costs = TerminatorCostVec::new_in(&body.basic_blocks, alloc); match body.source { Source::GraphReadFilter(_) => {} Source::Ctor(_) | Source::Closure(..) | Source::Thunk(..) | Source::Intrinsic(_) => { - return statement_costs; + return (statement_costs, terminator_costs); } } @@ -740,11 +774,12 @@ impl<'heap, A: Allocator + Clone, S: Allocator> StatementPlacement<'heap, A> cost: self.statement_cost, statement_costs, + terminator_costs, supported: &supported, }; visitor.visit_body(body); - visitor.statement_costs + (visitor.statement_costs, visitor.terminator_costs) } } diff --git a/libs/@local/hashql/mir/src/pass/execution/statement_placement/tests.rs b/libs/@local/hashql/mir/src/pass/execution/statement_placement/tests.rs index cf254b59d4b..2bcf0abff13 100644 --- a/libs/@local/hashql/mir/src/pass/execution/statement_placement/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/statement_placement/tests.rs @@ -15,13 +15,13 @@ use insta::{Settings, assert_snapshot}; use super::StatementPlacement; use crate::{ - body::{Body, local::Local, location::Location, statement::Statement}, + body::{Body, local::Local, location::Location, statement::Statement, terminator::Terminator}, builder::body, context::MirContext, intern::Interner, pass::execution::{ VertexType, - cost::StatementCostVec, + cost::{StatementCostVec, TerminatorCostVec}, statement_placement::{ EmbeddingStatementPlacement, InterpreterStatementPlacement, PostgresStatementPlacement, }, @@ -31,7 +31,8 @@ use crate::{ /// Annotation provider that displays statement costs as trailing comments. struct CostAnnotations<'costs, A: Allocator> { - costs: &'costs StatementCostVec, + statement_costs: &'costs StatementCostVec, + terminator_costs: &'costs TerminatorCostVec, } impl TextFormatAnnotations for CostAnnotations<'_, A> { @@ -39,13 +40,27 @@ impl TextFormatAnnotations for CostAnnotations<'_, A> { = impl Display where Self: 'this; + type TerminatorAnnotation<'this, 'heap> + = impl Display + where + Self: 'this; fn annotate_statement<'heap>( &self, location: Location, - _statement: &Statement<'heap>, + _: &Statement<'heap>, ) -> Option> { - let cost = self.costs.get(location)?; + let cost = self.statement_costs.get(location)?; + + Some(core::fmt::from_fn(move |fmt| write!(fmt, "cost: {cost}"))) + } + + fn annotate_terminator<'heap>( + &self, + location: Location, + _: &Terminator<'heap>, + ) -> Option> { + let cost = self.terminator_costs.of(location.block)?; Some(core::fmt::from_fn(move |fmt| write!(fmt, "cost: {cost}"))) } @@ -58,13 +73,14 @@ pub(crate) fn assert_placement<'heap, A: Allocator>( snapshot_subdir: &str, body: &Body<'heap>, context: &MirContext<'_, 'heap>, - statement_costs: &StatementCostVec, + (statement_costs, terminator_costs): &(StatementCostVec, TerminatorCostVec), ) { let formatter = Formatter::new(context.heap); let type_formatter = TypeFormatter::new(&formatter, context.env, TypeFormatterOptions::terse()); let annotations = CostAnnotations { - costs: statement_costs, + statement_costs, + terminator_costs, }; let mut text_format = TextFormatOptions { @@ -100,13 +116,19 @@ pub(crate) fn run_placement<'heap>( context: &MirContext<'_, 'heap>, placement: &mut impl StatementPlacement<'heap, &'heap Heap>, body: Body<'heap>, -) -> (Body<'heap>, StatementCostVec<&'heap Heap>) { +) -> ( + Body<'heap>, + ( + StatementCostVec<&'heap Heap>, + TerminatorCostVec<&'heap Heap>, + ), +) { let vertex = VertexType::from_local(context.env, &body.local_decls[Local::VERTEX]) .unwrap_or_else(|| unimplemented!("lookup for declared type")); - let statement_costs = placement.statement_placement_in(context, &body, vertex, context.heap); + let costs = placement.statement_placement_in(context, &body, vertex, context.heap); - (body, statement_costs) + (body, costs) } // ============================================================================= @@ -148,11 +170,17 @@ fn non_graph_read_filter_returns_empty() { let mut embedding = EmbeddingStatementPlacement::new_in(Global); let vertex = VertexType::Entity; - let postgres_statement = postgres.statement_placement_in(&context, &body, vertex, &heap); - let interpreter_statement = interpreter.statement_placement_in(&context, &body, vertex, &heap); - let embedding_statement = embedding.statement_placement_in(&context, &body, vertex, &heap); - - assert!(postgres_statement.all_unassigned()); - assert!(interpreter_statement.all_unassigned()); - assert!(embedding_statement.all_unassigned()); + let (postgres_statements, postgres_terminators) = + postgres.statement_placement_in(&context, &body, vertex, &heap); + let (interpreter_statements, interpreter_terminators) = + interpreter.statement_placement_in(&context, &body, vertex, &heap); + let (embedding_statements, embedding_terminators) = + embedding.statement_placement_in(&context, &body, vertex, &heap); + + assert!(postgres_statements.all_unassigned()); + assert!(postgres_terminators.all_unassigned()); + assert!(interpreter_statements.all_unassigned()); + assert!(interpreter_terminators.all_unassigned()); + assert!(embedding_statements.all_unassigned()); + assert!(embedding_terminators.all_unassigned()); } diff --git a/libs/@local/hashql/mir/src/pass/execution/terminator_placement/mod.rs b/libs/@local/hashql/mir/src/pass/execution/terminator_placement/mod.rs index b0d41247077..a260ca21170 100644 --- a/libs/@local/hashql/mir/src/pass/execution/terminator_placement/mod.rs +++ b/libs/@local/hashql/mir/src/pass/execution/terminator_placement/mod.rs @@ -10,7 +10,7 @@ //! # Main Types //! //! - [`TransMatrix`]: Per-edge transition costs indexed by (source, destination) target pairs -//! - [`TerminatorCostVec`]: Collection of transition matrices for all edges in a body +//! - [`TerminatorTransitionCostVec`]: Collection of transition matrices for all edges in a body //! - [`TerminatorPlacement`]: Analysis driver that computes placement for a body //! //! # Transition Rules @@ -251,9 +251,11 @@ forward_ref_op_assign!(impl AddAssign::add_assign for TransMatrix); /// [`Return`]: TerminatorKind::Return /// [`Unreachable`]: TerminatorKind::Unreachable #[derive(Debug)] -pub(crate) struct TerminatorCostVec(BlockPartitionedVec); +pub(crate) struct TerminatorTransitionCostVec( + BlockPartitionedVec, +); -impl TerminatorCostVec { +impl TerminatorTransitionCostVec { /// Creates a cost vector sized for `blocks`, with all transitions initially disallowed. pub(crate) fn new(blocks: &BasicBlocks, alloc: A) -> Self { Self(BlockPartitionedVec::new_in( @@ -273,7 +275,7 @@ impl TerminatorCostVec { } } -impl TerminatorCostVec { +impl TerminatorTransitionCostVec { pub(crate) const fn len(&self) -> usize { self.0.len() } @@ -444,8 +446,8 @@ impl PopulateEdgeMatrix { /// Computes terminator placement for a [`Body`]. /// /// Analyzes control flow edges to determine valid backend transitions and their costs. The -/// resulting [`TerminatorCostVec`] is used by the execution planner alongside statement placement -/// to select optimal execution targets. +/// resulting [`TerminatorTransitionCostVec`] is used by the execution planner alongside statement +/// placement to select optimal execution targets. /// /// # Usage /// @@ -505,7 +507,7 @@ impl TerminatorPlacement { vertex: VertexType, footprint: &BodyFootprint<&'heap Heap>, targets: &BasicBlockSlice, - ) -> TerminatorCostVec { + ) -> TerminatorTransitionCostVec { self.terminator_placement_in(body, vertex, footprint, targets, Global) } @@ -516,8 +518,8 @@ impl TerminatorPlacement { /// execute on (from statement placement), and `footprint` provides size estimates for /// computing transfer costs. /// - /// The returned [`TerminatorCostVec`] can be indexed by block ID to get the transition - /// matrices for that block's successor edges. + /// The returned [`TerminatorTransitionCostVec`] can be indexed by block ID to get the + /// transition matrices for that block's successor edges. pub(crate) fn terminator_placement_in<'heap, A: Allocator + Clone>( &self, body: &Body<'heap>, @@ -525,12 +527,12 @@ impl TerminatorPlacement { footprint: &BodyFootprint<&'heap Heap>, targets: &BasicBlockSlice, alloc: A, - ) -> TerminatorCostVec { + ) -> TerminatorTransitionCostVec { let live_in = self.compute_liveness(body, vertex); let scc = self.compute_scc(body); let switch_cost = backend_switch_cost(); - let mut output = TerminatorCostVec::new(&body.basic_blocks, alloc); + let mut output = TerminatorTransitionCostVec::new(&body.basic_blocks, alloc); let mut required_locals = DenseBitSet::new_empty(body.local_decls.len()); for (block_id, block) in body.basic_blocks.iter_enumerated() { diff --git a/libs/@local/hashql/mir/src/pass/execution/terminator_placement/tests.rs b/libs/@local/hashql/mir/src/pass/execution/terminator_placement/tests.rs index d76a0c5058c..7fcdbae7058 100644 --- a/libs/@local/hashql/mir/src/pass/execution/terminator_placement/tests.rs +++ b/libs/@local/hashql/mir/src/pass/execution/terminator_placement/tests.rs @@ -18,7 +18,7 @@ use hashql_core::{ use hashql_diagnostics::DiagnosticIssues; use insta::{Settings, assert_snapshot}; -use super::{Cost, TerminatorCostVec, TerminatorPlacement, TransMatrix}; +use super::{Cost, TerminatorPlacement, TerminatorTransitionCostVec, TransMatrix}; use crate::{ body::{ Body, @@ -91,7 +91,7 @@ fn assert_snapshot<'heap>( name: &'static str, context: &MirContext<'_, 'heap>, body: &Body<'heap>, - edges: &TerminatorCostVec, + edges: &TerminatorTransitionCostVec, ) { let formatter = Formatter::new(context.heap); let type_formatter = TypeFormatter::new(&formatter, context.env, TypeFormatterOptions::terse()); @@ -123,7 +123,7 @@ fn assert_snapshot<'heap>( } fn format_edge_summary( - edges: &TerminatorCostVec, + edges: &TerminatorTransitionCostVec, ) -> impl Display + '_ { fmt::from_fn(move |fmt| { for block in 0..edges.block_count() { @@ -176,7 +176,7 @@ fn terminator_cost_vec_successor_counts() { } }); - let costs = TerminatorCostVec::new(&body.basic_blocks, &heap); + let costs = TerminatorTransitionCostVec::new(&body.basic_blocks, &heap); assert_eq!(costs.of(BasicBlockId::new(0)).len(), 1); assert_eq!(costs.of(BasicBlockId::new(1)).len(), 3); diff --git a/libs/@local/hashql/mir/src/pretty/text.rs b/libs/@local/hashql/mir/src/pretty/text.rs index 96544f722c8..3b90a785ef4 100644 --- a/libs/@local/hashql/mir/src/pretty/text.rs +++ b/libs/@local/hashql/mir/src/pretty/text.rs @@ -86,6 +86,12 @@ pub trait TextFormatAnnotations { where Self: 'this; + /// The type of annotation displayed after terminators. + type TerminatorAnnotation<'this, 'heap>: Display + = ! + where + Self: 'this; + /// Returns an optional annotation for the given statement at `location`. #[expect(unused_variables, reason = "trait definition")] fn annotate_statement<'heap>( @@ -96,6 +102,15 @@ pub trait TextFormatAnnotations { None } + #[expect(unused_variables, reason = "trait definition")] + fn annotate_terminator<'heap>( + &self, + location: Location, + terminator: &Terminator<'heap>, + ) -> Option> { + None + } + /// Returns an optional annotation for the given local declaration. #[expect(unused_variables, reason = "trait definition")] fn annotate_local_decl<'heap>( @@ -132,6 +147,10 @@ impl TextFormatAnnotations for &mut T { = T::StatementAnnotation<'this, 'heap> where Self: 'this; + type TerminatorAnnotation<'this, 'heap> + = T::TerminatorAnnotation<'this, 'heap> + where + Self: 'this; fn annotate_statement<'heap>( &self, @@ -141,6 +160,14 @@ impl TextFormatAnnotations for &mut T { (**self).annotate_statement(location, statement) } + fn annotate_terminator<'heap>( + &self, + location: Location, + terminator: &Terminator<'heap>, + ) -> Option> { + (**self).annotate_terminator(location, terminator) + } + fn annotate_local_decl<'heap>( &self, local: Local, @@ -503,8 +530,9 @@ where self.newline()?; } + location.statement_index += 1; self.indent(2)?; - self.format_part(&block.terminator)?; + self.format_part((location, &block.terminator))?; self.newline()?; self.indent(1)?; @@ -780,14 +808,31 @@ where } } -impl<'heap, W, S, T, A> FormatPart<&Terminator<'heap>> for TextFormat +impl<'heap, W, S, T, A> FormatPart<(Location, &Terminator<'heap>)> for TextFormat where W: io::Write, S: SourceLookup<'heap>, + A: TextFormatAnnotations, { - fn format_part(&mut self, Terminator { span: _, kind }: &Terminator<'heap>) -> io::Result<()> { + fn format_part( + &mut self, + (location, terminator @ Terminator { span: _, kind }): (Location, &Terminator<'heap>), + ) -> io::Result<()> { self.format_part(TerminatorHead(kind))?; - self.format_part(TerminatorTail(kind)) + self.format_part(TerminatorTail(kind))?; + + let Some(annotation) = self.annotations.annotate_terminator(location, terminator) else { + return Ok(()); + }; + + // We estimate that we never exceed 80 columns, calculate the remaining width, if we don't + // have enough space, we add 4 spaces breathing room. + let remaining_width = 80_usize.checked_sub(self.line_buffer.len()).unwrap_or(4); + self.line_buffer + .resize(self.line_buffer.len() + remaining_width, b' '); + write!(self.line_buffer, "// {annotation}")?; + + Ok(()) } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/only_vectors_projection_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/only_vectors_projection_supported.snap index d27a7c6e3c9..643d044182b 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/only_vectors_projection_supported.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/only_vectors_projection_supported.snap @@ -8,6 +8,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { bb0(): { %2 = %1.encodings.vectors // cost: 4 - return %2 + return %2 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/storage_statements_zero_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/storage_statements_zero_cost.snap index 44a6d7cc273..2590b51e724 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/storage_statements_zero_cost.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/embedding/storage_statements_zero_cost.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { %2 = %1.encodings.vectors // cost: 4 drop %2 // cost: 0 - return %2 + return %2 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/all_statements_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/all_statements_supported.snap index b9de54f4bb3..319e93a6956 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/all_statements_supported.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/all_statements_supported.snap @@ -28,6 +28,6 @@ fn {graph::read::filter@4294967040}(%0: (Integer,), %1: Entity) -> Boolean { %11 = input LOAD param // cost: 8 %12 = true // cost: 8 - return %12 + return %12 // cost: 8 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/eq_opaque_entity_uuid.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/eq_opaque_entity_uuid.snap index 7ff2c1f6db5..f780244f5c4 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/eq_opaque_entity_uuid.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/eq_opaque_entity_uuid.snap @@ -1,6 +1,5 @@ --- source: libs/@local/hashql/mir/src/pass/execution/statement_placement/tests.rs -assertion_line: 92 expression: output --- fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { @@ -13,6 +12,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { %4 = opaque(::graph::types::knowledge::entity::EntityUuid, %3) // cost: 8 %2 = %1.id.entity_id.entity_uuid == %4 // cost: 8 - return %2 + return %2 // cost: 8 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/non_traversal_unaffected_by_costs.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/non_traversal_unaffected_by_costs.snap index 7fd988b687d..8da5e122828 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/non_traversal_unaffected_by_costs.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/non_traversal_unaffected_by_costs.snap @@ -12,6 +12,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { %3 = 42 // cost: 8 %4 = %3 > 10 // cost: 8 - return %4 + return %4 // cost: 8 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/storage_statements_zero_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/storage_statements_zero_cost.snap index 0086a9abef8..a1034f365bb 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/storage_statements_zero_cost.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/storage_statements_zero_cost.snap @@ -16,6 +16,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Integer { drop %2 // cost: 0 drop %3 // cost: 0 - return %4 + return %4 // cost: 8 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_multiple_paths_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_multiple_paths_cost.snap index 664620f8ba2..3dd4c5d7609 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_multiple_paths_cost.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_multiple_paths_cost.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> (?, Boolean) { %2 = %1.properties // cost: 8 %3 = (%1.properties, %1.metadata.archived) // cost: 8 - return %3 + return %3 // cost: 8 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_single_path_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_single_path_cost.snap index d51e6ceae90..6bb51b8e1d5 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_single_path_cost.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_single_path_cost.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { %2 = %1.metadata.archived // cost: 8 %3 = !%2 // cost: 8 - return %3 + return %3 // cost: 8 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_swallowing_reduces_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_swallowing_reduces_cost.snap index c43e55a04bd..01aa44dc79a 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_swallowing_reduces_cost.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/interpret/traversal_swallowing_reduces_cost.snap @@ -8,6 +8,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> (?, ?) { bb0(): { %2 = (%1.metadata.record_id.entity_id.web_id, %1.metadata.record_id) // cost: 8 - return %2 + return %2 // cost: 8 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_closure_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_closure_rejected.snap index 1f22a24cd1b..2b845759e6d 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_closure_rejected.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_closure_rejected.snap @@ -12,6 +12,6 @@ fn {graph::read::filter@4294967040}(%0: (Integer,), %1: Entity) -> Boolean { %3 = closure(({def@42} as FnPtr), %2) %4 = true // cost: 4 - return %4 + return %4 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_tuple_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_tuple_supported.snap index ca2fefd4973..c1e7c357804 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_tuple_supported.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/aggregate_tuple_supported.snap @@ -12,6 +12,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { %3 = (a: 10, b: 20) // cost: 4 %4 = true // cost: 4 - return %4 + return %4 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/binary_unary_ops_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/binary_unary_ops_supported.snap index 3bff03d7cf2..838217f2e16 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/binary_unary_ops_supported.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/binary_unary_ops_supported.snap @@ -16,6 +16,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { %5 = %4 > 15 // cost: 4 %6 = !%5 // cost: 4 - return %6 + return %6 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/diamond_must_analysis.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/diamond_must_analysis.snap index 8db2da9a59a..632cdf6cecd 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/diamond_must_analysis.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/diamond_must_analysis.snap @@ -16,7 +16,7 @@ fn {graph::read::filter@4294967040}(%0: (Integer,), %1: Entity) -> Boolean { %4 = (%3) // cost: 4 %5 = closure(({def@77} as FnPtr), %4) - switchInt(%2) -> [0: bb2(), 1: bb1()] + switchInt(%2) -> [0: bb2(), 1: bb1()] // cost: 4 } bb1(): { diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_column.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_column.snap index a18c0b1b663..e2ecac8d9a2 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_column.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_column.snap @@ -4,6 +4,6 @@ expression: output --- fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { bb0(): { - return %1.metadata.archived + return %1.metadata.archived // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_jsonb.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_jsonb.snap index d49c49e03cd..8c9dc6e3e0e 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_jsonb.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/entity_projection_jsonb.snap @@ -8,6 +8,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> ? { bb0(): { %2 = %1.properties // cost: 4 - return %2 + return %2 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_closure_field_rejected_other_accepted.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_closure_field_rejected_other_accepted.snap index c60090e59c5..d759926c979 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_closure_field_rejected_other_accepted.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_closure_field_rejected_other_accepted.snap @@ -12,6 +12,6 @@ fn {graph::read::filter@4294967040}(%0: (Integer, (Integer) -> Integer), %1: Ent %3 = %0.1 %4 = %2 == 42 // cost: 4 - return %4 + return %4 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_non_string_key_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_non_string_key_rejected.snap index bcceac667eb..a45b4dff025 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_non_string_key_rejected.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_non_string_key_rejected.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (Dict,), %1: Entity) - %2 = %0.0 %3 = true // cost: 4 - return %3 + return %3 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_opaque_string_key_accepted.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_opaque_string_key_accepted.snap index 2344681e1dd..d6506127e72 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_opaque_string_key_accepted.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_opaque_string_key_accepted.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (Dict,), %1: Entity) -> %2 = %0.0 // cost: 4 %3 = true // cost: 4 - return %3 + return %3 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_string_key_accepted.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_string_key_accepted.snap index 7f68af1af40..2f61ca5cdc2 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_string_key_accepted.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_dict_string_key_accepted.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (Dict,), %1: Entity) -> %2 = %0.0 // cost: 4 %3 = true // cost: 4 - return %3 + return %3 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_with_closure_type_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_with_closure_type_rejected.snap index b8a99b82d85..584c758b71b 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_with_closure_type_rejected.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_with_closure_type_rejected.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (Integer, (Integer) -> Integer), %1: Ent %2 = %0.0 // cost: 4 %3 = %2 == 42 // cost: 4 - return %3 + return %3 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_without_closure_accepted.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_without_closure_accepted.snap index 250a2822c48..bfd25139c57 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_without_closure_accepted.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/env_without_closure_accepted.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (Integer, Boolean), %1: Entity) -> Boole %2 = %0.0 // cost: 4 %3 = %2 == 42 // cost: 4 - return %3 + return %3 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/eq_place_vs_constant_accepted.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/eq_place_vs_constant_accepted.snap index d2479908f24..a154d08e48f 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/eq_place_vs_constant_accepted.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/eq_place_vs_constant_accepted.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (Dict,), %1: Entity) -> %2 = %0.0 // cost: 4 %3 = %2 == 42 // cost: 4 - return %3 + return %3 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/eq_same_type_accepted.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/eq_same_type_accepted.snap index 1cf9b7ba0b5..bbe4060e792 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/eq_same_type_accepted.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/eq_same_type_accepted.snap @@ -12,6 +12,6 @@ fn {graph::read::filter@4294967040}(%0: (Integer, Integer), %1: Entity) -> Boole %3 = %0.1 // cost: 4 %4 = %2 == %3 // cost: 4 - return %4 + return %4 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/fnptr_constant_rejected.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/fnptr_constant_rejected.snap index 2d8920c74a4..2295838d4cf 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/fnptr_constant_rejected.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/fnptr_constant_rejected.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { %2 = ({def@99} as FnPtr) %3 = true // cost: 4 - return %3 + return %3 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/input_supported.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/input_supported.snap index 6494ccad9ef..1fc775ccfc9 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/input_supported.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/input_supported.snap @@ -10,6 +10,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Boolean { %2 = input LOAD threshold // cost: 4 %3 = %2 > 100 // cost: 4 - return %3 + return %3 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/serialization_unsafe_edge_propagates.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/serialization_unsafe_edge_propagates.snap index 1c479bb79cc..b2076f709c6 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/serialization_unsafe_edge_propagates.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/serialization_unsafe_edge_propagates.snap @@ -19,6 +19,6 @@ fn {graph::read::filter@4294967040}(%0: (Uuid | String, Integer), %1: Entity) -> %4 = %2 %5 = %3 > 42 // cost: 4 - return %5 + return %5 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/serialization_unsafe_statement_no_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/serialization_unsafe_statement_no_cost.snap index faab5b071f2..3b0be3ed775 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/serialization_unsafe_statement_no_cost.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/serialization_unsafe_statement_no_cost.snap @@ -14,6 +14,6 @@ fn {graph::read::filter@4294967040}(%0: (Uuid | String, Integer), %1: Entity) -> %4 = %2 %5 = %3 > 42 // cost: 4 - return %5 + return %5 // cost: 4 } } diff --git a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/storage_statements_zero_cost.snap b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/storage_statements_zero_cost.snap index 4773d151019..d40228522e5 100644 --- a/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/storage_statements_zero_cost.snap +++ b/libs/@local/hashql/mir/tests/ui/pass/execution/statement_placement/postgres/storage_statements_zero_cost.snap @@ -16,6 +16,6 @@ fn {graph::read::filter@4294967040}(%0: (), %1: Entity) -> Integer { drop %2 // cost: 0 drop %3 // cost: 0 - return %4 + return %4 // cost: 4 } }