diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt index 85e6ca9f..f67aca39 100644 --- a/plugin_execution_providers/tensorrt/CMakeLists.txt +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -5,6 +5,8 @@ cmake_minimum_required(VERSION 3.26) project(TensorRTEp VERSION 1.0) set(CMAKE_CXX_STANDARD 17) +set(plugin_ep_common_dir ${CMAKE_SOURCE_DIR}/../common) +include(${plugin_ep_common_dir}/cmake/onnxruntime_library_utils.cmake) enable_language(CUDA) # via nvcc to get the CUDA tool kit file(TO_CMAKE_PATH "/usr/local/cuda" CUDAToolkit_ROOT) @@ -28,12 +30,17 @@ endif() add_definitions(-DONNX_NAMESPACE=onnx) add_definitions(-DONNX_ML) add_definitions(-DNOMINMAX) -file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu" "./*.h") +file(GLOB tensorrt_src "./src/*.cc" "./src/utils/*.cc" "./src/cuda/unary_elementwise_ops_impl.cu" "./src/*.h") add_library(TensorRTEp SHARED ${tensorrt_src}) -if (NOT ORT_HOME) - message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/") -endif() +set_onnxruntime_paths( + ORT_HOME ${ORT_HOME} + DEFAULT_ORT_VERSION "1.23.2" + ORT_INCLUDE_DIR_VAR ORT_INCLUDE_DIR + ORT_LIBRARY_DIR_VAR ORT_LIBRARY_DIR) + +message(STATUS "ORT_LIBRARY_DIR: ${ORT_LIBRARY_DIR}") +message(STATUS "ORT_INCLUDE_DIR: ${ORT_INCLUDE_DIR}") if (NOT TENSORRT_HOME) message(FATAL_ERROR "Please specify TENSORRT_HOME, e.g. -DTENSORRT_HOME=/path/to/trt/") @@ -111,7 +118,7 @@ if (WIN32) # Windows "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") set(TRT_EP_LIB_LINK_FLAG - "-DEF:${CMAKE_SOURCE_DIR}/tensorrt_execution_provider.def") + "-DEF:${CMAKE_SOURCE_DIR}/src/tensorrt_execution_provider.def") else() # Linux set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so") @@ -142,7 +149,7 @@ set_property(TARGET TensorRTEp APPEND_STRING PROPERTY LINK_FLAGS ${TRT_EP_LIB_LINK_FLAG}) target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include" - "./utils" + "./src/utils" "/usr/local/cuda/include" "${TENSORRT_HOME}/include" "${DEPS_PATH}/flatbuffers-src/include" diff --git a/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh b/plugin_execution_providers/tensorrt/src/cuda/cu_inc/unary_elementwise_impl.cuh similarity index 100% rename from plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh rename to plugin_execution_providers/tensorrt/src/cuda/cu_inc/unary_elementwise_impl.cuh diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu b/plugin_execution_providers/tensorrt/src/cuda/unary_elementwise_ops_impl.cu similarity index 100% rename from plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu rename to plugin_execution_providers/tensorrt/src/cuda/unary_elementwise_ops_impl.cu diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h b/plugin_execution_providers/tensorrt/src/cuda/unary_elementwise_ops_impl.h similarity index 100% rename from plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h rename to plugin_execution_providers/tensorrt/src/cuda/unary_elementwise_ops_impl.h diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.cc b/plugin_execution_providers/tensorrt/src/cuda_allocator.cc similarity index 100% rename from plugin_execution_providers/tensorrt/cuda_allocator.cc rename to plugin_execution_providers/tensorrt/src/cuda_allocator.cc diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/src/cuda_allocator.h similarity index 100% rename from plugin_execution_providers/tensorrt/cuda_allocator.h rename to plugin_execution_providers/tensorrt/src/cuda_allocator.h diff --git a/plugin_execution_providers/tensorrt/nv_includes.h b/plugin_execution_providers/tensorrt/src/nv_includes.h similarity index 100% rename from plugin_execution_providers/tensorrt/nv_includes.h rename to plugin_execution_providers/tensorrt/src/nv_includes.h diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/src/onnx_ctx_model_helper.cc similarity index 100% rename from plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc rename to plugin_execution_providers/tensorrt/src/onnx_ctx_model_helper.cc diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/src/onnx_ctx_model_helper.h similarity index 100% rename from plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h rename to plugin_execution_providers/tensorrt/src/onnx_ctx_model_helper.h diff --git a/plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h b/plugin_execution_providers/tensorrt/src/ort_trt_int8_cal_table.fbs.h similarity index 100% rename from plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h rename to plugin_execution_providers/tensorrt/src/ort_trt_int8_cal_table.fbs.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc similarity index 95% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc index 09041339..a418994d 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.cc @@ -812,9 +812,153 @@ bool TensorrtExecutionProvider::IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, int iterations, const int max_iterations, const OrtGraph* graph, bool* early_termination) const { - // Temporarily make all nodes supported - SubGraphCollection_t nodes_list_output = nodes_vector_input; + // Return if iterations are exceeding predefined number + SubGraphCollection_t nodes_list_output; + if (iterations > max_iterations) { + *early_termination = true; + return nodes_list_output; + } + + iterations++; + + auto ort_graph = Ort::ConstGraph(graph); + + // Sort the nodes in priority-based topological order + std::vector topo_sorted_nodes; + Ort::Status status(KahnsTopologicalSort( + *ort_graph, + [&](const OrtNode* node) { + topo_sorted_nodes.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); + ENFORCE(status.IsOK()); + + for (const auto& group : nodes_vector_input) { + // Construct subgraph + if (!group.first.empty()) { + if (group.second) { + nodes_list_output.push_back(group); + } else { + std::vector selected_nodes(group.first.size()); + size_t i = 0; + for (const auto& index : group.first) { + selected_nodes[i++] = topo_sorted_nodes[index]; + } + + Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); + + // Check if input tensors have shapes + if (iterations > 1) { + auto graph_inputs = sub_graph.GetInputs(); + for (auto& input_arg : graph_inputs) { + bool has_dim_value_or_param = true; + + auto type_info = input_arg.TypeInfo(); + if (type_info.GetONNXType() == ONNX_TYPE_TENSOR) { + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + + if (tensor_info.GetDimensionsCount() == 0) { + has_dim_value_or_param = false; + } + } + + if (type_info.GetONNXType() != ONNX_TYPE_TENSOR || !has_dim_value_or_param) { + std::string message = "TensorRT input: " + input_arg.GetName() + " has no shape specified. " + + "Please run shape inference on the onnx model first. Details can be found in " + + "https://onnxruntime.ai/docs/execution-providers/" + + "TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs"; + THROW_IF_ERROR(ort_api.CreateStatus(ORT_EP_FAIL, message.c_str())); + } + } + } + + // Construct ModelProto from OrtGraph + ONNX_NAMESPACE::ModelProto model_proto; + + // add back handle_initializer_data to save initializer to external file + OrtEpUtils::OrtGraphToProto(*sub_graph, model_proto /*, handle_initializer_data */); + + std::string string_buf; + model_proto.SerializeToString(&string_buf); + + if (dump_subgraphs_) { + // Dump TensorRT subgraph for debugging + std::fstream dump("TensorrtExecutionProvider_TRT_Subgraph.onnx", + std::ios::out | std::ios::trunc | std::ios::binary); + model_proto.SerializeToOstream(&dump); + } + + // Get supported node list recursively + SubGraphCollection_t parser_nodes_list; + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_, logger_, &ort_api); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) + ? 0 + : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#else + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); +#endif + + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_parser = + tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + bool is_model_supported = false; + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + + // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined + // behavior. + auto num_subgraphs = trt_parser->getNbSubgraphs(); + parser_nodes_list.reserve(num_subgraphs); + + for (int64_t i = 0; i < num_subgraphs; ++i) { + int64_t subgraph_len = 0; + int64_t* nodes = trt_parser->getSubgraphNodes(i, subgraph_len); + parser_nodes_list.emplace_back(); + parser_nodes_list.back().first.reserve(subgraph_len); + for (int64_t j = 0; j < subgraph_len; ++j) { + parser_nodes_list.back().first.push_back(nodes[j]); + } + parser_nodes_list.back().second = is_model_supported ? true : false; + } +#else + trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); +#endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + + // Sort the nodes in priority-based topological order + std::vector sub_graph_topo_sorted_nodes; + Ort::Status status(KahnsTopologicalSort( + *sub_graph, + [&](const OrtNode* node) { + sub_graph_topo_sorted_nodes.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); + ENFORCE(status.IsOK()); + + // This is the mapping table that stores the "node id to sub_graph's index" pair. + // It's used for locating the node index in original `group.first` given a node id. + std::unordered_map node_id_to_sub_graph_id; + size_t sub_graph_id = 0; + for (const auto& node : sub_graph_topo_sorted_nodes) { + node_id_to_sub_graph_id.emplace(node.GetId(), sub_graph_id++); + } + + SubGraphCollection_t next_nodes_list = + GetSupportedList(parser_nodes_list, iterations, max_iterations, sub_graph, early_termination); + for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { + for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + Ort::ConstNode sub_graph_node = sub_graph_topo_sorted_nodes[next_nodes_list[i].first[j]]; + next_nodes_list[i].first[j] = group.first[node_id_to_sub_graph_id[sub_graph_node.GetId()]]; + } + nodes_list_output.push_back(next_nodes_list[i]); + } + } + } + } return nodes_list_output; } @@ -822,14 +966,17 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this OrtEpGraphSupportInfo* graph_support_info) noexcept { TensorrtExecutionProvider* ep = static_cast(this_ptr); const OrtApi& ort_api = ep->ort_api; - auto ort_graph = Ort::ConstGraph(graph); - size_t num_nodes = 0; - RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + auto ort_graph = Ort::ConstGraph(graph); - // Get all the nodes from the graph - std::vector nodes(num_nodes); - RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + // Sort the nodes in priority-based topological order + std::vector topo_sorted_nodes; + RETURN_IF_ERROR(KahnsTopologicalSort( + *ort_graph, + [&](const OrtNode* node) { + topo_sorted_nodes.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; bool new_subgraph = true; @@ -855,8 +1002,8 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. * 2. Its op type is in the exclusion list. */ - for (size_t index = 0; index < nodes.size(); index++) { - const OrtNode* node = nodes[index]; + for (size_t index = 0; index < topo_sorted_nodes.size(); index++) { + const OrtNode* node = topo_sorted_nodes[index]; bool supported_node = true; /* If current node is control flow op, we take different approach based on following four cases: @@ -1003,7 +1150,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this for (const auto& group : supported_nodes_vector) { if (!group.first.empty()) { for (const auto& index : group.first) { - const OrtNode* supported_node = nodes[index]; + const OrtNode* supported_node = topo_sorted_nodes[index]; RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_node)); } } @@ -1024,7 +1171,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this supported_nodes.reserve(group.first.size()); for (const auto& index : group.first) { - const OrtNode* supported_node = nodes[index]; + const OrtNode* supported_node = topo_sorted_nodes[index]; supported_nodes.push_back(supported_node); } @@ -1048,7 +1195,7 @@ OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - } else if (number_of_trt_nodes == nodes.size()) { + } else if (number_of_trt_nodes == topo_sorted_nodes.size()) { std::string message = "[TensorRT EP] Whole graph will run on TensorRT execution provider"; Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -1071,6 +1218,19 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this /* out */ OrtNodeComputeInfo** node_compute_info, /* out */ OrtNode** ep_context_node) { TensorrtExecutionProvider* ep = static_cast(this_ptr); + auto ort_graph = Ort::ConstGraph(graph); + + // Sort the nodes in priority-based topological order + std::vector topo_sorted_nodes; + Ort::Status status(KahnsTopologicalSort( + *ort_graph, + [&](const OrtNode* node) { + topo_sorted_nodes.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); + ENFORCE(status.IsOK()); + + Ort::Graph topo_sorted_graph = ort_graph.GetGraphView(topo_sorted_nodes); // Comment out following code if you want the "large" initializers to be saved to a external file. /* @@ -1104,7 +1264,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this ONNX_NAMESPACE::ModelProto model_proto; // add back handle_initializer_data to save initializer to external file - OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */); + OrtEpUtils::OrtGraphToProto(*topo_sorted_graph, model_proto /*, handle_initializer_data */); std::string string_buf; model_proto.SerializeToString(&string_buf); @@ -1123,7 +1283,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this auto trt_builder = GetBuilder(trt_logger); auto network_flags = 0; #if NV_TENSORRT_MAJOR > 8 - network_flags |= (fp16_enable_ || int8_enable_) + network_flags |= (fp16_enable_ || int8_enable_ || bf16_enable_) ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #else @@ -1143,7 +1303,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(push) #pragma warning(disable : 4996) #endif - if (fp16_enable_ && layer_norm_fp32_fallback_) { + if ((fp16_enable_ || bf16_enable_) && layer_norm_fp32_fallback_) { for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { auto layer = trt_network->getLayer(idx); auto next_layer = trt_network->getLayer(idx + 1); @@ -1310,7 +1470,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this } // Check platform availability for low precision - if (fp16_enable_) { + if (fp16_enable_ || bf16_enable_) { #if defined(_MSC_VER) #pragma warning(push) #pragma warning(disable : 4996) @@ -1320,6 +1480,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this #pragma warning(pop) #endif fp16_enable_ = false; + bf16_enable_ = false; std::string message = "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but platform doesn't support fast native fp16/bf16"; Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, @@ -1371,6 +1532,16 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } + + if (bf16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + trt_node_name_with_precision += "_bf16"; + std::string message = "[TensorRT EP] BF16 mode is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); trt_node_name_with_precision += "_int8"; @@ -1834,7 +2005,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this profiles_.emplace(fused_node_name, std::move(trt_profiles)); // Create EP Context nodes - std::unique_ptr ep_ctx_node_helper = std::make_unique(*ep, graph, fused_node); + std::unique_ptr ep_ctx_node_helper = std::make_unique(*ep, topo_sorted_graph, fused_node); if (dump_ep_context_model_) { std::string compute_capability_hw_compat = compute_capability_; if (engine_cache_enable_ && engine_hw_compatible_) { @@ -1883,6 +2054,7 @@ OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this &tensorrt_mu_, compute_capability_, max_workspace_size_, + bf16_enable_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, @@ -2496,6 +2668,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa max_workspace_size_ = info_.max_workspace_size; fp16_enable_ = info_.fp16_enable; int8_enable_ = info_.int8_enable; + bf16_enable_ = info_.bf16_enable; if (int8_enable_) { int8_calibration_cache_name_ = info_.int8_calibration_table_name; int8_use_native_tensorrt_calibration_table_ = info_.int8_use_native_calibration_table; @@ -2537,7 +2710,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFa } force_sequential_engine_build_ = info_.force_sequential_engine_build; context_memory_sharing_enable_ = info_.context_memory_sharing_enable; - if (fp16_enable_) { + if (fp16_enable_ || bf16_enable_) { layer_norm_fp32_fallback_ = info_.layer_norm_fp32_fallback; } build_heuristics_enable_ = info_.build_heuristics_enable; @@ -3065,6 +3238,13 @@ OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); } + if (trt_state->bf16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kBF16); + std::string message = "[TensorRT EP] BF16 mode is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } #if defined(_MSC_VER) #pragma warning(pop) #endif diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.def b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.def similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider.def rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.def diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h similarity index 99% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h index 953b2b05..363baa02 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.h @@ -124,6 +124,7 @@ struct TensorrtComputeState { std::string compute_capability; size_t max_workspace_size = 1 << 30; // 1GB; bool fp16_enable = false; + bool bf16_enable = false; bool int8_enable = false; bool int8_calibration_cache_available = false; bool dla_enable = false; @@ -276,6 +277,7 @@ struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { size_t max_workspace_size_ = 1 << 30; // 1GB bool fp16_enable_ = false; bool int8_enable_ = false; + bool bf16_enable_ = false; bool dla_enable_ = false; int dla_core_ = 0; bool force_sequential_engine_build_ = false; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.lds similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider.lds diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_data_transfer.cc similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_data_transfer.cc diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_data_transfer.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_data_transfer.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc similarity index 98% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc index 17c65ef4..98a5684f 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.cc @@ -18,6 +18,7 @@ constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size"; constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; constexpr const char* kFp16Enable = "trt_fp16_enable"; constexpr const char* kInt8Enable = "trt_int8_enable"; +constexpr const char* kBf16Enable = "trt_bf16_enable"; constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name"; constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table"; constexpr const char* kDLAEnable = "trt_dla_enable"; @@ -95,6 +96,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kBf16Enable, info.bf16_enable) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) .AddAssignmentToReference(tensorrt::provider_option_names::kDLAEnable, info.dla_enable) diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h similarity index 98% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h index df315cf9..f8bfb266 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_info.h @@ -18,6 +18,7 @@ struct TensorrtExecutionProviderInfo { size_t max_workspace_size{1 << 30}; bool fp16_enable{false}; bool int8_enable{false}; + bool bf16_enable{false}; std::string int8_calibration_table_name{""}; bool int8_use_native_calibration_table{false}; bool dla_enable{false}; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_stream_support.cc similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_stream_support.cc diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_stream_support.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_stream_support.h diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h similarity index 99% rename from plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h rename to plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h index 091a7a16..ce788cb3 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h +++ b/plugin_execution_providers/tensorrt/src/tensorrt_execution_provider_utils.h @@ -55,9 +55,6 @@ AllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, return AllocatorUniquePtr{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }}; } -// Following helper functions/struct, GetNodeInputEdgeCount, GetOutputNodes, KahnsTopologicalSort, VisitorPriorityQueue, PriorityNodeCompare are added but are not used for now. -// TODO: They will be used for graph partition in the following PR. - template struct VisitorPriorityQueue { using ComparatorType = std::function; diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc similarity index 99% rename from plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc rename to plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc index c0a61dad..ab52ecfe 100644 --- a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc +++ b/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.cc @@ -14,7 +14,7 @@ namespace trt_ep { TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis) - : ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { + : OrtEpFactory {}, ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.h similarity index 100% rename from plugin_execution_providers/tensorrt/tensorrt_provider_factory.h rename to plugin_execution_providers/tensorrt/src/tensorrt_provider_factory.h diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h b/plugin_execution_providers/tensorrt/src/utils/cuda/cuda_call.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h rename to plugin_execution_providers/tensorrt/src/utils/cuda/cuda_call.h diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h b/plugin_execution_providers/tensorrt/src/utils/cuda/cuda_common.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h rename to plugin_execution_providers/tensorrt/src/utils/cuda/cuda_common.h diff --git a/plugin_execution_providers/tensorrt/utils/ep_utils.h b/plugin_execution_providers/tensorrt/src/utils/ep_utils.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/ep_utils.h rename to plugin_execution_providers/tensorrt/src/utils/ep_utils.h diff --git a/plugin_execution_providers/tensorrt/utils/helper.cc b/plugin_execution_providers/tensorrt/src/utils/helper.cc similarity index 100% rename from plugin_execution_providers/tensorrt/utils/helper.cc rename to plugin_execution_providers/tensorrt/src/utils/helper.cc diff --git a/plugin_execution_providers/tensorrt/utils/make_string.h b/plugin_execution_providers/tensorrt/src/utils/make_string.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/make_string.h rename to plugin_execution_providers/tensorrt/src/utils/make_string.h diff --git a/plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h new file mode 100644 index 00000000..aab899a8 --- /dev/null +++ b/plugin_execution_providers/tensorrt/src/utils/ort_graph_to_proto.h @@ -0,0 +1,849 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// DO NOT include ORT header files as this is meant to be a header-only utility that can be copied +// to other projects. + +/* + SUMMARY: + Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider + implementations that need to convert an OrtGraph instance into an ONNX protobuf model. + + Users may copy this file and modify as needed. + + USAGE: + This is a header-only implementation that includes both the function declarations and definitions. Copy this file + into a project that links with both ONNX Runtime and ONNX. + + Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ + file to define the implementation. Example: + + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + Other compilation units that depend on these utilities should include this file without defining the + preprocessor macro. + + Example program snippets are shown below. Refer to the function declarations for detailed usage information. + + EXAMPLE SNIPPET (initializers stored within TensorProto): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + onnx::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); + + // graph_proto stores initializers internally + } + ``` + + EXAMPLE SNIPPET (large initializers stored in external file): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + std::string external_file_path = "weights.bin"; + std::ofstream out_file(external_file_path, std::ios::binary); + + auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = out_file.tellp(); + location = external_file_path; + out_file.write(static_cast(data), bytes); + out_file.flush(); + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto stores large initializers in an external file + } + ``` + + EXAMPLE SNIPPET (external initializers that point to data in memory, not officially supported by ONNX spec): + + This example stores initializers externally. However, instead of storing the initializers in a separate + file, the onnx::TensorProto objects point directly to memory addresses. This requires setting the initializer's + location to a special tag like "_MEM_ADDR_" (instead of a file path). The offset is set to the pointer to the + initializer's data in memory (instead of an offset into a file). + + Because this is not standard ONNX, such a onnx::GraphProto should not be saved as an ONNX file. + However, it allows custom tools that operate directly on a onnx::GraphProto to get the initializer data + if it has already been loaded into memory. + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + auto handle_initializer_data = [](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + (void)value_info; + (void)bytes; + + offset = reinterpret_cast(data); + location = "_MEM_ADDR_"; // Some special location tag that indicates the offset is a pointer. + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto has initializers that look like they are stored in an external file, + // but they are actually pointing to the data in memory. + } + ``` +*/ + +#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ +#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +#include +#include +#include "onnxruntime_cxx_api.h" +#include "onnx/onnx_pb.h" + +namespace OrtEpUtils { + +/// +/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. +/// +/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data +/// within the TensorProto as raw_data. +/// +/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the +/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the +/// location and offset returned via the `location` and `offset` output parameters. +/// +/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure +/// ONNX shape inference works correctly with the serialized ONNX model. +/// +/// OrtValueInfo for the initializer. Can be used to query name, type, shape, +/// and consumer nodes. +/// Opaque pointer to the initializer data. +/// Size in bytes of the initializer data. +/// Output parameter set to true if the initializer data is stored externally. The +/// implementer is responsible for writing the initializer data to file. If set to false, +/// the initializer will be stored within the TensorProto. +/// Output parameter set to the location (e.g., file) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. +using HandleInitializerDataFunc = std::function; + +/// +/// Serializes the provided OrtGraph to a onnx::GraphProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination GraphProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); + +/// +/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination ModelProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); +/// +/// Convert the endianess of data based of tensor element type. Mainly used in BE systems. +/// +/// OrtValueInfo for the initializer. Can be used to query name, type, shape, +/// and consumer nodes. +/// Pointer to data buffer. +/// Length of data buffer. +/// An Ort::Status indicating success or an error. +Ort::Status ConvertExternalData(const OrtValueInfo* value_info, void* data, size_t bytes); + +} // namespace OrtEpUtils + +// End of header +#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +// +// IMPLEMENTATION BELOW +// +#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + +#include +#include +#include +#include +#include +#include + +#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ + do { \ + Ort::Status _status{(fn)}; \ + if (!_status.IsOK()) { \ + return _status; \ + } \ + } while (0) + +#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ + ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) + +#define ORT_EP_UTILS_C_RETURN_IF(cond, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{msg, ORT_FAIL}; \ + } \ + } while (0) + +namespace OrtEpUtils { + +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims, + /*out*/ bool& has_shape); +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status GetTensorElementSize(const ONNXTensorElementDataType& element_type, size_t& element_size); +static void SwapByteOrderInplace(void* data, const size_t& data_len, const size_t& element_size); + +// Below endian enum class is referenced from include/onnxruntime/core/framework/endian.h +enum class endian { +#if defined(_WIN32) + little = 0, + big = 1, + native = little, +#elif defined(__GNUC__) || defined(__clang__) + little = __ORDER_LITTLE_ENDIAN__, + big = __ORDER_BIG_ENDIAN__, + native = __BYTE_ORDER__, +#else +#error onnxruntime::endian is not implemented in this environment. +#endif +}; + +Ort::Status OrtGraphToProto(const OrtGraph& graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + try { + Ort::ConstGraph ort_graph{&graph}; + // + // Set GraphProto metadata + // + auto graph_name = ort_graph.GetName(); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + std::vector graph_inputs = ort_graph.GetInputs(); + std::vector graph_outputs = ort_graph.GetOutputs(); + + for (const auto& ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); + } + + for (const auto& ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(ort_value_info, *value_info_proto)); + } + + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&value_infos, + &initializer_value_infos](Ort::ConstValueInfo ort_value_info, + /*out*/ std::optional& value_name_out) { + auto value_name = ort_value_info.GetName(); + + if (value_name_out) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return; // Already processed this OrtValueInfo. + } + + bool is_required_graph_input = ort_value_info.IsRequiredGraphInput(); + bool is_optional_graph_input = ort_value_info.IsOptionalGraphInput(); + bool is_graph_output = ort_value_info.IsGraphOutput(); + bool is_constant_initializer = ort_value_info.IsConstantInitializer(); + bool is_from_outer_scope = ort_value_info.IsFromOuterScope(); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + if (is_from_outer_scope) { + value_infos.emplace(value_name, ort_value_info); + if (is_constant_initializer) { + initializer_value_infos.emplace(value_name, ort_value_info); + } + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, ort_value_info); + initializer_value_infos.emplace(value_name, ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, ort_value_info); // This is an internal OrtValueInfo. + } + }; + + std::vector nodes = ort_graph.GetNodes(); + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (const auto& ort_node : nodes) { + onnx::NodeProto* node_proto = graph_proto.add_node(); + + std::string node_name = ort_node.GetName(); + std::string node_domain = ort_node.GetDomain(); + std::string node_op_type = ort_node.GetOperatorType(); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + // Handle node attributes + std::vector ort_attrs = ort_node.GetAttributes(); + for (const auto& attr : ort_attrs) { + OrtOpAttrType attr_type = attr.GetType(); + if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { + // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. + // Can use Node_GetSubgraphs to get subgraphs. + continue; + } + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(attr, *attr_proto)); + } + + // Handle node subgraphs + std::vector ort_subgraphs = ort_node.GetSubgraphs(); + for (const auto& [subgraph_attr_name, ort_subgraph] : ort_subgraphs) { + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); + attr_proto->set_name(subgraph_attr_name); + attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); + } + + // Handle node inputs + std::vector ort_inputs = ort_node.GetInputs(); + for (const auto& vi : ort_inputs) { + if (vi == nullptr) { + // missing optional input. + node_proto->add_input(""); + continue; + } + + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_input(*value_name); + } + + // Handle implicit inputs to this node. + std::vector ort_implicit_inputs = ort_node.GetImplicitInputs(); + for (const auto& vi : ort_implicit_inputs) { + assert(vi != nullptr); + std::optional value_name; + collect_value_info(vi, value_name); + } + + // Handle node outputs + std::vector ort_outputs = ort_node.GetOutputs(); + for (const auto& vi : ort_outputs) { + if (vi == nullptr) { + // missing optional output. + node_proto->add_output(""); + continue; + } + + std::optional value_name; + value_name.emplace(); + collect_value_info(vi, value_name); + node_proto->add_output(*value_name); + } + } + + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const auto& [value_name, value_info] : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(value_info, *value_info_proto)); + } + + // There may be initializers in the original OrtGraph that have not been added yet. + // For example, an initializer may not be used by any node but is still a graph output. + // Iterating through all nodes to collect initializer value info is therefore not sufficient, + // initializers must also be obtained from ort_graph.GetInitializers(). + // Add those missing initializers and skip the ones that already in `initializer_value_infos` + std::vector ort_graph_initializers = ort_graph.GetInitializers(); + for (const auto& initializer : ort_graph_initializers) { + initializer_value_infos.emplace(initializer.GetName(), initializer); + } + + // Add initializers to GraphProto as TensorProto objects. + for (const auto& [initializer_name, initializer_value_info] : initializer_value_infos) { + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + bool has_shape = false; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims, has_shape)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } + + Ort::ConstValue ort_value{nullptr}; + ORT_EP_UTILS_C_RETURN_IF_ERROR(initializer_value_info.GetInitializer(ort_value)); + + assert(ort_value.IsTensor()); + const void* data = ort_value.GetTensorRawData(); + const size_t data_bytes = ort_value.GetTensorSizeInBytes(); + + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; + + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } + + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + if constexpr (endian::native == endian::big) { + size_t element_size = 0; + GetTensorElementSize(initializer_elem_type, element_size); + // create local copy of data and do endianess conversion + auto raw_data_buf = std::make_unique(data_bytes); + std::memcpy(raw_data_buf.get(), data, data_bytes); + SwapByteOrderInplace(raw_data_buf.get(), data_bytes, element_size); + tensor_proto->set_raw_data(raw_data_buf.get(), data_bytes); + } else { + tensor_proto->set_raw_data(data, data_bytes); + } + } + } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; + } + + return Ort::Status{nullptr}; +} + +Ort::Status OrtGraphToProto(const OrtGraph& graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + try { + Ort::ConstGraph ort_graph{&graph}; + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = ort_graph.GetOnnxIRVersion(); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + std::vector op_sets = ort_graph.GetOperatorSets(); + ORT_EP_UTILS_C_RETURN_IF(op_sets.empty(), "OrtGraph should have at least one operator set."); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (const auto& op_set : op_sets) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(op_set.domain); + operator_set->set_version(op_set.version); + } + + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_graph, *graph_proto, handle_initializer_data_func)); + + } catch (const Ort::Exception& ex) { + return Ort::Status(ex); + } catch (const std::exception& ex) { + return Ort::Status(ex.what(), ORT_EP_FAIL); + } + + return Ort::Status{nullptr}; +} + +static Ort::Status GetOrtValueInfoTensorTypeShape(Ort::ConstValueInfo vi, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims, + /*out*/ bool& has_shape) { + try { + Ort::ConstTypeInfo ort_type_info = vi.TypeInfo(); + ONNXType ort_onnx_type = ort_type_info.GetONNXType(); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, "Expected OrtValueInfo to represent a Tensor"); + + Ort::ConstTensorTypeAndShapeInfo ort_type_shape = ort_type_info.GetTensorTypeAndShapeInfo(); + elem_type = ort_type_shape.GetElementType(); + has_shape = ort_type_shape.HasShape(); + + if (has_shape) { + const size_t num_dims = ort_type_shape.GetDimensionsCount(); + dims = ort_type_shape.GetShape(); + + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ort_type_shape.GetSymbolicDimensions(ort_dim_syms.data(), ort_dim_syms.size()); + + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } + } + } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_EP_FAIL}; + } + return Ort::Status{nullptr}; +} + +// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). +static Ort::Status OrtValueInfoToProto(Ort::ConstValueInfo ort_value_info, + onnx::ValueInfoProto& value_info_proto) { + std::vector ort_dims; + std::vector ort_dim_syms; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. + bool has_shape = false; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, + ort_elem_type, ort_dims, ort_dim_syms, + has_shape)); + + value_info_proto.set_name(ort_value_info.GetName()); + + onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + type_proto_tensor->set_elem_type(ort_elem_type); + + // If there is no shape, do not set a TensorShapeProto. + if (has_shape) { + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; + + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } + } + } + } + + return Ort::Status{nullptr}; +} + +static Ort::Status OrtOpAttrToProto(Ort::ConstOpAttr attr, onnx::AttributeProto& attr_proto) { + try { + std::string attr_name = attr.GetName(); + attr_proto.set_name(attr_name); + + OrtOpAttrType attr_type = attr.GetType(); + + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + int64_t i_val = 0; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(i_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + attr_proto.set_i(i_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + std::vector i_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(i_vals)); + auto* ints = attr_proto.mutable_ints(); + ints->Assign(i_vals.begin(), i_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + float f_val = 0.0f; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(f_val)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + attr_proto.set_f(f_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + std::vector f_vals; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(f_vals)); + auto* floats = attr_proto.mutable_floats(); + floats->Assign(f_vals.begin(), f_vals.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + std::string* str = attr_proto.mutable_s(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValue(*str)); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + std::vector result; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(attr.GetValueArray(result)); + auto* strs = attr_proto.mutable_strings(); + strs->Assign(result.begin(), result.end()); + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + Ort::Value tensor; + ORT_EP_UTILS_C_RETURN_IF_ERROR(attr.GetTensorAttributeAsOrtValue(tensor)); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + auto shape = type_shape_info.GetShape(); + + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } + + const void* data = tensor.GetTensorRawData(); + const size_t data_bytes = tensor.GetTensorSizeInBytes(); + + // Copy the Ortvalue to TensorProto as raw data + if constexpr (endian::native == endian::big) { + size_t element_size = 0; + GetTensorElementSize(element_type, element_size); + // create local copy of data and do endianess conversion + auto raw_data_buf = std::make_unique(data_bytes); + std::memcpy(raw_data_buf.get(), data, data_bytes); + SwapByteOrderInplace(raw_data_buf.get(), data_bytes, element_size); + tensor_proto.set_raw_data(raw_data_buf.get(), data_bytes); + } else { + tensor_proto.set_raw_data(data, data_bytes); + } + + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; + } + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + } catch (const Ort::Exception& ex) { + return Ort::Status{ex}; + } catch (const std::exception& ex) { + return Ort::Status{ex.what(), ORT_FAIL}; + } + + return Ort::Status{nullptr}; +} + +Ort::Status ConvertExternalData(const OrtValueInfo* value_info, void* data, size_t bytes) { +#if !defined(_WIN32) + if constexpr (endian::native == endian::little) { + return Ort::Status{nullptr}; + } + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + size_t element_size = 0; + Ort::ConstValueInfo ort_value_info{value_info}; + bool has_shape{false}; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, false, + initializer_elem_type, initializer_dims, + initializer_sym_dims, has_shape)); + GetTensorElementSize(initializer_elem_type, element_size); + if (element_size != 1) { + SwapByteOrderInplace(data, bytes, element_size); + } +#else + (value_info); + (data); + (bytes); +#endif + return Ort::Status{nullptr}; +} + +static Ort::Status GetTensorElementSize(const ONNXTensorElementDataType& element_type, size_t& element_size) { + using TensorElemDataMap = std::unordered_map; + static TensorElemDataMap tensor_elem_data_size{ + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, sizeof(float)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, sizeof(int8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, sizeof(int16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, sizeof(uint16_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, sizeof(int32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, sizeof(uint32_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, sizeof(int64_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, sizeof(uint64_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, sizeof(double)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, sizeof(uint8_t)}, + {ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, sizeof(uint8_t)}, + }; + auto pos = tensor_elem_data_size.find(element_type); + if (pos == tensor_elem_data_size.end()) { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + element_size = pos->second; + return Ort::Status{nullptr}; +} + +static void SwapByteOrderInplace(void* data, const size_t& data_len, const size_t& element_size) { + char* bytes = reinterpret_cast(data); + size_t num_elements = data_len / element_size; + for (size_t i = 0; i < num_elements; ++i) { + char* start_byte = bytes + i * element_size; + char* end_byte = start_byte + element_size - 1; + for (size_t count = 0; count < element_size / 2; ++count) { + std::swap(*start_byte++, *end_byte--); + } + } +} + +} // namespace OrtEpUtils +#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL diff --git a/plugin_execution_providers/tensorrt/utils/parse_string.h b/plugin_execution_providers/tensorrt/src/utils/parse_string.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/parse_string.h rename to plugin_execution_providers/tensorrt/src/utils/parse_string.h diff --git a/plugin_execution_providers/tensorrt/utils/path_string.h b/plugin_execution_providers/tensorrt/src/utils/path_string.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/path_string.h rename to plugin_execution_providers/tensorrt/src/utils/path_string.h diff --git a/plugin_execution_providers/tensorrt/utils/provider_options.h b/plugin_execution_providers/tensorrt/src/utils/provider_options.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/provider_options.h rename to plugin_execution_providers/tensorrt/src/utils/provider_options.h diff --git a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h b/plugin_execution_providers/tensorrt/src/utils/provider_options_utils.h similarity index 100% rename from plugin_execution_providers/tensorrt/utils/provider_options_utils.h rename to plugin_execution_providers/tensorrt/src/utils/provider_options_utils.h diff --git a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h deleted file mode 100644 index 6f07c67a..00000000 --- a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h +++ /dev/null @@ -1,868 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// DO NOT include ORT header files as this is meant to be a header-only utility that can be copied -// to other projects. - -/* - SUMMARY: - Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider - implementations that need to convert an OrtGraph instance into an ONNX protobuf model. - - Users may copy this file and modify as needed. - - USAGE: - This is a header-only implementation that includes both the function declarations and definitions. Copy this file - into a project that links with both ONNX Runtime and ONNX. - - Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ - file to define the implementation. Example: - - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - Other compilation units that depend on these utilities should include this file without defining the - preprocessor macro. - - Example program snippets are shown below. Refer to the function declarations for detailed usage information. - - EXAMPLE SNIPPET (initializers stored within TensorProto): - - ```C++ - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, - OrtEpGraphSupportInfo* graph_support_info) { - onnx::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); - - // graph_proto stores initializers internally - } - ``` - - EXAMPLE SNIPPET (large initializers stored in external file): - - ```C++ - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, - OrtEpGraphSupportInfo* graph_support_info) { - std::string external_file_path = "weights.bin"; - std::ofstream out_file(external_file_path, std::ios::binary); - - auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, - const void* data, size_t bytes, - bool& is_external, std::string& location, - int64_t& offset) -> Ort::Status { - // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. - (void)value_info; - - if (bytes <= 127) { - is_external = false; // Keep small initializers stored inside the TensorProto. - return Ort::Status{nullptr}; - } - - offset = out_file.tellp(); - location = external_file_path; - out_file.write(static_cast(data), bytes); - out_file.flush(); - is_external = true; // True if is external initializer - return Ort::Status{nullptr}; - } - - ONNX_NAMESPACE::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); - - // graph_proto stores large initializers in an external file - } - ``` - - EXAMPLE SNIPPET (external initializers that point to data in memory, not officially supported by ONNX spec): - - This example stores initializers externally. However, instead of storing the initializers in a separate - file, the onnx::TensorProto objects point directly to memory addresses. This requires setting the initializer's - location to a special tag like "_MEM_ADDR_" (instead of a file path). The offset is set to the pointer to the - initializer's data in memory (instead of an offset into a file). - - Because this is not standard ONNX, such a onnx::GraphProto should not be saved as an ONNX file. - However, it allows custom tools that operate directly on a onnx::GraphProto to get the initializer data - if it has already been loaded into memory. - - ```C++ - #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - #include "ort_graph_to_proto.h" - - OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, - OrtEpGraphSupportInfo* graph_support_info) { - auto handle_initializer_data = [](const OrtValueInfo* value_info, - const void* data, size_t bytes, - bool& is_external, std::string& location, - int64_t& offset) -> Ort::Status { - (void)value_info; - (void)bytes; - - offset = reinterpret_cast(data); - location = "_MEM_ADDR_"; // Some special location tag that indicates the offset is a pointer. - is_external = true; // True if is external initializer - return Ort::Status{nullptr}; - } - - ONNX_NAMESPACE::GraphProto graph_proto; - OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); - - // graph_proto has initializers that look like they are stored in an external file, - // but they are actually pointing to the data in memory. - } - ``` -*/ - -#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ -#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ - -#include -#include "onnxruntime_cxx_api.h" -#include "onnx/onnx_pb.h" - -namespace OrtEpUtils { - -/// -/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. -/// -/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data -/// within the TensorProto as raw_data. -/// -/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the -/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the -/// location and offset returned via the `location` and `offset` output parameters. -/// -/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure -/// ONNX shape inference works correctly with the serialized ONNX model. -/// -/// OrtValueInfo for the initializer. Can be used to query name, type, shape, -/// and consumer nodes. -/// Opaque pointer to the initializer data. -/// Size in bytes of the initializer data. -/// Output parameter set to true if the initializer data is stored externally. The -/// implementer is responsible for writing the initializer data to file. If set to false, -/// the initializer will be stored within the TensorProto. -/// Output parameter set to the location (e.g., file) into which the initializer is stored -/// by the implementer of this function. Ignored if `is_external` is set to false. -/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored -/// by the implementer of this function. Ignored if `is_external` is set to false. -/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. -using HandleInitializerDataFunc = std::function; - -/// -/// Serializes the provided OrtGraph to a onnx::GraphProto. -/// Allows the caller to provide a function that specifies whether an initializer should be stored -/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). -/// -/// OrtGraph instance to serialize. -/// Destination GraphProto into which to serialize the input OrtGraph. -/// Optional function called to allow the user to determine -/// where the initializer data is stored. -/// An Ort::Status indicating success or an error. -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::GraphProto& graph_proto, - HandleInitializerDataFunc handle_initializer_data_func = nullptr); - -/// -/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. -/// Allows the caller to provide a function that specifies whether an initializer should be stored -/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). -/// -/// OrtGraph instance to serialize. -/// Destination ModelProto into which to serialize the input OrtGraph. -/// Optional function called to allow the user to determine -/// where the initializer data is stored. -/// An Ort::Status indicating success or an error. -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::ModelProto& model_proto, - HandleInitializerDataFunc handle_initializer_data_func = nullptr); -} // namespace OrtEpUtils - -// End of header -#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ - -// -// IMPLEMENTATION BELOW -// -#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL - -#include -#include -#include -#include -#include -#include - -#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ - do { \ - OrtStatus* _status = (fn); \ - if (_status != nullptr) { \ - return Ort::Status{_status}; \ - } \ - } while (0) - -#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ - do { \ - Ort::Status _status = (fn); \ - if (!_status.IsOK()) { \ - return _status; \ - } \ - } while (0) - -#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ - do { \ - if ((cond)) { \ - return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ - } \ - } while (0) - -namespace OrtEpUtils { - -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, - bool get_symbolic_dims, - /*out*/ ONNXTensorElementDataType& elem_type, - /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims); -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); - -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::GraphProto& graph_proto, - HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // - // Set GraphProto metadata - // - const char* graph_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); - graph_proto.set_name(graph_name); - graph_proto.set_doc_string("Serialized from OrtGraph"); - - // - // Set GraphProto inputs and outputs - // - size_t num_graph_inputs = 0; - size_t num_graph_outputs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); - - std::vector graph_inputs(num_graph_inputs); - std::vector graph_outputs(num_graph_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); - - for (const OrtValueInfo* ort_value_info : graph_inputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - for (const OrtValueInfo* ort_value_info : graph_outputs) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); - } - - // - // Set GraphProto nodes, value_infos, and initializers. - // - - // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. - // A std::map maintains its elements in a stable ordering. - std::map value_infos; // For GraphProto.value_info - std::map initializer_value_infos; // For GraphProto.initializer - - // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. - // Optionally returns the OrtValueInfo name to the caller. - auto collect_value_info = [&ort_api, &value_infos, - &initializer_value_infos](const OrtValueInfo& ort_value_info, - /*out*/ const char** value_name_out = nullptr) -> Ort::Status { - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - - if (value_name_out != nullptr) { - *value_name_out = value_name; - } - - if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { - return Ort::Status{nullptr}; // Already processed this OrtValueInfo. - } - - bool is_required_graph_input = false; - bool is_optional_graph_input = false; - bool is_graph_output = false; - bool is_constant_initializer = false; - bool is_from_outer_scope = false; - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); - - // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. - // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. - // For values defined in an outer scope, just add the value info but not the initializer. - if (is_from_outer_scope) { - value_infos.emplace(value_name, &ort_value_info); - } else if (is_optional_graph_input) { - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (is_constant_initializer) { - value_infos.emplace(value_name, &ort_value_info); - initializer_value_infos.emplace(value_name, &ort_value_info); - } else if (!is_required_graph_input && !is_graph_output) { - value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. - } - - return Ort::Status{nullptr}; - }; - - size_t num_nodes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); - - std::vector nodes(num_nodes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); - - // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos - // that will be stored in GraphProto.value_info and GraphProto.initializer. - for (size_t i = 0; i < num_nodes; i++) { - const OrtNode* ort_node = nodes[i]; - onnx::NodeProto* node_proto = graph_proto.add_node(); - - const char* node_name = nullptr; - const char* node_domain = nullptr; - const char* node_op_type = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); - - node_proto->set_name(node_name); - node_proto->set_domain(node_domain); - node_proto->set_op_type(node_op_type); - - size_t num_inputs = 0; - size_t num_implicit_inputs = 0; - size_t num_outputs = 0; - size_t num_attrs = 0; - size_t num_subgraphs = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); - - // Handle node attributes - if (num_attrs > 0) { - std::vector ort_attrs(num_attrs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); - - for (const OrtOpAttr* ort_attr : ort_attrs) { - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - - Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; - if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { - // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. - // Can use Node_GetSubgraphs to get subgraphs. - continue; - } - - if (!attr_type_status.IsOK()) { - // Unsupported attribute type. - return attr_type_status; - } - - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); - } - } - - // Handle node subgraphs - if (num_subgraphs > 0) { - std::vector ort_subgraphs(num_subgraphs); - std::vector subgraph_attr_names(num_subgraphs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), - subgraph_attr_names.data())); - - for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { - const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; - const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; - - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); - - attr_proto->set_name(subgraph_attr_name); - attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); - } - } - - // Handle node inputs - if (num_inputs > 0) { - std::vector ort_inputs(num_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_inputs) { - if (ort_value_info == nullptr) { - // missing optional input. - node_proto->add_input(""); - continue; - } - - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_input(value_name); - } - } - - // Handle implicit inputs to this node. - if (num_implicit_inputs > 0) { - std::vector ort_implicit_inputs(num_implicit_inputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), - ort_implicit_inputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { - assert(ort_value_info != nullptr); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); - } - } - - // Handle node outputs - if (num_outputs > 0) { - std::vector ort_outputs(num_outputs); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); - - for (const OrtValueInfo* ort_value_info : ort_outputs) { - if (ort_value_info == nullptr) { - // missing optional output. - node_proto->add_output(""); - continue; - } - - const char* value_name = nullptr; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); - - node_proto->add_output(value_name); - } - } - } - - // Add value_infos to GraphProto as ValueInfoProto objects. - for (const std::pair& entry : value_infos) { - onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); - } - - // Add initializers to GraphProto as TensorProto objects. - for (const std::pair& entry : initializer_value_infos) { - const OrtValueInfo* initializer_value_info = entry.second; - std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. - std::vector initializer_dims; - std::vector initializer_sym_dims; - ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, - initializer_elem_type, initializer_dims, - initializer_sym_dims)); - - onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); - tensor_proto->set_name(initializer_name); - tensor_proto->set_data_type(initializer_elem_type); - - auto* tensor_proto_dims = tensor_proto->mutable_dims(); - for (int64_t dim : initializer_dims) { - tensor_proto_dims->Add(dim); - } - - const OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); - - const void* data = nullptr; - size_t data_bytes = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); - - std::string ext_location; - int64_t ext_offset = 0; - bool is_external = false; - - if (handle_initializer_data_func != nullptr) { - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, - is_external, ext_location, ext_offset)); - } - - if (is_external) { - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); - auto* ext_data_entries = tensor_proto->mutable_external_data(); - onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); - onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); - - location_entry->set_key("location"); - location_entry->set_value(ext_location); - offset_entry->set_key("offset"); - offset_entry->set_value(std::to_string(ext_offset)); - length_entry->set_key("length"); - length_entry->set_value(std::to_string(data_bytes)); - } else { - // User wants to store data inline the TensorProto's raw_data - tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); - tensor_proto->set_raw_data(data, data_bytes); - } - } - - return Ort::Status{nullptr}; -} - -Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, - onnx::ModelProto& model_proto, - HandleInitializerDataFunc handle_initializer_data_func) { - const OrtApi& ort_api = Ort::GetApi(); - - // Check that OrtGraph is a top-level graph (no parent node). - const OrtNode* parent_node = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); - ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); - - // Set model description. - model_proto.set_doc_string("Serialized from OrtGraph"); - model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); - - // Set ir version. - int64_t ir_version = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); - model_proto.set_ir_version(ir_version); - - // Set operator sets. - size_t num_operator_sets = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); - ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); - - std::vector domains(num_operator_sets, nullptr); - std::vector opset_versions(num_operator_sets); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), - num_operator_sets)); - - auto* operator_sets = model_proto.mutable_opset_import(); - - for (size_t i = 0; i < num_operator_sets; ++i) { - onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); - operator_set->set_domain(domains[i]); - operator_set->set_version(opset_versions[i]); - } - - model_proto.clear_graph(); - onnx::GraphProto* graph_proto = model_proto.mutable_graph(); - - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); - - return Ort::Status{nullptr}; -} - -static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, - bool get_symbolic_dims, - /*out*/ ONNXTensorElementDataType& elem_type, - /*out*/ std::vector& dims, - /*out*/ std::vector& symbolic_dims) { - const OrtApi& ort_api = Ort::GetApi(); - - const OrtTypeInfo* ort_type_info = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); - - ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); - ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); - - const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); - - size_t num_dims = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); - - std::vector ort_dims(num_dims, 0); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); - - elem_type = ort_elem_type; - dims = std::move(ort_dims); - - if (get_symbolic_dims) { - std::vector ort_dim_syms(num_dims, nullptr); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), - ort_dim_syms.size())); - - symbolic_dims.reserve(num_dims); - for (const char* sym_dim : ort_dim_syms) { - symbolic_dims.push_back(sym_dim); - } - } - - return Ort::Status{nullptr}; -} - -// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). -static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, - onnx::ValueInfoProto& value_info_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - std::vector ort_dims; - std::vector ort_dim_syms; - ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - - // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, - ort_elem_type, ort_dims, ort_dim_syms)); - - const char* value_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); - value_info_proto.set_name(value_name); - - onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); - type_proto_tensor->set_elem_type(ort_elem_type); - - // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks - // like a scalar value. - if (!ort_dims.empty()) { - onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); - - for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { - onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); - - if (ort_dims[dim_idx] >= 0) { - dim_proto->set_dim_value(ort_dims[dim_idx]); - } else { - const std::string& dim_param = ort_dim_syms[dim_idx]; - - // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, - // which represents an unknown dimension. - if (!dim_param.empty()) { - dim_proto->set_dim_param(dim_param); - } - } - } - } - - return Ort::Status{nullptr}; -} - -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { - const OrtApi& ort_api = Ort::GetApi(); - - const char* attr_name = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); - attr_proto.set_name(attr_name); - - size_t total_attr_bytes = 0; - OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); - - switch (attr_type) { - case OrtOpAttrType::ORT_OP_ATTR_INT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); - - int64_t i_val = 0; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); - attr_proto.set_i(i_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_INTS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector i_vals(total_attr_bytes / sizeof(int64_t)); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* ints = attr_proto.mutable_ints(); - for (int64_t val : i_vals) { - ints->Add(val); - } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); - - float f_val = 0.0f; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); - attr_proto.set_f(f_val); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector f_vals(total_attr_bytes / sizeof(float)); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* floats = attr_proto.mutable_floats(); - for (float val : f_vals) { - floats->Add(val); - } - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRING: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::string* str = attr_proto.mutable_s(); - - str->resize(total_attr_bytes); - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, - &total_attr_bytes)); - - str->resize(total_attr_bytes); - break; - } - case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); - - // First call to ReadOpAttr gets the total byte size. Second call reads the data. - Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; - std::vector chars(total_attr_bytes, '\0'); - - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, - &total_attr_bytes)); - - auto* strs = attr_proto.mutable_strings(); - - // Strings are all in a single buffer, each separated with a '\0'. - // Extract each string and add it to the STRINGS attribute array. - char* at = chars.data(); - char* end = at + chars.size(); - - while (at < end) { - char* str_begin = at; - - while (*at && at < end) { - at++; - } - - strs->Add()->assign(str_begin, at - str_begin); - if (at < end) { - assert(*at == '\0'); - at++; // Skip '\0' to get to the beginning of the next string. - } - } - - break; - } - case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { - attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); - - onnx::TensorProto tensor_proto; - - // TensorProto as an attribute value doesn't require a name. - - OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); - - Ort::Value tensor(ort_value); - - // Get tensor type and shape info - Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); - - // Get tensor type - ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); - - size_t element_size = 0; - switch (element_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); - element_size = sizeof(float); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); - element_size = sizeof(uint8_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); - element_size = sizeof(int8_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); - element_size = sizeof(uint16_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); - element_size = sizeof(int16_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); - element_size = sizeof(int32_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); - element_size = sizeof(int64_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); - element_size = sizeof(bool); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); - element_size = sizeof(double); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); - element_size = sizeof(uint32_t); - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { - tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); - element_size = sizeof(uint64_t); - break; - } - default: { - std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); - } - } - - auto shape = type_shape_info.GetShape(); - - for (auto& dim : shape) { - tensor_proto.add_dims(dim); - } - - size_t element_count = type_shape_info.GetElementCount(); - size_t data_bytes = element_count * element_size; - const void* data = tensor.GetTensorData(); - - // Copy the Ortvalue to TensorProto as raw data - tensor_proto.set_raw_data(data, data_bytes); - - *(attr_proto.mutable_t()) = std::move(tensor_proto); - break; - } - default: { - std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); - return Ort::Status(err_msg.c_str(), ORT_FAIL); - } - } - - return Ort::Status{nullptr}; -} - -} // namespace OrtEpUtils -#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL