diff --git a/Cargo.lock b/Cargo.lock index 28fe6a43a5c5..9e5963a70a0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -348,6 +348,7 @@ dependencies = [ "base64", "bytes", "clap", + "criterion", "futures", "http", "http-body", @@ -1063,6 +1064,7 @@ dependencies = [ "serde", "serde_json", "tinytemplate", + "tokio", "walkdir", ] diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 8f95e1995a67..8e399fbc5a52 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -76,6 +76,7 @@ cli = ["arrow-array/chrono-tz", "arrow-cast/prettyprint", "tonic/tls-webpki-root [dev-dependencies] arrow-cast = { workspace = true, features = ["prettyprint"] } assert_cmd = "2.0.8" +criterion = { workspace = true, default-features = false, features = ["async_tokio"] } http = "1.1.0" http-body = "1.0.0" hyper-util = "0.1" @@ -105,3 +106,8 @@ required-features = ["flight-sql", "tls-ring"] name = "flight_sql_client_cli" path = "tests/flight_sql_client_cli.rs" required-features = ["cli", "flight-sql", "tls-ring"] + +[[bench]] +name = "flight" +path = "benches/flight.rs" +harness = false \ No newline at end of file diff --git a/arrow-flight/benches/common/mod.rs b/arrow-flight/benches/common/mod.rs new file mode 100644 index 000000000000..eb8bea8dc591 --- /dev/null +++ b/arrow-flight/benches/common/mod.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::{Arc, RwLock}; + +use arrow_array::{ + Array, ArrayRef, DictionaryArray, Int32Array, Int64Array, ListArray, RecordBatch, StringArray, + types::Int32Type, +}; +use arrow_buffer::OffsetBuffer; +use arrow_flight::{ + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, + flight_service_server::{FlightService, FlightServiceServer}, +}; +use arrow_schema::{DataType, Field, Schema}; +use bytes::Bytes; +use futures::{StreamExt, TryStreamExt, stream::BoxStream}; +use hyper_util::rt::TokioIo; +use tonic::{ + Request, Response, Status, Streaming, + transport::{Channel, Endpoint, Server}, +}; + +pub type Builder = fn(usize) -> ArrayRef; + +pub const TYPES: &[(&str, Builder)] = &[ + ("fixed", fixed), + ("nested", nested), + ("variable", variable), + ("dict", dict), +]; + +fn fixed(n: usize) -> ArrayRef { + Arc::new(Int64Array::from_iter_values(0..n as i64)) +} + +fn variable(n: usize) -> ArrayRef { + Arc::new(StringArray::from_iter_values( + (0..n).map(|i| format!("variable_string_{i}{}", "_".repeat(i % 16))), + )) +} + +fn nested(n: usize) -> ArrayRef { + let values = Int32Array::from_iter_values(0..(n * 4) as i32); + let offsets = OffsetBuffer::::from_lengths(std::iter::repeat_n(4usize, n)); + let field = Arc::new(Field::new_list_field(DataType::Int32, false)); + Arc::new(ListArray::new(field, offsets, Arc::new(values), None)) +} + +fn dict(n: usize) -> ArrayRef { + let keys = Int32Array::from_iter_values((0..n).map(|i| (i % 32) as i32)); + let values = StringArray::from_iter_values((0..32).map(|i| format!("dictionary_value_{i:03}"))); + Arc::new(DictionaryArray::::try_new(keys, Arc::new(values)).unwrap()) +} + +pub fn build_batch(name: &str, rows: usize, cols: usize, build: Builder) -> RecordBatch { + let arrays: Vec = (0..cols).map(|_| build(rows)).collect(); + let fields: Vec = arrays + .iter() + .enumerate() + .map(|(i, a)| Field::new(format!("column_{i}_{name}"), a.data_type().clone(), false)) + .collect(); + RecordBatch::try_new(Arc::new(Schema::new(fields)), arrays).unwrap() +} + +#[derive(Clone, Default)] +pub struct BenchServer { + frames: Arc>>, +} + +impl BenchServer { + #[allow(dead_code)] + pub fn set_frames(&self, frames: Vec) { + *self.frames.write().unwrap() = frames; + } +} + +fn unimpl() -> Result { + Err(Status::unimplemented("")) +} + +#[rustfmt::skip] +#[tonic::async_trait] +impl FlightService for BenchServer { + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; + + async fn do_get(&self, _: Request) -> Result, Status> { + let frames = self.frames.read().unwrap().clone(); + Ok(Response::new(futures::stream::iter(frames.into_iter().map(Ok)).boxed())) + } + + async fn do_put(&self, req: Request>) -> Result, Status> { + let _: Vec = req.into_inner().try_collect().await?; + let ack = PutResult { app_metadata: Bytes::new() }; + Ok(Response::new(futures::stream::iter([Ok(ack)]).boxed())) + } + + async fn do_exchange(&self, req: Request>) -> Result, Status> { + Ok(Response::new(req.into_inner().boxed())) + } + + async fn handshake(&self, _: Request>) -> Result, Status> { unimpl() } + async fn list_flights(&self, _: Request) -> Result, Status> { unimpl() } + async fn get_flight_info(&self, _: Request) -> Result, Status> { unimpl() } + async fn poll_flight_info(&self, _: Request) -> Result, Status> { unimpl() } + async fn get_schema(&self, _: Request) -> Result, Status> { unimpl() } + async fn do_action(&self, _: Request) -> Result, Status> { unimpl() } + async fn list_actions(&self, _: Request) -> Result, Status> { unimpl() } +} +#[allow(dead_code)] +pub async fn start_server() -> (Channel, BenchServer) { + const DUMMY_URL: &str = "http://localhost:50051"; + + let bench_server = BenchServer::default(); + + let (client, server) = tokio::io::duplex(1024 * 1024); + + let mut client = Some(client); + let channel = Endpoint::try_from(DUMMY_URL) + .expect("Invalid dummy URL for building an endpoint. This should never happen") + .connect_with_connector_lazy(tower::service_fn(move |_| { + let client = client + .take() + .expect("Client taken twice. This should never happen"); + async move { Ok::<_, std::io::Error>(TokioIo::new(client)) } + })); + tokio::spawn( + Server::builder() + .add_service(FlightServiceServer::new(bench_server.clone())) + .serve_with_incoming(tokio_stream::once(Ok::<_, std::io::Error>(server))), + ); + (channel, bench_server) +} diff --git a/arrow-flight/benches/flight.rs b/arrow-flight/benches/flight.rs new file mode 100644 index 000000000000..4841e9dd9822 --- /dev/null +++ b/arrow-flight/benches/flight.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::RecordBatch; +use arrow_flight::{FlightClient, FlightData, encode::FlightDataEncoderBuilder}; +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use futures::TryStreamExt; +use tonic::transport::Channel; + +mod common; +use common::{TYPES, build_batch, start_server}; + +const ROWS: [usize; 2] = [8 * 1024, 64 * 1024]; +const COLS: [usize; 2] = [1, 8]; + +fn bench_encode(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let mut g = c.benchmark_group("encode"); + + for &(name, build) in TYPES { + for &rows in &ROWS { + for &cols in &COLS { + let batch = build_batch(name, rows, cols, build); + let id = BenchmarkId::new(name, format!("{rows}x{cols}")); + g.throughput(Throughput::Bytes(batch.get_array_memory_size() as u64)); + g.bench_with_input(id, &batch, |b, batch| { + b.to_async(&rt).iter(|| async { + let _: Vec = FlightDataEncoderBuilder::new() + .build(futures::stream::iter([Ok(batch.clone())])) + .try_collect() + .await + .unwrap(); + }); + }); + } + } + } +} + +async fn roundtrip(channel: Channel, batch: RecordBatch) { + let mut client = FlightClient::new(channel); + let frames = FlightDataEncoderBuilder::new().build(futures::stream::iter([Ok(batch)])); + let _: Vec = client + .do_exchange(frames) + .await + .unwrap() + .try_collect() + .await + .unwrap(); +} + +fn bench_roundtrip(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let (channel, _) = rt.block_on(start_server()); + let mut g = c.benchmark_group("roundtrip"); + + for &(name, build) in TYPES { + for &rows in &ROWS { + for &cols in &COLS { + let batch = build_batch(name, rows, cols, build); + let id = BenchmarkId::new(name, format!("{rows}x{cols}")); + g.throughput(Throughput::Bytes(batch.get_array_memory_size() as u64)); + g.bench_with_input(id, &batch, |b, batch| { + b.to_async(&rt) + .iter(|| roundtrip(channel.clone(), batch.clone())); + }); + } + } + } +} + +criterion_group!(benches, bench_encode, bench_roundtrip); +criterion_main!(benches);