Skip to content

Commit 2dc9a02

Browse files
authored
Added the run_throughput_pipeline_py PyO3 binding function so make test_python passes (#1098)
1 parent 5db5e8d commit 2dc9a02

1 file changed

Lines changed: 36 additions & 0 deletions

File tree

qdp/qdp-python/src/lib.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,46 @@ mod pytorch;
2121
mod tensor;
2222

2323
use engine::QdpEngine;
24+
use pyo3::exceptions::PyRuntimeError;
2425
use pyo3::prelude::*;
2526
use tensor::QuantumTensor;
2627

2728
#[cfg(target_os = "linux")]
2829
use loader::PyQuantumLoader;
2930

31+
#[cfg(target_os = "linux")]
32+
#[pyfunction]
33+
#[pyo3(signature = (device_id, num_qubits, batch_size, total_batches, encoding_method, warmup_batches=0, seed=None))]
34+
#[allow(clippy::too_many_arguments)]
35+
fn run_throughput_pipeline_py(
36+
py: Python<'_>,
37+
device_id: usize,
38+
num_qubits: u32,
39+
batch_size: usize,
40+
total_batches: usize,
41+
encoding_method: String,
42+
warmup_batches: usize,
43+
seed: Option<u64>,
44+
) -> PyResult<(f64, f64, f64)> {
45+
let config = qdp_core::PipelineConfig {
46+
device_id,
47+
num_qubits,
48+
batch_size,
49+
total_batches,
50+
encoding_method,
51+
seed,
52+
warmup_batches,
53+
};
54+
let result = py
55+
.detach(|| qdp_core::run_throughput_pipeline(&config))
56+
.map_err(|e| PyRuntimeError::new_err(format!("Pipeline failed: {e}")))?;
57+
Ok((
58+
result.duration_sec,
59+
result.vectors_per_sec,
60+
result.latency_ms_per_vector,
61+
))
62+
}
63+
3064
/// Quantum Data Plane (QDP) Python module
3165
///
3266
/// GPU-accelerated quantum data encoding with DLPack integration.
@@ -39,5 +73,7 @@ fn _qdp(m: &Bound<'_, PyModule>) -> PyResult<()> {
3973
m.add_class::<QuantumTensor>()?;
4074
#[cfg(target_os = "linux")]
4175
m.add_class::<PyQuantumLoader>()?;
76+
#[cfg(target_os = "linux")]
77+
m.add_function(wrap_pyfunction!(run_throughput_pipeline_py, m)?)?;
4278
Ok(())
4379
}

0 commit comments

Comments
 (0)