Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions crates/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, Table
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
use datafusion::error::DataFusionError;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::execution::context::TaskContext;
use datafusion::logical_expr::SortExpr;
use datafusion::logical_expr::dml::InsertOp;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
Expand All @@ -51,6 +52,13 @@ use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};

use crate::common::data_type::PyScalarValue;
use datafusion::physical_plan::{
ExecutionPlan as DFExecutionPlan,
collect as df_collect,
collect_partitioned as df_collect_partitioned,
execute_stream as df_execute_stream,
execute_stream_partitioned as df_execute_stream_partitioned,
};
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err};
use crate::expr::PyExpr;
use crate::expr::sort_expr::{PySortExpr, to_sort_expressions};
Expand Down Expand Up @@ -308,6 +316,9 @@ pub struct PyDataFrame {

// In IPython environment cache batches between __repr__ and _repr_html_ calls.
batches: SharedCachedBatches,

// Cache the last physical plan so that metrics are available after execution.
last_plan: Arc<Mutex<Option<Arc<dyn DFExecutionPlan>>>>,
}

impl PyDataFrame {
Expand All @@ -316,6 +327,7 @@ impl PyDataFrame {
Self {
df: Arc::new(df),
batches: Arc::new(Mutex::new(None)),
last_plan: Arc::new(Mutex::new(None)),
}
}

Expand Down Expand Up @@ -387,6 +399,20 @@ impl PyDataFrame {
Ok(html_str)
}

/// Create the physical plan, cache it in `last_plan`, and return the plan together
/// with a task context. Centralises the repeated three-line pattern that appears in
/// `collect`, `collect_partitioned`, `execute_stream`, and `execute_stream_partitioned`.
fn create_and_cache_plan(
&self,
py: Python,
) -> PyDataFusionResult<(Arc<dyn DFExecutionPlan>, Arc<TaskContext>)> {
let df = self.df.as_ref().clone();
let new_plan = wait_for_future(py, df.create_physical_plan())??;
*self.last_plan.lock() = Some(Arc::clone(&new_plan));
let task_ctx = Arc::new(self.df.as_ref().task_ctx());
Ok((new_plan, task_ctx))
}

async fn collect_column_inner(&self, column: &str) -> Result<ArrayRef, DataFusionError> {
let batches = self
.df
Expand Down Expand Up @@ -645,7 +671,8 @@ impl PyDataFrame {
/// Unless some order is specified in the plan, there is no
/// guarantee of the order of the result.
fn collect<'py>(&self, py: Python<'py>) -> PyResult<Vec<Bound<'py, PyAny>>> {
let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
let (plan, task_ctx) = self.create_and_cache_plan(py)?;
let batches = wait_for_future(py, df_collect(plan, task_ctx))?
.map_err(PyDataFusionError::from)?;
// cannot use PyResult<Vec<RecordBatch>> return type due to
// https://github.com/PyO3/pyo3/issues/1813
Expand All @@ -661,7 +688,8 @@ impl PyDataFrame {
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
/// maintaining the input partitioning.
fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult<Vec<Vec<Bound<'py, PyAny>>>> {
let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
let (plan, task_ctx) = self.create_and_cache_plan(py)?;
let batches = wait_for_future(py, df_collect_partitioned(plan, task_ctx))?
.map_err(PyDataFusionError::from)?;

batches
Expand Down Expand Up @@ -821,7 +849,13 @@ impl PyDataFrame {
}

/// Get the execution plan for this `DataFrame`
///
/// If the DataFrame has already been executed (e.g. via `collect()`),
/// returns the cached plan which includes populated metrics.
fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
if let Some(plan) = self.last_plan.lock().as_ref() {
return Ok(PyExecutionPlan::new(Arc::clone(plan)));
}
let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
Ok(plan.into())
Comment on lines +856 to 860
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you go the route of using the existing last_plan for collect() like in my other comment then I think you could set it here just like you do in collect().

}
Expand Down Expand Up @@ -1146,14 +1180,16 @@ impl PyDataFrame {
}

fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
let df = self.df.as_ref().clone();
let stream = spawn_future(py, async move { df.execute_stream().await })?;
let (plan, task_ctx) = self.create_and_cache_plan(py)?;
let stream = spawn_future(py, async move { df_execute_stream(plan, task_ctx) })?;
Ok(PyRecordBatchStream::new(stream))
}

fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
let df = self.df.as_ref().clone();
let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
let (plan, task_ctx) = self.create_and_cache_plan(py)?;
let streams = spawn_future(py, async move {
df_execute_stream_partitioned(plan, task_ctx)
})?;
Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
}

Expand Down
3 changes: 3 additions & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod errors;
pub mod expr;
#[allow(clippy::borrow_deref_ref)]
mod functions;
pub mod metrics;
mod options;
pub mod physical_plan;
mod pyarrow_filter_expression;
Expand Down Expand Up @@ -92,6 +93,8 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<udtf::PyTableFunction>()?;
m.add_class::<config::PyConfig>()?;
m.add_class::<sql::logical::PyLogicalPlan>()?;
m.add_class::<metrics::PyMetricsSet>()?;
m.add_class::<metrics::PyMetric>()?;
m.add_class::<physical_plan::PyExecutionPlan>()?;
m.add_class::<record_batch::PyRecordBatch>()?;
m.add_class::<record_batch::PyRecordBatchStream>()?;
Expand Down
164 changes: 164 additions & 0 deletions crates/core/src/metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;
use std::sync::Arc;

