diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index e50f72632b3f2..696bf3acf8483 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -521,6 +521,68 @@ impl MemoryReservation { pub fn take(&mut self) -> MemoryReservation { self.split(self.size.load(atomic::Ordering::Relaxed)) } + + /// Attempts to grow the reservation by `capacity` bytes and returns + /// a [`ReservationGuard`] that will automatically shrink the reservation + /// when dropped. + /// + /// This is useful for tracking transient allocations that are freed + /// when a scope exits (including error paths via `?`). + /// + /// Call [`ReservationGuard::release`] to prevent the automatic shrink + /// when ownership of the allocated memory is transferred elsewhere. + pub fn try_grow_guard(&self, capacity: usize) -> Result> { + self.try_grow(capacity)?; + Ok(ReservationGuard { + reservation: self, + size: capacity, + }) + } + + /// Grows the reservation by `capacity` bytes (infallible) and returns + /// a [`ReservationGuard`] that will automatically shrink on drop. + /// + /// Use only for named exceptions where the allocation is required for + /// correctness/progress and cannot be deferred (e.g., join probe-side + /// index arrays that must be allocated to produce results). + pub fn grow_guard(&self, capacity: usize) -> ReservationGuard<'_> { + self.grow(capacity); + ReservationGuard { + reservation: self, + size: capacity, + } + } +} + +/// RAII guard that automatically shrinks a [`MemoryReservation`] on drop. +/// +/// Created by [`MemoryReservation::try_grow_guard`]. When the guard is +/// dropped, it shrinks the reservation by the guarded size. Call +/// [`Self::release`] to transfer ownership and prevent the automatic shrink. +pub struct ReservationGuard<'a> { + reservation: &'a MemoryReservation, + size: usize, +} + +impl ReservationGuard<'_> { + /// Prevents the automatic shrink on drop, effectively transferring + /// ownership of the reserved bytes to a longer-lived reservation. + pub fn release(mut self) { + self.size = 0; + } + + /// Returns the guarded size in bytes. + pub fn size(&self) -> usize { + self.size + } +} + +impl Drop for ReservationGuard<'_> { + fn drop(&mut self) { + if self.size > 0 { + self.reservation.shrink(self.size); + } + } } impl Drop for MemoryReservation { @@ -670,4 +732,55 @@ mod tests { assert_eq!(r1.size(), 0); assert_eq!(pool.reserved(), 80); } + + #[test] + fn test_try_grow_guard_auto_shrinks() { + let pool = Arc::new(GreedyMemoryPool::new(1000)) as _; + let r1 = MemoryConsumer::new("test").register(&pool); + + { + let _guard = r1.try_grow_guard(100).unwrap(); + assert_eq!(r1.size(), 100); + assert_eq!(pool.reserved(), 100); + } + assert_eq!(r1.size(), 0); + assert_eq!(pool.reserved(), 0); + } + + #[test] + fn test_try_grow_guard_release_prevents_shrink() { + let pool = Arc::new(GreedyMemoryPool::new(1000)) as _; + let r1 = MemoryConsumer::new("test").register(&pool); + + { + let guard = r1.try_grow_guard(100).unwrap(); + guard.release(); + } + assert_eq!(r1.size(), 100); + assert_eq!(pool.reserved(), 100); + } + + #[test] + fn test_grow_guard_auto_shrinks() { + let pool = Arc::new(GreedyMemoryPool::new(1000)) as _; + let r1 = MemoryConsumer::new("test").register(&pool); + + { + let _guard = r1.grow_guard(200); + assert_eq!(r1.size(), 200); + } + assert_eq!(r1.size(), 0); + assert_eq!(pool.reserved(), 0); + } + + #[test] + fn test_try_grow_guard_error_path() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let r1 = MemoryConsumer::new("test").register(&pool); + + let result = r1.try_grow_guard(100); + assert!(result.is_err()); + assert_eq!(r1.size(), 0); + assert_eq!(pool.reserved(), 0); + } } diff --git a/datafusion/physical-plan/benches/spill_io.rs b/datafusion/physical-plan/benches/spill_io.rs index fac2547a131b4..e34c24ee32fdd 100644 --- a/datafusion/physical-plan/benches/spill_io.rs +++ b/datafusion/physical-plan/benches/spill_io.rs @@ -27,6 +27,7 @@ use criterion::{ use datafusion_common::config::SpillCompression; use datafusion_common::human_readable_size; use datafusion_common::instant::Instant; +use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_plan::SpillManager; use datafusion_physical_plan::common::collect; @@ -90,7 +91,14 @@ fn bench_spill_io(c: &mut Criterion) { Field::new("c2", DataType::Date32, true), Field::new("c3", DataType::Decimal128(11, 2), true), ])); - let spill_manager = SpillManager::new(env, metrics, schema); + let spill_manager = SpillManager::new( + Arc::clone(&env), + metrics, + Arc::clone(&schema), + MemoryConsumer::new("bench") + .with_can_spill(true) + .register(&env.memory_pool), + ); let mut group = c.benchmark_group("spill_io"); let rt = Runtime::new().unwrap(); @@ -116,7 +124,7 @@ fn bench_spill_io(c: &mut Criterion) { |spill_file| { rt.block_on(async { let stream = spill_manager - .read_spill_as_stream(spill_file, None) + .read_spill_as_stream(spill_file, None, None) .unwrap(); let _ = collect(stream).await.unwrap(); }) @@ -504,9 +512,15 @@ fn benchmark_spill_batches_for_all_codec( for &compression in compressions { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = - SpillManager::new(Arc::clone(&env), metrics.clone(), Arc::clone(&schema)) - .with_compression_type(compression); + let spill_manager = SpillManager::new( + Arc::clone(&env), + metrics.clone(), + Arc::clone(&schema), + MemoryConsumer::new("bench") + .with_can_spill(true) + .register(&env.memory_pool), + ) + .with_compression_type(compression); let bench_id = BenchmarkId::new(batch_label, compression.to_string()); group.bench_with_input(bench_id, &spill_manager, |b, spill_manager| { @@ -522,7 +536,7 @@ fn benchmark_spill_batches_for_all_codec( .unwrap() .unwrap(); let stream = spill_manager - .read_spill_as_stream(spill_file, None) + .read_spill_as_stream(spill_file, None, None) .unwrap(); let _ = collect(stream).await.unwrap(); }) @@ -557,7 +571,7 @@ fn benchmark_spill_batches_for_all_codec( let start = Instant::now(); rt.block_on(async { let stream = spill_manager - .read_spill_as_stream(spill_file, None) + .read_spill_as_stream(spill_file, None, None) .unwrap(); let _ = collect(stream).await.unwrap(); }); diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ee253e5d7afdd..b9fdaf4b330fd 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -111,6 +111,14 @@ pub trait GroupValues: Send { /// Emits the group values fn emit(&mut self, emit_to: EmitTo) -> Result>; + /// Returns an estimate of the memory that will be allocated by [`Self::emit`] + /// for the decode/output buffers. + /// + /// This is used by the aggregation operator to pre-reserve memory before + /// calling `emit()`, ensuring the memory pool is aware of transient + /// decode buffer allocations. + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize; + /// Clear the contents and shrink the capacity to the size of the batch (free up memory usage) fn clear_shrink(&mut self, num_rows: usize); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index ee2d300d9bff8..47fa7efc36673 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -1088,6 +1088,19 @@ impl GroupValues for GroupValuesColumn { self.group_values[0].len() } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let total = self.len(); + let emit_count = match emit_to { + EmitTo::All => total, + EmitTo::First(n) => (*n).min(total), + }; + if total == 0 { + return 0; + } + let group_values_size: usize = self.group_values.iter().map(|v| v.size()).sum(); + group_values_size * emit_count / total + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { let mut output = match emit_to { EmitTo::All => { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index a3bd31f76c233..5463c88c06580 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -195,6 +195,21 @@ impl GroupValues for GroupValuesRows { .unwrap_or(0) } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let total_rows = self.len(); + if total_rows == 0 { + return 0; + } + let rows_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0); + match emit_to { + EmitTo::All => rows_size, + EmitTo::First(n) => { + let n = (*n).min(total_rows); + rows_size * n / total_rows + } + } + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { let mut group_values = self .group_values diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index e993c0c53d199..c71e318741ea2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -97,6 +97,14 @@ impl GroupValues for GroupValuesBoolean { + self.null_group.is_some() as usize } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let emit_count = match emit_to { + EmitTo::All => self.len(), + EmitTo::First(n) => (*n).min(self.len()), + }; + emit_count.div_ceil(8) * 2 + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { let len = self.len(); let mut builder = BooleanBufferBuilder::new(len); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index b881a51b25474..1f19e6e4f5b75 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -84,6 +84,21 @@ impl GroupValues for GroupValuesBytes { self.num_groups } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let total = self.len(); + let emit_count = match emit_to { + EmitTo::All => total, + EmitTo::First(n) => (*n).min(total), + }; + if total == 0 { + return 0; + } + // Offsets + data bytes (proportional) + null bitmap + (emit_count + 1) * std::mem::size_of::() + + self.size() * emit_count / total + + emit_count.div_ceil(8) + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index 7a56f7c52c11a..04a69cf270bb0 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -86,6 +86,17 @@ impl GroupValues for GroupValuesBytesView { self.num_groups } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let emit_count = match emit_to { + EmitTo::All => self.len(), + EmitTo::First(n) => (*n).min(self.len()), + }; + // Views (16 bytes each) + out-of-line data estimate + null bitmap + emit_count * 16 + + self.size() * emit_count / self.len().max(1) + + emit_count.div_ceil(8) + } + fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { // Reset the map to default, and convert it into a single array let map_contents = self.map.take().into_state(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 07535cfdaa6de..96e168c23f31e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -165,6 +165,14 @@ where self.values.len() } + fn estimated_emit_size(&self, emit_to: &EmitTo) -> usize { + let emit_count = match emit_to { + EmitTo::All => self.len(), + EmitTo::First(n) => (*n).min(self.len()), + }; + emit_count * std::mem::size_of::() + emit_count.div_ceil(8) + } + fn emit(&mut self, emit_to: EmitTo) -> Result> { fn build_primitive( values: Vec, diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index c3f73976c721a..3f3dc100a727a 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -447,6 +447,12 @@ pub(crate) struct GroupedHashAggregateStream { /// The memory reservation for this grouping reservation: MemoryReservation, + /// Tracks memory for emitted output arrays that are live between + /// emit() and poll_next yielding them downstream. Ensures the pool + /// knows about in-flight batches that have left internal state but + /// haven't been consumed yet. + transient_reservation: MemoryReservation, + /// The behavior to trigger when out of memory occurs oom_mode: OutOfMemoryMode, @@ -601,6 +607,7 @@ impl GroupedHashAggregateStream { // to ensure fair application of back pressure amongst the memory consumers. .with_can_spill(oom_mode != OutOfMemoryMode::ReportError) .register(context.memory_pool()); + let transient_reservation = reservation.new_empty(); timer.done(); let exec_state = ExecutionState::ReadingInput; @@ -609,6 +616,7 @@ impl GroupedHashAggregateStream { context.runtime_env(), metrics::SpillMetrics::new(&agg.metrics, partition), Arc::clone(&spill_schema), + reservation.new_empty(), ) .with_compression_type(context.session_config().spill_compression()); @@ -682,6 +690,7 @@ impl GroupedHashAggregateStream { filter_expressions, group_by: agg_group_by, reservation, + transient_reservation, oom_mode, group_values, current_group_indices: Default::default(), @@ -862,6 +871,8 @@ impl Stream for GroupedHashAggregateStream { let output_batch; let size = self.batch_size; (self.exec_state, output_batch) = if batch.num_rows() <= size { + // Entire batch yielded — release transient reservation + self.transient_reservation.free(); ( if self.input_done { ExecutionState::Done @@ -1116,6 +1127,16 @@ impl GroupedHashAggregateStream { return Ok(None); } + // RAII guard pre-reserves memory for decode buffers allocated + // inside group_values.emit(). Guard automatically releases on + // drop (including error paths via ?). + let emit_estimate = self.group_values.estimated_emit_size(&emit_to); + let emit_guard = if emit_estimate > 0 { + Some(self.reservation.grow_guard(emit_estimate)) + } else { + None + }; + let timer = self.group_by_metrics.emitting_time.timer(); let mut output = self.group_values.emit(emit_to)?; if let EmitTo::First(n) = emit_to { @@ -1127,19 +1148,24 @@ impl GroupedHashAggregateStream { if self.mode.output_mode() == AggregateOutputMode::Final && !spilling { output.push(acc.evaluate(emit_to)?) } else { - // Output partial state: either because we're in a non-final mode, - // or because we're spilling and will merge/re-evaluate later. output.extend(acc.state(emit_to)?) } } drop(timer); - // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is - // over the target memory size after emission, we can emit again rather than returning Err. + // Release the pre-reserved emit estimate (decode buffers consumed) + drop(emit_guard); + + // emit reduces the memory usage. Ignore Err from update_memory_reservation. let _ = self.update_memory_reservation(); let batch = RecordBatch::try_new(schema, output)?; debug_assert!(batch.num_rows() > 0); + // Track emitted batch in transient_reservation until yielded downstream. + // This ensures the pool knows about in-flight output arrays. + let batch_size = batch.get_array_memory_size(); + self.transient_reservation.grow(batch_size); + Ok(Some(batch)) } @@ -1844,4 +1870,131 @@ mod tests { "ratio == threshold should not trigger skip (boundary is exclusive)" ); } + + #[tokio::test] + async fn test_aggregation_reservation_balanced() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + let num_rows = 500; + let group_ids: Vec = (0..num_rows).map(|i| i % 50).collect(); + let values: Vec = vec![1; num_rows as usize]; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(group_ids)), + Arc::new(Int64Array::from(values)), + ], + )?; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(8192, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )]; + + let exec = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )?; + + let stream = aggregate_exec.execute(0, task_ctx)?; + let batches: Vec = crate::common::collect(stream).await?; + assert!(!batches.is_empty(), "Should produce output"); + + drop(batches); + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after aggregation completes" + ); + Ok(()) + } + + #[tokio::test] + async fn test_aggregation_reservation_balanced_group_values_rows() -> Result<()> { + use arrow::datatypes::TimeUnit; + let schema = Arc::new(Schema::new(vec![ + Field::new("dur_key", DataType::Duration(TimeUnit::Microsecond), false), + Field::new("str_key", DataType::Utf8, false), + Field::new("value_col", DataType::Int64, false), + ])); + + let num_rows = 400; + let dur_keys: Vec = (0..num_rows as i64).map(|i| (i % 40) * 1000).collect(); + let str_keys: Vec = (0..num_rows) + .map(|i| format!("group_{:04}", i % 10)) + .collect(); + let values: Vec = vec![1; num_rows]; + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(DurationMicrosecondArray::from(dur_keys)), + Arc::new(StringArray::from(str_keys)), + Arc::new(Int64Array::from(values)), + ], + )?; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(16384, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + + let task_ctx = Arc::new(TaskContext::default().with_runtime(runtime)); + + let group_expr = vec![ + (col("dur_key", &schema)?, "dur_key".to_string()), + (col("str_key", &schema)?, "str_key".to_string()), + ]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )]; + + let exec = TestMemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None)?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )?; + + let stream = aggregate_exec.execute(0, task_ctx)?; + let batches: Vec = crate::common::collect(stream).await?; + assert!(!batches.is_empty(), "Should produce output"); + + drop(batches); + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after GroupValuesRows aggregation completes" + ); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index a18ec0cbe4504..2d2698ff58843 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -61,7 +61,7 @@ use arrow::record_batch::RecordBatch; use arrow_schema::DataType; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ - JoinSide, NullEquality, Result, ScalarValue, Statistics, arrow_err, + DataFusionError, JoinSide, NullEquality, Result, ScalarValue, Statistics, arrow_err, assert_eq_or_internal_err, internal_datafusion_err, internal_err, project_schema, unwrap_or_internal_err, }; @@ -673,6 +673,11 @@ impl ExecutionPlan for NestedLoopJoinExec { SpillState::Disabled }; + let probe_reservation = + MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]")) + .with_can_spill(true) + .register(context.memory_pool()); + Ok(Box::pin(NestedLoopJoinStream::new( self.schema(), self.filter.clone(), @@ -683,6 +688,7 @@ impl ExecutionPlan for NestedLoopJoinExec { metrics, batch_size, spill_state, + probe_reservation, ))) } @@ -1037,6 +1043,10 @@ pub(crate) struct NestedLoopJoinStream { /// Output buffer holds the join result to output. It will emit eagerly when /// the threshold is reached. output_buffer: Box, + /// Best-effort tracked size of output buffer data via try_grow. + /// Only the successfully reserved portion is tracked; if try_grow fails + /// (pool full), the batch is still pushed but not counted. + output_buffer_reserved: usize, /// See comments in [`NLJState::Done`] for its purpose handled_empty_output: bool, @@ -1064,6 +1074,10 @@ pub(crate) struct NestedLoopJoinStream { /// Memory-limited spill fallback state. See [`SpillState`] for details. spill_state: SpillState, + + /// Memory reservation for transient probe-side allocations + /// (Cartesian product indices, take() intermediates, output buffering). + probe_reservation: MemoryReservation, } pub(crate) struct NestedLoopJoinMetrics { @@ -1315,6 +1329,7 @@ impl NestedLoopJoinStream { metrics: NestedLoopJoinMetrics, batch_size: usize, spill_state: SpillState, + probe_reservation: MemoryReservation, ) -> Self { Self { output_schema: Arc::clone(&schema), @@ -1326,6 +1341,7 @@ impl NestedLoopJoinStream { metrics, buffered_left_data: None, output_buffer: Box::new(BatchCoalescer::new(schema, batch_size)), + output_buffer_reserved: 0, batch_size, current_right_batch: None, current_right_batch_matched: None, @@ -1337,6 +1353,7 @@ impl NestedLoopJoinStream { handled_empty_output: false, should_track_unmatched_right: need_produce_right_in_final(join_type), spill_state, + probe_reservation, } } @@ -1346,12 +1363,9 @@ impl NestedLoopJoinStream { } /// Check if we can fall back to memory-limited mode on this error. - fn can_fallback_to_spill(&self, error: &datafusion_common::DataFusionError) -> bool { + fn can_fallback_to_spill(&self, error: &DataFusionError) -> bool { matches!(self.spill_state, SpillState::Pending { .. }) - && matches!( - error.find_root(), - datafusion_common::DataFusionError::ResourcesExhausted(_) - ) + && matches!(error.find_root(), DataFusionError::ResourcesExhausted(_)) } /// Switch from the standard OnceFut path to memory-limited mode. @@ -1389,6 +1403,9 @@ impl NestedLoopJoinStream { ctx.runtime_env(), spill_metrics, Arc::clone(&schema), + MemoryConsumer::new("NestedLoopJoinLeftSpill") + .with_can_spill(true) + .register(ctx.memory_pool()), ) .with_compression_type(ctx.session_config().spill_compression()); @@ -1438,6 +1455,9 @@ impl NestedLoopJoinStream { context.runtime_env(), self.metrics.spill_metrics.clone(), right_schema, + MemoryConsumer::new("NestedLoopJoinRightSpill") + .with_can_spill(true) + .register(context.memory_pool()), ) .with_compression_type(context.session_config().spill_compression()); @@ -1528,10 +1548,11 @@ impl NestedLoopJoinStream { if active.left_stream.is_none() { match active.left_spill_fut.get_shared(cx) { Poll::Ready(Ok(spill_data)) => { - match spill_data - .spill_manager - .read_spill_as_stream(spill_data.spill_file.clone(), None) - { + match spill_data.spill_manager.read_spill_as_stream( + spill_data.spill_file.clone(), + None, + None, + ) { Ok(stream) => { active.left_schema = Some(Arc::clone(&spill_data.schema)); active.left_stream = Some(stream); @@ -1570,8 +1591,10 @@ impl NestedLoopJoinStream { if !can_grow && !active.pending_batches.is_empty() { // Memory limit reached and we already have data. - // Push this batch into pending (it's already in memory) - // and stop buffering for this chunk. + // This batch is already in memory and can't be returned + // to the source. It's the documented make-progress + // exception — one over-budget batch is accepted without + // reservation to avoid blocking downstream operators. active.pending_batches.push(batch); self.left_exhausted = false; self.left_buffered_in_one_pass = false; @@ -1614,6 +1637,18 @@ impl NestedLoopJoinStream { return ControlFlow::Continue(()); } + // RAII guard for concat dual-live overhead: both pending batches + // and concat result coexist during concat_batches. + // Named exception (infallible grow): concat is required to create + // JoinLeftData for probe processing, and pending batches are already + // in memory — the total footprint doesn't actually increase. + let pending_size: usize = active + .pending_batches + .iter() + .map(|b| b.get_array_memory_size()) + .sum(); + let _concat_guard = active.reservation.grow_guard(pending_size); + let merged_batch = match concat_batches( active .left_schema @@ -1808,13 +1843,13 @@ impl NestedLoopJoinStream { "This state is yielding output for unmatched rows in the current right batch, so both the right batch and the bitmap must be present" ); match self.process_right_unmatched() { - Ok(Some(batch)) => match self.output_buffer.push_batch(batch) { + Ok(Some(batch)) => match self.push_output_batch(batch) { Ok(()) => { debug_assert!(self.current_right_batch.is_none()); self.state = NLJState::FetchingRight; ControlFlow::Continue(()) } - Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), }, Ok(None) => { debug_assert!(self.current_right_batch.is_none()); @@ -1960,9 +1995,9 @@ impl NestedLoopJoinStream { self.join_type, JoinSide::Right, ) { - Ok(Some(batch)) => match self.output_buffer.push_batch(batch) { + Ok(Some(batch)) => match self.push_output_batch(batch) { Ok(()) => ControlFlow::Continue(()), - Err(e) => ControlFlow::Break(Poll::Ready(Some(arrow_err!(e)))), + Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), }, Ok(None) => ControlFlow::Continue(()), Err(e) => ControlFlow::Break(Poll::Ready(Some(Err(e)))), @@ -2066,7 +2101,7 @@ impl NestedLoopJoinStream { )?; if let Some(batch) = joined_batch { - self.output_buffer.push_batch(batch)?; + self.push_output_batch(batch)?; } self.left_probe_idx += l_row_count; @@ -2079,7 +2114,7 @@ impl NestedLoopJoinStream { self.process_single_left_row_join(&left_data, &right_batch, l_idx)?; if let Some(batch) = joined_batch { - self.output_buffer.push_batch(batch)?; + self.push_output_batch(batch)?; } // ==== Prepare for the next iteration ==== @@ -2111,6 +2146,14 @@ impl NestedLoopJoinStream { let right_rows = right_batch.num_rows(); let total_rows = l_row_count * right_rows; + // Named exception (infallible grow): Cartesian product index arrays + // are required for join correctness. The probe phase cannot be deferred + // or spilled once started — the left chunk is already loaded. These + // arrays are function-scoped and freed when _indices_guard drops. + // Bounded by: 2 * l_row_count * right_rows * sizeof(u32). + let indices_size = 2 * total_rows * size_of::(); + let _indices_guard = self.probe_reservation.grow_guard(indices_size); + // Build index arrays for cartesian product: left_range X right_batch let left_indices: UInt32Array = UInt32Array::from_iter_values((0..l_row_count).flat_map(|i| { @@ -2129,9 +2172,31 @@ impl NestedLoopJoinStream { // Evaluate the join filter (if any) over an intermediate batch built // using the filter's own schema/column indices. let bitmap_combined = if let Some(filter) = &self.join_filter { - // Build the intermediate batch for filter evaluation + // Pre-estimate filter take sizes before allocation. + // Guard declared here (outside the if/else) so it covers the + // intermediate_batch through filter evaluation at line ~2213. + let filter_estimate = if !filter.schema.fields().is_empty() { + let mut est = 0usize; + for ci in filter.column_indices() { + let col = if ci.side == JoinSide::Left { + left_data.batch().column(ci.index) + } else { + right_batch.column(ci.index) + }; + let col_len = col.len().max(1); + est += total_rows * col.get_array_memory_size() / col_len; + } + est + } else { + 0 + }; + // Named exception (infallible grow): filter intermediate batch is + // built from take() on source columns for filter evaluation. The + // probe phase cannot be deferred. Function-scoped, freed on drop. + // Bounded by: sum(total_rows * col_size / col_len) per filter column. + let _filter_guard = self.probe_reservation.grow_guard(filter_estimate); + let intermediate_batch = if filter.schema.fields().is_empty() { - // Constant predicate (e.g., TRUE/FALSE). Use an empty schema with row_count create_record_batch_with_empty_schema( Arc::new((*filter.schema).clone()), total_rows, @@ -2149,7 +2214,6 @@ impl NestedLoopJoinStream { }; filter_columns.push(array); } - RecordBatch::try_new(Arc::new((*filter.schema).clone()), filter_columns)? }; @@ -2159,7 +2223,6 @@ impl NestedLoopJoinStream { .into_array(intermediate_batch.num_rows())?; let filter_arr = as_boolean_array(&filter_result)?; - // Combine with null bitmap to get a unified mask boolean_mask_from_filter(filter_arr) } else { // No filter: all pairs match @@ -2258,6 +2321,24 @@ impl NestedLoopJoinStream { )?)); } + // Pre-estimate output take sizes before allocation. + let mut output_estimate = 0usize; + for column_index in &self.column_indices { + let col = if column_index.side == JoinSide::Left { + left_data.batch().column(column_index.index) + } else { + right_batch.column(column_index.index) + }; + let col_len = col.len().max(1); + output_estimate += total_rows * col.get_array_memory_size() / col_len; + } + // Named exception (infallible grow): output batch columns are built + // from take() on source arrays. The probe phase cannot be deferred. + // Function-scoped, freed on drop. The output batch is then moved to + // push_output_batch which tracks it separately via best-effort try_grow. + // Bounded by: sum(total_rows * col_size / col_len) per output column. + let _output_guard = self.probe_reservation.grow_guard(output_estimate); + let mut out_columns: Vec> = Vec::with_capacity(self.output_schema.fields().len()); for column_index in &self.column_indices { @@ -2370,7 +2451,7 @@ impl NestedLoopJoinStream { if let Some(batch) = self.process_left_unmatched_range(left_data, start_idx, end_idx)? { - self.output_buffer.push_batch(batch)?; + self.push_output_batch(batch)?; } // ==== Prepare for the next iteration ==== @@ -2476,13 +2557,43 @@ impl NestedLoopJoinStream { .ok_or_else(|| internal_datafusion_err!("LeftData should be available")) } + /// Push a batch into the output buffer with best-effort memory tracking. + /// + /// Named exception (best-effort): the output batch is already allocated + /// by Arrow take() before reaching this point. Rejecting it would discard + /// valid join results without freeing the underlying allocation. Uses + /// try_grow so the pool counter reflects tracked output when capacity + /// allows, but accepts the batch regardless to avoid aborting mid-probe. + fn push_output_batch(&mut self, batch: RecordBatch) -> Result<()> { + let size = batch.get_array_memory_size(); + let reserved = self.probe_reservation.try_grow(size).is_ok(); + if reserved { + self.output_buffer_reserved += size; + } + if let Err(e) = self.output_buffer.push_batch(batch) { + if reserved { + self.probe_reservation.shrink(size); + self.output_buffer_reserved -= size; + } + return Err(DataFusionError::ArrowError(Box::new(e), None)); + } + Ok(()) + } + /// Flush the `output_buffer` if there are batches ready to output /// None if no result batch ready. fn maybe_flush_ready_batch(&mut self) -> Option>>> { if self.output_buffer.has_completed_batch() && let Some(batch) = self.output_buffer.next_completed_batch() { - // Update output rows for selectivity metric + // Release tracked output buffer reservation + let batch_size = batch.get_array_memory_size(); + let shrink = batch_size.min(self.output_buffer_reserved); + if shrink > 0 { + self.probe_reservation.shrink(shrink); + self.output_buffer_reserved -= shrink; + } + let output_rows = batch.num_rows(); self.metrics.selectivity.add_part(output_rows); @@ -3931,4 +4042,158 @@ pub(crate) mod tests { ")); Ok(()) } + + /// Verify that NLJ probe-side reservation returns to zero after + /// join completion for both spilling (inner) and non-spilling paths. + #[tokio::test] + async fn test_nlj_reservation_balanced_after_inner_join() -> Result<()> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(16); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(cfg) + .with_runtime(runtime), + ); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, metrics) = + join_collect(left, right, &JoinType::Inner, Some(filter), task_ctx).await?; + + assert!(!batches.is_empty(), "Should produce output"); + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to verify spill-path accounting" + ); + + // After join completes and all streams/operators are dropped, + // pool should have zero reserved (all grow/shrink balanced). + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after NLJ inner join completes" + ); + Ok(()) + } + + /// Verify reservation balance for RIGHT JOIN which exercises the + /// global right bitmap accumulation and unmatched-row emission paths. + #[tokio::test] + async fn test_nlj_reservation_balanced_after_right_join() -> Result<()> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(16); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(cfg) + .with_runtime(runtime), + ); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, metrics) = + join_collect(left, right, &JoinType::Right, Some(filter), task_ctx).await?; + + assert!(!batches.is_empty(), "Should produce output"); + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to verify spill-path accounting" + ); + + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after NLJ right join completes" + ); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_reservation_balanced_after_left_join() -> Result<()> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(16); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(cfg) + .with_runtime(runtime), + ); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, metrics) = + join_collect(left, right, &JoinType::Left, Some(filter), task_ctx).await?; + + assert!(!batches.is_empty(), "Should produce output"); + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to verify spill-path accounting" + ); + + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after NLJ left join completes" + ); + Ok(()) + } + + #[tokio::test] + async fn test_nlj_reservation_balanced_after_full_join() -> Result<()> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(50, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let cfg = TaskContext::default() + .session_config() + .clone() + .with_batch_size(16); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(cfg) + .with_runtime(runtime), + ); + + let left = build_left_table(); + let right = build_right_table(); + let filter = prepare_join_filter(); + + let (_columns, batches, metrics) = + join_collect(left, right, &JoinType::Full, Some(filter), task_ctx).await?; + + assert!(!batches.is_empty(), "Should produce output"); + assert!( + metrics.spill_count().unwrap_or(0) > 0, + "Expected spilling to verify spill-path accounting" + ); + + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after NLJ full join completes" + ); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index a86cb647e4bff..c078db7689c96 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -510,6 +510,7 @@ impl ExecutionPlan for SortMergeJoinExec { context.runtime_env(), SpillMetrics::new(&self.metrics, partition), buffered.schema(), + reservation.new_empty(), ) .with_compression_type(context.session_config().spill_compression()); diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index c4377b3189ff7..de3664b19b423 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -4021,7 +4021,7 @@ fn test_stream_resources( let ctx = TaskContext::default(); let runtime_env = ctx.runtime_env(); let reservation = MemoryConsumer::new("test").register(ctx.memory_pool()); - let spill_manager = SpillManager::new( + let spill_manager = SpillManager::new_default( Arc::clone(&runtime_env), SpillMetrics::new(metrics, 0), inner_schema, @@ -4676,7 +4676,7 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> { let metrics = ExecutionPlanMetricsSet::new(); let reservation = MemoryConsumer::new("test").register(&runtime.memory_pool); - let spill_manager = SpillManager::new( + let spill_manager = SpillManager::new_default( Arc::clone(&runtime), SpillMetrics::new(&metrics, 0), Arc::clone(&right_schema), diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 3d30dd82762b1..fb167e9ce43df 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1239,6 +1239,9 @@ impl ExecutionPlan for RepartitionExec { Arc::clone(&context.runtime_env()), spill_metrics, input.schema(), + MemoryConsumer::new("RepartitionSpill") + .with_can_spill(true) + .register(context.memory_pool()), ); // Get existing ordering to use for merging diff --git a/datafusion/physical-plan/src/sorts/multi_level_merge.rs b/datafusion/physical-plan/src/sorts/multi_level_merge.rs index 8985e1d8c70ee..acd55e6cbb8e7 100644 --- a/datafusion/physical-plan/src/sorts/multi_level_merge.rs +++ b/datafusion/physical-plan/src/sorts/multi_level_merge.rs @@ -236,9 +236,12 @@ impl MultiLevelMergeBuilder { (1, 0) => { let spill_file = self.sorted_spill_files.remove(0); - // Not reserving any memory for this disk as we are not holding it in memory - self.spill_manager - .read_spill_as_stream(spill_file.file, None) + let read_reservation = self.reservation.take(); + self.spill_manager.read_spill_as_stream( + spill_file.file, + Some(spill_file.max_record_batch_memory), + Some(read_reservation), + ) } // Only in memory streams, so merge them all in a single pass @@ -292,6 +295,7 @@ impl MultiLevelMergeBuilder { .read_spill_as_stream( spill.file, Some(spill.max_record_batch_memory), + None, )?; sorted_streams.push(stream); } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 929ff4f7dfc85..11a9222c32e29 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -292,6 +292,7 @@ impl ExternalSorter { Arc::clone(&runtime), metrics.spill_metrics.clone(), Arc::clone(&schema), + reservation.new_empty(), ) .with_compression_type(spill_compression); @@ -2980,4 +2981,59 @@ mod tests { assert_eq!(desc.self_filters()[0].len(), 1); Ok(()) } + + #[tokio::test] + async fn test_sort_spill_reservation_balanced() -> Result<()> { + let session_config = SessionConfig::new(); + let sort_spill_reservation_bytes = session_config + .options() + .execution + .sort_spill_reservation_bytes; + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(sort_spill_reservation_bytes + 12288, 1.0) + .build_arc()?; + let pool = Arc::clone(&runtime.memory_pool); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + let partitions = 100; + let input = test::scan_partitioned(partitions); + let schema = input.schema(); + + let sort_exec = Arc::new(SortExec::new( + [PhysicalSortExpr { + expr: col("i", &schema)?, + options: SortOptions::default(), + }] + .into(), + Arc::new(CoalescePartitionsExec::new(input)), + )); + + let result = collect( + Arc::clone(&sort_exec) as Arc, + Arc::clone(&task_ctx), + ) + .await?; + assert!(!result.is_empty(), "Should produce output"); + + let metrics = sort_exec.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to verify spill-path accounting" + ); + + drop(result); + drop(sort_exec); + drop(task_ctx); + + assert_eq!( + pool.reserved(), + 0, + "Pool reservation should be zero after sort with spilling completes" + ); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs index e0548bd5bf860..77745e5458193 100644 --- a/datafusion/physical-plan/src/spill/in_progress_spill_file.rs +++ b/datafusion/physical-plan/src/spill/in_progress_spill_file.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use arrow::array::RecordBatch; use datafusion_common::exec_datafusion_err; use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::MemoryReservation; use super::{ IPCStreamWriter, gc_view_arrays, @@ -38,20 +39,31 @@ pub struct InProgressSpillFile { writer: Option, /// Lazily initialized in-progress file, it will be moved out when the `finish` method is invoked in_progress_file: Option, + /// Memory reservation for tracking IPC write buffer overhead. + /// `append_batch` reserves memory before writing and releases it + /// after the write completes. Freed automatically on Drop. + reservation: MemoryReservation, } impl InProgressSpillFile { pub fn new( spill_writer: Arc, in_progress_file: RefCountedTempFile, + reservation: MemoryReservation, ) -> Self { Self { spill_writer, in_progress_file: Some(in_progress_file), writer: None, + reservation, } } + #[cfg(test)] + pub(crate) fn reservation_size(&self) -> usize { + self.reservation.size() + } + /// Appends a `RecordBatch` to the spill file, initializing the writer if necessary. /// /// Before writing, performs GC on StringView/BinaryView arrays to compact backing @@ -71,13 +83,24 @@ impl InProgressSpillFile { )); } + // Named exception (infallible grow): spill writes MUST complete + // even under memory pressure. Operators free their main reservation + // before spilling; failing here would prevent memory recovery. + // The grow/shrink is balanced within each call — the reservation + // returns to its pre-call size after the write completes or errors. + let write_overhead = batch.get_array_memory_size(); + self.reservation.grow(write_overhead); + + let result = self.append_batch_inner(batch); + + self.reservation.shrink(write_overhead); + result + } + + fn append_batch_inner(&mut self, batch: &RecordBatch) -> Result { let gc_batch = gc_view_arrays(batch)?; if self.writer.is_none() { - // Use the SpillManager's declared schema rather than the batch's schema. - // Individual batches may have different schemas (e.g., different nullability) - // when they come from different branches of a UnionExec. The SpillManager's - // schema represents the canonical schema that all batches should conform to. let schema = self.spill_writer.schema(); if let Some(in_progress_file) = &mut self.in_progress_file { self.writer = Some(IPCStreamWriter::new( @@ -86,10 +109,8 @@ impl InProgressSpillFile { self.spill_writer.compression, )?); - // Update metrics self.spill_writer.metrics.spill_file_count.add(1); - // Update initial size (schema/header) in_progress_file.update_disk_usage()?; let initial_size = in_progress_file.current_disk_usage(); self.spill_writer @@ -111,9 +132,10 @@ impl InProgressSpillFile { .spilled_bytes .add((post_size - pre_size) as usize); } else { - unreachable!() // Already checked inside current function + unreachable!() } } + gc_batch.get_sliced_size() } @@ -180,7 +202,7 @@ mod tests { let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); let metrics_set = ExecutionPlanMetricsSet::new(); let spill_metrics = SpillMetrics::new(&metrics_set, 0); - let spill_manager = Arc::new(SpillManager::new( + let spill_manager = Arc::new(SpillManager::new_default( runtime, spill_metrics, Arc::clone(&nullable_schema), @@ -210,7 +232,7 @@ mod tests { let spill_file = in_progress.finish()?.unwrap(); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; // Stream schema should be nullable assert_eq!(stream.schema(), nullable_schema); diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index 3c95a1da5b33c..9e7184830ac47 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -54,8 +54,10 @@ use datafusion_common::config::SpillCompression; use datafusion_common::{DataFusionError, Result, exec_datafusion_err, exec_err}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::RecordBatchStream; +use datafusion_execution::SendableRecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; -use futures::{FutureExt as _, Stream}; +use datafusion_execution::memory_pool::MemoryReservation; +use futures::{FutureExt as _, Stream, StreamExt as _}; use log::debug; /// Stream that reads spill files from disk where each batch is read in a spawned blocking task @@ -69,10 +71,13 @@ struct SpillReaderStream { schema: SchemaRef, state: SpillReaderStreamState, /// Maximum memory size observed among spilling sorted record batches. - /// This is used for validation purposes during reading each RecordBatch from spill. - /// For context on why this value is recorded and validated, - /// see `physical_plan/sort/multi_level_merge.rs`. max_record_batch_memory: Option, + /// Optional reservation for tracking decoded batch memory. + /// When provided, grows when a batch is decoded and shrinks + /// when the next batch is read (previous batch consumed by caller). + reservation: Option, + /// Size of the last decoded batch, used for shrinking on next poll. + last_batch_size: usize, } // Small margin allowed to accommodate slight memory accounting variation @@ -102,20 +107,72 @@ impl SpillReaderStream { schema: SchemaRef, spill_file: RefCountedTempFile, max_record_batch_memory: Option, + reservation: Option, ) -> Self { Self { schema, state: SpillReaderStreamState::Uninitialized(spill_file), max_record_batch_memory, + reservation, + last_batch_size: 0, } } + /// Pre-reserve one decoded-batch slot before launching a blocking read. + /// Shrinks the previous batch's reservation first, then grows for the + /// next expected batch. Returns Err(ResourcesExhausted) if the pool + /// cannot accommodate the next batch. + fn pre_reserve_next_batch(&mut self) -> Result<()> { + if let (Some(res), Some(max_mem)) = + (&self.reservation, self.max_record_batch_memory) + { + if self.last_batch_size > 0 { + res.shrink(self.last_batch_size); + self.last_batch_size = 0; + } + res.try_grow(max_mem)?; + self.last_batch_size = max_mem; + } + Ok(()) + } + + /// After a decoded batch arrives, adjust the reservation from the + /// pre-reserved estimate to the actual batch size. + fn adjust_reservation_to_actual(&mut self, batch: &RecordBatch) { + let Some(res) = &self.reservation else { + return; + }; + let actual_size = get_record_batch_memory_size(batch); + + if self.max_record_batch_memory.is_some() { + // Pre-reserved path: adjust from estimate to actual + if actual_size > self.last_batch_size { + res.grow(actual_size - self.last_batch_size); + } else if actual_size < self.last_batch_size { + res.shrink(self.last_batch_size - actual_size); + } + } else { + // Fallback: no max_record_batch_memory, post-decode accounting + if self.last_batch_size > 0 { + res.shrink(self.last_batch_size); + } + res.grow(actual_size); + } + self.last_batch_size = actual_size; + } + fn poll_next_inner( &mut self, cx: &mut Context<'_>, ) -> Poll>> { match &mut self.state { SpillReaderStreamState::Uninitialized(_) => { + // Pre-reserve before the first blocking read + if let Err(e) = self.pre_reserve_next_batch() { + self.state = SpillReaderStreamState::Done; + return Poll::Ready(Some(Err(e))); + } + // Temporarily replace with `Done` to be able to pass the file to the task. let SpillReaderStreamState::Uninitialized(spill_file) = std::mem::replace(&mut self.state, SpillReaderStreamState::Done) @@ -126,15 +183,10 @@ impl SpillReaderStream { let expected_schema = Arc::clone(&self.schema); let task = SpawnedTask::spawn_blocking(move || { let file = BufReader::new(File::open(spill_file.path())?); - // SAFETY: DataFusion's spill writer strictly follows Arrow IPC specifications - // with validated schemas and buffers. Skip redundant validation during read - // to speedup read operation. This is safe for DataFusion as input guaranteed to be correct when written. let mut reader = unsafe { StreamReader::try_new(file, None)?.with_skip_validation(true) }; - // Validate the schema read from Arrow IPC file is the same as the - // schema of the current `SpillManager` let actual_schema = reader.schema(); if actual_schema != expected_schema { @@ -146,8 +198,6 @@ impl SpillReaderStream { ); } - // TODO: Same-schema reads from a different SpillManager still pass today. - // Add a SpillManager UID to IPC metadata and validate it here as well. let next_batch = reader.next().transpose()?; Ok((reader, next_batch)) @@ -155,8 +205,6 @@ impl SpillReaderStream { self.state = SpillReaderStreamState::ReadInProgress(task); - // Poll again immediately so the inner task is polled and the waker is - // registered. self.poll_next_inner(cx) } @@ -185,12 +233,21 @@ impl SpillReaderStream { ); } } + + self.adjust_reservation_to_actual(&batch); + self.state = SpillReaderStreamState::Waiting(reader); Poll::Ready(Some(Ok(batch))) } None => { - // Stream is done + // Stream done — release any remaining reservation + if let Some(res) = &self.reservation + && self.last_batch_size > 0 + { + res.shrink(self.last_batch_size); + self.last_batch_size = 0; + } self.state = SpillReaderStreamState::Done; Poll::Ready(None) @@ -198,6 +255,13 @@ impl SpillReaderStream { } } Err(err) => { + // Release pre-reservation on error + if let Some(res) = &self.reservation + && self.last_batch_size > 0 + { + res.shrink(self.last_batch_size); + self.last_batch_size = 0; + } self.state = SpillReaderStreamState::Done; Poll::Ready(Some(Err(err))) @@ -206,6 +270,12 @@ impl SpillReaderStream { } SpillReaderStreamState::Waiting(_) => { + // Pre-reserve before the next blocking read + if let Err(e) = self.pre_reserve_next_batch() { + self.state = SpillReaderStreamState::Done; + return Poll::Ready(Some(Err(e))); + } + // Temporarily replace with `Done` to be able to pass the file to the task. let SpillReaderStreamState::Waiting(mut reader) = std::mem::replace(&mut self.state, SpillReaderStreamState::Done) @@ -221,8 +291,6 @@ impl SpillReaderStream { self.state = SpillReaderStreamState::ReadInProgress(task); - // Poll again immediately so the inner task is polled and the waker is - // registered. self.poll_next_inner(cx) } @@ -245,6 +313,58 @@ impl RecordBatchStream for SpillReaderStream { } } +/// Wraps a buffered read stream with a capacity-level `MemoryReservation`. +/// +/// For buffered reads (via `spawn_buffered`), multiple decoded batches can +/// be simultaneously live in the channel. Per-batch tracking in the inner +/// stream would under-count because shrinking happens before the consumer +/// actually drops the previous batch. Instead, the caller pre-reserves +/// `max_batch_memory * buffer_capacity` and this wrapper holds that +/// reservation alive until the stream is fully consumed or dropped. +pub(crate) struct ReadStreamWithReservation { + stream: SendableRecordBatchStream, + reservation: MemoryReservation, +} + +impl ReadStreamWithReservation { + pub(crate) fn new( + stream: SendableRecordBatchStream, + reservation: MemoryReservation, + ) -> Self { + Self { + stream, + reservation, + } + } +} + +impl Stream for ReadStreamWithReservation { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match self.stream.poll_next_unpin(cx) { + Poll::Ready(Some(Err(err))) => { + self.reservation.free(); + Poll::Ready(Some(Err(err))) + } + Poll::Ready(None) => { + self.reservation.free(); + Poll::Ready(None) + } + other => other, + } + } +} + +impl RecordBatchStream for ReadStreamWithReservation { + fn schema(&self) -> SchemaRef { + self.stream.schema() + } +} + /// Spill the `RecordBatch` to disk as smaller batches /// split by `batch_size_rows` #[deprecated( @@ -560,7 +680,6 @@ mod tests { use arrow::compute::cast; use arrow::datatypes::{DataType, Field}; use datafusion_execution::runtime_env::RuntimeEnv; - use futures::StreamExt as _; #[tokio::test] async fn test_batch_spill_and_read() -> Result<()> { @@ -582,7 +701,7 @@ mod tests { // Construct SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let spill_manager = SpillManager::new_default(env, metrics, Arc::clone(&schema)); let spill_file = spill_manager .spill_record_batch_and_finish(&[batch1, batch2], "Test")? @@ -591,7 +710,7 @@ mod tests { let spilled_rows = spill_manager.metrics.spilled_rows.value(); assert_eq!(spilled_rows, num_rows); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), schema); let batches = collect(stream).await?; @@ -646,7 +765,8 @@ mod tests { // Construct SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&dict_schema)); + let spill_manager = + SpillManager::new_default(env, metrics, Arc::clone(&dict_schema)); let num_rows = batch1.num_rows() + batch2.num_rows(); let spill_file = spill_manager @@ -655,7 +775,7 @@ mod tests { let spilled_rows = spill_manager.metrics.spilled_rows.value(); assert_eq!(spilled_rows, num_rows); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), dict_schema); let batches = collect(stream).await?; assert_eq!(batches.len(), 2); @@ -674,7 +794,7 @@ mod tests { let schema = batch1.schema(); let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let spill_manager = SpillManager::new_default(env, metrics, Arc::clone(&schema)); let row_batches: Vec = (0..batch1.num_rows()).map(|i| batch1.slice(i, 1)).collect(); @@ -687,7 +807,7 @@ mod tests { assert!(spill_file.path().exists()); assert!(max_batch_mem > 0); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), schema); let batches = collect(stream).await?; @@ -722,7 +842,7 @@ mod tests { let spilled_rows = spill_manager.metrics.spilled_rows.value(); assert_eq!(spilled_rows, num_rows); - let stream = spill_manager.read_spill_as_stream(spill_file, None)?; + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), schema); let batches = collect(stream).await?; @@ -744,16 +864,16 @@ mod tests { let uncompressed_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let lz4_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let zstd_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let uncompressed_spill_manager = SpillManager::new( + let uncompressed_spill_manager = SpillManager::new_default( Arc::clone(&env), uncompressed_metrics, Arc::clone(&schema), ); let lz4_spill_manager = - SpillManager::new(Arc::clone(&env), lz4_metrics, Arc::clone(&schema)) + SpillManager::new_default(Arc::clone(&env), lz4_metrics, Arc::clone(&schema)) .with_compression_type(SpillCompression::Lz4Frame); let zstd_spill_manager = - SpillManager::new(env, zstd_metrics, Arc::clone(&schema)) + SpillManager::new_default(env, zstd_metrics, Arc::clone(&schema)) .with_compression_type(SpillCompression::Zstd); let uncompressed_spill_file = uncompressed_spill_manager .spill_record_batch_and_finish(&batches, "Test")? @@ -814,7 +934,7 @@ mod tests { Field::new("b", DataType::Utf8, false), ])); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let spill_manager = SpillManager::new_default(env, metrics, Arc::clone(&schema)); let batch = RecordBatch::try_new( schema, @@ -872,7 +992,7 @@ mod tests { ])); let spill_manager = - Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema))); + Arc::new(SpillManager::new_default(env, metrics, Arc::clone(&schema))); let mut in_progress_file = spill_manager.create_in_progress_file("Test")?; let batch1 = RecordBatch::try_new( @@ -920,7 +1040,7 @@ mod tests { ])); let spill_manager = - Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema))); + Arc::new(SpillManager::new_default(env, metrics, Arc::clone(&schema))); // Test write empty batch with interface `InProgressSpillFile` and `append_batch()` let mut in_progress_file = spill_manager.create_in_progress_file("Test")?; @@ -968,7 +1088,8 @@ mod tests { // Construct SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + let spill_manager = + SpillManager::new_default(env, metrics, Arc::clone(&schema)); let batches: [_; 10] = std::array::from_fn(|_| batch.clone()); let spill_file_1 = spill_manager @@ -979,9 +1100,9 @@ mod tests { .unwrap(); let mut stream_1 = - spill_manager.read_spill_as_stream(spill_file_1, None)?; + spill_manager.read_spill_as_stream(spill_file_1, None, None)?; let mut stream_2 = - spill_manager.read_spill_as_stream(spill_file_2, None)?; + spill_manager.read_spill_as_stream(spill_file_2, None, None)?; stream_1.next().await; stream_2.next().await; @@ -1012,7 +1133,7 @@ mod tests { Field::new("b", DataType::Utf8, false), ])); - let spill_manager = Arc::new(SpillManager::new( + let spill_manager = Arc::new(SpillManager::new_default( Arc::clone(&env), metrics.clone(), Arc::clone(&schema), @@ -1305,7 +1426,7 @@ mod tests { let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, schema); + let spill_manager = SpillManager::new_default(env, metrics, schema); let mut in_progress_file = spill_manager.create_in_progress_file("Test GC")?; @@ -1478,7 +1599,7 @@ mod tests { // 3. Spill to disk using SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, schema); + let spill_manager = SpillManager::new_default(env, metrics, schema); let spill_file = spill_manager .spill_record_batch_and_finish(&[sliced_batch], "TestGC")? .unwrap(); @@ -1522,7 +1643,7 @@ mod tests { // 3. Spill to disk using SpillManager let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - let spill_manager = SpillManager::new(env, metrics, schema); + let spill_manager = SpillManager::new_default(env, metrics, schema); let spill_file = spill_manager .spill_record_batch_and_finish(&[sliced_batch], "TestGCBinary")? .unwrap(); diff --git a/datafusion/physical-plan/src/spill/replayable_spill_input.rs b/datafusion/physical-plan/src/spill/replayable_spill_input.rs index fea998d268c59..c4ab15eae6893 100644 --- a/datafusion/physical-plan/src/spill/replayable_spill_input.rs +++ b/datafusion/physical-plan/src/spill/replayable_spill_input.rs @@ -225,7 +225,7 @@ impl ReplayableSpillStream { spill_file: Option, ) -> Result { let inner = if let Some(file) = spill_file.as_ref() { - spill_manager.read_spill_as_stream(file.clone(), None)? + spill_manager.read_spill_as_stream(file.clone(), None, None)? } else { Box::pin(EmptyRecordBatchStream::new(Arc::clone(&schema))) }; @@ -344,7 +344,7 @@ mod tests { let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); let metrics_set = ExecutionPlanMetricsSet::new(); let spill_metrics = SpillMetrics::new(&metrics_set, 0); - Ok(SpillManager::new(runtime, spill_metrics, schema)) + Ok(SpillManager::new_default(runtime, spill_metrics, schema)) } fn build_batch(schema: SchemaRef, values: Vec) -> Result { diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index 365a9f977eace..9a2f8707b399a 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -17,7 +17,10 @@ //! Define the `SpillManager` struct, which is responsible for reading and writing `RecordBatch`es to raw files based on the provided configurations. -use super::{SpillReaderStream, in_progress_spill_file::InProgressSpillFile}; +use super::{ + ReadStreamWithReservation, SpillReaderStream, + in_progress_spill_file::InProgressSpillFile, +}; use crate::coop::cooperative; use crate::{common::spawn_buffered, metrics::SpillMetrics}; use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray}; @@ -26,6 +29,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, config::SpillCompression}; use datafusion_execution::SendableRecordBatchStream; use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::runtime_env::RuntimeEnv; use std::borrow::Borrow; use std::sync::Arc; @@ -36,7 +40,7 @@ use std::sync::Arc; /// /// Note: The caller (external operators such as `SortExec`) is responsible for interpreting the spilled files. /// For example, all records within the same spill file are ordered according to a specific order. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct SpillManager { env: Arc, pub(crate) metrics: SpillMetrics, @@ -45,19 +49,57 @@ pub struct SpillManager { batch_read_buffer_capacity: usize, /// general-purpose compression options pub(crate) compression: SpillCompression, + /// Owned reservation split from the operator's reservation. + /// Per-file write reservations are split from this. + reservation: MemoryReservation, +} + +impl Clone for SpillManager { + fn clone(&self) -> Self { + Self { + env: Arc::clone(&self.env), + metrics: self.metrics.clone(), + schema: Arc::clone(&self.schema), + batch_read_buffer_capacity: self.batch_read_buffer_capacity, + compression: self.compression, + reservation: self.reservation.new_empty(), + } + } } impl SpillManager { - pub fn new(env: Arc, metrics: SpillMetrics, schema: SchemaRef) -> Self { + pub fn new( + env: Arc, + metrics: SpillMetrics, + schema: SchemaRef, + reservation: MemoryReservation, + ) -> Self { Self { env, metrics, schema, batch_read_buffer_capacity: 2, compression: SpillCompression::default(), + reservation, } } + /// Convenience constructor for tests that creates an independent + /// reservation. Production code must use `new()` with an operator-split + /// reservation. + #[cfg(test)] + pub(crate) fn new_default( + env: Arc, + metrics: SpillMetrics, + schema: SchemaRef, + ) -> Self { + use datafusion_execution::memory_pool::MemoryConsumer; + let reservation = MemoryConsumer::new("SpillManager") + .with_can_spill(true) + .register(&env.memory_pool); + Self::new(env, metrics, schema, reservation) + } + pub fn with_batch_read_buffer_capacity( mut self, batch_read_buffer_capacity: usize, @@ -76,15 +118,19 @@ impl SpillManager { &self.schema } - /// Creates a temporary file for in-progress operations, returning an error - /// message if file creation fails. The file can be used to append batches - /// incrementally and then finish the file when done. + /// Creates a temporary file for in-progress operations with automatic + /// IPC write buffer accounting via the memory pool. pub fn create_in_progress_file( &self, request_msg: &str, ) -> Result { let temp_file = self.env.disk_manager.create_tmp_file(request_msg)?; - Ok(InProgressSpillFile::new(Arc::new(self.clone()), temp_file)) + let reservation = self.reservation.new_empty(); + Ok(InProgressSpillFile::new( + Arc::new(self.clone()), + temp_file, + reservation, + )) } /// Spill input `batches` into a single file in a atomic operation. If it is @@ -176,30 +222,68 @@ impl SpillManager { /// /// That path uses the maximum spilled batch size to conservatively estimate /// the merge degree when merging multiple sorted runs. + /// + /// # Arg `reservation` + /// + /// Optional caller-owned reservation for tracking decoded batch memory. + /// When provided along with `max_record_batch_memory`, pre-reserves + /// `max_record_batch_memory * buffer_capacity` to account for all + /// batches that can be simultaneously live in the `spawn_buffered` + /// channel. The reservation is freed when the stream ends or is + /// dropped. + /// + /// Callers that already track decoded batches via their own reservation + /// (e.g., NLJ build-side, sort merge multi-file path) should pass + /// `None` to avoid double-counting. pub fn read_spill_as_stream( &self, spill_file_path: RefCountedTempFile, max_record_batch_memory: Option, + reservation: Option, ) -> Result { + // Reserve capacity BEFORE spawning the producer task. + // spawn_buffered immediately starts a producer that decodes batches, + // so the reservation must be in place before any allocation happens. + if let (Some(res), Some(max_mem)) = (&reservation, max_record_batch_memory) { + let capacity_bytes = max_mem.saturating_mul(self.batch_read_buffer_capacity); + let deficit = capacity_bytes.saturating_sub(res.size()); + if deficit > 0 { + res.try_grow(deficit)?; + } + } + let stream = Box::pin(cooperative(SpillReaderStream::new( Arc::clone(&self.schema), spill_file_path, max_record_batch_memory, + None, // per-batch tracking not used for buffered reads ))); - Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) + let buffered = spawn_buffered(stream, self.batch_read_buffer_capacity); + + match reservation { + Some(res) => Ok(Box::pin(ReadStreamWithReservation::new(buffered, res))), + None => Ok(buffered), + } } /// Same as `read_spill_as_stream`, but without buffering. + /// + /// When `reservation` is provided, per-batch tracking is used: + /// each decoded batch grows the reservation and the previous batch's + /// reservation is shrunk. This is correct for unbuffered reads where + /// only one batch is live at a time. pub fn read_spill_as_stream_unbuffered( &self, spill_file_path: RefCountedTempFile, max_record_batch_memory: Option, + reservation: Option, ) -> Result { Ok(Box::pin(cooperative(SpillReaderStream::new( Arc::clone(&self.schema), spill_file_path, max_record_batch_memory, + reservation, )))) } } @@ -265,7 +349,7 @@ mod tests { schema: Arc, ) -> SpillManager { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); - SpillManager::new(env, metrics, schema) + SpillManager::new_default(env, metrics, schema) } fn build_writer_batch(schema: Arc) -> Result { @@ -307,7 +391,7 @@ mod tests { // Same-schema reads through a different SpillManager currently pass // because only schema compatibility is validated. This is not a // supported usage pattern. - let stream = reader.read_spill_as_stream(spill_file, None)?; + let stream = reader.read_spill_as_stream(spill_file, None, None)?; assert_eq!(stream.schema(), reader_schema); let batches = collect(stream).await?; @@ -341,7 +425,7 @@ mod tests { )? .unwrap(); - let stream = reader.read_spill_as_stream(spill_file, None)?; + let stream = reader.read_spill_as_stream(spill_file, None, None)?; let err = collect(stream) .await .expect_err("schema mismatch should fail fast"); @@ -402,4 +486,344 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_spill_write_reservation_balanced() -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(schema)?; + + let mut in_progress = spill_manager.create_in_progress_file("test_balanced")?; + in_progress.append_batch(&batch)?; + in_progress.append_batch(&batch)?; + + // After each append the grow/shrink is balanced within the call + let reserved_during = env.memory_pool.reserved(); + + let _file = in_progress.finish()?; + drop(in_progress); + + // After drop, reservation should be fully released + assert_eq!( + env.memory_pool.reserved(), + 0, + "Pool should have zero reserved after InProgressSpillFile is dropped, got {reserved_during} during" + ); + Ok(()) + } + + #[tokio::test] + async fn test_spill_read_reservation_tracked() -> Result<()> { + let env = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(schema)?; + + // Write a spill file + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "test_read")? + .expect("should have spill file"); + + // Read back — reservation should track decoded batches + let stream = spill_manager.read_spill_as_stream(spill_file, None, None)?; + let batches = collect(stream).await?; + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].num_rows(), 3); + + // After stream is consumed and dropped, pool reserved should be zero + assert_eq!( + env.memory_pool.reserved(), + 0, + "Pool should have zero reserved after read stream is consumed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_spill_write_balanced_under_exhausted_pool() -> Result<()> { + use datafusion_execution::memory_pool::GreedyMemoryPool; + + let pool: Arc = + Arc::new(GreedyMemoryPool::new(64)); + let env = Arc::new( + datafusion_execution::runtime_env::RuntimeEnvBuilder::new() + .with_memory_pool(Arc::clone(&pool)) + .build()?, + ); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + + // Exhaust most of the pool + let blocker = datafusion_execution::memory_pool::MemoryConsumer::new("blocker") + .register(&pool); + blocker.grow(60); + + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(schema)?; + + // append_batch uses infallible grow — it must succeed even when + // the pool is nearly exhausted (spill-write make-progress exception). + let mut in_progress = spill_manager.create_in_progress_file("exhausted")?; + in_progress.append_batch(&batch)?; + + // Reservation should be balanced within the call + let reserved_after_append = in_progress.reservation_size(); + assert_eq!( + reserved_after_append, 0, + "Write reservation should be zero after append (grow/shrink balanced)" + ); + + let _file = in_progress.finish()?; + drop(in_progress); + + // Release the blocker + blocker.free(); + + assert_eq!( + pool.reserved(), + 0, + "Pool should be zero after all reservations freed" + ); + Ok(()) + } + + #[tokio::test] + async fn test_unbuffered_read_reservation_tracks_batches() -> Result<()> { + use datafusion_execution::memory_pool::MemoryConsumer; + + let env = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(Arc::clone(&schema))?; + + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch.clone(), batch], "test_read")? + .expect("should have spill file"); + + let read_reservation = + MemoryConsumer::new("read_test").register(&env.memory_pool); + + let mut stream = spill_manager.read_spill_as_stream_unbuffered( + spill_file, + None, + Some(read_reservation), + )?; + + use futures::StreamExt; + // Read first batch — reservation should grow + let b1 = stream.next().await.unwrap()?; + assert_eq!(b1.num_rows(), 3); + let reserved_after_b1 = env.memory_pool.reserved(); + assert!( + reserved_after_b1 > 0, + "Reservation should be non-zero after reading first batch" + ); + + // Read second batch — previous shrinks, current grows + let b2 = stream.next().await.unwrap()?; + assert_eq!(b2.num_rows(), 3); + let reserved_after_b2 = env.memory_pool.reserved(); + assert!( + reserved_after_b2 > 0, + "Reservation should be non-zero after reading second batch" + ); + + // Stream ends — remaining reservation freed + assert!(stream.next().await.is_none()); + assert_eq!( + env.memory_pool.reserved(), + 0, + "Pool should be zero after unbuffered read stream consumed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_buffered_read_reservation_prereserves_capacity() -> Result<()> { + use datafusion_execution::memory_pool::MemoryConsumer; + + let env = Arc::new(RuntimeEnv::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + let batch = build_writer_batch(Arc::clone(&schema))?; + let batch_mem = get_record_batch_memory_size(&batch); + + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "test_read")? + .expect("should have spill file"); + + let read_reservation = + MemoryConsumer::new("read_test").register(&env.memory_pool); + + let buffer_capacity = spill_manager.batch_read_buffer_capacity; + let expected_capacity = batch_mem * buffer_capacity; + + let stream = spill_manager.read_spill_as_stream( + spill_file, + Some(batch_mem), + Some(read_reservation), + )?; + + // Capacity should be pre-reserved immediately + assert_eq!( + env.memory_pool.reserved(), + expected_capacity, + "Pool should have capacity pre-reserved for buffered read" + ); + + // Consume the stream + let batches = collect(stream).await?; + assert_eq!(batches.len(), 1); + + // After stream consumed and dropped, reservation freed + assert_eq!( + env.memory_pool.reserved(), + 0, + "Pool should be zero after buffered read stream consumed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_buffered_read_reuses_transferred_reservation() -> Result<()> { + use datafusion_execution::memory_pool::{GreedyMemoryPool, MemoryConsumer}; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let batch = build_writer_batch(Arc::clone(&schema))?; + let batch_mem = get_record_batch_memory_size(&batch); + + // Pool just large enough for the read buffer capacity + let buffer_capacity = 2usize; + let capacity_bytes = batch_mem * buffer_capacity; + let pool_size = capacity_bytes + batch_mem; // extra for write overhead + let pool: Arc = + Arc::new(GreedyMemoryPool::new(pool_size)); + let env = Arc::new( + datafusion_execution::runtime_env::RuntimeEnvBuilder::new() + .with_memory_pool(Arc::clone(&pool)) + .build()?, + ); + + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "test")? + .expect("should have spill file"); + + // Pre-grow read reservation to full capacity (simulating take() + // from a merge reservation that already holds these bytes) + let read_reservation = MemoryConsumer::new("read").register(&env.memory_pool); + read_reservation.grow(capacity_bytes); + + // Consume ALL remaining pool capacity with a competing consumer + let remaining = pool_size - pool.reserved(); + let blocker = MemoryConsumer::new("blocker").register(&env.memory_pool); + blocker.grow(remaining); + assert_eq!(pool.reserved(), pool_size); + + // With grow-to-at-least semantics, read_spill_as_stream should + // succeed because the pre-grown reservation already covers the + // required capacity — no additional pool allocation needed. + let stream = spill_manager.read_spill_as_stream( + spill_file, + Some(batch_mem), + Some(read_reservation), + )?; + + let batches = collect(stream).await?; + assert_eq!(batches.len(), 1); + + blocker.free(); + assert_eq!( + pool.reserved(), + 0, + "Pool should be zero after all reservations freed" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_unbuffered_read_exhausted_pool_returns_error() -> Result<()> { + use datafusion_execution::memory_pool::{GreedyMemoryPool, MemoryConsumer}; + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + ])); + let batch = build_writer_batch(Arc::clone(&schema))?; + let batch_mem = get_record_batch_memory_size(&batch); + + // Pool large enough for writing but not for a subsequent read + let pool_size = batch_mem + 64; + let pool: Arc = + Arc::new(GreedyMemoryPool::new(pool_size)); + let env = Arc::new( + datafusion_execution::runtime_env::RuntimeEnvBuilder::new() + .with_memory_pool(Arc::clone(&pool)) + .build()?, + ); + + let spill_manager = + build_test_spill_manager(Arc::clone(&env), Arc::clone(&schema)); + + let spill_file = spill_manager + .spill_record_batch_and_finish(&[batch], "test")? + .expect("should have spill file"); + + // Exhaust the pool completely + let blocker = MemoryConsumer::new("blocker").register(&pool); + blocker.grow(pool_size - pool.reserved()); + + let read_reservation = MemoryConsumer::new("read").register(&pool); + + // Unbuffered read with pre-reservation: try_grow(batch_mem) + // should fail with controlled ResourcesExhausted + let mut stream = spill_manager.read_spill_as_stream_unbuffered( + spill_file, + Some(batch_mem), + Some(read_reservation), + )?; + + use futures::StreamExt; + let result = stream.next().await; + assert!(result.is_some()); + let err = result.unwrap().unwrap_err(); + assert!( + err.to_string().contains("Resources exhausted"), + "Expected ResourcesExhausted error, got: {err}" + ); + + blocker.free(); + assert_eq!(pool.reserved(), 0); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/spill/spill_pool.rs b/datafusion/physical-plan/src/spill/spill_pool.rs index 2639188a2609d..beb0a0b07373e 100644 --- a/datafusion/physical-plan/src/spill/spill_pool.rs +++ b/datafusion/physical-plan/src/spill/spill_pool.rs @@ -399,6 +399,7 @@ impl Drop for SpillPoolWriter { /// # use datafusion_physical_plan::spill::spill_pool; /// # use datafusion_physical_plan::spill::SpillManager; // Re-exported for doctests /// # use datafusion_physical_plan::metrics::{ExecutionPlanMetricsSet, SpillMetrics}; +/// # use datafusion_execution::memory_pool::MemoryConsumer; /// # /// # #[tokio::main] /// # async fn main() -> datafusion_common::Result<()> { @@ -406,7 +407,8 @@ impl Drop for SpillPoolWriter { /// # let env = Arc::new(RuntimeEnv::default()); /// # let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); /// # let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); -/// # let spill_manager = Arc::new(SpillManager::new(env, metrics, schema.clone())); +/// # let reservation = MemoryConsumer::new("example").with_can_spill(true).register(&env.memory_pool); +/// # let spill_manager = Arc::new(SpillManager::new(Arc::clone(&env), metrics, schema.clone(), reservation)); /// # /// // Create channel with 1MB file size limit /// let (writer, mut reader) = spill_pool::channel(1024 * 1024, spill_manager); @@ -565,7 +567,7 @@ impl Stream for SpillFile { // we want this unbuffered because files are actively being written to match self .spill_manager - .read_spill_as_stream_unbuffered(file, None) + .read_spill_as_stream_unbuffered(file, None, None) { Ok(stream) => { self.reader = Some(SpillFileReader { @@ -767,7 +769,7 @@ mod tests { let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let schema = create_test_schema(); - let spill_manager = Arc::new(SpillManager::new(env, metrics, schema)); + let spill_manager = Arc::new(SpillManager::new_default(env, metrics, schema)); channel(max_file_size, spill_manager) } @@ -778,7 +780,8 @@ mod tests { let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let schema = create_test_schema(); - let spill_manager = Arc::new(SpillManager::new(env, metrics.clone(), schema)); + let spill_manager = + Arc::new(SpillManager::new_default(env, metrics.clone(), schema)); let (writer, reader) = channel(max_file_size, spill_manager); (writer, reader, metrics) @@ -1319,8 +1322,11 @@ mod tests { let env = Arc::new(RuntimeEnv::default()); let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let schema = create_test_schema(); - let spill_manager = - Arc::new(SpillManager::new(Arc::clone(&env), metrics.clone(), schema)); + let spill_manager = Arc::new(SpillManager::new_default( + Arc::clone(&env), + metrics.clone(), + schema, + )); let (writer, mut reader) = channel(1024 * 1024, spill_manager); @@ -1461,7 +1467,8 @@ mod tests { let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let schema = create_test_schema(); - let spill_manager = Arc::new(SpillManager::new(runtime, metrics.clone(), schema)); + let spill_manager = + Arc::new(SpillManager::new_default(runtime, metrics.clone(), schema)); let (writer, mut reader) = channel(batch_size, spill_manager); diff --git a/datafusion/sqllogictest/src/accounting_pool.rs b/datafusion/sqllogictest/src/accounting_pool.rs index a9d2db9f12261..aec453f0994f6 100644 --- a/datafusion/sqllogictest/src/accounting_pool.rs +++ b/datafusion/sqllogictest/src/accounting_pool.rs @@ -39,8 +39,9 @@ use std::sync::Arc; /// Headroom over the pool's declared limit. Anything past this is an /// untracked allocation — by definition, since DF's pool didn't see it. /// -/// 800% high, but that's what it takes to pass the SLT suite right now. Goal should be ~10% -const HEADROOM_FACTOR: f64 = 8.0; +/// Reduced from 8.0 after adding SpillManager IPC buffer accounting, +/// GroupValuesRows emit pre-reservation, and NLJ probe-side tracking. +const HEADROOM_FACTOR: f64 = 6.0; pub struct AccountingMemoryPool { inner: Arc,