From 5ae528f6c8de3e7c4c9bfda503909f919a8cc0b4 Mon Sep 17 00:00:00 2001 From: xuanyili Date: Sun, 7 Jun 2026 18:19:35 +0000 Subject: [PATCH] fix: track untracked memory allocations in spill-capable operators Add memory reservation tracking for previously unaccounted allocations in NestedLoopJoin, GroupValuesRows emit, External Sort merge, and SpillManager IPC buffers to close the gap between actual allocator usage and MemoryPool-reported usage, preventing OOM kills. Key changes: SpillManager centralized accounting: - SpillManager::new() requires caller-owned MemoryReservation - Write-side: append_batch tracks IPC buffer overhead (grow/shrink) - Read-side: read_spill_as_stream accepts optional reservation with capacity pre-reservation before spawn_buffered; unbuffered reads pre-reserve one batch slot with try_grow before each blocking decode - ReservationGuard RAII utility (try_grow_guard/grow_guard) NestedLoopJoin probe accounting: - probe_reservation field for tracking probe-phase intermediates - RAII guards for Cartesian indices, filter take, output take - push_output_batch with best-effort try_grow and error-path cleanup - 4 reservation balance tests (inner/left/right/full joins) Aggregation emit accounting: - transient_reservation tracks emitted arrays until yielded downstream - estimated_emit_size() on GroupValues trait (6 implementations) - 2 reservation balance tests (primitive + GroupValuesRows path) External sort: - Sort merge single-file path transfers merge reservation via take() - Grow-to-at-least semantics for transferred reservation reuse - 1 sort spill reservation balance test HEADROOM_FACTOR reduced from 8.0 to 6.0. Co-Authored-By: Claude Opus 4.6 (1M context) --- datafusion/execution/src/memory_pool/mod.rs | 113 +++++ datafusion/physical-plan/benches/spill_io.rs | 28 +- .../src/aggregates/group_values/mod.rs | 8 + .../group_values/multi_group_by/mod.rs | 13 + .../src/aggregates/group_values/row.rs | 15 + .../group_values/single_group_by/boolean.rs | 8 + .../group_values/single_group_by/bytes.rs | 15 + .../single_group_by/bytes_view.rs | 11 + .../group_values/single_group_by/primitive.rs | 8 + .../physical-plan/src/aggregates/row_hash.rs | 161 ++++++- .../src/joins/nested_loop_join.rs | 313 +++++++++++- .../src/joins/sort_merge_join/exec.rs | 1 + .../src/joins/sort_merge_join/tests.rs | 4 +- .../physical-plan/src/repartition/mod.rs | 3 + .../src/sorts/multi_level_merge.rs | 10 +- datafusion/physical-plan/src/sorts/sort.rs | 56 +++ .../src/spill/in_progress_spill_file.rs | 40 +- datafusion/physical-plan/src/spill/mod.rs | 195 ++++++-- .../src/spill/replayable_spill_input.rs | 4 +- .../physical-plan/src/spill/spill_manager.rs | 446 +++++++++++++++++- .../physical-plan/src/spill/spill_pool.rs | 21 +- .../sqllogictest/src/accounting_pool.rs | 5 +- 22 files changed, 1370 insertions(+), 108 deletions(-) diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index e50f72632b3f2..696bf3acf8483 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -521,6 +521,68 @@ impl MemoryReservation { pub fn take(&mut self) -> MemoryReservation { self.split(self.size.load(atomic::Ordering::Relaxed)) } + + /// Attempts to grow the reservation by `capacity` bytes and returns + /// a [`ReservationGuard`] that will automatically shrink the reservation + /// when dropped. + /// + /// This is useful for tracking transient allocations that are freed + /// when a scope exits (including error paths via `?`). + /// + /// Call [`ReservationGuard::release`] to prevent the automatic shrink + /// when ownership of the allocated memory is transferred elsewhere. + pub fn try_grow_guard(&self, capacity: usize) -> Result> { + self.try_grow(capacity)?; + Ok(ReservationGuard { + reservation: self, + size: capacity, + }) + } + + /// Grows the reservation by `capacity` bytes (infallible) and returns + /// a [`ReservationGuard`] that will automatically shrink on drop. + /// + /// Use only for named exceptions where the allocation is required for + /// correctness/progress and cannot be deferred (e.g., join probe-side + /// index arrays that must be allocated to produce results). + pub fn grow_guard(&self, capacity: usize) -> ReservationGuard<'_> { + self.grow(capacity); + ReservationGuard { + reservation: self, + size: capacity, + } + } +} + +/// RAII guard that automatically shrinks a [`MemoryReservation`] on drop. +/// +/// Created by [`MemoryReservation::try_grow_guard`]. When the guard is +/// dropped, it shrinks the reservation by the guarded size. Call +/// [`Self::release`] to transfer ownership and prevent the automatic shrink. +pub struct ReservationGuard<'a> { + reservation: &'a MemoryReservation, + size: usize, +} + +impl ReservationGuard<'_> { + /// Prevents the automatic shrink on drop, effectively transferring + /// ownership of the reserved bytes to a longer-lived reservation. + pub fn release(mut self) { + self.size = 0; + } + + /// Returns the guarded size in bytes. + pub fn size(&self) -> usize { + self.size + } +} + +impl Drop for ReservationGuard<'_> { + fn drop(&mut self) { + if self.size > 0 { + self.reservation.shrink(self.size); + } + } } impl Drop for MemoryReservation { @@ -670,4 +732,55 @@ mod tests { assert_eq!(r1.size(), 0); assert_eq!(pool.reserved(), 80); } + + #[test] + fn test_try_grow_guard_auto_shrinks() { + let pool = Arc::new(GreedyMemoryPool::new(1000)) as _; + let r1 = MemoryConsumer::new("test").register(&pool); + + { + let _guard = r1.try_grow_guard(100).unwrap(); + assert_eq!(r1.size(), 100); + assert_eq!(pool.reserved(), 100); + } + assert_eq!(r1.size(), 0); + assert_eq!(pool.reserved(), 0); + } + + #[test] + fn test_try_grow_guard_release_prevents_shrink() { + let pool = Arc::new(GreedyMemoryPool::new(1000)) as _; + let r1 = MemoryConsumer::new("test").register(&pool); + + { + let guard = r1.try_grow_guard(100).unwrap(); + guard.release(); + } + assert_eq!(r1.size(), 100); + assert_eq!(pool.reserved(), 100); + } + + #[test] + fn test_grow_guard_auto_shrinks() { + let pool = Arc::new(GreedyMemoryPool::new(1000)) as _; + let r1 = MemoryConsumer::new("test").register(&pool); + + { + let _guard = r1.grow_guard(200); + assert_eq!(r1.size(), 200); + } + assert_eq!(r1.size(), 0); + assert_eq!(pool.reserved(), 0); + } + + #[test] + fn test_try_grow_guard_error_path() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let r1 = MemoryConsumer::new("test").register(&pool); + + let result = r1.try_grow_guard(100); + assert!(result.is_err()); + assert_eq!(r1.size(), 0); + assert_eq!(pool.reserved(), 0); + } } diff --git a/datafusion/physical-plan/benches/spill_io.rs b/datafusion/physical-plan/benches/spill_io.rs index fac2547a131b4..e34c24ee32fdd 100644 --- a/datafusion/physical-plan/benches/spill_io.rs +++ b/datafusion/physical-plan/benches/spill_io.rs @@ -27,6 +27,7 @@ use criterion::{ use datafusion_common::config::SpillCompression; use datafusion_common::human_readable_size; use datafusion_common::instant::Instant; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_plan::SpillManager; use datafusion_physical_plan::common::collect; @@ -90,7 +91,14 @@ fn bench_spill_io(c: &mut Criterion) { Field::new("c2", DataType::Date32, true), Field::new("c3", DataType::Decimal128(11, 2), true), ])); - let spill_manager = SpillManager::new(env, metrics, schema); + let spill_manager = SpillManager::new( + Arc::clone(&env), + metrics, + Arc::clone(&schema), + MemoryConsumer::new("bench") + .with_can_spill(true) + .register(&env.memory_pool), + ); let mut group = c.benchmark_group("spill_io"); let rt = Runtime::new().unwrap(); @@ -116,7 +124,7 @@ fn bench_spill_io(c: &mut Criterion) { |spill_file| { rt.block_on(async { let stream = spill_manager - .read_spill_as_stream(spill_file, None) + .read_spill_as_stream(spill_file, None, None) .unwrap(); let _ = collect(stream).await.unwrap(); }) @@ -504,9 +512,15 @@ fn benchmark_spill_batches_for_all_codec( for &compression in compressions { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = - SpillManager::new(Arc::clone(&env), metrics.clone(), Arc::clone(&schema)) - .with_compression_type(compression); + let spill_manager = SpillManager::new( + Arc::clone(&env), + metrics.clone(), + Arc::clone(&schema), + MemoryConsumer::new("bench") + .with_can_spill(true) + .register(&env.memory_pool), + ) + .with_compression_type(compression); let bench_id = BenchmarkId::new(batch_label, compression.to_string()); group.bench_with_input(bench_id, &spill_manager, |b, spill_manager| { @@ -522,7 +536,7 @@ fn benchmark_spill_batches_for_all_codec( .unwrap() .unwrap(); let stream = spill_manager - .read_spill_as_stream(spill_file, None) + .read_spill_as_stream(spill_file, None, None) .unwrap(); let _ = collect(stream).await.unwrap(); }) @@ -557,7 +571,7 @@ fn benchmark_spill_batches_for_all_codec( let start = Instant::now(); rt.block_on(async { let stream = spill_manager - .read_spill_as_stream(spill_file, None) + .read_spill_as_stream(spill_file, None, None) .unwrap(); let _ = collect(stream).await.unwrap(); }); diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ee253e5d7afdd..b9fdaf4b330fd 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -111,6 +111,14 @@ pub trait GroupValues: Send { /// Emits the group values fn emit(&mut self, emit_to: EmitTo) -> Result>; + /// Returns an estimate of the memory that will be allocated by [`Self::emit`] + /// for the decode/output buffers. + /// + /// This is used by the aggregation operator to pre-reserve memory before + /// calling `emit()`, ensuring the memory pool is aware of transient + /// decode buffer allocations. + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize; + /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, num_rows: usize); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index ee2d300d9bff8..47fa7efc36673 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -1088,6 +1088,19 @@ impl GroupValues for GroupValuesColumn { self.group_values[0].len() } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let total = self.len(); + let emit_count = match emit_to { + EmitTo::All => total, + EmitTo::First(n) => (*n).min(total), + }; + if total == 0 { + return 0; + } + let group_values_size: usize = self.group_values.iter().map(|v| v.size()).sum(); + group_values_size * emit_count / total + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { let mut output = match emit_to { EmitTo::All => { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index a3bd31f76c233..5463c88c06580 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -195,6 +195,21 @@ impl GroupValues for GroupValuesRows { .unwrap_or(0) } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let total_rows = self.len(); + if total_rows == 0 { + return 0; + } + let rows_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0); + match emit_to { + EmitTo::All => rows_size, + EmitTo::First(n) => { + let n = (*n).min(total_rows); + rows_size * n / total_rows + } + } + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { let mut group_values = self .group_values diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index e993c0c53d199..c71e318741ea2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -97,6 +97,14 @@ impl GroupValues for GroupValuesBoolean { + self.null_group.is_some() as usize } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let emit_count = match emit_to { + EmitTo::All => self.len(), + EmitTo::First(n) => (*n).min(self.len()), + }; + emit_count.div_ceil(8) * 2 + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { let len = self.len(); let mut builder = BooleanBufferBuilder::new(len); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index b881a51b25474..1f19e6e4f5b75 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -84,6 +84,21 @@ impl GroupValues for GroupValuesBytes { self.num_groups } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let total = self.len(); + let emit_count = match emit_to { + EmitTo::All => total, + EmitTo::First(n) => (*n).min(total), + }; + if total == 0 { + return 0; + } + // Offsets + data bytes (proportional) + null bitmap + (emit_count + 1) * std::mem::size_of::() + + self.size() * emit_count / total + + emit_count.div_ceil(8) + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index 7a56f7c52c11a..04a69cf270bb0 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -86,6 +86,17 @@ impl GroupValues for GroupValuesBytesView { self.num_groups } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let emit_count = match emit_to { + EmitTo::All => self.len(), + EmitTo::First(n) => (*n).min(self.len()), + }; + // Views (16 bytes each) + out-of-line data estimate + null bitmap + emit_count * 16 + + self.size() * emit_count / self.len().max(1) + + emit_count.div_ceil(8) + } + fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 07535cfdaa6de..96e168c23f31e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -165,6 +165,14 @@ where self.values.len() } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let emit_count = match emit_to { + EmitTo::All => self.len(), + EmitTo::First(n) => (*n).min(self.len()), + }; + emit_count * std::mem::size_of::() + emit_count.div_ceil(8) + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { fn build_primitive( values: Vec, diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index c3f73976c721a..3f3dc100a727a 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -447,6 +447,12 @@ pub(crate) struct GroupedHashAggregateStream { /// The memory reservation for this grouping reservation: MemoryReservation, + /// Tracks memory for emitted output arrays that are live between + /// emit() and poll_next yielding them downstream. Ensures the pool + /// knows about in-flight batches that have left internal state but + /// haven't been consumed yet. + transient_reservation: MemoryReservation, + /// The behavior to trigger when out of memory occurs oom_mode: OutOfMemoryMode, @@ -601,6 +607,7 @@ impl GroupedHashAggregateStream { // to ensure fair application of back pressure amongst the memory consumers. .with_can_spill(oom_mode != OutOfMemoryMode::ReportError) .register(context.memory_pool()); + let transient_reservation = reservation.new_empty(); timer.done(); let exec_state = ExecutionState::ReadingInput; @@ -609,6 +616,7 @@ impl GroupedHashAggregateStream { context.runtime_env(), metrics::SpillMetrics::new(&agg.metrics, partition), Arc::clone(&spill_schema), + reservation.new_empty(), ) .with_compression_type(context.session_config().spill_compression()); @@ -682,6 +690,7 @@ impl GroupedHashAggregateStream { filter_expressions, group_by: agg_group_by, reservation, + transient_reservation, oom_mode, group_values, current_group_indices: Default::default(), @@ -862,6 +871,8 @@ impl Stream for GroupedHashAggregateStream { let output_batch; let size = self.batch_size; (self.exec_state, output_batch) = if batch.num_rows() <= size { + // Entire batch yielded — release transient reservation + self.transient_reservation.free(); ( if self.input_done { ExecutionState::Done @@ -1116,6 +1127,16 @@ impl GroupedHashAggregateStream { return Ok(None); } + // RAII guard pre-reserves memory for decode buffers allocated + // inside group_values.emit(). Guard automatically releases on + // drop (including error paths via ?). + let emit_estimate = self.group_values.estimated_emit_size(&emit_to); + let emit_guard = if emit_estimate > 0 { + Some(self.reservation.grow_guard(emit_estimate)) + } else { + None + }; + let timer = self.group_by_metrics.emitting_time.timer(); let mut output = self.group_values.emit(emit_to)?; if let EmitTo::First(n) = emit_to { @@ -1127,19 +1148,24 @@ impl GroupedHashAggregateStream { if self.mode.output_mode() == AggregateOutputMode::Final && !spilling { output.push(acc.evaluate(emit_to)?) } else { - // Output partial state: either because we're in a non-final mode, - // or because we're spilling and will merge/re-evaluate later. output.extend(acc.state(emit_to)?) } } drop(timer); - // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is - // over the target memory size after emission, we can emit again rather than returning Err. + // Release the pre-reserved emit estimate (decode buffers consumed) + drop(emit_guard); + + // emit reduces the memory usage. Ignore Err from update_memory_reservation. let _ = self.update_memory_reservation(); let batch = RecordBatch::try_new(schema, output)?; debug_assert!(batch.num_rows() > 0); + // Track emitted batch in transient_reservation until yielded downstream. + // This ensures the pool knows about in-flight output arrays. + let batch_size = batch.get_array_memory_size(); + self.transient_reservation.grow(batch_size); + Ok(Some(batch)) } @@ -1844,4 +1870,131 @@ mod tests { "ratio == threshold should not trigger skip (boundary is exclusive)" ); } + + #[tokio::test] + async fn test_aggregation_reservation_balanced() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + let num_rows = 500; + let group_ids: Vec = (0..num_rows).map(|i| i % 50).collect(); + let values: Vec = vec![1; num_rows as usize]; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(group_ids)), + Arc::new(Int64Array::from(values)), + ], + )?; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(8192, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )]; + + let exec = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )?; + + let stream = aggregate_exec.execute(0, task_ctx)?; + let batches: Vec = crate::common::collect(stream).await?; + assert!(!batches.is_empty(), "Should produce output"); + + drop(batches); + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after aggregation completes" + ); + Ok(()) + } + + #[tokio::test] + async fn test_aggregation_reservation_balanced_group_values_rows() -> Result<()> { + use arrow::datatypes::TimeUnit; + let schema = Arc::new(Schema::new(vec![ + Field::new("dur_key", DataType::Duration(TimeUnit::Microsecond), false), + Field::new("str_key", DataType::Utf8, false), + Field::new("value_col", DataType::Int64, false), + ])); + + let num_rows = 400; + let dur_keys: Vec = (0..num_rows as i64).map(|i| (i % 40) * 1000).collect(); + let str_keys: Vec = (0..num_rows) + .map(|i| format!("group_{:04}", i % 10)) + .collect(); + let values: Vec = vec![1; num_rows]; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(DurationMicrosecondArray::from(dur_keys)), + Arc::new(StringArray::from(str_keys)), + Arc::new(Int64Array::from(values)), + ], + )?; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(16384, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + + let group_expr = vec![ + (col("dur_key", &schema)?, "dur_key".to_string()), + (col("str_key", &schema)?, "str_key".to_string()), + ]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )]; + + let exec = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )?; + + let stream = aggregate_exec.execute(0, task_ctx)?; + let batches: Vec = crate::common::collect(stream).await?; + assert!(!batches.is_empty(), "Should produce output"); + + drop(batches); + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after GroupValuesRows aggregation completes" + ); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index a18ec0cbe4504..2d2698ff58843 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -61,7 +61,7 @@ use arrow::record_batch::RecordBatch; use arrow_schema::DataType; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ - JoinSide, NullEquality, Result, ScalarValue, Statistics, arrow_err, + DataFusionError, JoinSide, NullEquality, Result, ScalarValue, Statistics, arrow_err, assert_eq_or_internal_err, internal_datafusion_err, internal_err, project_schema, unwrap_or_internal_err, }; @@ -673,6 +673,11 @@ impl ExecutionPlan for NestedLoopJoinExec { SpillState::Disabled }; + let probe_reservation = + MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]")) + .with_can_spill(true) + .register(context.memory_pool()); + Ok(Box::pin(NestedLoopJoinStream::new( self.schema(), self.filter.clone(), @@ -683,6 +688,7 @@ impl ExecutionPlan for NestedLoopJoinExec { metrics, batch_size, spill_state, + probe_reservation, ))) } @@ -1037,6 +1043,10 @@ pub(crate) struct NestedLoopJoinStream { /// Output buffer holds the join result to output. It will emit eagerly when /// the threshold is reached. output_buffer: Box, + /// Best-effort tracked size of output buffer data via try_grow. + /// Only the successfully reserved portion is tracked; if try_grow fails + /// (pool full), the batch is still pushed but not counted. + output_buffer_reserved: usize, /// See comments in [`NLJState::Done`] for its purpose handled_empty_output: bool, @@ -1064,6 +1074,10 @@ pub(crate) struct NestedLoopJoinStream { /// Memory-limited spill fallback state. See [`SpillState`] for details. spill_state: SpillState, + + /// Memory reservation for transient probe-side allocations + /// (Cartesian product indices, take() intermediates, output buffering). + probe_reservation: MemoryReservation, } pub(crate) struct NestedLoopJoinMetrics { @@ -1315,6 +1329,7 @@ impl NestedLoopJoinStream { metrics: NestedLoopJoinMetrics, batch_size: usize, spill_state: SpillState, + probe_reservation: MemoryReservation, ) -> Self { Self { output_schema: Arc::clone(&schema), @@ -1326,6 +1341,7 @@ impl NestedLoopJoinStream { metrics, buffered_left_data: None, output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)), + output_buffer_reserved: 0, batch_size, current_right_batch: None, current_right_batch_matched: None, @@ -1337,6 +1353,7 @@ impl NestedLoopJoinStream { handled_empty_output: false, should_track_unmatched_right: need_produce_right_in_final(join_type), spill_state, + probe_reservation, } } @@ -1346,12 +1363,9 @@ impl NestedLoopJoinStream { } /// Check if we can fall back to memory-limited mode on this error. - fn can_fallback_to_spill(&self, error: &datafusion_common::DataFusionError) -> bool { + fn can_fallback_to_spill(&self, error: &DataFusionError) -> bool { matches!(self.spill_state, SpillState::Pending { .. }) - && matches!( - error.find_root(), - datafusion_common::DataFusionError::ResourcesExhausted(_) - ) + && matches!(error.find_root(), DataFusionError::ResourcesExhausted(_)) } /// Switch from the standard OnceFut path to memory-limited mode. @@ -1389,6 +1403,9 @@ impl NestedLoopJoinStream { ctx.runtime_env(), spill_metrics, Arc::clone(&schema), + MemoryConsumer::new("NestedLoopJoinLeftSpill") + .with_can_spill(true) + .register(ctx.memory_pool()), ) .with_compression_type(ctx.session_config().spill_compression()); @@ -1438,6 +1455,9 @@ impl NestedLoopJoinStream { context.runtime_env(), self.metrics.spill_metrics.clone(), right_schema, + MemoryConsumer::new("NestedLoopJoinRightSpill") + .with_can_spill(true) + .register(context.memory_pool()), ) .with_compression_type(context.session_config().spill_compression()); @@ -1528,10 +1548,11 @@ impl NestedLoopJoinStream { if active.left_stream.is_none() { match active.left_spill_fut.get_shared(cx) { Poll::Ready(Ok(spill_data)) => { - match spill_data - .spill_manager - .read_spill_as_stream(spill_data.spill_file.clone(), None) - { + match spill_data.spill_manager.read_spill_as_stream( + spill_data.spill_file.clone(), + None, + None, + ) { Ok(stream) => { active.left_schema = Some(Arc::clone(&spill_data.schema)); active.left_stream = Some(stream); @@ -1570,8 +1591,10 @@ impl NestedLoopJoinStream { if !can_grow && !active.pending_batches.is_empty() { // Memory limit reached and we already have data. - // Push this batch into pending (it's already in memory) - // and stop buffering for this chunk. + // This batch is already in memory and can't be returned + // to the source. It's the documented make-progress + // exception — one over-budget batch is accepted without + // reservation to avoid blocking downstream operators. active.pending_batches.push(batch); self.left_exhausted = false; self.left_buffered_in_one_pass = false; @@ -1614,6 +1637,18 @@ impl NestedLoopJoinStream { return ControlFlow::Continue(()); } + // RAII guard for concat dual-live overhead: both pending batches + // and concat result coexist during concat_batches. + // Named exception (infallible grow): concat is required to create + // JoinLeftData for probe processing, and pending batches are already + // in memory — the total footprint doesn't actually increase. + let pending_size: usize = active + .pending_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum(); + let _concat_guard = active.reservation.grow_guard(pending_size); + let merged_batch = match concat_batches( active .left_schema @@ -1808,13 +1843,13 @@ impl NestedLoopJoinStream { "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present" ); match self.process_right_unmatched() { - Ok(Some(batch)) => match self.output_buffer.push_batch(batch) { + Ok(Some(batch)) => match self.push_output_batch(batch) { Ok(()) => { debug_assert!(self.current_right_batch.is_none()); self.state = NLJState::FetchingRight; ControlFlow::Continue(()) } - Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), }, Ok(None) => { debug_assert!(self.current_right_batch.is_none()); @@ -1960,9 +1995,9 @@ impl NestedLoopJoinStream { self.join_type, JoinSide::Right, ) { - Ok(Some(batch)) => match self.output_buffer.push_batch(batch) { + Ok(Some(batch)) => match self.push_output_batch(batch) { Ok(()) => ControlFlow::Continue(()), - Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), }, Ok(None) => ControlFlow::Continue(()), Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), @@ -2066,7 +2101,7 @@ impl NestedLoopJoinStream { )?; if let Some(batch) = joined_batch { - self.output_buffer.push_batch(batch)?; + self.push_output_batch(batch)?; } self.left_probe_idx += l_row_count; @@ -2079,7 +2114,7 @@ impl NestedLoopJoinStream { self.process_single_left_row_join(&left_data, &right_batch, l_idx)?; if let Some(batch) = joined_batch { - self.output_buffer.push_batch(batch)?; + self.push_output_batch(batch)?; } // ==== Prepare for the next iteration ==== @@ -2111,6 +2146,14 @@ impl NestedLoopJoinStream { let right_rows = right_batch.num_rows(); let total_rows = l_row_count * right_rows; + // Named exception (infallible grow): Cartesian product index arrays + // are required for join correctness. The probe phase cannot be deferred + // or spilled once started — the left chunk is already loaded. These + // arrays are function-scoped and freed when _indices_guard drops. + // Bounded by: 2 * l_row_count * right_rows * sizeof(u32). + let indices_size = 2 * total_rows * size_of::(); + let _indices_guard = self.probe_reservation.grow_guard(indices_size); + // Build index arrays for cartesian product: left_range X right_batch let left_indices: UInt32Array = UInt32Array::from_iter_values((0..l_row_count).flat_map(|i| { @@ -2129,9 +2172,31 @@ impl NestedLoopJoinStream { // Evaluate the join filter (if any) over an intermediate batch built // using the filter's own schema/column indices. let bitmap_combined = if let Some(filter) = &self.join_filter { - // Build the intermediate batch for filter evaluation + // Pre-estimate filter take sizes before allocation. + // Guard declared here (outside the if/else) so it covers the + // intermediate_batch through filter evaluation at line ~2213. + let filter_estimate = if !filter.schema.fields().is_empty() { + let mut est = 0usize; + for ci in filter.column_indices() { + let col = if ci.side == JoinSide::Left { + left_data.batch().column(ci.index) + } else { + right_batch.column(ci.index) + }; + let col_len = col.len().max(1); + est += total_rows * col.get_array_memory_size() / col_len; + } + est + } else { + 0 + }; + // Named exception (infallible grow): filter intermediate batch is + // built from take() on source columns for filter evaluation. The + // probe phase cannot be deferred. Function-scoped, freed on drop. + // Bounded by: sum(total_rows * col_size / col_len) per filter column. + let _filter_guard = self.probe_reservation.grow_guard(filter_estimate); + let intermediate_batch = if filter.schema.fields().is_empty() { - // Constant predicate (e.g., TRUE/FALSE). Use an empty schema with row_count create_record_batch_with_empty_schema( Arc::new((*filter.schema).clone()), total_rows, @@ -2149,7 +2214,6 @@ impl NestedLoopJoinStream { }; filter_columns.push(array); } - RecordBatch::try_new(Arc::new((*filter.schema).clone()), filter_columns)? }; @@ -2159,7 +2223,6 @@ impl NestedLoopJoinStream { .into_array(intermediate_batch.num_rows())?; let filter_arr = as_boolean_array(&filter_result)?; - // Combine with null bitmap to get a unified mask boolean_mask_from_filter(filter_arr) } else { // No filter: all pairs match @@ -2258,6 +2321,24 @@ impl NestedLoopJoinStream { )?)); } + // Pre-estimate output take sizes before allocation. + let mut output_estimate = 0usize; + for column_index in &self.column_indices { + let col = if column_index.side == JoinSide::Left { + left_data.batch().column(column_index.index) + } else { + right_batch.column(column_index.index) + }; + let col_len = col.len().max(1); + output_estimate += total_rows * col.get_array_memory_size() / col_len; + } + // Named exception (infallible grow): output batch columns are built + // from take() on source arrays. The probe phase cannot be deferred. + // Function-scoped, freed on drop. The output batch is then moved to + // push_output_batch which tracks it separately via best-effort try_grow. + // Bounded by: sum(total_rows * col_size / col_len) per output column. + let _output_guard = self.probe_reservation.grow_guard(output_estimate); + let mut out_columns: Vec> = Vec::with_capacity(self.output_schema.fields().len()); for column_index in &self.column_indices { @@ -2370,7 +2451,7 @@ impl NestedLoopJoinStream { if let Some(batch) = self.process_left_unmatched_range(left_data, start_idx, end_idx)? { - self.output_buffer.push_batch(batch)?; + self.push_output_batch(batch)?; } // ==== Prepare for the next iteration ==== @@ -2476,13 +2557,43 @@ impl NestedLoopJoinStream { .ok_or_else(|| internal_datafusion_err!("LeftData should be available")) } + /// Push a batch into the output buffer with best-effort memory tracking. + /// + /// Named exception (best-effort): the output batch is already allocated + /// by Arrow take() before reaching this point. Rejecting it would discard + /// valid join results without freeing the underlying allocation. Uses + /// try_grow so the pool counter reflects tracked output when capacity + /// allows, but accepts the batch regardless to avoid aborting mid-probe. + fn push_output_batch(&mut self, batch: RecordBatch) -> Result<()> { + let size = batch.get_array_memory_size(); + let reserved = self.probe_reservation.try_grow(size).is_ok(); + if reserved { + self.output_buffer_reserved += size; + } + if let Err(e) = self.output_buffer.push_batch(batch) { + if reserved { + self.probe_reservation.shrink(size); + self.output_buffer_reserved -= size; + } + return Err(DataFusionError::ArrowError(Box::new(e), None)); + } + Ok(()) + } + /// Flush the `output_buffer` if there are batches ready to output /// None if no result batch ready. fn maybe_flush_ready_batch(&mut self) -> Option>>> { if self.output_buffer.has_completed_batch() && let Some(batch) = self.output_buffer.next_completed_batch() { - // Update output rows for selectivity metric + // Release tracked output buffer reservation + let batch_size = batch.get_array_memory_size(); + let shrink = batch_size.min(self.output_buffer_reserved); + if shrink > 0 { + self.probe_reservation.shrink(shrink); + self.output_buffer_reserved -= shrink; + } + let output_rows = batch.num_rows(); self.metrics.selectivity.add_part(output_rows); @@ -3931,4 +4042,158 @@ pub(crate) mod tests { ")); Ok(()) } + + /// Verify that NLJ probe-side reservation returns to zero after + /// join completion for both spilling (inner) and non-spilling paths. + #[tokio::test] + async fn test_nlj_reservation_balanced_after_inner_join() -> Result<()> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(16); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(cfg) + .with_runtime(runtime), + ); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, metrics) = + join_collect(left, right, &JoinType::Inner, Some(filter), task_ctx).await?; + + assert!(!batches.is_empty(), "Should produce output"); + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to verify spill-path accounting" + ); + + // After join completes and all streams/operators are dropped, + // pool should have zero reserved (all grow/shrink balanced). + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after NLJ inner join completes" + ); + Ok(()) + } + + /// Verify reservation balance for RIGHT JOIN which exercises the + /// global right bitmap accumulation and unmatched-row emission paths. + #[tokio::test] + async fn test_nlj_reservation_balanced_after_right_join() -> Result<()> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(16); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(cfg) + .with_runtime(runtime), + ); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, metrics) = + join_collect(left, right, &JoinType::Right, Some(filter), task_ctx).await?; + + assert!(!batches.is_empty(), "Should produce output"); + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to verify spill-path accounting" + ); + + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after NLJ right join completes" + ); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_reservation_balanced_after_left_join() -> Result<()> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(16); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(cfg) + .with_runtime(runtime), + ); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, metrics) = + join_collect(left, right, &JoinType::Left, Some(filter), task_ctx).await?; + + assert!(!batches.is_empty(), "Should produce output"); + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to verify spill-path accounting" + ); + + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after NLJ left join completes" + ); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_reservation_balanced_after_full_join() -> Result<()> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(16); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(cfg) + .with_runtime(runtime), + ); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, metrics) = + join_collect(left, right, &JoinType::Full, Some(filter), task_ctx).await?; + + assert!(!batches.is_empty(), "Should produce output"); + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to verify spill-path accounting" + ); + + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after NLJ full join completes" + ); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index a86cb647e4bff..c078db7689c96 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -510,6 +510,7 @@ impl ExecutionPlan for SortMergeJoinExec { context.runtime_env(), SpillMetrics::new(&self.metrics, partition), buffered.schema(), + reservation.new_empty(), ) .with_compression_type(context.session_config().spill_compression()); diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index c4377b3189ff7..de3664b19b423 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -4021,7 +4021,7 @@ fn test_stream_resources( let ctx = TaskContext::default(); let runtime_env = ctx.runtime_env(); let reservation = MemoryConsumer::new("test").register(ctx.memory_pool()); - let spill_manager = SpillManager::new( + let spill_manager = SpillManager::new_default( Arc::clone(&runtime_env), SpillMetrics::new(metrics, 0), inner_schema, @@ -4676,7 +4676,7 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> { let metrics = ExecutionPlanMetricsSet::new(); let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); - let spill_manager = SpillManager::new( + let spill_manager = SpillManager::new_default( Arc::clone(&runtime), SpillMetrics::new(&metrics, 0), Arc::clone(&right_schema), diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 3d30dd82762b1..fb167e9ce43df 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1239,6 +1239,9 @@ impl ExecutionPlan for RepartitionExec { Arc::clone(&context.runtime_env()), spill_metrics, input.schema(), + MemoryConsumer::new("RepartitionSpill") + .with_can_spill(true) + .register(context.memory_pool()), ); // Get existing ordering to use for merging diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 8985e1d8c70ee..acd55e6cbb8e7 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -236,9 +236,12 @@ impl MultiLevelMergeBuilder { (1, 0) => { let spill_file = self.sorted_spill_files.remove(0); - // Not reserving any memory for this disk as we are not holding it in memory - self.spill_manager - .read_spill_as_stream(spill_file.file, None) + let read_reservation = self.reservation.take(); + self.spill_manager.read_spill_as_stream( + spill_file.file, + Some(spill_file.max_record_batch_memory), + Some(read_reservation), + ) } // Only in memory streams, so merge them all in a single pass @@ -292,6 +295,7 @@ impl MultiLevelMergeBuilder { .read_spill_as_stream( spill.file, Some(spill.max_record_batch_memory), + None, )?; sorted_streams.push(stream); } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 929ff4f7dfc85..11a9222c32e29 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -292,6 +292,7 @@ impl ExternalSorter { Arc::clone(&runtime), metrics.spill_metrics.clone(), Arc::clone(&schema), + reservation.new_empty(), ) .with_compression_type(spill_compression); @@ -2980,4 +2981,59 @@ mod tests { assert_eq!(desc.self_filters()[0].len(), 1); Ok(()) } + + #[tokio::test] + async fn test_sort_spill_reservation_balanced() -> Result<()> { + let session_config = SessionConfig::new(); + let sort_spill_reservation_bytes = session_config + .options() + .execution + .sort_spill_reservation_bytes; + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(sort_spill_reservation_bytes + 12288, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + let partitions = 100; + let input = test::scan_partitioned(partitions); + let schema = input.schema(); + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }] + .into(), + Arc::new(CoalescePartitionsExec::new(input)), + )); + + let result = collect( + Arc::clone(&sort_exec) as Arc, + Arc::clone(&task_ctx), + ) + .await?; + assert!(!result.is_empty(), "Should produce output"); + + let metrics = sort_exec.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to verify spill-path accounting" + ); + + drop(result); + drop(sort_exec); + drop(task_ctx); + + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after sort with spilling completes" + ); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs index e0548bd5bf860..77745e5458193 100644 --- a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs +++ b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use arrow::array::RecordBatch; use datafusion_common::exec_datafusion_err; use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::MemoryReservation; use super::{ IPCStreamWriter, gc_view_arrays, @@ -38,20 +39,31 @@ pub struct InProgressSpillFile { writer: Option, /// Lazily initialized in-progress file, it will be moved out when the `finish` method is invoked in_progress_file: Option, + /// Memory reservation for tracking IPC write buffer overhead. + /// `append_batch` reserves memory before writing and releases it + /// after the write completes. Freed automatically on Drop. + reservation: MemoryReservation, } impl InProgressSpillFile { pub fn new( spill_writer: Arc, in_progress_file: RefCountedTempFile, + reservation: MemoryReservation, ) -> Self { Self { spill_writer, in_progress_file: Some(in_progress_file), writer: None, + reservation, } } + #[cfg(test)] + pub(crate) fn reservation_size(&self) -> usize { + self.reservation.size() + } + /// Appends a `RecordBatch` to the spill file, initializing the writer if necessary. /// /// Before writing, performs GC on StringView/BinaryView arrays to compact backing @@ -71,13 +83,24 @@ impl InProgressSpillFile { )); } + // Named exception (infallible grow): spill writes MUST complete + // even under memory pressure. Operators free their main reservation + // before spilling; failing here would prevent memory recovery. + // The grow/shrink is balanced within each call — the reservation + // returns to its pre-call size after the write completes or errors. + let write_overhead = batch.get_array_memory_size(); + self.reservation.grow(write_overhead); + + let result = self.append_batch_inner(batch); + + self.reservation.shrink(write_overhead); + result + } + + fn append_batch_inner(&mut self, batch: &RecordBatch) -> Result { let gc_batch = gc_view_arrays(batch)?; if self.writer.is_none() { - // Use the SpillManager's declared schema rather than the batch's schema. - // Individual batches may have different schemas (e.g., different nullability) - // when they come from different branches of a UnionExec. The SpillManager's - // schema represents the canonical schema that all batches should conform to. let schema = self.spill_writer.schema(); if let Some(in_progress_file) = &mut self.in_progress_file { self.writer = Some(IPCStreamWriter::new( @@ -86,10 +109,8 @@ impl InProgressSpillFile { self.spill_writer.compression, )?); - // Update metrics self.spill_writer.metrics.spill_file_count.add(1); - // Update initial size (schema/header) in_progress_file.update_disk_usage()?; let initial_size = in_progress_file.current_disk_usage(); self.spill_writer @@ -111,9 +132,10 @@ impl InProgressSpillFile { .spilled_bytes .add((post_size - pre_size) as usize); } else { - unreachable!() // Already checked inside current function + unreachable!() } } + gc_batch.get_sliced_size() } @@ -180,7 +202,7 @@ mod tests { let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); let metrics_set = ExecutionPlanMetricsSet::new(); let spill_metrics = SpillMetrics::new(&metrics_set, 0); - let spill_manager = Arc::new(SpillManager::new( + let spill_manager = Arc::new(SpillManager::new_default( runtime, spill_metrics, Arc::clone(&nullable_schema), @@ -210,7 +232,7 @@ mod tests { let spill_file = in_progress.finish()?.unwrap(); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; // Stream schema should be nullable assert_eq!(stream.schema(), nullable_schema); diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 3c95a1da5b33c..9e7184830ac47 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -54,8 +54,10 @@ use datafusion_common::config::SpillCompression; use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::RecordBatchStream; +use datafusion_execution::SendableRecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; -use futures::{FutureExt as _, Stream}; +use datafusion_execution::memory_pool::MemoryReservation; +use futures::{FutureExt as _, Stream, StreamExt as _}; use log::debug; /// Stream that reads spill files from disk where each batch is read in a spawned blocking task @@ -69,10 +71,13 @@ struct SpillReaderStream { schema: SchemaRef, state: SpillReaderStreamState, /// Maximum memory size observed among spilling sorted record batches. - /// This is used for validation purposes during reading each RecordBatch from spill. - /// For context on why this value is recorded and validated, - /// see `physical_plan/sort/multi_level_merge.rs`. max_record_batch_memory: Option, + /// Optional reservation for tracking decoded batch memory. + /// When provided, grows when a batch is decoded and shrinks + /// when the next batch is read (previous batch consumed by caller). + reservation: Option, + /// Size of the last decoded batch, used for shrinking on next poll. + last_batch_size: usize, } // Small margin allowed to accommodate slight memory accounting variation @@ -102,20 +107,72 @@ impl SpillReaderStream { schema: SchemaRef, spill_file: RefCountedTempFile, max_record_batch_memory: Option, + reservation: Option, ) -> Self { Self { schema, state: SpillReaderStreamState::Uninitialized(spill_file), max_record_batch_memory, + reservation, + last_batch_size: 0, } } + /// Pre-reserve one decoded-batch slot before launching a blocking read. + /// Shrinks the previous batch's reservation first, then grows for the + /// next expected batch. Returns Err(ResourcesExhausted) if the pool + /// cannot accommodate the next batch. + fn pre_reserve_next_batch(&mut self) -> Result<()> { + if let (Some(res), Some(max_mem)) = + (&self.reservation, self.max_record_batch_memory) + { + if self.last_batch_size > 0 { + res.shrink(self.last_batch_size); + self.last_batch_size = 0; + } + res.try_grow(max_mem)?; + self.last_batch_size = max_mem; + } + Ok(()) + } + + /// After a decoded batch arrives, adjust the reservation from the + /// pre-reserved estimate to the actual batch size. + fn adjust_reservation_to_actual(&mut self, batch: &RecordBatch) { + let Some(res) = &self.reservation else { + return; + }; + let actual_size = get_record_batch_memory_size(batch); + + if self.max_record_batch_memory.is_some() { + // Pre-reserved path: adjust from estimate to actual + if actual_size > self.last_batch_size { + res.grow(actual_size - self.last_batch_size); + } else if actual_size < self.last_batch_size { + res.shrink(self.last_batch_size - actual_size); + } + } else { + // Fallback: no max_record_batch_memory, post-decode accounting + if self.last_batch_size > 0 { + res.shrink(self.last_batch_size); + } + res.grow(actual_size); + } + self.last_batch_size = actual_size; + } + fn poll_next_inner( &mut self, cx: &mut Context<'_>, ) -> Poll>> { match &mut self.state { SpillReaderStreamState::Uninitialized(_) => { + // Pre-reserve before the first blocking read + if let Err(e) = self.pre_reserve_next_batch() { + self.state = SpillReaderStreamState::Done; + return Poll::Ready(Some(Err(e))); + } + // Temporarily replace with `Done` to be able to pass the file to the task. let SpillReaderStreamState::Uninitialized(spill_file) = std::mem::replace(&mut self.state, SpillReaderStreamState::Done) @@ -126,15 +183,10 @@ impl SpillReaderStream { let expected_schema = Arc::clone(&self.schema); let task = SpawnedTask::spawn_blocking(move || { let file = BufReader::new(File::open(spill_file.path())?); - // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications - // with validated schemas and buffers. Skip redundant validation during read - // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written. let mut reader = unsafe { StreamReader::try_new(file, None)?.with_skip_validation(true) }; - // Validate the schema read from Arrow IPC file is the same as the - // schema of the current `SpillManager` let actual_schema = reader.schema(); if actual_schema != expected_schema { @@ -146,8 +198,6 @@ impl SpillReaderStream { ); } - // TODO: Same-schema reads from a different SpillManager still pass today. - // Add a SpillManager UID to IPC metadata and validate it here as well. let next_batch = reader.next().transpose()?; Ok((reader, next_batch)) @@ -155,8 +205,6 @@ impl SpillReaderStream { self.state = SpillReaderStreamState::ReadInProgress(task); - // Poll again immediately so the inner task is polled and the waker is - // registered. self.poll_next_inner(cx) } @@ -185,12 +233,21 @@ impl SpillReaderStream { ); } } + + self.adjust_reservation_to_actual(&batch); + self.state = SpillReaderStreamState::Waiting(reader); Poll::Ready(Some(Ok(batch))) } None => { - // Stream is done + // Stream done — release any remaining reservation + if let Some(res) = &self.reservation + && self.last_batch_size > 0 + { + res.shrink(self.last_batch_size); + self.last_batch_size = 0; + } self.state = SpillReaderStreamState::Done; Poll::Ready(None) @@ -198,6 +255,13 @@ impl SpillReaderStream { } } Err(err) => { + // Release pre-reservation on error + if let Some(res) = &self.reservation + && self.last_batch_size > 0 + { + res.shrink(self.last_batch_size); + self.last_batch_size = 0; + } self.state = SpillReaderStreamState::Done; Poll::Ready(Some(Err(err))) @@ -206,6 +270,12 @@ impl SpillReaderStream { } SpillReaderStreamState::Waiting(_) => { + // Pre-reserve before the next blocking read + if let Err(e) = self.pre_reserve_next_batch() { + self.state = SpillReaderStreamState::Done; + return Poll::Ready(Some(Err(e))); + } + // Temporarily replace with `Done` to be able to pass the file to the task. let SpillReaderStreamState::Waiting(mut reader) = std::mem::replace(&mut self.state, SpillReaderStreamState::Done) @@ -221,8 +291,6 @@ impl SpillReaderStream { self.state = SpillReaderStreamState::ReadInProgress(task); - // Poll again immediately so the inner task is polled and the waker is - // registered. self.poll_next_inner(cx) } @@ -245,6 +313,58 @@ impl RecordBatchStream for SpillReaderStream { } } +/// Wraps a buffered read stream with a capacity-level `MemoryReservation`. +/// +/// For buffered reads (via `spawn_buffered`), multiple decoded batches can +/// be simultaneously live in the channel. Per-batch tracking in the inner +/// stream would under-count because shrinking happens before the consumer +/// actually drops the previous batch. Instead, the caller pre-reserves +/// `max_batch_memory * buffer_capacity` and this wrapper holds that +/// reservation alive until the stream is fully consumed or dropped. +pub(crate) struct ReadStreamWithReservation { + stream: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl ReadStreamWithReservation { + pub(crate) fn new( + stream: SendableRecordBatchStream, + reservation: MemoryReservation, + ) -> Self { + Self { + stream, + reservation, + } + } +} + +impl Stream for ReadStreamWithReservation { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.stream.poll_next_unpin(cx) { + Poll::Ready(Some(Err(err))) => { + self.reservation.free(); + Poll::Ready(Some(Err(err))) + } + Poll::Ready(None) => { + self.reservation.free(); + Poll::Ready(None) + } + other => other, + } + } +} + +impl RecordBatchStream for ReadStreamWithReservation { + fn schema(&self) -> SchemaRef { + self.stream.schema() + } +} + /// Spill the `RecordBatch` to disk as smaller batches /// split by `batch_size_rows` #[deprecated( @@ -560,7 +680,6 @@ mod tests { use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; use datafusion_execution::runtime_env::RuntimeEnv; - use futures::StreamExt as _; #[tokio::test] async fn test_batch_spill_and_read() -> Result<()> { @@ -582,7 +701,7 @@ mod tests { // Construct SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let spill_manager = SpillManager::new_default(env, metrics, Arc::clone(&schema)); let spill_file = spill_manager .spill_record_batch_and_finish(&[batch1, batch2], "Test")? @@ -591,7 +710,7 @@ mod tests { let spilled_rows = spill_manager.metrics.spilled_rows.value(); assert_eq!(spilled_rows, num_rows); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), schema); let batches = collect(stream).await?; @@ -646,7 +765,8 @@ mod tests { // Construct SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&dict_schema)); + let spill_manager = + SpillManager::new_default(env, metrics, Arc::clone(&dict_schema)); let num_rows = batch1.num_rows() + batch2.num_rows(); let spill_file = spill_manager @@ -655,7 +775,7 @@ mod tests { let spilled_rows = spill_manager.metrics.spilled_rows.value(); assert_eq!(spilled_rows, num_rows); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), dict_schema); let batches = collect(stream).await?; assert_eq!(batches.len(), 2); @@ -674,7 +794,7 @@ mod tests { let schema = batch1.schema(); let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let spill_manager = SpillManager::new_default(env, metrics, Arc::clone(&schema)); let row_batches: Vec = (0..batch1.num_rows()).map(|i| batch1.slice(i, 1)).collect(); @@ -687,7 +807,7 @@ mod tests { assert!(spill_file.path().exists()); assert!(max_batch_mem > 0); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), schema); let batches = collect(stream).await?; @@ -722,7 +842,7 @@ mod tests { let spilled_rows = spill_manager.metrics.spilled_rows.value(); assert_eq!(spilled_rows, num_rows); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), schema); let batches = collect(stream).await?; @@ -744,16 +864,16 @@ mod tests { let uncompressed_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let lz4_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let zstd_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let uncompressed_spill_manager = SpillManager::new( + let uncompressed_spill_manager = SpillManager::new_default( Arc::clone(&env), uncompressed_metrics, Arc::clone(&schema), ); let lz4_spill_manager = - SpillManager::new(Arc::clone(&env), lz4_metrics, Arc::clone(&schema)) + SpillManager::new_default(Arc::clone(&env), lz4_metrics, Arc::clone(&schema)) .with_compression_type(SpillCompression::Lz4Frame); let zstd_spill_manager = - SpillManager::new(env, zstd_metrics, Arc::clone(&schema)) + SpillManager::new_default(env, zstd_metrics, Arc::clone(&schema)) .with_compression_type(SpillCompression::Zstd); let uncompressed_spill_file = uncompressed_spill_manager .spill_record_batch_and_finish(&batches, "Test")? @@ -814,7 +934,7 @@ mod tests { Field::new("b", DataType::Utf8, false), ])); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let spill_manager = SpillManager::new_default(env, metrics, Arc::clone(&schema)); let batch = RecordBatch::try_new( schema, @@ -872,7 +992,7 @@ mod tests { ])); let spill_manager = - Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema))); + Arc::new(SpillManager::new_default(env, metrics, Arc::clone(&schema))); let mut in_progress_file = spill_manager.create_in_progress_file("Test")?; let batch1 = RecordBatch::try_new( @@ -920,7 +1040,7 @@ mod tests { ])); let spill_manager = - Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema))); + Arc::new(SpillManager::new_default(env, metrics, Arc::clone(&schema))); // Test write empty batch with interface `InProgressSpillFile` and `append_batch()` let mut in_progress_file = spill_manager.create_in_progress_file("Test")?; @@ -968,7 +1088,8 @@ mod tests { // Construct SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let spill_manager = + SpillManager::new_default(env, metrics, Arc::clone(&schema)); let batches: [_; 10] = std::array::from_fn(|_| batch.clone()); let spill_file_1 = spill_manager @@ -979,9 +1100,9 @@ mod tests { .unwrap(); let mut stream_1 = - spill_manager.read_spill_as_stream(spill_file_1, None)?; + spill_manager.read_spill_as_stream(spill_file_1, None, None)?; let mut stream_2 = - spill_manager.read_spill_as_stream(spill_file_2, None)?; + spill_manager.read_spill_as_stream(spill_file_2, None, None)?; stream_1.next().await; stream_2.next().await; @@ -1012,7 +1133,7 @@ mod tests { Field::new("b", DataType::Utf8, false), ])); - let spill_manager = Arc::new(SpillManager::new( + let spill_manager = Arc::new(SpillManager::new_default( Arc::clone(&env), metrics.clone(), Arc::clone(&schema), @@ -1305,7 +1426,7 @@ mod tests { let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, schema); + let spill_manager = SpillManager::new_default(env, metrics, schema); let mut in_progress_file = spill_manager.create_in_progress_file("Test GC")?; @@ -1478,7 +1599,7 @@ mod tests { // 3. Spill to disk using SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, schema); + let spill_manager = SpillManager::new_default(env, metrics, schema); let spill_file = spill_manager .spill_record_batch_and_finish(&[sliced_batch], "TestGC")? .unwrap(); @@ -1522,7 +1643,7 @@ mod tests { // 3. Spill to disk using SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, schema); + let spill_manager = SpillManager::new_default(env, metrics, schema); let spill_file = spill_manager .spill_record_batch_and_finish(&[sliced_batch], "TestGCBinary")? .unwrap(); diff --git a/datafusion/physical-plan/src/spill/replayable_spill_input.rs b/datafusion/physical-plan/src/spill/replayable_spill_input.rs index fea998d268c59..c4ab15eae6893 100644 --- a/datafusion/physical-plan/src/spill/replayable_spill_input.rs +++ b/datafusion/physical-plan/src/spill/replayable_spill_input.rs @@ -225,7 +225,7 @@ impl ReplayableSpillStream { spill_file: Option, ) -> Result { let inner = if let Some(file) = spill_file.as_ref() { - spill_manager.read_spill_as_stream(file.clone(), None)? + spill_manager.read_spill_as_stream(file.clone(), None, None)? } else { Box::pin(EmptyRecordBatchStream::new(Arc::clone(&schema))) }; @@ -344,7 +344,7 @@ mod tests { let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); let metrics_set = ExecutionPlanMetricsSet::new(); let spill_metrics = SpillMetrics::new(&metrics_set, 0); - Ok(SpillManager::new(runtime, spill_metrics, schema)) + Ok(SpillManager::new_default(runtime, spill_metrics, schema)) } fn build_batch(schema: SchemaRef, values: Vec) -> Result { diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 365a9f977eace..9a2f8707b399a 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,7 +17,10 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. -use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; +use super::{ + ReadStreamWithReservation, SpillReaderStream, + in_progress_spill_file::InProgressSpillFile, +}; use crate::coop::cooperative; use crate::{common::spawn_buffered, metrics::SpillMetrics}; use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray}; @@ -26,6 +29,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, config::SpillCompression}; use datafusion_execution::SendableRecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::runtime_env::RuntimeEnv; use std::borrow::Borrow; use std::sync::Arc; @@ -36,7 +40,7 @@ use std::sync::Arc; /// /// Note: The caller (external operators such as `SortExec`) is responsible for interpreting the spilled files. /// For example, all records within the same spill file are ordered according to a specific order. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SpillManager { env: Arc, pub(crate) metrics: SpillMetrics, @@ -45,19 +49,57 @@ pub struct SpillManager { batch_read_buffer_capacity: usize, /// general-purpose compression options pub(crate) compression: SpillCompression, + /// Owned reservation split from the operator's reservation. + /// Per-file write reservations are split from this. + reservation: MemoryReservation, +} + +impl Clone for SpillManager { + fn clone(&self) -> Self { + Self { + env: Arc::clone(&self.env), + metrics: self.metrics.clone(), + schema: Arc::clone(&self.schema), + batch_read_buffer_capacity: self.batch_read_buffer_capacity, + compression: self.compression, + reservation: self.reservation.new_empty(), + } + } } impl SpillManager { - pub fn new(env: Arc, metrics: SpillMetrics, schema: SchemaRef) -> Self { + pub fn new( + env: Arc, + metrics: SpillMetrics, + schema: SchemaRef, + reservation: MemoryReservation, + ) -> Self { Self { env, metrics, schema, batch_read_buffer_capacity: 2, compression: SpillCompression::default(), + reservation, } } + /// Convenience constructor for tests that creates an independent + /// reservation. Production code must use `new()` with an operator-split + /// reservation. + #[cfg(test)] + pub(crate) fn new_default( + env: Arc, + metrics: SpillMetrics, + schema: SchemaRef, + ) -> Self { + use datafusion_execution::memory_pool::MemoryConsumer; + let reservation = MemoryConsumer::new("SpillManager") + .with_can_spill(true) + .register(&env.memory_pool); + Self::new(env, metrics, schema, reservation) + } + pub fn with_batch_read_buffer_capacity( mut self, batch_read_buffer_capacity: usize, @@ -76,15 +118,19 @@ impl SpillManager { &self.schema } - /// Creates a temporary file for in-progress operations, returning an error - /// message if file creation fails. The file can be used to append batches - /// incrementally and then finish the file when done. + /// Creates a temporary file for in-progress operations with automatic + /// IPC write buffer accounting via the memory pool. pub fn create_in_progress_file( &self, request_msg: &str, ) -> Result { let temp_file = self.env.disk_manager.create_tmp_file(request_msg)?; - Ok(InProgressSpillFile::new(Arc::new(self.clone()), temp_file)) + let reservation = self.reservation.new_empty(); + Ok(InProgressSpillFile::new( + Arc::new(self.clone()), + temp_file, + reservation, + )) } /// Spill input `batches` into a single file in a atomic operation. If it is @@ -176,30 +222,68 @@ impl SpillManager { /// /// That path uses the maximum spilled batch size to conservatively estimate /// the merge degree when merging multiple sorted runs. + /// + /// # Arg `reservation` + /// + /// Optional caller-owned reservation for tracking decoded batch memory. + /// When provided along with `max_record_batch_memory`, pre-reserves + /// `max_record_batch_memory * buffer_capacity` to account for all + /// batches that can be simultaneously live in the `spawn_buffered` + /// channel. The reservation is freed when the stream ends or is + /// dropped. + /// + /// Callers that already track decoded batches via their own reservation + /// (e.g., NLJ build-side, sort merge multi-file path) should pass + /// `None` to avoid double-counting. pub fn read_spill_as_stream( &self, spill_file_path: RefCountedTempFile, max_record_batch_memory: Option, + reservation: Option, ) -> Result { + // Reserve capacity BEFORE spawning the producer task. + // spawn_buffered immediately starts a producer that decodes batches, + // so the reservation must be in place before any allocation happens. + if let (Some(res), Some(max_mem)) = (&reservation, max_record_batch_memory) { + let capacity_bytes = max_mem.saturating_mul(self.batch_read_buffer_capacity); + let deficit = capacity_bytes.saturating_sub(res.size()); + if deficit > 0 { + res.try_grow(deficit)?; + } + } + let stream = Box::pin(cooperative(SpillReaderStream::new( Arc::clone(&self.schema), spill_file_path, max_record_batch_memory, + None, // per-batch tracking not used for buffered reads ))); - Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) + let buffered = spawn_buffered(stream, self.batch_read_buffer_capacity); + + match reservation { + Some(res) => Ok(Box::pin(ReadStreamWithReservation::new(buffered, res))), + None => Ok(buffered), + } } /// Same as `read_spill_as_stream`, but without buffering. + /// + /// When `reservation` is provided, per-batch tracking is used: + /// each decoded batch grows the reservation and the previous batch's + /// reservation is shrunk. This is correct for unbuffered reads where + /// only one batch is live at a time. pub fn read_spill_as_stream_unbuffered( &self, spill_file_path: RefCountedTempFile, max_record_batch_memory: Option, + reservation: Option, ) -> Result { Ok(Box::pin(cooperative(SpillReaderStream::new( Arc::clone(&self.schema), spill_file_path, max_record_batch_memory, + reservation, )))) } } @@ -265,7 +349,7 @@ mod tests { schema: Arc, ) -> SpillManager { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - SpillManager::new(env, metrics, schema) + SpillManager::new_default(env, metrics, schema) } fn build_writer_batch(schema: Arc) -> Result { @@ -307,7 +391,7 @@ mod tests { // Same-schema reads through a different SpillManager currently pass // because only schema compatibility is validated. This is not a // supported usage pattern. - let stream = reader.read_spill_as_stream(spill_file, None)?; + let stream = reader.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), reader_schema); let batches = collect(stream).await?; @@ -341,7 +425,7 @@ mod tests { )? .unwrap(); - let stream = reader.read_spill_as_stream(spill_file, None)?; + let stream = reader.read_spill_as_stream(spill_file, None, None)?; let err = collect(stream) .await .expect_err("schema mismatch should fail fast"); @@ -402,4 +486,344 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_spill_write_reservation_balanced() -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(schema)?; + + let mut in_progress = spill_manager.create_in_progress_file("test_balanced")?; + in_progress.append_batch(&batch)?; + in_progress.append_batch(&batch)?; + + // After each append the grow/shrink is balanced within the call + let reserved_during = env.memory_pool.reserved(); + + let _file = in_progress.finish()?; + drop(in_progress); + + // After drop, reservation should be fully released + assert_eq!( + env.memory_pool.reserved(), + 0, + "Pool should have zero reserved after InProgressSpillFile is dropped, got {reserved_during} during" + ); + Ok(()) + } + + #[tokio::test] + async fn test_spill_read_reservation_tracked() -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(schema)?; + + // Write a spill file + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "test_read")? + .expect("should have spill file"); + + // Read back — reservation should track decoded batches + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; + let batches = collect(stream).await?; + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 3); + + // After stream is consumed and dropped, pool reserved should be zero + assert_eq!( + env.memory_pool.reserved(), + 0, + "Pool should have zero reserved after read stream is consumed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_spill_write_balanced_under_exhausted_pool() -> Result<()> { + use datafusion_execution::memory_pool::GreedyMemoryPool; + + let pool: Arc = + Arc::new(GreedyMemoryPool::new(64)); + let env = Arc::new( + datafusion_execution::runtime_env::RuntimeEnvBuilder::new() + .with_memory_pool(Arc::clone(&pool)) + .build()?, + ); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + + // Exhaust most of the pool + let blocker = datafusion_execution::memory_pool::MemoryConsumer::new("blocker") + .register(&pool); + blocker.grow(60); + + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(schema)?; + + // append_batch uses infallible grow — it must succeed even when + // the pool is nearly exhausted (spill-write make-progress exception). + let mut in_progress = spill_manager.create_in_progress_file("exhausted")?; + in_progress.append_batch(&batch)?; + + // Reservation should be balanced within the call + let reserved_after_append = in_progress.reservation_size(); + assert_eq!( + reserved_after_append, 0, + "Write reservation should be zero after append (grow/shrink balanced)" + ); + + let _file = in_progress.finish()?; + drop(in_progress); + + // Release the blocker + blocker.free(); + + assert_eq!( + pool.reserved(), + 0, + "Pool should be zero after all reservations freed" + ); + Ok(()) + } + + #[tokio::test] + async fn test_unbuffered_read_reservation_tracks_batches() -> Result<()> { + use datafusion_execution::memory_pool::MemoryConsumer; + + let env = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(Arc::clone(&schema))?; + + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch.clone(), batch], "test_read")? + .expect("should have spill file"); + + let read_reservation = + MemoryConsumer::new("read_test").register(&env.memory_pool); + + let mut stream = spill_manager.read_spill_as_stream_unbuffered( + spill_file, + None, + Some(read_reservation), + )?; + + use futures::StreamExt; + // Read first batch — reservation should grow + let b1 = stream.next().await.unwrap()?; + assert_eq!(b1.num_rows(), 3); + let reserved_after_b1 = env.memory_pool.reserved(); + assert!( + reserved_after_b1 > 0, + "Reservation should be non-zero after reading first batch" + ); + + // Read second batch — previous shrinks, current grows + let b2 = stream.next().await.unwrap()?; + assert_eq!(b2.num_rows(), 3); + let reserved_after_b2 = env.memory_pool.reserved(); + assert!( + reserved_after_b2 > 0, + "Reservation should be non-zero after reading second batch" + ); + + // Stream ends — remaining reservation freed + assert!(stream.next().await.is_none()); + assert_eq!( + env.memory_pool.reserved(), + 0, + "Pool should be zero after unbuffered read stream consumed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_buffered_read_reservation_prereserves_capacity() -> Result<()> { + use datafusion_execution::memory_pool::MemoryConsumer; + + let env = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(Arc::clone(&schema))?; + let batch_mem = get_record_batch_memory_size(&batch); + + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "test_read")? + .expect("should have spill file"); + + let read_reservation = + MemoryConsumer::new("read_test").register(&env.memory_pool); + + let buffer_capacity = spill_manager.batch_read_buffer_capacity; + let expected_capacity = batch_mem * buffer_capacity; + + let stream = spill_manager.read_spill_as_stream( + spill_file, + Some(batch_mem), + Some(read_reservation), + )?; + + // Capacity should be pre-reserved immediately + assert_eq!( + env.memory_pool.reserved(), + expected_capacity, + "Pool should have capacity pre-reserved for buffered read" + ); + + // Consume the stream + let batches = collect(stream).await?; + assert_eq!(batches.len(), 1); + + // After stream consumed and dropped, reservation freed + assert_eq!( + env.memory_pool.reserved(), + 0, + "Pool should be zero after buffered read stream consumed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_buffered_read_reuses_transferred_reservation() -> Result<()> { + use datafusion_execution::memory_pool::{GreedyMemoryPool, MemoryConsumer}; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let batch = build_writer_batch(Arc::clone(&schema))?; + let batch_mem = get_record_batch_memory_size(&batch); + + // Pool just large enough for the read buffer capacity + let buffer_capacity = 2usize; + let capacity_bytes = batch_mem * buffer_capacity; + let pool_size = capacity_bytes + batch_mem; // extra for write overhead + let pool: Arc = + Arc::new(GreedyMemoryPool::new(pool_size)); + let env = Arc::new( + datafusion_execution::runtime_env::RuntimeEnvBuilder::new() + .with_memory_pool(Arc::clone(&pool)) + .build()?, + ); + + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "test")? + .expect("should have spill file"); + + // Pre-grow read reservation to full capacity (simulating take() + // from a merge reservation that already holds these bytes) + let read_reservation = MemoryConsumer::new("read").register(&env.memory_pool); + read_reservation.grow(capacity_bytes); + + // Consume ALL remaining pool capacity with a competing consumer + let remaining = pool_size - pool.reserved(); + let blocker = MemoryConsumer::new("blocker").register(&env.memory_pool); + blocker.grow(remaining); + assert_eq!(pool.reserved(), pool_size); + + // With grow-to-at-least semantics, read_spill_as_stream should + // succeed because the pre-grown reservation already covers the + // required capacity — no additional pool allocation needed. + let stream = spill_manager.read_spill_as_stream( + spill_file, + Some(batch_mem), + Some(read_reservation), + )?; + + let batches = collect(stream).await?; + assert_eq!(batches.len(), 1); + + blocker.free(); + assert_eq!( + pool.reserved(), + 0, + "Pool should be zero after all reservations freed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_unbuffered_read_exhausted_pool_returns_error() -> Result<()> { + use datafusion_execution::memory_pool::{GreedyMemoryPool, MemoryConsumer}; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let batch = build_writer_batch(Arc::clone(&schema))?; + let batch_mem = get_record_batch_memory_size(&batch); + + // Pool large enough for writing but not for a subsequent read + let pool_size = batch_mem + 64; + let pool: Arc = + Arc::new(GreedyMemoryPool::new(pool_size)); + let env = Arc::new( + datafusion_execution::runtime_env::RuntimeEnvBuilder::new() + .with_memory_pool(Arc::clone(&pool)) + .build()?, + ); + + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "test")? + .expect("should have spill file"); + + // Exhaust the pool completely + let blocker = MemoryConsumer::new("blocker").register(&pool); + blocker.grow(pool_size - pool.reserved()); + + let read_reservation = MemoryConsumer::new("read").register(&pool); + + // Unbuffered read with pre-reservation: try_grow(batch_mem) + // should fail with controlled ResourcesExhausted + let mut stream = spill_manager.read_spill_as_stream_unbuffered( + spill_file, + Some(batch_mem), + Some(read_reservation), + )?; + + use futures::StreamExt; + let result = stream.next().await; + assert!(result.is_some()); + let err = result.unwrap().unwrap_err(); + assert!( + err.to_string().contains("Resources exhausted"), + "Expected ResourcesExhausted error, got: {err}" + ); + + blocker.free(); + assert_eq!(pool.reserved(), 0); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/spill_pool.rs b/datafusion/physical-plan/src/spill/spill_pool.rs index 2639188a2609d..beb0a0b07373e 100644 --- a/datafusion/physical-plan/src/spill/spill_pool.rs +++ b/datafusion/physical-plan/src/spill/spill_pool.rs @@ -399,6 +399,7 @@ impl Drop for SpillPoolWriter { /// # use datafusion_physical_plan::spill::spill_pool; /// # use datafusion_physical_plan::spill::SpillManager; // Re-exported for doctests /// # use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; +/// # use datafusion_execution::memory_pool::MemoryConsumer; /// # /// # #[tokio::main] /// # async fn main() -> datafusion_common::Result<()> { @@ -406,7 +407,8 @@ impl Drop for SpillPoolWriter { /// # let env = Arc::new(RuntimeEnv::default()); /// # let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); /// # let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); -/// # let spill_manager = Arc::new(SpillManager::new(env, metrics, schema.clone())); +/// # let reservation = MemoryConsumer::new("example").with_can_spill(true).register(&env.memory_pool); +/// # let spill_manager = Arc::new(SpillManager::new(Arc::clone(&env), metrics, schema.clone(), reservation)); /// # /// // Create channel with 1MB file size limit /// let (writer, mut reader) = spill_pool::channel(1024 * 1024, spill_manager); @@ -565,7 +567,7 @@ impl Stream for SpillFile { // we want this unbuffered because files are actively being written to match self .spill_manager - .read_spill_as_stream_unbuffered(file, None) + .read_spill_as_stream_unbuffered(file, None, None) { Ok(stream) => { self.reader = Some(SpillFileReader { @@ -767,7 +769,7 @@ mod tests { let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let schema = create_test_schema(); - let spill_manager = Arc::new(SpillManager::new(env, metrics, schema)); + let spill_manager = Arc::new(SpillManager::new_default(env, metrics, schema)); channel(max_file_size, spill_manager) } @@ -778,7 +780,8 @@ mod tests { let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let schema = create_test_schema(); - let spill_manager = Arc::new(SpillManager::new(env, metrics.clone(), schema)); + let spill_manager = + Arc::new(SpillManager::new_default(env, metrics.clone(), schema)); let (writer, reader) = channel(max_file_size, spill_manager); (writer, reader, metrics) @@ -1319,8 +1322,11 @@ mod tests { let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let schema = create_test_schema(); - let spill_manager = - Arc::new(SpillManager::new(Arc::clone(&env), metrics.clone(), schema)); + let spill_manager = Arc::new(SpillManager::new_default( + Arc::clone(&env), + metrics.clone(), + schema, + )); let (writer, mut reader) = channel(1024 * 1024, spill_manager); @@ -1461,7 +1467,8 @@ mod tests { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let schema = create_test_schema(); - let spill_manager = Arc::new(SpillManager::new(runtime, metrics.clone(), schema)); + let spill_manager = + Arc::new(SpillManager::new_default(runtime, metrics.clone(), schema)); let (writer, mut reader) = channel(batch_size, spill_manager); diff --git a/datafusion/sqllogictest/src/accounting_pool.rs b/datafusion/sqllogictest/src/accounting_pool.rs index a9d2db9f12261..aec453f0994f6 100644 --- a/datafusion/sqllogictest/src/accounting_pool.rs +++ b/datafusion/sqllogictest/src/accounting_pool.rs @@ -39,8 +39,9 @@ use std::sync::Arc; /// Headroom over the pool's declared limit. Anything past this is an /// untracked allocation — by definition, since DF's pool didn't see it. /// -/// 800% high, but that's what it takes to pass the SLT suite right now. Goal should be ~10% -const HEADROOM_FACTOR: f64 = 8.0; +/// Reduced from 8.0 after adding SpillManager IPC buffer accounting, +/// GroupValuesRows emit pre-reservation, and NLJ probe-side tracking. +const HEADROOM_FACTOR: f64 = 6.0; pub struct AccountingMemoryPool { inner: Arc,