diff --git a/Cargo.toml b/Cargo.toml index 58e654e..781fdd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,9 @@ authors = ["Oddity.ai Developers "] repository = "https://github.com/oddity-ai/async-tensorrt" license = "MIT OR Apache-2.0" +[features] +lean = [] + [dependencies] async-cuda = "0.5.4" cpp = "0.5" diff --git a/build.rs b/build.rs index 7206204..af61c80 100644 --- a/build.rs +++ b/build.rs @@ -33,6 +33,11 @@ fn main() { #[cfg(not(windows))] println!("cargo:rustc-link-search=/usr/local/tensorrt/lib64"); + #[cfg(feature = "lean")] + println!("cargo:rustc-link-lib=nvinfer_lean"); + + #[cfg(not(feature = "lean"))] println!("cargo:rustc-link-lib=nvinfer"); + #[cfg(not(feature = "lean"))] println!("cargo:rustc-link-lib=nvonnxparser"); } diff --git a/src/engine.rs b/src/engine.rs index bc3d54f..64f0310 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -5,6 +5,7 @@ use crate::ffi::memory::HostBuffer; use crate::ffi::sync::engine::Engine as InnerEngine; use crate::ffi::sync::engine::ExecutionContext as InnerExecutionContext; +pub use crate::ffi::sync::engine::TensorDataType; pub use crate::ffi::sync::engine::TensorIoMode; type Result = std::result::Result; @@ -77,6 +78,18 @@ impl Engine { pub fn tensor_io_mode(&self, tensor_name: &str) -> TensorIoMode { self.inner.tensor_io_mode(tensor_name) } + + /// Get the data type of a tensor. + /// + /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#a569361fe7b7fced4b9c3f500346baca2) + /// + /// # Arguments + /// + /// * `tensor_name` - Tensor name. + #[inline(always)] + pub fn tensor_data_type(&self, tensor_name: &str) -> TensorDataType { + self.inner.tensor_data_type(tensor_name) + } } /// Context for executing inference using an engine. diff --git a/src/ffi/builder_config.rs b/src/ffi/builder_config.rs index 3b455da..8f300a8 100644 --- a/src/ffi/builder_config.rs +++ b/src/ffi/builder_config.rs @@ -66,6 +66,34 @@ impl BuilderConfig { self } + /// Set the `kVERSION_COMPATIBLE` flag. + /// + /// [TensorRT documentation for `setFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ac9821504ae7a11769e48b0e62761837e) + /// [TensorRT documentation for `kVERSION_COMPATIBLE`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#abdc74c40fe7a0c3d05d2caeccfbc29c1a64917aa1f8d9238c555a46fa1d4e83b7) + pub fn with_version_compability(mut self) -> Self { + let internal = self.as_mut_ptr(); + cpp!(unsafe [ + internal as "void*" + ] { + ((IBuilderConfig*) internal)->setFlag(BuilderFlag::kVERSION_COMPATIBLE); + }); + self + } + + /// Set the `kEXCLUDE_LEAN_RUNTIME` flag. + /// + /// [TensorRT documentation for `setFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ac9821504ae7a11769e48b0e62761837e) + /// [TensorRT documentation for `kEXCLUDE_LEAN_RUNTIME`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#abdc74c40fe7a0c3d05d2caeccfbc29c1a239d59ead8393beeecaadd21ce3b3502) + pub fn with_exclude_lean_runtime(mut self) -> Self { + let internal = self.as_mut_ptr(); + cpp!(unsafe [ + internal as "void*" + ] { + ((IBuilderConfig*) internal)->setFlag(BuilderFlag::kEXCLUDE_LEAN_RUNTIME); + }); + self + } + /// Set the `kFP16` flag. /// /// [TensorRT documentation for `setFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ac9821504ae7a11769e48b0e62761837e) @@ -80,6 +108,20 @@ impl BuilderConfig { self } + /// Set the `kINT8` flag. + /// + /// [TensorRT documentation for `setFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ac9821504ae7a11769e48b0e62761837e) + /// [TensorRT documentation for `kINT8`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#abdc74c40fe7a0c3d05d2caeccfbc29c1a69c1a4a69db0e50820cf63122f90ad09) + pub fn with_int8(mut self) -> Self { + let internal = self.as_mut_ptr(); + cpp!(unsafe [ + internal as "void*" + ] { + ((IBuilderConfig*) internal)->setFlag(BuilderFlag::kINT8); + }); + self + } + /// Add an optimization profile. /// /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ab97fa40c85fa8afab65fc2659e38da82) diff --git a/src/ffi/mod.rs b/src/ffi/mod.rs index 66c0f70..4d43243 100644 --- a/src/ffi/mod.rs +++ b/src/ffi/mod.rs @@ -8,11 +8,16 @@ mod pre { mod logger; } +#[cfg(not(feature = "lean"))] pub mod builder_config; + pub mod error; pub mod memory; pub mod network; + +#[cfg(not(feature = "lean"))] pub mod optimization_profile; + pub mod parser; pub mod sync; diff --git a/src/ffi/sync/engine.rs b/src/ffi/sync/engine.rs index 359ec9a..bdd65fc 100644 --- a/src/ffi/sync/engine.rs +++ b/src/ffi/sync/engine.rs @@ -110,6 +110,19 @@ impl Engine { TensorIoMode::from_i32(tensor_io_mode) } + pub fn tensor_data_type(&self, tensor_name: &str) -> TensorDataType { + let internal = self.as_ptr(); + let tensor_name_cstr = std::ffi::CString::new(tensor_name).unwrap(); + let tensor_name_ptr = tensor_name_cstr.as_ptr(); + let tensor_data_type = cpp!(unsafe [ + internal as "const void*", + tensor_name_ptr as "const char*" + ] -> i32 as "std::int32_t" { + return (std::int32_t) ((const ICudaEngine*) internal)->getTensorDataType(tensor_name_ptr); + }); + TensorDataType::from_i32(tensor_data_type) + } + #[inline(always)] pub fn as_ptr(&self) -> *const std::ffi::c_void { let Engine { internal, .. } = *self; @@ -342,3 +355,41 @@ struct Dims { pub nbDims: i32, pub d: [i32; 8usize], } + +/// Tensor DataType. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum TensorDataType { + FLOAT, + HALF, + INT8, + INT32, + BOOL, + UINT8, + FP8, + BF16, + INT64, + INT4, +} + +impl TensorDataType { + /// Create [`TensorDataType`] from `value`. + /// + /// # Arguments + /// + /// * `value` - Integer representation of IO mode. + fn from_i32(value: i32) -> Self { + match value { + 0 => TensorDataType::FLOAT, + 1 => TensorDataType::HALF, + 2 => TensorDataType::INT8, + 3 => TensorDataType::INT32, + 4 => TensorDataType::BOOL, + 5 => TensorDataType::UINT8, + 6 => TensorDataType::FP8, + 7 => TensorDataType::BF16, + 8 => TensorDataType::INT64, + 9 => TensorDataType::INT4, + _ => panic!("Unknown data type {}", value), + } + } +} diff --git a/src/ffi/sync/mod.rs b/src/ffi/sync/mod.rs index ee05aa7..1cc2ee6 100644 --- a/src/ffi/sync/mod.rs +++ b/src/ffi/sync/mod.rs @@ -1,3 +1,5 @@ +#[cfg(not(feature = "lean"))] pub mod builder; + pub mod engine; pub mod runtime; diff --git a/src/ffi/sync/runtime.rs b/src/ffi/sync/runtime.rs index efc9045..8e2dd8d 100644 --- a/src/ffi/sync/runtime.rs +++ b/src/ffi/sync/runtime.rs @@ -85,6 +85,16 @@ impl Runtime { result!(internal_engine, Engine::wrap(internal_engine, self)) } + pub fn set_engine_host_code_allowed(&mut self, allowed: bool) { + let internal = self.as_mut_ptr(); + cpp!(unsafe [ + internal as "void*", + allowed as "bool" + ] { + ((IRuntime*) internal)->setEngineHostCodeAllowed(allowed); + }); + } + #[inline(always)] pub fn as_ptr(&self) -> *const std::ffi::c_void { self.addr diff --git a/src/lib.rs b/src/lib.rs index 95745e1..4214573 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,8 @@ #![recursion_limit = "256"] +#[cfg(not(feature = "lean"))] pub mod builder; + pub mod engine; pub mod error; pub mod ffi; @@ -9,12 +11,20 @@ pub mod runtime; #[cfg(test)] mod tests; +#[cfg(not(feature = "lean"))] pub use builder::Builder; + pub use engine::{Engine, ExecutionContext}; pub use error::Error; + +#[cfg(not(feature = "lean"))] pub use ffi::builder_config::BuilderConfig; + pub use ffi::memory::HostBuffer; pub use ffi::network::{NetworkDefinition, NetworkDefinitionCreationFlags, Tensor}; + +#[cfg(not(feature = "lean"))] pub use ffi::optimization_profile::OptimizationProfile; + pub use ffi::parser::Parser; pub use runtime::Runtime; diff --git a/src/runtime.rs b/src/runtime.rs index c8208e7..fe912b0 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -20,6 +20,17 @@ impl Runtime { Self { inner } } + /// Set whether the runtime is allowed to deserialize engines with host executable code. + /// + /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_runtime.html#a5a19c2524f74179cd9b781c6240eb3ce) + /// + /// # Arguments + /// + /// * `allowed` - Whether the runtime is allowed to deserialize engines with host executable code. + pub fn set_engine_host_code_allowed(&mut self, allowed: bool) { + self.inner.set_engine_host_code_allowed(allowed); + } + /// Deserialize engine from a plan (a [`HostBuffer`]). /// /// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_runtime.html#ad0dc765e77cab99bfad901e47216a767)