From 187381ba84e1c5920f3f512a2eaad10c55353ebc Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 14 Nov 2025 10:07:32 -0800 Subject: [PATCH 01/20] move code under src --- plugin_execution_providers/tensorrt/CMakeLists.txt | 6 +++--- .../{ => src}/cuda/cu_inc/unary_elementwise_impl.cuh | 0 .../tensorrt/{ => src}/cuda/unary_elementwise_ops_impl.cu | 0 .../tensorrt/{ => src}/cuda/unary_elementwise_ops_impl.h | 0 .../tensorrt/{ => src}/cuda_allocator.cc | 0 .../tensorrt/{ => src}/cuda_allocator.h | 0 plugin_execution_providers/tensorrt/{ => src}/nv_includes.h | 0 .../tensorrt/{ => src}/onnx_ctx_model_helper.cc | 0 .../tensorrt/{ => src}/onnx_ctx_model_helper.h | 0 .../tensorrt/{ => src}/ort_trt_int8_cal_table.fbs.h | 0 .../tensorrt/{ => src}/tensorrt_execution_provider.cc | 0 .../tensorrt/{ => src}/tensorrt_execution_provider.def | 0 .../tensorrt/{ => src}/tensorrt_execution_provider.h | 0 .../tensorrt/{ => src}/tensorrt_execution_provider.lds | 0 .../{ => src}/tensorrt_execution_provider_data_transfer.cc | 0 .../{ => src}/tensorrt_execution_provider_data_transfer.h | 0 .../tensorrt/{ => src}/tensorrt_execution_provider_info.cc | 0 .../tensorrt/{ => src}/tensorrt_execution_provider_info.h | 0 .../{ => src}/tensorrt_execution_provider_stream_support.cc | 0 .../{ => src}/tensorrt_execution_provider_stream_support.h | 0 .../tensorrt/{ => src}/tensorrt_execution_provider_utils.h | 0 .../tensorrt/{ => src}/tensorrt_provider_factory.cc | 0 .../tensorrt/{ => src}/tensorrt_provider_factory.h | 0 .../tensorrt/{ => src}/utils/cuda/cuda_call.h | 0 .../tensorrt/{ => src}/utils/cuda/cuda_common.h | 0 .../tensorrt/{ => src}/utils/ep_utils.h | 0 .../tensorrt/{ => src}/utils/helper.cc | 0 .../tensorrt/{ => src}/utils/make_string.h | 0 .../tensorrt/{ => src}/utils/ort_graph_to_proto.h | 0 .../tensorrt/{ => src}/utils/parse_string.h | 0 .../tensorrt/{ => src}/utils/path_string.h | 0 .../tensorrt/{ => src}/utils/provider_options.h | 0 .../tensorrt/{ => src}/utils/provider_options_utils.h | 0 33 files changed, 3 insertions(+), 3 deletions(-) rename plugin_execution_providers/tensorrt/{ => src}/cuda/cu_inc/unary_elementwise_impl.cuh (100%) rename plugin_execution_providers/tensorrt/{ => src}/cuda/unary_elementwise_ops_impl.cu (100%) rename plugin_execution_providers/tensorrt/{ => src}/cuda/unary_elementwise_ops_impl.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/cuda_allocator.cc (100%) rename plugin_execution_providers/tensorrt/{ => src}/cuda_allocator.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/nv_includes.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/onnx_ctx_model_helper.cc (100%) rename plugin_execution_providers/tensorrt/{ => src}/onnx_ctx_model_helper.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/ort_trt_int8_cal_table.fbs.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider.cc (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider.def (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider.lds (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider_data_transfer.cc (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider_data_transfer.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider_info.cc (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider_info.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider_stream_support.cc (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider_stream_support.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_execution_provider_utils.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_provider_factory.cc (100%) rename plugin_execution_providers/tensorrt/{ => src}/tensorrt_provider_factory.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/cuda/cuda_call.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/cuda/cuda_common.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/ep_utils.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/helper.cc (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/make_string.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/ort_graph_to_proto.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/parse_string.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/path_string.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/provider_options.h (100%) rename plugin_execution_providers/tensorrt/{ => src}/utils/provider_options_utils.h (100%) diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index 85e6ca9f..cd44d594 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -28,7 +28,7 @@ endif() add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) add_definitions(-DNOMINMAX) -file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu" "./*.h") +file(GLOB tensorrt_src "./src/*.cc" "./src/utils/*.cc" "./src/cuda/unary_elementwise_ops_impl.cu" "./src/*.h") add_library(TensorRTEp SHARED ${tensorrt_src}) if (NOT ORT_HOME) @@ -111,7 +111,7 @@ if (WIN32) # Windows "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") set(TRT_EP_LIB_LINK_FLAG - "-DEF:${CMAKE_SOURCE_DIR}/tensorrt_execution_provider.def") + "-DEF:${CMAKE_SOURCE_DIR}/src/tensorrt_execution_provider.def") else() # Linux set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so") @@ -142,7 +142,7 @@ set_property(TARGET TensorRTEp APPEND_STRING PROPERTY LINK_FLAGS ${TRT_EP_LIB_LINK_FLAG}) target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include" - "./utils" + "./src/utils" "/usr/local/cuda/include" "${TENSORRT_HOME}/include" "${DEPS_PATH}/flatbuffers-src/include" diff --git a/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh b/plugin_execution_providers/tensorrt/src/cuda/cu_inc/unary_elementwise_impl.cuh similarity index 100% rename from plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh rename to plugin_execution_providers/tensorrt/src/cuda/cu_inc/unary_elementwise_impl.cuh diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu b/plugin_execution_providers/tensorrt/src/cuda/unary_elementwise_ops_impl.cu similarity index 100% rename from plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu rename to plugin_execution_providers/tensorrt/src/cuda/unary_elementwise_ops_impl.cu diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h b/plugin_execution_providers/tensorrt/src/cuda/unary_elementwise_ops_impl.h similarity index 100% rename from plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h rename to plugin_execution_providers/tensorrt/src/cuda/unary_elementwise_ops_impl.h diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.cc b/plugin_execution_providers/tensorrt/src/cuda_allocator.cc similarity index 100% rename from plugin_execution_providers/tensorrt/cuda_allocator.cc rename to plugin_execution_providers/tensorrt/src/cuda_allocator.cc diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/src/cuda_allocator.h similarity index 100% rename from plugin_execution_providers/tensorrt/cuda_allocator.h rename to plugin_execution_providers/tensorrt/src/cuda_allocator.h diff --git a/plugin_execution_providers/tensorrt/nv_includes.h b/plugin_execution_providers/tensorrt/src/nv_includes.h similarity index 100% rename from plugin_execution_providers/tensorrt/nv_includes.h rename to plugin_execution_providers/tensorrt/src/nv_includes.h diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/src/onnx_ctx_model_helper.cc similarity index 100% rename from plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc rename to plugin_execution_providers/tensorrt/src/onnx_ctx_model_helper.cc diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/src/onnx_ctx_model_helper.h similarity index 100% rename from plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h rename to plugin_execution_providers/tensorrt/src/onnx_ctx_model_helper.h diff --git a/plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h b/plugin_execution_providers/tensorrt/src/ort_trt_int8_cal_table.fbs.h similarity index 100% rename from plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h rename to plugin_execution_providers/tensorrt/src/ort_trt_int8_cal_table.fbs.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.def b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.def similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider.def rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.def diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.lds similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.lds diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_data_transfer.cc similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_data_transfer.cc diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_data_transfer.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_data_transfer.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_stream_support.cc similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_stream_support.cc diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_stream_support.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_stream_support.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_provider_factory.h rename to plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.h diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h b/plugin_execution_providers/tensorrt/src/utils/cuda/cuda_call.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h rename to plugin_execution_providers/tensorrt/src/utils/cuda/cuda_call.h diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h b/plugin_execution_providers/tensorrt/src/utils/cuda/cuda_common.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h rename to plugin_execution_providers/tensorrt/src/utils/cuda/cuda_common.h diff --git a/plugin_execution_providers/tensorrt/utils/ep_utils.h b/plugin_execution_providers/tensorrt/src/utils/ep_utils.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/ep_utils.h rename to plugin_execution_providers/tensorrt/src/utils/ep_utils.h diff --git a/plugin_execution_providers/tensorrt/utils/helper.cc b/plugin_execution_providers/tensorrt/src/utils/helper.cc similarity index 100% rename from plugin_execution_providers/tensorrt/utils/helper.cc rename to plugin_execution_providers/tensorrt/src/utils/helper.cc diff --git a/plugin_execution_providers/tensorrt/utils/make_string.h b/plugin_execution_providers/tensorrt/src/utils/make_string.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/make_string.h rename to plugin_execution_providers/tensorrt/src/utils/make_string.h diff --git a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h rename to plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h diff --git a/plugin_execution_providers/tensorrt/utils/parse_string.h b/plugin_execution_providers/tensorrt/src/utils/parse_string.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/parse_string.h rename to plugin_execution_providers/tensorrt/src/utils/parse_string.h diff --git a/plugin_execution_providers/tensorrt/utils/path_string.h b/plugin_execution_providers/tensorrt/src/utils/path_string.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/path_string.h rename to plugin_execution_providers/tensorrt/src/utils/path_string.h diff --git a/plugin_execution_providers/tensorrt/utils/provider_options.h b/plugin_execution_providers/tensorrt/src/utils/provider_options.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/provider_options.h rename to plugin_execution_providers/tensorrt/src/utils/provider_options.h diff --git a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h b/plugin_execution_providers/tensorrt/src/utils/provider_options_utils.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/provider_options_utils.h rename to plugin_execution_providers/tensorrt/src/utils/provider_options_utils.h From a3168490101c4e5bba95035570264c0216cc74ea Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 14 Nov 2025 10:18:36 -0800 Subject: [PATCH 02/20] add files for building wheel --- .../tensorrt/plugin_trt_ep/__init__.py | 28 +++++++++++++++++++ plugin_execution_providers/tensorrt/setup.py | 21 ++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py create mode 100644 plugin_execution_providers/tensorrt/setup.py diff --git a/plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py b/plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py new file mode 100644 index 00000000..40eed2ca --- /dev/null +++ b/plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py @@ -0,0 +1,28 @@ +import os +import importlib.resources +import ctypes +import onnxruntime as ort + +ort_dir = os.path.dirname(os.path.abspath(ort.__file__)) +dll_path = os.path.join(ort_dir, "capi", "onnxruntime.dll") + +# When the application calls ort.register_execution_provider_library() with the path to the plugin EP DLL, +# ORT internally uses LoadLibraryExW() to load that DLL. Since the plugin EP depends on onnxruntime.dll, +# the operating system will attempt to locate and load onnxruntime.dll first. +# +# On Windows, LoadLibraryExW() searches the directory containing the plugin EP DLL before searching system directories. +# Because onnxruntime.dll is not located in the plugin EP’s directory, Windows ends up loading the copy from a +# system directory instead — which is not the correct version. +# +# To ensure the plugin EP uses the correct onnxruntime.dll bundled with the ONNX Runtime package, +# we load that DLL explicitly before loading the plugin EP DLL. +ctypes.WinDLL(dll_path) + +def get_path(filename: str = "TensorRTEp.dll") -> str: + """ + Returns the absolute filesystem path to a DLL (or any file) + packaged inside plugin_trt_ep/libs. + """ + package = __name__ + ".libs" + with importlib.resources.as_file(importlib.resources.files(package) / filename) as path: + return str(path) \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/setup.py b/plugin_execution_providers/tensorrt/setup.py new file mode 100644 index 00000000..d6fe9ba4 --- /dev/null +++ b/plugin_execution_providers/tensorrt/setup.py @@ -0,0 +1,21 @@ +from setuptools import setup, find_packages +from setuptools.dist import Distribution + +class BinaryDistribution(Distribution): + # This ensures wheel is marked as "non-pure" (has binary files) + def has_ext_modules(self): + return True + +setup( + name="plugin_trt_ep", + version="0.1.0", + packages=find_packages(), + include_package_data=True, # include MANIFEST.in contents + package_data={ + "plugin_trt_ep": ["libs/*.dll"], # include DLLs inside the wheel + }, + distclass=BinaryDistribution, + description="Example package including DLLs", + author="ORT", + python_requires=">=3.8", +) From 431a3fbc5a78199f825ae84e4aeaa4e003f36911 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 14 Nov 2025 10:20:17 -0800 Subject: [PATCH 03/20] rename --- plugin_execution_providers/tensorrt/README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 plugin_execution_providers/tensorrt/README.md diff --git a/plugin_execution_providers/tensorrt/README.md b/plugin_execution_providers/tensorrt/README.md new file mode 100644 index 00000000..e69de29b From 13125148b72175f75a2e0b08fbd4c1b4b9076d48 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 14 Nov 2025 10:52:06 -0800 Subject: [PATCH 04/20] update README.md --- plugin_execution_providers/tensorrt/README.md | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/plugin_execution_providers/tensorrt/README.md b/plugin_execution_providers/tensorrt/README.md index e69de29b..bef8b073 100644 --- a/plugin_execution_providers/tensorrt/README.md +++ b/plugin_execution_providers/tensorrt/README.md @@ -0,0 +1,37 @@ +# Plugin TensorRT EP + +This repo contains: +- the plugin TRT EP implementation +- How to build plugin TRT EP +- How to build python wheel for plugin TRT EP + +Plugin TRT EP is migrated from the original TRT EP and provides the implementations of `OrtEpFactory`, `OrtEp`, `OrtNodeComputeInfo`, `OrtDataTransferImpl` ... that are required for a plugin EP to be able to interact with ONNX Runtime via the EP ABI (introduced in ORT 1.23.0). + +## How to build (on Windows) ## +````bash +mkdir build;cd build +```` +````bash +cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DTENSORRT_HOME=C:/folder/to/trt -DORT_HOME=C:/folder/to/ort +```` +````bash +cmake --build ./ --config Debug +````` +Note: The `ORT_HOME` should contain the `include` and `lib` folder as below +``` +C:/folder/to/ort + | + | ----- lib + | | ----- onnxruntime.dll + | | ----- onnxruntime.lib + | | ----- onnxruntime.pdb + | ... + | + | ---- include + | | ----- onnxruntime_c_api.h + | | ----- onnxruntime_ep_c_api.h + | | ----- onnxruntime_cxx_api.h + | | ----- onnxruntime_cxx_inline_api.h + | ... +``` +## How to build python wheel (on Windows) ## \ No newline at end of file From 31d7cd800a895f7e73fdd14250e349aad2c9469e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 17 Nov 2025 10:32:12 -0800 Subject: [PATCH 05/20] update README.md --- plugin_execution_providers/tensorrt/README.md | 5 +-- plugin_execution_providers/tensorrt/setup.py | 34 ++++++++++++++++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/plugin_execution_providers/tensorrt/README.md b/plugin_execution_providers/tensorrt/README.md index bef8b073..55789ee0 100644 --- a/plugin_execution_providers/tensorrt/README.md +++ b/plugin_execution_providers/tensorrt/README.md @@ -1,7 +1,7 @@ # Plugin TensorRT EP This repo contains: -- the plugin TRT EP implementation +- The plugin TRT EP implementation - How to build plugin TRT EP - How to build python wheel for plugin TRT EP @@ -34,4 +34,5 @@ C:/folder/to/ort | | ----- onnxruntime_cxx_inline_api.h | ... ``` -## How to build python wheel (on Windows) ## \ No newline at end of file +## How to build python wheel (on Windows) ## +setup.py bdist_wheel \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/setup.py b/plugin_execution_providers/tensorrt/setup.py index d6fe9ba4..84a71cbb 100644 --- a/plugin_execution_providers/tensorrt/setup.py +++ b/plugin_execution_providers/tensorrt/setup.py @@ -1,15 +1,47 @@ from setuptools import setup, find_packages from setuptools.dist import Distribution +import os +import shutil + +ep_dll = "TensorRTEp.dll" +src_folder = r".\build\\Debug" +dst_folder = r".\\plugin_trt_ep\\libs" class BinaryDistribution(Distribution): # This ensures wheel is marked as "non-pure" (has binary files) def has_ext_modules(self): return True + +def copy_ep_dll(src_folder: str, dst_folder: str, ep_dll: str = "TensorRTEp.dll"): + """ + Copy EP DLL from src_folder to dst_folder. + Create dst_folder if it doesn't exist. + """ + src_dll_path = os.path.join(src_folder, ep_dll) + + # Validate source file + if not os.path.isfile(src_dll_path): + raise FileNotFoundError(f"Source DLL not found: {src_dll_path}") + + # Create destination folder if needed + os.makedirs(dst_folder, exist_ok=True) + + dst_dll_path = os.path.join(dst_folder, ep_dll) + + # Copy file + shutil.copy2(src_dll_path, dst_dll_path) + + print(f"Copied {ep_dll} to: {dst_dll_path}") +try: + copy_ep_dll(src_folder, dst_folder, ep_dll) +except Exception as e: + print(f"Error: {e}") + setup( name="plugin_trt_ep", version="0.1.0", - packages=find_packages(), + packages=["plugin_trt_ep"], include_package_data=True, # include MANIFEST.in contents package_data={ "plugin_trt_ep": ["libs/*.dll"], # include DLLs inside the wheel From 563f74196c94aa4525b252c443a8f767f885741e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 17 Nov 2025 10:33:04 -0800 Subject: [PATCH 06/20] update README.md --- plugin_execution_providers/tensorrt/README.md | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/plugin_execution_providers/tensorrt/README.md b/plugin_execution_providers/tensorrt/README.md index 55789ee0..5383fd70 100644 --- a/plugin_execution_providers/tensorrt/README.md +++ b/plugin_execution_providers/tensorrt/README.md @@ -17,9 +17,18 @@ cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DTENSORRT_HOME=C:/folder/to/trt -DO ````bash cmake --build ./ --config Debug ````` + +If the build succeeds, you will see the TRT EP DLL being generated at: +``` +C:\repos\onnxruntime-inference-examples\plugin_execution_providers\tensorrt\build> ls .\Debug + +TensorRTEp.dll +``` + + Note: The `ORT_HOME` should contain the `include` and `lib` folder as below ``` -C:/folder/to/ort +C:\folder\to\ort | | ----- lib | | ----- onnxruntime.dll @@ -35,4 +44,12 @@ C:/folder/to/ort | ... ``` ## How to build python wheel (on Windows) ## -setup.py bdist_wheel \ No newline at end of file +``` +setup.py bdist_wheel +``` +Once it's done, you will see the wheel file at: +``` +C:\repos\onnxruntime-inference-examples\plugin_execution_providers\tensorrt> ls .\dist + +plugin_trt_ep-0.1.0-cp312-cp312-win_amd64.whl +``` \ No newline at end of file From 73bfd38dac9b97d58fecf370e8640d4f535326ce Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 17 Nov 2025 10:46:38 -0800 Subject: [PATCH 07/20] add example for running inference using python --- .../tensorrt/example/plugin_ep_inference.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 plugin_execution_providers/tensorrt/example/plugin_ep_inference.py diff --git a/plugin_execution_providers/tensorrt/example/plugin_ep_inference.py b/plugin_execution_providers/tensorrt/example/plugin_ep_inference.py new file mode 100644 index 00000000..ca9599f7 --- /dev/null +++ b/plugin_execution_providers/tensorrt/example/plugin_ep_inference.py @@ -0,0 +1,52 @@ +import onnxruntime as onnxrt +import plugin_trt_ep +import numpy as np + +# Path to the plugin EP library +ep_lib_path = plugin_trt_ep.get_path() +# Registration name can be anything the application chooses +ep_registration_name = "TensorRTEp" +# EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of OrtEP::CreateEp) +ep_name = ep_registration_name + +# Register plugin EP library with ONNX Runtime +onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path) + +# +# Create ORT session with explicit OrtEpDevice(s) +# + +# Find the OrtEpDevice for "TensorRTEp" +ep_devices = onnxrt.get_ep_devices() +trt_ep_device = None +for ep_device in ep_devices: + if ep_device.ep_name == ep_name: + trt_ep_device = ep_device + +assert trt_ep_device != None + +sess_options = onnxrt.SessionOptions() + +# Equivalent to the C API's SessionOptionsAppendExecutionProvider_V2 that appends "TensorRTEp" to ORT session option +sess_options.add_provider_for_devices([trt_ep_device], {'trt_engine_cache_enable': '1'}) + +assert sess_options.has_providers() == True + +# Create ORT session with "TensorRTEp" plugin EP +sess = onnxrt.InferenceSession("C:\\models\\mul_1.onnx", sess_options=sess_options) + +# Run sample model and check output +x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) +input_name = sess.get_inputs()[0].name +res = sess.run([], {input_name: x}) +output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) +np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + +# Unregister the library using the application-specified registration name. +# Must only unregister a library after all sessions that use the library have been released. +onnxrt.unregister_execution_provider_library(ep_registration_name) + + +# Note: +# The mul_1.onnx can be found here: +# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/mul_1.onnx \ No newline at end of file From 72e3c2e6da9a9b70e5b73ad5e569d7f51901d2f4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 17 Nov 2025 10:50:02 -0800 Subject: [PATCH 08/20] update README.md --- plugin_execution_providers/tensorrt/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/plugin_execution_providers/tensorrt/README.md b/plugin_execution_providers/tensorrt/README.md index 5383fd70..482703c9 100644 --- a/plugin_execution_providers/tensorrt/README.md +++ b/plugin_execution_providers/tensorrt/README.md @@ -4,6 +4,7 @@ This repo contains: - The plugin TRT EP implementation - How to build plugin TRT EP - How to build python wheel for plugin TRT EP +- How to run inference with plugin TRT EP using python API Plugin TRT EP is migrated from the original TRT EP and provides the implementations of `OrtEpFactory`, `OrtEp`, `OrtNodeComputeInfo`, `OrtDataTransferImpl` ... that are required for a plugin EP to be able to interact with ONNX Runtime via the EP ABI (introduced in ORT 1.23.0). From 7ae7d802947e41915c83e8909752710ec4ebc778 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 18 Nov 2025 09:37:10 -0800 Subject: [PATCH 09/20] remove unnecessary files --- plugin_execution_providers/tensorrt/README.md | 56 ------------------- .../tensorrt/example/plugin_ep_inference.py | 52 ----------------- .../tensorrt/plugin_trt_ep/__init__.py | 28 ---------- 3 files changed, 136 deletions(-) delete mode 100644 plugin_execution_providers/tensorrt/README.md delete mode 100644 plugin_execution_providers/tensorrt/example/plugin_ep_inference.py delete mode 100644 plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py diff --git a/plugin_execution_providers/tensorrt/README.md b/plugin_execution_providers/tensorrt/README.md deleted file mode 100644 index 482703c9..00000000 --- a/plugin_execution_providers/tensorrt/README.md +++ /dev/null @@ -1,56 +0,0 @@ -# Plugin TensorRT EP - -This repo contains: -- The plugin TRT EP implementation -- How to build plugin TRT EP -- How to build python wheel for plugin TRT EP -- How to run inference with plugin TRT EP using python API - -Plugin TRT EP is migrated from the original TRT EP and provides the implementations of `OrtEpFactory`, `OrtEp`, `OrtNodeComputeInfo`, `OrtDataTransferImpl` ... that are required for a plugin EP to be able to interact with ONNX Runtime via the EP ABI (introduced in ORT 1.23.0). - -## How to build (on Windows) ## -````bash -mkdir build;cd build -```` -````bash -cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DTENSORRT_HOME=C:/folder/to/trt -DORT_HOME=C:/folder/to/ort -```` -````bash -cmake --build ./ --config Debug -````` - -If the build succeeds, you will see the TRT EP DLL being generated at: -``` -C:\repos\onnxruntime-inference-examples\plugin_execution_providers\tensorrt\build> ls .\Debug - -TensorRTEp.dll -``` - - -Note: The `ORT_HOME` should contain the `include` and `lib` folder as below -``` -C:\folder\to\ort - | - | ----- lib - | | ----- onnxruntime.dll - | | ----- onnxruntime.lib - | | ----- onnxruntime.pdb - | ... - | - | ---- include - | | ----- onnxruntime_c_api.h - | | ----- onnxruntime_ep_c_api.h - | | ----- onnxruntime_cxx_api.h - | | ----- onnxruntime_cxx_inline_api.h - | ... -``` -## How to build python wheel (on Windows) ## -``` -setup.py bdist_wheel -``` -Once it's done, you will see the wheel file at: -``` -C:\repos\onnxruntime-inference-examples\plugin_execution_providers\tensorrt> ls .\dist - -plugin_trt_ep-0.1.0-cp312-cp312-win_amd64.whl -``` \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/example/plugin_ep_inference.py b/plugin_execution_providers/tensorrt/example/plugin_ep_inference.py deleted file mode 100644 index ca9599f7..00000000 --- a/plugin_execution_providers/tensorrt/example/plugin_ep_inference.py +++ /dev/null @@ -1,52 +0,0 @@ -import onnxruntime as onnxrt -import plugin_trt_ep -import numpy as np - -# Path to the plugin EP library -ep_lib_path = plugin_trt_ep.get_path() -# Registration name can be anything the application chooses -ep_registration_name = "TensorRTEp" -# EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of OrtEP::CreateEp) -ep_name = ep_registration_name - -# Register plugin EP library with ONNX Runtime -onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path) - -# -# Create ORT session with explicit OrtEpDevice(s) -# - -# Find the OrtEpDevice for "TensorRTEp" -ep_devices = onnxrt.get_ep_devices() -trt_ep_device = None -for ep_device in ep_devices: - if ep_device.ep_name == ep_name: - trt_ep_device = ep_device - -assert trt_ep_device != None - -sess_options = onnxrt.SessionOptions() - -# Equivalent to the C API's SessionOptionsAppendExecutionProvider_V2 that appends "TensorRTEp" to ORT session option -sess_options.add_provider_for_devices([trt_ep_device], {'trt_engine_cache_enable': '1'}) - -assert sess_options.has_providers() == True - -# Create ORT session with "TensorRTEp" plugin EP -sess = onnxrt.InferenceSession("C:\\models\\mul_1.onnx", sess_options=sess_options) - -# Run sample model and check output -x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) -input_name = sess.get_inputs()[0].name -res = sess.run([], {input_name: x}) -output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) -np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - -# Unregister the library using the application-specified registration name. -# Must only unregister a library after all sessions that use the library have been released. -onnxrt.unregister_execution_provider_library(ep_registration_name) - - -# Note: -# The mul_1.onnx can be found here: -# https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/mul_1.onnx \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py b/plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py deleted file mode 100644 index 40eed2ca..00000000 --- a/plugin_execution_providers/tensorrt/plugin_trt_ep/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -import importlib.resources -import ctypes -import onnxruntime as ort - -ort_dir = os.path.dirname(os.path.abspath(ort.__file__)) -dll_path = os.path.join(ort_dir, "capi", "onnxruntime.dll") - -# When the application calls ort.register_execution_provider_library() with the path to the plugin EP DLL, -# ORT internally uses LoadLibraryExW() to load that DLL. Since the plugin EP depends on onnxruntime.dll, -# the operating system will attempt to locate and load onnxruntime.dll first. -# -# On Windows, LoadLibraryExW() searches the directory containing the plugin EP DLL before searching system directories. -# Because onnxruntime.dll is not located in the plugin EP’s directory, Windows ends up loading the copy from a -# system directory instead — which is not the correct version. -# -# To ensure the plugin EP uses the correct onnxruntime.dll bundled with the ONNX Runtime package, -# we load that DLL explicitly before loading the plugin EP DLL. -ctypes.WinDLL(dll_path) - -def get_path(filename: str = "TensorRTEp.dll") -> str: - """ - Returns the absolute filesystem path to a DLL (or any file) - packaged inside plugin_trt_ep/libs. - """ - package = __name__ + ".libs" - with importlib.resources.as_file(importlib.resources.files(package) / filename) as path: - return str(path) \ No newline at end of file From 54d204ae2c8775aacb26e3072dbfffb7d1428855 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 18 Nov 2025 12:14:59 -0800 Subject: [PATCH 10/20] add GetCapability() implementation --- .../src/tensorrt_execution_provider.cc | 171 +++++++++++++++++- .../src/tensorrt_execution_provider.h | 1 + .../src/tensorrt_execution_provider_info.cc | 2 + .../src/tensorrt_execution_provider_info.h | 1 + 4 files changed, 173 insertions(+), 2 deletions(-) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc index 09041339..d061509b 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc @@ -812,9 +812,176 @@ bool TensorrtExecutionProvider::IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, const OrtGraph* graph, bool* early_termination) const { - // Temporarily make all nodes supported - SubGraphCollection_t nodes_list_output = nodes_vector_input; + // Return if iterations are exceeding predefined number + SubGraphCollection_t nodes_list_output; + if (iterations > max_iterations) { + *early_termination = true; + return nodes_list_output; + } + + iterations++; + auto ort_graph = Ort::ConstGraph(graph); + + // Sort OrtGraph with a custom Kahn's topological sorting algorithm. + std::vector node_index; + THROW_IF_ERROR(KahnsTopologicalSort( + *ort_graph, + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ENFORCE(status.IsOK()); + + node_index.push_back(node_id); + }, + PriorityNodeCompare())); + + for (const auto& group : nodes_vector_input) { + // Construct subgraph + if (!group.first.empty()) { + if (group.second) { + nodes_list_output.push_back(group); + } else { + + std::vector nodes = ort_graph.GetNodes(); + std::vector selected_nodes(group.first.size()); + size_t i = 0; + for (const auto& index : group.first) { + selected_nodes[i++] = nodes[node_index[index]]; + } + + Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); + + // Check if input tensors have shapes + if (iterations > 1) { + auto graph_inputs = sub_graph.GetInputs(); + for (auto& input_arg : graph_inputs) { + bool has_dim_value_or_param = true; + + auto type_info = input_arg.TypeInfo(); + if (type_info.GetONNXType() == ONNX_TYPE_TENSOR) { + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + + if (tensor_info.GetDimensionsCount() == 0) { + has_dim_value_or_param = false; + } + } + + if (type_info.GetONNXType() != ONNX_TYPE_TENSOR || !has_dim_value_or_param) { + std::string message = "TensorRT input: " + input_arg.GetName() + " has no shape specified. " + + "Please run shape inference on the onnx model first. Details can be found in " + + "https://onnxruntime.ai/docs/execution-providers/" + + "TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs"; + THROW_IF_ERROR(ort_api.CreateStatus(ORT_EP_FAIL, message.c_str())); + } + } + } + // Construct ModelProto from OrtGraph + ONNX_NAMESPACE::ModelProto model_proto; + + // add back handle_initializer_data to save initializer to external file + OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */); + + std::string string_buf; + model_proto.SerializeToString(&string_buf); + + if (dump_subgraphs_) { + // Dump TensorRT subgraph for debugging + std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx", + std::ios::out | std::ios::trunc | std::ios::binary); + model_proto.SerializeToOstream(&dump); + } + + // Get supported node list recursively + SubGraphCollection_t parser_nodes_list; + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_, logger_, &ort_api); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) + ? 0 + : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#else + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); +#endif + + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_parser = + tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + bool is_model_supported = false; + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + + // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined + // behavior. + auto num_subgraphs = trt_parser->getNbSubgraphs(); + parser_nodes_list.reserve(num_subgraphs); + + for (int64_t i = 0; i < num_subgraphs; ++i) { + int64_t subgraph_len = 0; + int64_t* nodes = trt_parser->getSubgraphNodes(i, subgraph_len); + parser_nodes_list.emplace_back(); + parser_nodes_list.back().first.reserve(subgraph_len); + for (int64_t j = 0; j < subgraph_len; ++j) { + parser_nodes_list.back().first.push_back(nodes[j]); + } + parser_nodes_list.back().second = is_model_supported ? true : false; + } +#else + trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); +#endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + + + std::vector sub_graph_node_index; + THROW_IF_ERROR(KahnsTopologicalSort( + *sub_graph, + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ENFORCE(status.IsOK()); + + sub_graph_node_index.push_back(node_id); + }, + PriorityNodeCompare())); + + SubGraphCollection_t next_nodes_list = + GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph, early_termination); + + for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { + for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + /* + * Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT + * TRT. + * + * TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model + * proto and to string buffer) to onnx-tensorrt parser. The node index in the list returning from + * onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a + * node index mapping table here. + * + * The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node + * order in the newly constructed graph (see Graph::AllocateNode() in graph.cc), however, once the graph is + * converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls + * model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort. + * + * The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table: + * sub_graph_node_index[node index from onnx-tensorrt parser] = index in group.first + * + * In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result + * not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4]. With the change of using ORT's + * priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of + * 0, 1, ... n-1 for most of the cases, therefore subgraph_node_index as a mapping table is not needed + * anymore. + * + * TODO: Remove the subgraph_node_index + */ + next_nodes_list[i].first[j] = group.first[sub_graph_node_index[next_nodes_list[i].first[j]]]; + } + nodes_list_output.push_back(next_nodes_list[i]); + } + } + } + } return nodes_list_output; } diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h index 953b2b05..96c19070 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h @@ -276,6 +276,7 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { size_t max_workspace_size_ = 1 << 30; // 1GB bool fp16_enable_ = false; bool int8_enable_ = false; + bool bf16_enable_ = false; bool dla_enable_ = false; int dla_core_ = 0; bool force_sequential_engine_build_ = false; diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc index 17c65ef4..98a5684f 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc @@ -18,6 +18,7 @@ constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size"; constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; constexpr const char* kFp16Enable = "trt_fp16_enable"; constexpr const char* kInt8Enable = "trt_int8_enable"; +constexpr const char* kBf16Enable = "trt_bf16_enable"; constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table"; constexpr const char* kDLAEnable = "trt_dla_enable"; @@ -95,6 +96,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kBf16Enable, info.bf16_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) .AddAssignmentToReference(tensorrt::provider_option_names::kDLAEnable, info.dla_enable) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h index df315cf9..f8bfb266 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h @@ -18,6 +18,7 @@ struct TensorrtExecutionProviderInfo { size_t max_workspace_size{1 << 30}; bool fp16_enable{false}; bool int8_enable{false}; + bool bf16_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; bool dla_enable{false}; From 9c725915d8314e22c82668d2b70c0b1805a777f1 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 18 Nov 2025 15:51:29 -0800 Subject: [PATCH 11/20] update --- .../src/tensorrt_execution_provider.cc | 70 +++++-------------- 1 file changed, 17 insertions(+), 53 deletions(-) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc index d061509b..d780c841 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc @@ -822,19 +822,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect iterations++; auto ort_graph = Ort::ConstGraph(graph); - // Sort OrtGraph with a custom Kahn's topological sorting algorithm. - std::vector node_index; - THROW_IF_ERROR(KahnsTopologicalSort( - *ort_graph, - [&](const OrtNode* node) { - size_t node_id = 0; - Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); - ENFORCE(status.IsOK()); - - node_index.push_back(node_id); - }, - PriorityNodeCompare())); - for (const auto& group : nodes_vector_input) { // Construct subgraph if (!group.first.empty()) { @@ -846,7 +833,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect std::vector selected_nodes(group.first.size()); size_t i = 0; for (const auto& index : group.first) { - selected_nodes[i++] = nodes[node_index[index]]; + selected_nodes[i++] = nodes[index]; } Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); @@ -932,50 +919,12 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); #endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 - - std::vector sub_graph_node_index; - THROW_IF_ERROR(KahnsTopologicalSort( - *sub_graph, - [&](const OrtNode* node) { - size_t node_id = 0; - Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); - ENFORCE(status.IsOK()); - - sub_graph_node_index.push_back(node_id); - }, - PriorityNodeCompare())); - SubGraphCollection_t next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph, early_termination); for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { - /* - * Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT - * TRT. - * - * TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model - * proto and to string buffer) to onnx-tensorrt parser. The node index in the list returning from - * onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a - * node index mapping table here. - * - * The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node - * order in the newly constructed graph (see Graph::AllocateNode() in graph.cc), however, once the graph is - * converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls - * model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort. - * - * The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table: - * sub_graph_node_index[node index from onnx-tensorrt parser] = index in group.first - * - * In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result - * not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4]. With the change of using ORT's - * priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of - * 0, 1, ... n-1 for most of the cases, therefore subgraph_node_index as a mapping table is not needed - * anymore. - * - * TODO: Remove the subgraph_node_index - */ - next_nodes_list[i].first[j] = group.first[sub_graph_node_index[next_nodes_list[i].first[j]]]; + next_nodes_list[i].first[j] = group.first[next_nodes_list[i].first[j]]; } nodes_list_output.push_back(next_nodes_list[i]); } @@ -994,6 +943,21 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this size_t num_nodes = 0; RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + /* + // Sort OrtGraph with a custom Kahn's topological sorting algorithm. + std::vector node_index; + THROW_IF_ERROR(KahnsTopologicalSort( + *graph, + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ENFORCE(status.IsOK()); + + node_index.push_back(node_id); + }, + PriorityNodeCompare())); + */ + // Get all the nodes from the graph std::vector nodes(num_nodes); RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); From a8e9ea1f26ab2f940f75372e5ad194cca25e2f17 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 19 Nov 2025 10:25:46 -0800 Subject: [PATCH 12/20] remove unnecessary file --- plugin_execution_providers/tensorrt/setup.py | 53 -------------------- 1 file changed, 53 deletions(-) delete mode 100644 plugin_execution_providers/tensorrt/setup.py diff --git a/plugin_execution_providers/tensorrt/setup.py b/plugin_execution_providers/tensorrt/setup.py deleted file mode 100644 index 84a71cbb..00000000 --- a/plugin_execution_providers/tensorrt/setup.py +++ /dev/null @@ -1,53 +0,0 @@ -from setuptools import setup, find_packages -from setuptools.dist import Distribution -import os -import shutil - -ep_dll = "TensorRTEp.dll" -src_folder = r".\build\\Debug" -dst_folder = r".\\plugin_trt_ep\\libs" - -class BinaryDistribution(Distribution): - # This ensures wheel is marked as "non-pure" (has binary files) - def has_ext_modules(self): - return True - -def copy_ep_dll(src_folder: str, dst_folder: str, ep_dll: str = "TensorRTEp.dll"): - """ - Copy EP DLL from src_folder to dst_folder. - Create dst_folder if it doesn't exist. - """ - src_dll_path = os.path.join(src_folder, ep_dll) - - # Validate source file - if not os.path.isfile(src_dll_path): - raise FileNotFoundError(f"Source DLL not found: {src_dll_path}") - - # Create destination folder if needed - os.makedirs(dst_folder, exist_ok=True) - - dst_dll_path = os.path.join(dst_folder, ep_dll) - - # Copy file - shutil.copy2(src_dll_path, dst_dll_path) - - print(f"Copied {ep_dll} to: {dst_dll_path}") - -try: - copy_ep_dll(src_folder, dst_folder, ep_dll) -except Exception as e: - print(f"Error: {e}") - -setup( - name="plugin_trt_ep", - version="0.1.0", - packages=["plugin_trt_ep"], - include_package_data=True, # include MANIFEST.in contents - package_data={ - "plugin_trt_ep": ["libs/*.dll"], # include DLLs inside the wheel - }, - distclass=BinaryDistribution, - description="Example package including DLLs", - author="ORT", - python_requires=">=3.8", -) From 85ee09f82327797c822b011af1afac579354360f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 6 Jan 2026 15:07:22 -0800 Subject: [PATCH 13/20] Use topo sort for nodes in GetCapability --- .../src/tensorrt_execution_provider.cc | 72 ++++++++++++------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc index d780c841..5c91a93e 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc @@ -821,19 +821,29 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect iterations++; auto ort_graph = Ort::ConstGraph(graph); + std::vector topo_sorted_nodes; + Ort::Status status(KahnsTopologicalSort( + *graph, + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ENFORCE(status.IsOK()); + + topo_sorted_nodes.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); + ENFORCE(status.IsOK()); for (const auto& group : nodes_vector_input) { // Construct subgraph if (!group.first.empty()) { if (group.second) { nodes_list_output.push_back(group); - } else { - - std::vector nodes = ort_graph.GetNodes(); + } else { std::vector selected_nodes(group.first.size()); size_t i = 0; for (const auto& index : group.first) { - selected_nodes[i++] = nodes[index]; + selected_nodes[i++] = topo_sorted_nodes[index]; } Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); @@ -867,7 +877,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect ONNX_NAMESPACE::ModelProto model_proto; // add back handle_initializer_data to save initializer to external file - OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */); + OrtEpUtils::OrtGraphToProto(*sub_graph, model_proto /*, handle_initializer_data */); std::string string_buf; model_proto.SerializeToString(&string_buf); @@ -919,12 +929,34 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); #endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + std::vector sub_graph_topo_sorted_nodes; + Ort::Status status(KahnsTopologicalSort( + *sub_graph, + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ENFORCE(status.IsOK()); + + sub_graph_topo_sorted_nodes.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); + ENFORCE(status.IsOK()); + + // This is the mapping table that stores the "node id to sub_graph's index" pair. + // It's used for locating the node index in original `group.first` given a node id. + std::unordered_map node_id_to_sub_graph_id; + size_t sub_graph_id = 0; + for (const auto& node : sub_graph_topo_sorted_nodes) { + node_id_to_sub_graph_id.emplace(node.GetId(), sub_graph_id++); + } + SubGraphCollection_t next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph, early_termination); for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { - next_nodes_list[i].first[j] = group.first[next_nodes_list[i].first[j]]; + Ort::ConstNode sub_graph_node = sub_graph_topo_sorted_nodes[next_nodes_list[i].first[j]]; + next_nodes_list[i].first[j] = group.first[node_id_to_sub_graph_id[sub_graph_node.GetId()]]; } nodes_list_output.push_back(next_nodes_list[i]); } @@ -938,29 +970,19 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this OrtEpGraphSupportInfo* graph_support_info) noexcept { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; - auto ort_graph = Ort::ConstGraph(graph); - size_t num_nodes = 0; - RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); - - /* - // Sort OrtGraph with a custom Kahn's topological sorting algorithm. - std::vector node_index; - THROW_IF_ERROR(KahnsTopologicalSort( + auto ort_graph = Ort::ConstGraph(graph); + std::vector topo_sorted_nodes; + RETURN_IF_ERROR(KahnsTopologicalSort( *graph, [&](const OrtNode* node) { size_t node_id = 0; Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); ENFORCE(status.IsOK()); - node_index.push_back(node_id); + topo_sorted_nodes.push_back(Ort::ConstNode(node)); }, PriorityNodeCompare())); - */ - - // Get all the nodes from the graph - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; bool new_subgraph = true; @@ -986,8 +1008,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. * 2. Its op type is in the exclusion list. */ - for (size_t index = 0; index < nodes.size(); index++) { - const OrtNode* node = nodes[index]; + for (size_t index = 0; index < topo_sorted_nodes.size(); index++) { + const OrtNode* node = topo_sorted_nodes[index]; bool supported_node = true; /* If current node is control flow op, we take different approach based on following four cases: @@ -1134,7 +1156,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { for (const auto& index : group.first) { - const OrtNode* supported_node = nodes[index]; + const OrtNode* supported_node = topo_sorted_nodes[index]; RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_node)); } } @@ -1155,7 +1177,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this supported_nodes.reserve(group.first.size()); for (const auto& index : group.first) { - const OrtNode* supported_node = nodes[index]; + const OrtNode* supported_node = topo_sorted_nodes[index]; supported_nodes.push_back(supported_node); } @@ -1179,7 +1201,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - } else if (number_of_trt_nodes == nodes.size()) { + } else if (number_of_trt_nodes == topo_sorted_nodes.size()) { std::string message = "[TensorRT EP] Whole graph will run on TensorRT execution provider"; Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, From 33d59ee7eca78ad2245568aa1286b9c47c72b276 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 21 Jan 2026 09:34:14 -0800 Subject: [PATCH 14/20] update ORT to graph proto utils --- .../tensorrt/src/utils/ort_graph_to_proto.h | 1031 ++++++++--------- 1 file changed, 506 insertions(+), 525 deletions(-) diff --git a/plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h index 6f07c67a..aab899a8 100644 --- a/plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h +++ b/plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h @@ -122,6 +122,7 @@ #define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ #include +#include #include "onnxruntime_cxx_api.h" #include "onnx/onnx_pb.h" @@ -184,6 +185,16 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::ModelProto& model_proto, HandleInitializerDataFunc handle_initializer_data_func = nullptr); +/// +/// Convert the endianess of data based of tensor element type. Mainly used in BE systems. +/// +/// OrtValueInfo for the initializer. Can be used to query name, type, shape, +/// and consumer nodes. +/// Pointer to data buffer. +/// Length of data buffer. +/// An Ort::Status indicating success or an error. +Ort::Status ConvertExternalData(const OrtValueInfo* value_info, void* data, size_t bytes); + } // namespace OrtEpUtils // End of header @@ -203,433 +214,387 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, #define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return Ort::Status{_status}; \ + Ort::Status _status{(fn)}; \ + if (!_status.IsOK()) { \ + return _status; \ } \ } while (0) #define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ - do { \ - Ort::Status _status = (fn); \ - if (!_status.IsOK()) { \ - return _status; \ - } \ - } while (0) + ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) -#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ - } \ +#define ORT_EP_UTILS_C_RETURN_IF(cond, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{msg, ORT_FAIL}; \ + } \ } while (0) namespace OrtEpUtils { -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, bool get_symbolic_dims, /*out*/ ONNXTensorElementDataType& elem_type, /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims); -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); - -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + /*out*/ std::vector& symbolic_dims, + /*out*/ bool& has_shape); +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status GetTensorElementSize(const ONNXTensorElementDataType& element_type, size_t& element_size); +static void SwapByteOrderInplace(void* data, const size_t& data_len, const size_t& element_size); + +// Below endian enum class is referenced from include/onnxruntime/core/framework/endian.h +enum class endian { +#if defined(_WIN32) + little = 0, + big = 1, + native = little, +#elif defined(__GNUC__) || defined(__clang__) + little = __ORDER_LITTLE_ENDIAN__, + big = __ORDER_BIG_ENDIAN__, + native = __BYTE_ORDER__, +#else +#error onnxruntime::endian is not implemented in this environment. +#endif +}; + +Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::GraphProto& graph_proto, HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // - // Set GraphProto metadata - // - const char* graph_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); - graph_proto.set_name(graph_name); - graph_proto.set_doc_string("Serialized from OrtGraph"); - - // - // Set GraphProto inputs and outputs - // - size_t num_graph_inputs = 0; - size_t num_graph_outputs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); - - std::vector graph_inputs(num_graph_inputs); - std::vector graph_outputs(num_graph_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); - - for (const OrtValueInfo* ort_value_info : graph_inputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - for (const OrtValueInfo* ort_value_info : graph_outputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - // - // Set GraphProto nodes, value_infos, and initializers. - // - - // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. - // A std::map maintains its elements in a stable ordering. - std::map value_infos; // For GraphProto.value_info - std::map initializer_value_infos; // For GraphProto.initializer - - // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. - // Optionally returns the OrtValueInfo name to the caller. - auto collect_value_info = [&ort_api, &value_infos, - &initializer_value_infos](const OrtValueInfo& ort_value_info, - /*out*/ const char** value_name_out = nullptr) -> Ort::Status { - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - - if (value_name_out != nullptr) { - *value_name_out = value_name; + try { + Ort::ConstGraph ort_graph{&graph}; + // + // Set GraphProto metadata + // + auto graph_name = ort_graph.GetName(); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + std::vector graph_inputs = ort_graph.GetInputs(); + std::vector graph_outputs = ort_graph.GetOutputs(); + + for (const auto& ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); } - if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { - return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + for (const auto& ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); } - bool is_required_graph_input = false; - bool is_optional_graph_input = false; - bool is_graph_output = false; - bool is_constant_initializer = false; - bool is_from_outer_scope = false; - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); - - // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. - // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. - // For values defined in an outer scope, just add the value info but not the initializer. - if (is_from_outer_scope) { - value_infos.emplace(value_name, &ort_value_info); - } else if (is_optional_graph_input) { - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (is_constant_initializer) { - value_infos.emplace(value_name, &ort_value_info); - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (!is_required_graph_input && !is_graph_output) { - value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. - } + // + // Set GraphProto nodes, value_infos, and initializers. + // - return Ort::Status{nullptr}; - }; + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&value_infos, + &initializer_value_infos](Ort::ConstValueInfo ort_value_info, + /*out*/ std::optional& value_name_out) { + auto value_name = ort_value_info.GetName(); - size_t num_nodes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); - - // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos - // that will be stored in GraphProto.value_info and GraphProto.initializer. - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - onnx::NodeProto* node_proto = graph_proto.add_node(); - - const char* node_name = nullptr; - const char* node_domain = nullptr; - const char* node_op_type = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); - - node_proto->set_name(node_name); - node_proto->set_domain(node_domain); - node_proto->set_op_type(node_op_type); - - size_t num_inputs = 0; - size_t num_implicit_inputs = 0; - size_t num_outputs = 0; - size_t num_attrs = 0; - size_t num_subgraphs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); - - // Handle node attributes - if (num_attrs > 0) { - std::vector ort_attrs(num_attrs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); - - for (const OrtOpAttr* ort_attr : ort_attrs) { - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - - Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (value_name_out) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return; // Already processed this OrtValueInfo. + } + + bool is_required_graph_input = ort_value_info.IsRequiredGraphInput(); + bool is_optional_graph_input = ort_value_info.IsOptionalGraphInput(); + bool is_graph_output = ort_value_info.IsGraphOutput(); + bool is_constant_initializer = ort_value_info.IsConstantInitializer(); + bool is_from_outer_scope = ort_value_info.IsFromOuterScope(); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + if (is_from_outer_scope) { + value_infos.emplace(value_name, ort_value_info); + if (is_constant_initializer) { + initializer_value_infos.emplace(value_name, ort_value_info); + } + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, ort_value_info); + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, ort_value_info); // This is an internal OrtValueInfo. + } + }; + + std::vector nodes = ort_graph.GetNodes(); + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (const auto& ort_node : nodes) { + onnx::NodeProto* node_proto = graph_proto.add_node(); + + std::string node_name = ort_node.GetName(); + std::string node_domain = ort_node.GetDomain(); + std::string node_op_type = ort_node.GetOperatorType(); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + // Handle node attributes + std::vector ort_attrs = ort_node.GetAttributes(); + for (const auto& attr : ort_attrs) { + OrtOpAttrType attr_type = attr.GetType(); if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. // Can use Node_GetSubgraphs to get subgraphs. continue; } - if (!attr_type_status.IsOK()) { - // Unsupported attribute type. - return attr_type_status; - } - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(attr, *attr_proto)); } - } - - // Handle node subgraphs - if (num_subgraphs > 0) { - std::vector ort_subgraphs(num_subgraphs); - std::vector subgraph_attr_names(num_subgraphs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), - subgraph_attr_names.data())); - - for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; - const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + // Handle node subgraphs + std::vector ort_subgraphs = ort_node.GetSubgraphs(); + for (const auto& [subgraph_attr_name, ort_subgraph] : ort_subgraphs) { onnx::AttributeProto* attr_proto = node_proto->add_attribute(); onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); - attr_proto->set_name(subgraph_attr_name); attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); } - } - - // Handle node inputs - if (num_inputs > 0) { - std::vector ort_inputs(num_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); - for (const OrtValueInfo* ort_value_info : ort_inputs) { - if (ort_value_info == nullptr) { + // Handle node inputs + std::vector ort_inputs = ort_node.GetInputs(); + for (const auto& vi : ort_inputs) { + if (vi == nullptr) { // missing optional input. node_proto->add_input(""); continue; } - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_input(value_name); + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_input(*value_name); } - } - // Handle implicit inputs to this node. - if (num_implicit_inputs > 0) { - std::vector ort_implicit_inputs(num_implicit_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), - ort_implicit_inputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { - assert(ort_value_info != nullptr); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + // Handle implicit inputs to this node. + std::vector ort_implicit_inputs = ort_node.GetImplicitInputs(); + for (const auto& vi : ort_implicit_inputs) { + assert(vi != nullptr); + std::optional value_name; + collect_value_info(vi, value_name); } - } - - // Handle node outputs - if (num_outputs > 0) { - std::vector ort_outputs(num_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); - for (const OrtValueInfo* ort_value_info : ort_outputs) { - if (ort_value_info == nullptr) { + // Handle node outputs + std::vector ort_outputs = ort_node.GetOutputs(); + for (const auto& vi : ort_outputs) { + if (vi == nullptr) { // missing optional output. node_proto->add_output(""); continue; } - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_output(value_name); + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_output(*value_name); } } - } - // Add value_infos to GraphProto as ValueInfoProto objects. - for (const std::pair& entry : value_infos) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); - } + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const auto& [value_name, value_info] : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(value_info, *value_info_proto)); + } - // Add initializers to GraphProto as TensorProto objects. - for (const std::pair& entry : initializer_value_infos) { - const OrtValueInfo* initializer_value_info = entry.second; - std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. - std::vector initializer_dims; - std::vector initializer_sym_dims; - ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, - initializer_elem_type, initializer_dims, - initializer_sym_dims)); - - onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); - tensor_proto->set_name(initializer_name); - tensor_proto->set_data_type(initializer_elem_type); - - auto* tensor_proto_dims = tensor_proto->mutable_dims(); - for (int64_t dim : initializer_dims) { - tensor_proto_dims->Add(dim); + // There may be initializers in the original OrtGraph that have not been added yet. + // For example, an initializer may not be used by any node but is still a graph output. + // Iterating through all nodes to collect initializer value info is therefore not sufficient, + // initializers must also be obtained from ort_graph.GetInitializers(). + // Add those missing initializers and skip the ones that already in `initializer_value_infos` + std::vector ort_graph_initializers = ort_graph.GetInitializers(); + for (const auto& initializer : ort_graph_initializers) { + initializer_value_infos.emplace(initializer.GetName(), initializer); } - const OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + // Add initializers to GraphProto as TensorProto objects. + for (const auto& [initializer_name, initializer_value_info] : initializer_value_infos) { + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + bool has_shape = false; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims, has_shape)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } - const void* data = nullptr; - size_t data_bytes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + Ort::ConstValue ort_value{nullptr}; + ORT_EP_UTILS_C_RETURN_IF_ERROR(initializer_value_info.GetInitializer(ort_value)); - std::string ext_location; - int64_t ext_offset = 0; - bool is_external = false; + assert(ort_value.IsTensor()); + const void* data = ort_value.GetTensorRawData(); + const size_t data_bytes = ort_value.GetTensorSizeInBytes(); - if (handle_initializer_data_func != nullptr) { - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, - is_external, ext_location, ext_offset)); - } + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; - if (is_external) { - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); - auto* ext_data_entries = tensor_proto->mutable_external_data(); - onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); - - location_entry->set_key("location"); - location_entry->set_value(ext_location); - offset_entry->set_key("offset"); - offset_entry->set_value(std::to_string(ext_offset)); - length_entry->set_key("length"); - length_entry->set_value(std::to_string(data_bytes)); - } else { - // User wants to store data inline the TensorProto's raw_data - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); - tensor_proto->set_raw_data(data, data_bytes); + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } + + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + if constexpr (endian::native == endian::big) { + size_t element_size = 0; + GetTensorElementSize(initializer_elem_type, element_size); + // create local copy of data and do endianess conversion + auto raw_data_buf = std::make_unique(data_bytes); + std::memcpy(raw_data_buf.get(), data, data_bytes); + SwapByteOrderInplace(raw_data_buf.get(), data_bytes, element_size); + tensor_proto->set_raw_data(raw_data_buf.get(), data_bytes); + } else { + tensor_proto->set_raw_data(data, data_bytes); + } + } } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; } return Ort::Status{nullptr}; } -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, +Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::ModelProto& model_proto, HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // Check that OrtGraph is a top-level graph (no parent node). - const OrtNode* parent_node = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); - ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); - - // Set model description. - model_proto.set_doc_string("Serialized from OrtGraph"); - model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); - - // Set ir version. - int64_t ir_version = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); - model_proto.set_ir_version(ir_version); - - // Set operator sets. - size_t num_operator_sets = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); - ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); - - std::vector domains(num_operator_sets, nullptr); - std::vector opset_versions(num_operator_sets); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), - num_operator_sets)); - - auto* operator_sets = model_proto.mutable_opset_import(); - - for (size_t i = 0; i < num_operator_sets; ++i) { - onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); - operator_set->set_domain(domains[i]); - operator_set->set_version(opset_versions[i]); - } + try { + Ort::ConstGraph ort_graph{&graph}; - model_proto.clear_graph(); - onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + // Set ir version. + int64_t ir_version = ort_graph.GetOnnxIRVersion(); + model_proto.set_ir_version(ir_version); - return Ort::Status{nullptr}; -} - -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, - bool get_symbolic_dims, - /*out*/ ONNXTensorElementDataType& elem_type, - /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims) { - const OrtApi& ort_api = Ort::GetApi(); - - const OrtTypeInfo* ort_type_info = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); + // Set operator sets. + std::vector op_sets = ort_graph.GetOperatorSets(); + ORT_EP_UTILS_C_RETURN_IF(op_sets.empty(), "OrtGraph should have at least one operator set."); - ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); - ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); - - const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); + auto* operator_sets = model_proto.mutable_opset_import(); - size_t num_dims = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + for (const auto& op_set : op_sets) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(op_set.domain); + operator_set->set_version(op_set.version); + } - std::vector ort_dims(num_dims, 0); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_graph, *graph_proto, handle_initializer_data_func)); - elem_type = ort_elem_type; - dims = std::move(ort_dims); + } catch (const Ort::Exception& ex) { + return Ort::Status(ex); + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_EP_FAIL); + } - if (get_symbolic_dims) { - std::vector ort_dim_syms(num_dims, nullptr); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), - ort_dim_syms.size())); + return Ort::Status{nullptr}; +} - symbolic_dims.reserve(num_dims); - for (const char* sym_dim : ort_dim_syms) { - symbolic_dims.push_back(sym_dim); +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims, + /*out*/ bool& has_shape) { + try { + Ort::ConstTypeInfo ort_type_info = vi.TypeInfo(); + ONNXType ort_onnx_type = ort_type_info.GetONNXType(); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, "Expected OrtValueInfo to represent a Tensor"); + + Ort::ConstTensorTypeAndShapeInfo ort_type_shape = ort_type_info.GetTensorTypeAndShapeInfo(); + elem_type = ort_type_shape.GetElementType(); + has_shape = ort_type_shape.HasShape(); + + if (has_shape) { + const size_t num_dims = ort_type_shape.GetDimensionsCount(); + dims = ort_type_shape.GetShape(); + + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ort_type_shape.GetSymbolicDimensions(ort_dim_syms.data(), ort_dim_syms.size()); + + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } + } } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_EP_FAIL}; } - return Ort::Status{nullptr}; } // Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto) { - const OrtApi& ort_api = Ort::GetApi(); - std::vector ort_dims; std::vector ort_dim_syms; ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. + bool has_shape = false; ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, - ort_elem_type, ort_dims, ort_dim_syms)); + ort_elem_type, ort_dims, ort_dim_syms, + has_shape)); - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - value_info_proto.set_name(value_name); + value_info_proto.set_name(ort_value_info.GetName()); onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); type_proto_tensor->set_elem_type(ort_elem_type); - // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks - // like a scalar value. - if (!ort_dims.empty()) { + // If there is no shape, do not set a TensorShapeProto. + if (has_shape) { onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { @@ -652,217 +617,233 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* attr_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); - attr_proto.set_name(attr_name); - - size_t total_attr_bytes = 0; - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr attr, onnx::AttributeProto& attr_proto) { + try { + std::string attr_name = attr.GetName(); + attr_proto.set_name(attr_name); - switch (attr_type) { - case OrtOpAttrType::ORT_OP_ATTR_INT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + OrtOpAttrType attr_type = attr.GetType(); - int64_t i_val = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); - attr_proto.set_i(i_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_INTS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector i_vals(total_attr_bytes / sizeof(int64_t)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* ints = attr_proto.mutable_ints(); - for (int64_t val : i_vals) { - ints->Add(val); + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + int64_t i_val = 0; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(i_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto.set_i(i_val); + break; } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); - - float f_val = 0.0f; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); - attr_proto.set_f(f_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector f_vals(total_attr_bytes / sizeof(float)); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* floats = attr_proto.mutable_floats(); - for (float val : f_vals) { - floats->Add(val); + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + std::vector i_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(i_vals)); + auto* ints = attr_proto.mutable_ints(); + ints->Assign(i_vals.begin(), i_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + break; } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRING: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::string* str = attr_proto.mutable_s(); - - str->resize(total_attr_bytes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, - &total_attr_bytes)); - - str->resize(total_attr_bytes); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector chars(total_attr_bytes, '\0'); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* strs = attr_proto.mutable_strings(); - - // Strings are all in a single buffer, each separated with a '\0'. - // Extract each string and add it to the STRINGS attribute array. - char* at = chars.data(); - char* end = at + chars.size(); + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + float f_val = 0.0f; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(f_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto.set_f(f_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + std::vector f_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(f_vals)); + auto* floats = attr_proto.mutable_floats(); + floats->Assign(f_vals.begin(), f_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + std::string* str = attr_proto.mutable_s(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(*str)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + std::vector result; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(result)); + auto* strs = attr_proto.mutable_strings(); + strs->Assign(result.begin(), result.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + Ort::Value tensor; + ORT_EP_UTILS_C_RETURN_IF_ERROR(attr.GetTensorAttributeAsOrtValue(tensor)); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } - while (at < end) { - char* str_begin = at; + auto shape = type_shape_info.GetShape(); - while (*at && at < end) { - at++; + for (auto& dim : shape) { + tensor_proto.add_dims(dim); } - strs->Add()->assign(str_begin, at - str_begin); - if (at < end) { - assert(*at == '\0'); - at++; // Skip '\0' to get to the beginning of the next string. + const void* data = tensor.GetTensorRawData(); + const size_t data_bytes = tensor.GetTensorSizeInBytes(); + + // Copy the Ortvalue to TensorProto as raw data + if constexpr (endian::native == endian::big) { + size_t element_size = 0; + GetTensorElementSize(element_type, element_size); + // create local copy of data and do endianess conversion + auto raw_data_buf = std::make_unique(data_bytes); + std::memcpy(raw_data_buf.get(), data, data_bytes); + SwapByteOrderInplace(raw_data_buf.get(), data_bytes, element_size); + tensor_proto.set_raw_data(raw_data_buf.get(), data_bytes); + } else { + tensor_proto.set_raw_data(data, data_bytes); } - } - - break; - } - case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); - - onnx::TensorProto tensor_proto; - - // TensorProto as an attribute value doesn't require a name. - - OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); - - Ort::Value tensor(ort_value); - // Get tensor type and shape info - Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); - - // Get tensor type - ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); - - size_t element_size = 0; - switch (element_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); - element_size = sizeof(float); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); - element_size = sizeof(uint8_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); - element_size = sizeof(int8_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); - element_size = sizeof(uint16_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); - element_size = sizeof(int16_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); - element_size = sizeof(int32_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); - element_size = sizeof(int64_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); - element_size = sizeof(bool); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); - element_size = sizeof(double); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); - element_size = sizeof(uint32_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); - element_size = sizeof(uint64_t); - break; - } - default: { - std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); - } + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; } - - auto shape = type_shape_info.GetShape(); - - for (auto& dim : shape) { - tensor_proto.add_dims(dim); + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); } + } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; + } - size_t element_count = type_shape_info.GetElementCount(); - size_t data_bytes = element_count * element_size; - const void* data = tensor.GetTensorData(); - - // Copy the Ortvalue to TensorProto as raw data - tensor_proto.set_raw_data(data, data_bytes); + return Ort::Status{nullptr}; +} - *(attr_proto.mutable_t()) = std::move(tensor_proto); - break; - } - default: { - std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); - } +Ort::Status ConvertExternalData(const OrtValueInfo* value_info, void* data, size_t bytes) { +#if !defined(_WIN32) + if constexpr (endian::native == endian::little) { + return Ort::Status{nullptr}; + } + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + size_t element_size = 0; + Ort::ConstValueInfo ort_value_info{value_info}; + bool has_shape{false}; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, false, + initializer_elem_type, initializer_dims, + initializer_sym_dims, has_shape)); + GetTensorElementSize(initializer_elem_type, element_size); + if (element_size != 1) { + SwapByteOrderInplace(data, bytes, element_size); } +#else + (value_info); + (data); + (bytes); +#endif + return Ort::Status{nullptr}; +} +static Ort::Status GetTensorElementSize(const ONNXTensorElementDataType& element_type, size_t& element_size) { + using TensorElemDataMap = std::unordered_map; + static TensorElemDataMap tensor_elem_data_size{ + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, sizeof(int16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, sizeof(uint32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, sizeof(int64_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, sizeof(uint64_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, sizeof(double)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, sizeof(uint8_t)}, + }; + auto pos = tensor_elem_data_size.find(element_type); + if (pos == tensor_elem_data_size.end()) { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + element_size = pos->second; return Ort::Status{nullptr}; } +static void SwapByteOrderInplace(void* data, const size_t& data_len, const size_t& element_size) { + char* bytes = reinterpret_cast(data); + size_t num_elements = data_len / element_size; + for (size_t i = 0; i < num_elements; ++i) { + char* start_byte = bytes + i * element_size; + char* end_byte = start_byte + element_size - 1; + for (size_t count = 0; count < element_size / 2; ++count) { + std::swap(*start_byte++, *end_byte--); + } + } +} + } // namespace OrtEpUtils #endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL From 9d985facc412485893bf24b5f07772eb857fd94f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 21 Jan 2026 09:38:39 -0800 Subject: [PATCH 15/20] add default initialization of OrtFactory in ctor --- .../tensorrt/src/tensorrt_provider_factory.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc index c0a61dad..ab52ecfe 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc @@ -14,7 +14,7 @@ namespace trt_ep { TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis) - : ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { + : OrtEpFactory {}, ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; From a84d2fe3c5831e733ead6fb244a913f4d8a5991f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 26 Jan 2026 15:44:06 -0800 Subject: [PATCH 16/20] update --- .../src/tensorrt_execution_provider.cc | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc index 5c91a93e..9575ffdb 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc @@ -820,10 +820,13 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } iterations++; + auto ort_graph = Ort::ConstGraph(graph); + + // Sort the nodes in priority-based topological order std::vector topo_sorted_nodes; Ort::Status status(KahnsTopologicalSort( - *graph, + *ort_graph, [&](const OrtNode* node) { size_t node_id = 0; Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); @@ -929,6 +932,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); #endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + // Sort the nodes in priority-based topological order std::vector sub_graph_topo_sorted_nodes; Ort::Status status(KahnsTopologicalSort( *sub_graph, @@ -972,9 +976,11 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this const OrtApi& ort_api = ep->ort_api; auto ort_graph = Ort::ConstGraph(graph); + + // Sort the nodes in priority-based topological order std::vector topo_sorted_nodes; RETURN_IF_ERROR(KahnsTopologicalSort( - *graph, + *ort_graph, [&](const OrtNode* node) { size_t node_id = 0; Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); @@ -1224,6 +1230,23 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this /* out */ OrtNodeComputeInfo** node_compute_info, /* out */ OrtNode** ep_context_node) { TensorrtExecutionProvider* ep = static_cast(this_ptr); + auto ort_graph = Ort::ConstGraph(graph); + + // Sort the nodes in priority-based topological order + std::vector topo_sorted_nodes; + Ort::Status status(KahnsTopologicalSort( + *ort_graph, + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ENFORCE(status.IsOK()); + + topo_sorted_nodes.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); + ENFORCE(status.IsOK()); + + Ort::Graph topo_sorted_graph = ort_graph.GetGraphView(topo_sorted_nodes); // Comment out following code if you want the "large" initializers to be saved to a external file. /* @@ -1987,7 +2010,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this profiles_.emplace(fused_node_name, std::move(trt_profiles)); // Create EP Context nodes - std::unique_ptr ep_ctx_node_helper = std::make_unique(*ep, graph, fused_node); + std::unique_ptr ep_ctx_node_helper = std::make_unique(*ep, topo_sorted_graph, fused_node); if (dump_ep_context_model_) { std::string compute_capability_hw_compat = compute_capability_; if (engine_cache_enable_ && engine_hw_compatible_) { From 3ec823e0e9b3a26586ee273bef441182374309ce Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 26 Jan 2026 15:46:02 -0800 Subject: [PATCH 17/20] update --- .../tensorrt/src/tensorrt_execution_provider.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc index 9575ffdb..84178e65 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc @@ -1280,7 +1280,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this ONNX_NAMESPACE::ModelProto model_proto; // add back handle_initializer_data to save initializer to external file - OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */); + OrtEpUtils::OrtGraphToProto(*topo_sorted_graph, model_proto /*, handle_initializer_data */); std::string string_buf; model_proto.SerializeToString(&string_buf); From bfb25c9c23ea309c52939a5488c9fdc6a1ddbe49 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 27 Jan 2026 12:30:43 -0800 Subject: [PATCH 18/20] address reviewer's comments --- plugin_execution_providers/tensorrt/CMakeLists.txt | 13 ++++++++++--- .../src/tensorrt_execution_provider_utils.h | 3 --- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index cd44d594..f67aca39 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -5,6 +5,8 @@ cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) set(CMAKE_CXX_STANDARD 17) +set(plugin_ep_common_dir ${CMAKE_SOURCE_DIR}/../common) +include(${plugin_ep_common_dir}/cmake/onnxruntime_library_utils.cmake) enable_language(CUDA) # via nvcc to get the CUDA tool kit file(TO_CMAKE_PATH "/usr/local/cuda" CUDAToolkit_ROOT) @@ -31,9 +33,14 @@ add_definitions(-DNOMINMAX) file(GLOB tensorrt_src "./src/*.cc" "./src/utils/*.cc" "./src/cuda/unary_elementwise_ops_impl.cu" "./src/*.h") add_library(TensorRTEp SHARED ${tensorrt_src}) -if (NOT ORT_HOME) - message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/") -endif() +set_onnxruntime_paths( + ORT_HOME ${ORT_HOME} + DEFAULT_ORT_VERSION "1.23.2" + ORT_INCLUDE_DIR_VAR ORT_INCLUDE_DIR + ORT_LIBRARY_DIR_VAR ORT_LIBRARY_DIR) + +message(STATUS "ORT_LIBRARY_DIR: ${ORT_LIBRARY_DIR}") +message(STATUS "ORT_INCLUDE_DIR: ${ORT_INCLUDE_DIR}") if (NOT TENSORRT_HOME) message(FATAL_ERROR "Please specify TENSORRT_HOME, e.g. -DTENSORRT_HOME=/path/to/trt/") diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h index 091a7a16..ce788cb3 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h @@ -55,9 +55,6 @@ AllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, return AllocatorUniquePtr{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }}; } -// Following helper functions/struct, GetNodeInputEdgeCount, GetOutputNodes, KahnsTopologicalSort, VisitorPriorityQueue, PriorityNodeCompare are added but are not used for now. -// TODO: They will be used for graph partition in the following PR. - template struct VisitorPriorityQueue { using ComparatorType = std::function; From 58814ba46ddb273ed2a886a8ef745a4f92a88ff1 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 27 Jan 2026 12:34:16 -0800 Subject: [PATCH 19/20] address reviewer's comments --- .../tensorrt/src/tensorrt_execution_provider.cc | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc index 84178e65..c8ff2944 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc @@ -828,10 +828,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect Ort::Status status(KahnsTopologicalSort( *ort_graph, [&](const OrtNode* node) { - size_t node_id = 0; - Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); - ENFORCE(status.IsOK()); - topo_sorted_nodes.push_back(Ort::ConstNode(node)); }, PriorityNodeCompare())); @@ -937,10 +933,6 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect Ort::Status status(KahnsTopologicalSort( *sub_graph, [&](const OrtNode* node) { - size_t node_id = 0; - Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); - ENFORCE(status.IsOK()); - sub_graph_topo_sorted_nodes.push_back(Ort::ConstNode(node)); }, PriorityNodeCompare())); @@ -982,10 +974,6 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this RETURN_IF_ERROR(KahnsTopologicalSort( *ort_graph, [&](const OrtNode* node) { - size_t node_id = 0; - Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); - ENFORCE(status.IsOK()); - topo_sorted_nodes.push_back(Ort::ConstNode(node)); }, PriorityNodeCompare())); @@ -1237,10 +1225,6 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this Ort::Status status(KahnsTopologicalSort( *ort_graph, [&](const OrtNode* node) { - size_t node_id = 0; - Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); - ENFORCE(status.IsOK()); - topo_sorted_nodes.push_back(Ort::ConstNode(node)); }, PriorityNodeCompare())); From 82faff0cf13349b8dc7573ca807d80243b52236c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 27 Jan 2026 15:01:25 -0800 Subject: [PATCH 20/20] update --- .../src/tensorrt_execution_provider.cc | 28 ++++++++++++++++--- .../src/tensorrt_execution_provider.h | 1 + 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc index c8ff2944..a418994d 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc @@ -1283,7 +1283,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 - network_flags |= (fp16_enable_ || int8_enable_) + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else @@ -1303,7 +1303,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(push) #pragma warning(disable : 4996) #endif - if (fp16_enable_ && layer_norm_fp32_fallback_) { + if ((fp16_enable_ || bf16_enable_) && layer_norm_fp32_fallback_) { for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { auto layer = trt_network->getLayer(idx); auto next_layer = trt_network->getLayer(idx + 1); @@ -1470,7 +1470,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } // Check platform availability for low precision - if (fp16_enable_) { + if (fp16_enable_ || bf16_enable_) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) @@ -1480,6 +1480,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(pop) #endif fp16_enable_ = false; + bf16_enable_ = false; std::string message = "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but platform doesn't support fast native fp16/bf16"; Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, @@ -1531,6 +1532,16 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } + + if (bf16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + trt_node_name_with_precision += "_bf16"; + std::string message = "[TensorRT EP] BF16 mode is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); trt_node_name_with_precision += "_int8"; @@ -2043,6 +2054,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this &tensorrt_mu_, compute_capability_, max_workspace_size_, + bf16_enable_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, @@ -2656,6 +2668,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa max_workspace_size_ = info_.max_workspace_size; fp16_enable_ = info_.fp16_enable; int8_enable_ = info_.int8_enable; + bf16_enable_ = info_.bf16_enable; if (int8_enable_) { int8_calibration_cache_name_ = info_.int8_calibration_table_name; int8_use_native_tensorrt_calibration_table_ = info_.int8_use_native_calibration_table; @@ -2697,7 +2710,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa } force_sequential_engine_build_ = info_.force_sequential_engine_build; context_memory_sharing_enable_ = info_.context_memory_sharing_enable; - if (fp16_enable_) { + if (fp16_enable_ || bf16_enable_) { layer_norm_fp32_fallback_ = info_.layer_norm_fp32_fallback; } build_heuristics_enable_ = info_.build_heuristics_enable; @@ -3225,6 +3238,13 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } + if (trt_state->bf16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + std::string message = "[TensorRT EP] BF16 mode is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } #if defined(_MSC_VER) #pragma warning(pop) #endif diff --git a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h index 96c19070..363baa02 100644 --- a/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h @@ -124,6 +124,7 @@ struct TensorrtComputeState { std::string compute_capability; size_t max_workspace_size = 1 << 30; // 1GB; bool fp16_enable = false; + bool bf16_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; bool dla_enable = false;