use datafusion::physical_plan::metrics::{MetricValue, MetricsSet, Metric, Timestamp};
use pyo3::prelude::*;

#[pyclass(frozen, name = "MetricsSet", module = "datafusion")]
#[derive(Debug, Clone)]
pub struct PyMetricsSet {
metrics: MetricsSet,
}

impl PyMetricsSet {
pub fn new(metrics: MetricsSet) -> Self {
Self { metrics }
}
}

#[pymethods]
impl PyMetricsSet {
fn metrics(&self) -> Vec<PyMetric> {
self.metrics
.iter()
.map(|m| PyMetric::new(Arc::clone(m)))
.collect()
}

fn output_rows(&self) -> Option<usize> {
self.metrics.output_rows()
}

fn elapsed_compute(&self) -> Option<usize> {
self.metrics.elapsed_compute()
}

fn spill_count(&self) -> Option<usize> {
self.metrics.spill_count()
}

fn spilled_bytes(&self) -> Option<usize> {
self.metrics.spilled_bytes()
}

fn spilled_rows(&self) -> Option<usize> {
self.metrics.spilled_rows()
}

fn sum_by_name(&self, name: &str) -> Option<usize> {
self.metrics.sum_by_name(name).map(|v| v.as_usize())
}

fn __repr__(&self) -> String {
format!("{}", self.metrics)
}
}

#[pyclass(frozen, name = "Metric", module = "datafusion")]
#[derive(Debug, Clone)]
pub struct PyMetric {
metric: Arc<Metric>,
}

impl PyMetric {
pub fn new(metric: Arc<Metric>) -> Self {
Self { metric }
}

fn timestamp_to_pyobject<'py>(
py: Python<'py>,
ts: &Timestamp,
) -> PyResult<Option<Bound<'py, PyAny>>> {
match ts.value() {
Some(dt) => {
let nanos = dt.timestamp_nanos_opt().ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyOverflowError, _>(
"timestamp out of range",
)
})?;
let datetime_mod = py.import("datetime")?;
let datetime_cls = datetime_mod.getattr("datetime")?;
let tz_utc = datetime_mod.getattr("timezone")?.getattr("utc")?;
let secs = nanos / 1_000_000_000;
let micros = (nanos % 1_000_000_000) / 1_000;
let result = datetime_cls.call_method1(
"fromtimestamp",
(secs as f64 + micros as f64 / 1_000_000.0, tz_utc),
)?;
Ok(Some(result))
}
None => Ok(None),
}
}
}

#[pymethods]
impl PyMetric {
#[getter]
fn name(&self) -> String {
self.metric.value().name().to_string()
}

#[getter]
fn value<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
match self.metric.value() {
MetricValue::OutputRows(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
MetricValue::OutputBytes(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
MetricValue::ElapsedCompute(t) => Ok(Some(t.value().into_pyobject(py)?.into_any())),
MetricValue::SpillCount(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
MetricValue::SpilledBytes(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
MetricValue::SpilledRows(c) => Ok(Some(c.value().into_pyobject(py)?.into_any())),
MetricValue::CurrentMemoryUsage(g) => Ok(Some(g.value().into_pyobject(py)?.into_any())),
MetricValue::Count { count, .. } => Ok(Some(count.value().into_pyobject(py)?.into_any())),
MetricValue::Gauge { gauge, .. } => Ok(Some(gauge.value().into_pyobject(py)?.into_any())),
MetricValue::Time { time, .. } => Ok(Some(time.value().into_pyobject(py)?.into_any())),
MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => {
Self::timestamp_to_pyobject(py, ts)
}
_ => Ok(None),
}
}

fn value_as_datetime<'py>(&self, py: Python<'py>) -> PyResult<Option<Bound<'py, PyAny>>> {
match self.metric.value() {
MetricValue::StartTimestamp(ts) | MetricValue::EndTimestamp(ts) => {
Self::timestamp_to_pyobject(py, ts)
}
_ => Ok(None),
}
}

#[getter]
fn partition(&self) -> Option<usize> {
self.metric.partition()
}

fn labels(&self) -> HashMap<String, String> {
self.metric
.labels()
.iter()
.map(|l| (l.name().to_string(), l.value().to_string()))
.collect()
}

fn __repr__(&self) -> String {
format!("{}", self.metric.value())
}
}
5 changes: 5 additions & 0 deletions crates/core/src/physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use pyo3::types::PyBytes;

use crate::context::PySessionContext;
use crate::errors::PyDataFusionResult;
use crate::metrics::PyMetricsSet;

#[pyclass(
from_py_object,
Expand Down Expand Up @@ -96,6 +97,10 @@ impl PyExecutionPlan {
Ok(Self::new(plan))
}

pub fn metrics(&self) -> Option<PyMetricsSet> {
self.plan.metrics().map(PyMetricsSet::new)
}

fn __repr__(&self) -> String {
self.display_indent()
}
Expand Down
Loading
Loading