Skip to content

Commit 16ceb85

Browse files
authored
feat(rust)!: return Box<RecordBatchReader + 'static> for caller flexibility (#3904)
This avoids inadvertently tying the result lifetime to the lifetimes of any input arguments, and fully type-erases the result. Closes #2694.
1 parent 7cc8b6a commit 16ceb85

6 files changed

Lines changed: 76 additions & 55 deletions

File tree

rust/core/src/sync.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ pub trait Connection: Optionable<Option = OptionConnection> {
125125
fn get_info(
126126
&self,
127127
codes: Option<HashSet<options::InfoCode>>,
128-
) -> Result<impl RecordBatchReader + Send>;
128+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>>;
129129

130130
/// Get a hierarchical view of all catalogs, database schemas, tables, and
131131
/// columns.
@@ -233,7 +233,7 @@ pub trait Connection: Optionable<Option = OptionConnection> {
233233
table_name: Option<&str>,
234234
table_type: Option<Vec<&str>>,
235235
column_name: Option<&str>,
236-
) -> Result<impl RecordBatchReader + Send>;
236+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>>;
237237

238238
/// Get the Arrow schema of a table.
239239
///
@@ -258,7 +258,7 @@ pub trait Connection: Optionable<Option = OptionConnection> {
258258
/// Field Name | Field Type
259259
/// ---------------|--------------
260260
/// table_type | utf8 not null
261-
fn get_table_types(&self) -> Result<impl RecordBatchReader + Send>;
261+
fn get_table_types(&self) -> Result<Box<dyn RecordBatchReader + Send + 'static>>;
262262

263263
/// Get the names of statistics specific to this driver.
264264
///
@@ -273,7 +273,7 @@ pub trait Connection: Optionable<Option = OptionConnection> {
273273
///
274274
/// # Since
275275
/// ADBC API revision 1.1.0
276-
fn get_statistic_names(&self) -> Result<impl RecordBatchReader + Send>;
276+
fn get_statistic_names(&self) -> Result<Box<dyn RecordBatchReader + Send + 'static>>;
277277

278278
/// Get statistics about the data distribution of table(s).
279279
///
@@ -339,7 +339,7 @@ pub trait Connection: Optionable<Option = OptionConnection> {
339339
db_schema: Option<&str>,
340340
table_name: Option<&str>,
341341
approximate: bool,
342-
) -> Result<impl RecordBatchReader + Send>;
342+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>>;
343343

344344
/// Commit any pending transactions. Only used if autocommit is disabled.
345345
///
@@ -358,7 +358,10 @@ pub trait Connection: Optionable<Option = OptionConnection> {
358358
/// # Arguments
359359
///
360360
/// - `partition` - The partition descriptor.
361-
fn read_partition(&self, partition: impl AsRef<[u8]>) -> Result<impl RecordBatchReader + Send>;
361+
fn read_partition(
362+
&self,
363+
partition: impl AsRef<[u8]>,
364+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>>;
362365
}
363366

364367
/// A handle to an ADBC statement.
@@ -391,10 +394,7 @@ pub trait Statement: Optionable<Option = OptionStatement> {
391394
/// Execute a statement and get the results.
392395
///
393396
/// This invalidates any prior result sets.
394-
// TODO(alexandreyc): is the Send bound absolutely necessary? same question
395-
// for all methods that return an impl RecordBatchReader
396-
// See: https://github.com/apache/arrow-adbc/pull/1725#discussion_r1567748242
397-
fn execute(&mut self) -> Result<impl RecordBatchReader + Send>;
397+
fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send + 'static>>;
398398

399399
/// Execute a statement that doesn’t have a result set and get the number
400400
/// of affected rows.

rust/driver/datafusion/src/lib.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ impl Connection for DataFusionConnection {
742742
fn get_info(
743743
&self,
744744
codes: Option<std::collections::HashSet<adbc_core::options::InfoCode>>,
745-
) -> Result<impl RecordBatchReader + Send> {
745+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
746746
let mut get_info_builder = GetInfoBuilder::new();
747747

748748
codes.unwrap().into_iter().for_each(|f| match f {
@@ -755,7 +755,7 @@ impl Connection for DataFusionConnection {
755755

756756
let batch = get_info_builder.finish()?;
757757
let reader = SingleBatchReader::new(batch);
758-
Ok(reader)
758+
Ok(Box::new(reader))
759759
}
760760

761761
fn get_objects(
@@ -766,10 +766,10 @@ impl Connection for DataFusionConnection {
766766
_table_name: Option<&str>,
767767
_table_type: Option<Vec<&str>>,
768768
_column_name: Option<&str>,
769-
) -> Result<impl RecordBatchReader + Send> {
769+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
770770
let batch = GetObjectsBuilder::new().build(&self.runtime, &self.ctx, &depth)?;
771771
let reader = SingleBatchReader::new(batch);
772-
Ok(reader)
772+
Ok(Box::new(reader))
773773
}
774774

775775
fn get_table_schema(
@@ -781,11 +781,11 @@ impl Connection for DataFusionConnection {
781781
todo!()
782782
}
783783

784-
fn get_table_types(&self) -> Result<SingleBatchReader> {
784+
fn get_table_types(&self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
785785
todo!()
786786
}
787787

788-
fn get_statistic_names(&self) -> Result<SingleBatchReader> {
788+
fn get_statistic_names(&self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
789789
todo!()
790790
}
791791

@@ -795,7 +795,7 @@ impl Connection for DataFusionConnection {
795795
_db_schema: Option<&str>,
796796
_table_name: Option<&str>,
797797
_approximate: bool,
798-
) -> Result<SingleBatchReader> {
798+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
799799
todo!()
800800
}
801801

@@ -807,7 +807,10 @@ impl Connection for DataFusionConnection {
807807
todo!()
808808
}
809809

810-
fn read_partition(&self, _partition: impl AsRef<[u8]>) -> Result<SingleBatchReader> {
810+
fn read_partition(
811+
&self,
812+
_partition: impl AsRef<[u8]>,
813+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
811814
todo!()
812815
}
813816
}
@@ -901,7 +904,7 @@ impl Statement for DataFusionStatement {
901904
todo!()
902905
}
903906

904-
fn execute(&mut self) -> Result<impl RecordBatchReader + Send> {
907+
fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send>> {
905908
self.runtime.block_on(async {
906909
let df = if self.sql_query.is_some() {
907910
self.ctx
@@ -916,7 +919,7 @@ impl Statement for DataFusionStatement {
916919
self.ctx.execute_logical_plan(plan).await.unwrap()
917920
};
918921

919-
Ok(DataFusionReader::new(df).await)
922+
Ok(Box::new(DataFusionReader::new(df).await) as Box<dyn RecordBatchReader + Send>)
920923
})
921924
}
922925

rust/driver/dummy/src/lib.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,10 @@ impl Connection for DummyConnection {
310310
Ok(())
311311
}
312312

313-
fn get_info(&self, _codes: Option<HashSet<InfoCode>>) -> Result<impl RecordBatchReader> {
313+
fn get_info(
314+
&self,
315+
_codes: Option<HashSet<InfoCode>>,
316+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
314317
let string_value_array = StringArray::from(vec!["MyVendorName"]);
315318
let bool_value_array = BooleanArray::from(vec![true]);
316319
let int64_value_array = Int64Array::from(vec![42]);
@@ -407,7 +410,7 @@ impl Connection for DummyConnection {
407410
vec![Arc::new(name_array), Arc::new(value_array)],
408411
)?;
409412
let reader = SingleBatchReader::new(batch);
410-
Ok(reader)
413+
Ok(Box::new(reader))
411414
}
412415

413416
fn get_objects(
@@ -418,7 +421,7 @@ impl Connection for DummyConnection {
418421
_table_name: Option<&str>,
419422
_table_type: Option<Vec<&str>>,
420423
_column_name: Option<&str>,
421-
) -> Result<impl RecordBatchReader> {
424+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
422425
let constraint_column_usage_array_inner = StructArray::from(vec![
423426
(
424427
Arc::new(Field::new("fk_catalog", DataType::Utf8, true)),
@@ -645,7 +648,7 @@ impl Connection for DummyConnection {
645648
],
646649
)?;
647650
let reader = SingleBatchReader::new(batch);
648-
Ok(reader)
651+
Ok(Box::new(reader))
649652
}
650653

651654
fn get_statistics(
@@ -654,7 +657,7 @@ impl Connection for DummyConnection {
654657
_db_schema: Option<&str>,
655658
_table_name: Option<&str>,
656659
_approximate: bool,
657-
) -> Result<impl RecordBatchReader> {
660+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
658661
let statistic_value_int64_array = Int64Array::from(Vec::<i64>::new());
659662
let statistic_value_uint64_array = UInt64Array::from(vec![42]);
660663
let statistic_value_float64_array = Float64Array::from(Vec::<f64>::new());
@@ -759,18 +762,18 @@ impl Connection for DummyConnection {
759762
)?;
760763

761764
let reader = SingleBatchReader::new(batch);
762-
Ok(reader)
765+
Ok(Box::new(reader))
763766
}
764767

765-
fn get_statistic_names(&self) -> Result<impl RecordBatchReader> {
768+
fn get_statistic_names(&self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
766769
let name_array = StringArray::from(vec!["sum", "min", "max"]);
767770
let key_array = Int16Array::from(vec![0, 1, 2]);
768771
let batch = RecordBatch::try_new(
769772
schemas::GET_STATISTIC_NAMES_SCHEMA.clone(),
770773
vec![Arc::new(name_array), Arc::new(key_array)],
771774
)?;
772775
let reader = SingleBatchReader::new(batch);
773-
Ok(reader)
776+
Ok(Box::new(reader))
774777
}
775778

776779
fn get_table_schema(
@@ -792,17 +795,20 @@ impl Connection for DummyConnection {
792795
}
793796
}
794797

795-
fn get_table_types(&self) -> Result<impl RecordBatchReader> {
798+
fn get_table_types(&self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
796799
let array = Arc::new(StringArray::from(vec!["table", "view"]));
797800
let batch = RecordBatch::try_new(schemas::GET_TABLE_TYPES_SCHEMA.clone(), vec![array])?;
798801
let reader = SingleBatchReader::new(batch);
799-
Ok(reader)
802+
Ok(Box::new(reader))
800803
}
801804

802-
fn read_partition(&self, _partition: impl AsRef<[u8]>) -> Result<impl RecordBatchReader> {
805+
fn read_partition(
806+
&self,
807+
_partition: impl AsRef<[u8]>,
808+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
803809
let batch = get_table_data();
804810
let reader = SingleBatchReader::new(batch);
805-
Ok(reader)
811+
Ok(Box::new(reader))
806812
}
807813

808814
fn rollback(&mut self) -> Result<()> {
@@ -852,11 +858,11 @@ impl Statement for DummyStatement {
852858
Ok(())
853859
}
854860

855-
fn execute(&mut self) -> Result<impl RecordBatchReader> {
861+
fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
856862
maybe_panic("StatementExecuteQuery");
857863
let batch = get_table_data();
858864
let reader = SingleBatchReader::new(batch);
859-
Ok(reader)
865+
Ok(Box::new(reader))
860866
}
861867

862868
fn execute_partitions(&mut self) -> Result<PartitionedResult> {

rust/driver/snowflake/src/connection.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ impl adbc_core::Connection for Connection {
7474
self.0.cancel()
7575
}
7676

77-
fn get_info(&self, codes: Option<HashSet<InfoCode>>) -> Result<impl RecordBatchReader + Send> {
77+
fn get_info(
78+
&self,
79+
codes: Option<HashSet<InfoCode>>,
80+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
7881
self.0.get_info(codes)
7982
}
8083

@@ -86,7 +89,7 @@ impl adbc_core::Connection for Connection {
8689
table_name: Option<&str>,
8790
table_type: Option<Vec<&str>>,
8891
column_name: Option<&str>,
89-
) -> Result<impl RecordBatchReader + Send> {
92+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
9093
self.0.get_objects(
9194
depth,
9295
catalog,
@@ -106,11 +109,11 @@ impl adbc_core::Connection for Connection {
106109
self.0.get_table_schema(catalog, db_schema, table_name)
107110
}
108111

109-
fn get_table_types(&self) -> Result<impl RecordBatchReader + Send> {
112+
fn get_table_types(&self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
110113
self.0.get_table_types()
111114
}
112115

113-
fn get_statistic_names(&self) -> Result<impl RecordBatchReader + Send> {
116+
fn get_statistic_names(&self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
114117
self.0.get_statistic_names()
115118
}
116119

@@ -120,7 +123,7 @@ impl adbc_core::Connection for Connection {
120123
db_schema: Option<&str>,
121124
table_name: Option<&str>,
122125
approximate: bool,
123-
) -> Result<impl RecordBatchReader + Send> {
126+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
124127
self.0
125128
.get_statistics(catalog, db_schema, table_name, approximate)
126129
}
@@ -133,7 +136,10 @@ impl adbc_core::Connection for Connection {
133136
self.0.rollback()
134137
}
135138

136-
fn read_partition(&self, partition: impl AsRef<[u8]>) -> Result<impl RecordBatchReader + Send> {
139+
fn read_partition(
140+
&self,
141+
partition: impl AsRef<[u8]>,
142+
) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
137143
self.0.read_partition(partition)
138144
}
139145
}

rust/driver/snowflake/src/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ impl adbc_core::Statement for Statement {
6464
self.0.bind_stream(reader)
6565
}
6666

67-
fn execute(&mut self) -> Result<impl RecordBatchReader + Send> {
67+
fn execute(&mut self) -> Result<Box<dyn RecordBatchReader + Send + 'static>> {
6868
self.0.execute()
6969
}
7070

0 commit comments

Comments
 (0)