diff --git a/datafusion/ffi/src/schema_provider.rs b/datafusion/ffi/src/schema_provider.rs index 8441b84b48697..dedf127534c86 100644 --- a/datafusion/ffi/src/schema_provider.rs +++ b/datafusion/ffi/src/schema_provider.rs @@ -22,6 +22,7 @@ use async_ffi::{FfiFuture, FutureExt}; use async_trait::async_trait; use datafusion_catalog::{SchemaProvider, TableProvider}; use datafusion_common::error::{DataFusionError, Result}; +use datafusion_expr::TableType; use datafusion_proto::logical_plan::{ DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; @@ -32,6 +33,7 @@ use tokio::runtime::Handle; use crate::execution::FFI_TaskContextProvider; use crate::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use crate::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use crate::table_source::FFI_TableType; use crate::util::{FFI_Option, FFI_Result}; use crate::{df_result, sresult_return}; @@ -50,6 +52,12 @@ pub struct FFI_SchemaProvider { FFI_Result>, >, + pub table_type: + unsafe extern "C" fn( + provider: &Self, + name: SString, + ) -> FfiFuture>>, + pub register_table: unsafe extern "C" fn( provider: &Self, name: SString, @@ -144,6 +152,24 @@ unsafe extern "C" fn table_fn_wrapper( } } +unsafe extern "C" fn table_type_fn_wrapper( + provider: &FFI_SchemaProvider, + name: SString, +) -> FfiFuture>> { + unsafe { + let provider = Arc::clone(provider.inner()); + + async move { + let table_type = sresult_return!(provider.table_type(name.as_str()).await) + .map(Into::into) + .into(); + + FFI_Result::Ok(table_type) + } + .into_ffi() + } +} + unsafe extern "C" fn register_table_fn_wrapper( provider: &FFI_SchemaProvider, name: SString, @@ -216,6 +242,7 @@ unsafe extern "C" fn clone_fn_wrapper( owner_name: provider.owner_name.clone(), table_names: table_names_fn_wrapper, table: table_fn_wrapper, + table_type: table_type_fn_wrapper, register_table: register_table_fn_wrapper, deregister_table: deregister_table_fn_wrapper, table_exist: table_exist_fn_wrapper, @@ -270,6 +297,7 @@ impl FFI_SchemaProvider { owner_name, table_names: table_names_fn_wrapper, table: table_fn_wrapper, + table_type: table_type_fn_wrapper, register_table: register_table_fn_wrapper, deregister_table: deregister_table_fn_wrapper, table_exist: table_exist_fn_wrapper, @@ -339,6 +367,15 @@ impl SchemaProvider for ForeignSchemaProvider { } } + async fn table_type(&self, name: &str) -> Result> { + unsafe { + let table_type: Option = + df_result!((self.0.table_type)(&self.0, name.into()).await)?.into(); + + Ok(table_type.map(Into::into)) + } + } + fn register_table( &self, name: String, @@ -384,6 +421,7 @@ mod tests { use arrow::datatypes::Schema; use datafusion::catalog::MemorySchemaProvider; use datafusion::datasource::empty::EmptyTable; + use std::sync::atomic::{AtomicUsize, Ordering}; use super::*; @@ -452,6 +490,96 @@ mod tests { assert!(foreign_schema_provider.table_exist("second_table")); } + #[derive(Debug)] + struct TableTypeSchemaProvider { + table_calls: Arc, + table_type_calls: Arc, + } + + #[async_trait] + impl SchemaProvider for TableTypeSchemaProvider { + fn table_names(&self) -> Vec { + vec!["view_table".to_string()] + } + + async fn table( + &self, + _name: &str, + ) -> Result>, DataFusionError> { + self.table_calls.fetch_add(1, Ordering::SeqCst); + Ok(Some(empty_table())) + } + + async fn table_type(&self, name: &str) -> Result> { + self.table_type_calls.fetch_add(1, Ordering::SeqCst); + Ok((name == "view_table").then_some(TableType::View)) + } + + fn table_exist(&self, name: &str) -> bool { + name == "view_table" + } + } + + #[tokio::test] + async fn test_ffi_schema_provider_table_type_uses_foreign_hook() { + let table_calls = Arc::new(AtomicUsize::new(0)); + let table_type_calls = Arc::new(AtomicUsize::new(0)); + let schema_provider = Arc::new(TableTypeSchemaProvider { + table_calls: Arc::clone(&table_calls), + table_type_calls: Arc::clone(&table_type_calls), + }); + + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let mut ffi_schema_provider = + FFI_SchemaProvider::new(schema_provider, None, task_ctx_provider, None); + ffi_schema_provider.library_marker_id = crate::mock_foreign_marker_id; + + let foreign_schema_provider: Arc = + (&ffi_schema_provider).into(); + + let table_type = foreign_schema_provider + .table_type("view_table") + .await + .expect("Unable to query table type"); + + assert_eq!(table_type, Some(TableType::View)); + assert_eq!(table_type_calls.load(Ordering::SeqCst), 1); + assert_eq!(table_calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn test_ffi_schema_provider_table_type_local_bypass() { + let table_calls = Arc::new(AtomicUsize::new(0)); + let table_type_calls = Arc::new(AtomicUsize::new(0)); + let schema_provider = Arc::new(TableTypeSchemaProvider { + table_calls: Arc::clone(&table_calls), + table_type_calls: Arc::clone(&table_type_calls), + }); + + let (_ctx, task_ctx_provider) = crate::util::tests::test_session_and_ctx(); + + let ffi_schema_provider = + FFI_SchemaProvider::new(schema_provider, None, task_ctx_provider, None); + + let schema_provider: Arc = (&ffi_schema_provider).into(); + + assert!( + schema_provider + .downcast_ref::() + .is_some() + ); + + let table_type = schema_provider + .table_type("view_table") + .await + .expect("Unable to query table type"); + + assert_eq!(table_type, Some(TableType::View)); + assert_eq!(table_type_calls.load(Ordering::SeqCst), 1); + assert_eq!(table_calls.load(Ordering::SeqCst), 0); + } + #[test] fn test_ffi_schema_provider_local_bypass() { let schema_provider = Arc::new(MemorySchemaProvider::new()); diff --git a/datafusion/ffi/src/tests/catalog.rs b/datafusion/ffi/src/tests/catalog.rs index 0c02de5d049ae..f16de6d7ca296 100644 --- a/datafusion/ffi/src/tests/catalog.rs +++ b/datafusion/ffi/src/tests/catalog.rs @@ -35,6 +35,7 @@ use datafusion_catalog::{ MemoryCatalogProviderList, MemorySchemaProvider, SchemaProvider, TableProvider, }; use datafusion_common::{Result, exec_err}; +use datafusion_expr::TableType; use crate::catalog_provider::FFI_CatalogProvider; use crate::catalog_provider_list::FFI_CatalogProviderList; @@ -96,6 +97,14 @@ impl SchemaProvider for FixedSchemaProvider { self.inner.table(name).await } + async fn table_type(&self, name: &str) -> Result> { + if name == "purchases" { + return Ok(Some(TableType::View)); + } + + self.inner.table_type(name).await + } + fn table_exist(&self, name: &str) -> bool { self.inner.table_exist(name) } @@ -119,8 +128,8 @@ impl SchemaProvider for FixedSchemaProvider { } } -/// This catalog provider is intended only for unit tests. It prepopulates with one -/// schema and only allows for schemas named after four types of fruit. +/// This catalog provider is intended only for unit tests. It prepopulates with +/// two schemas and only allows for schemas named after four types of fruit. #[derive(Debug)] pub struct FixedCatalogProvider { inner: MemoryCatalogProvider, @@ -131,6 +140,11 @@ impl Default for FixedCatalogProvider { let inner = MemoryCatalogProvider::new(); let _ = inner.register_schema("apple", Arc::new(FixedSchemaProvider::default())); + let fallback_schema = Arc::new(MemorySchemaProvider::new()); + fallback_schema + .register_table("sales".to_string(), fruit_table()) + .unwrap(); + let _ = inner.register_schema("banana", fallback_schema); Self { inner } } diff --git a/datafusion/ffi/tests/ffi_catalog.rs b/datafusion/ffi/tests/ffi_catalog.rs index 440a435f75c7d..4864f9cecaaa2 100644 --- a/datafusion/ffi/tests/ffi_catalog.rs +++ b/datafusion/ffi/tests/ffi_catalog.rs @@ -24,6 +24,7 @@ mod tests { use std::sync::Arc; use datafusion::catalog::{CatalogProvider, CatalogProviderList}; + use datafusion_expr::TableType; use datafusion_ffi::tests::utils::get_module; #[tokio::test] @@ -47,6 +48,25 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_schema_provider_table_type() -> datafusion_common::Result<()> { + let module = get_module()?; + let (_ctx, codec) = super::utils::ctx_and_codec(); + + let ffi_catalog = (module.create_catalog)(codec); + let foreign_catalog: Arc = (&ffi_catalog).into(); + + let override_schema = foreign_catalog.schema("apple").expect("apple schema"); + let fallback_schema = foreign_catalog.schema("banana").expect("banana schema"); + let table_type = override_schema.table_type("purchases").await?; + let fallback_table_type = fallback_schema.table_type("sales").await?; + + assert_eq!(table_type, Some(TableType::View)); + assert_eq!(fallback_table_type, Some(TableType::Base)); + + Ok(()) + } + #[tokio::test] async fn test_catalog_list() -> datafusion_common::Result<()> { let module = get_module()?;