Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 70 additions & 1 deletion examples/benchmarks/ort_inference_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,28 @@
"""Micro benchmark example for ONNXRuntime inference performance.

Commands to run:
In-house models:
python3 examples/benchmarks/ort_inference_performance.py
python3 examples/benchmarks/ort_inference_performance.py --model_source in-house

HuggingFace models:
python3 examples/benchmarks/ort_inference_performance.py \
--model_source huggingface --model_identifier bert-base-uncased
python3 examples/benchmarks/ort_inference_performance.py \
--model_source huggingface --model_identifier microsoft/resnet-50

Environment variables:
HF_TOKEN: HuggingFace token for gated models (optional)
"""

import argparse

from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger

if __name__ == '__main__':

def run_inhouse_benchmark():
"""Run ORT inference with in-house torchvision models."""
context = BenchmarkRegistry.create_benchmark_context(
'ort-inference', platform=Platform.CUDA, parameters='--pytorch_models resnet50 resnet101 --precision float16'
)
Expand All @@ -21,3 +36,57 @@
benchmark.name, benchmark.return_code, benchmark.result
)
)
return benchmark


def run_huggingface_benchmark(model_identifier, precision='float16', batch_size=32, seq_length=512):
"""Run ORT inference with a HuggingFace model.

Args:
model_identifier: HuggingFace model ID (e.g., 'bert-base-uncased').
precision: Inference precision ('float32', 'float16', 'int8').
batch_size: Batch size for inference.
seq_length: Sequence length for transformer models.
"""
parameters = (
f'--model_source huggingface '
f'--model_identifier {model_identifier} '
f'--precision {precision} '
f'--batch_size {batch_size} '
f'--seq_length {seq_length}'
)

logger.info(f'Running ORT inference benchmark with HuggingFace model: {model_identifier}')

context = BenchmarkRegistry.create_benchmark_context('ort-inference', platform=Platform.CUDA, parameters=parameters)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
return benchmark


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ORT inference benchmark')
parser.add_argument(
'--model_source',
type=str,
default='in-house',
choices=['in-house', 'huggingface'],
help='Source of the model: in-house (default) or huggingface'
)
parser.add_argument(
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
)
parser.add_argument('--precision', type=str, default='float16', choices=['float32', 'float16', 'int8'])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--seq_length', type=int, default=512)
args = parser.parse_args()

if args.model_source == 'huggingface':
run_huggingface_benchmark(args.model_identifier, args.precision, args.batch_size, args.seq_length)
else:
run_inhouse_benchmark()
78 changes: 77 additions & 1 deletion examples/benchmarks/tensorrt_inference_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,28 @@
"""Micro benchmark example for TensorRT inference performance.

Commands to run:
In-house models:
python3 examples/benchmarks/tensorrt_inference_performance.py
python3 examples/benchmarks/tensorrt_inference_performance.py --model_source in-house

HuggingFace models:
python3 examples/benchmarks/tensorrt_inference_performance.py \
--model_source huggingface --model_identifier bert-base-uncased
python3 examples/benchmarks/tensorrt_inference_performance.py \
--model_source huggingface --model_identifier microsoft/resnet-50

Environment variables:
HF_TOKEN: HuggingFace token for gated models (optional)
"""

import argparse

from superbench.benchmarks import BenchmarkRegistry, Platform
from superbench.common.utils import logger

if __name__ == '__main__':

def run_inhouse_benchmark():
"""Run TensorRT inference with in-house torchvision models."""
context = BenchmarkRegistry.create_benchmark_context('tensorrt-inference', platform=Platform.CUDA)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
Expand All @@ -19,3 +34,64 @@
benchmark.name, benchmark.return_code, benchmark.result
)
)
return benchmark


def run_huggingface_benchmark(model_identifier, precision='fp16', batch_size=32, seq_length=512, iterations=2048):
"""Run TensorRT inference with a HuggingFace model.

Args:
model_identifier: HuggingFace model ID (e.g., 'bert-base-uncased').
precision: Inference precision ('fp32', 'fp16', 'int8').
batch_size: Batch size for inference.
seq_length: Sequence length for transformer models.
iterations: Number of inference iterations.
"""
parameters = (
f'--model_source huggingface '
f'--model_identifier {model_identifier} '
f'--precision {precision} '
f'--batch_size {batch_size} '
f'--seq_length {seq_length} '
f'--iterations {iterations}'
)

logger.info(f'Running TensorRT inference benchmark with HuggingFace model: {model_identifier}')

context = BenchmarkRegistry.create_benchmark_context(
'tensorrt-inference', platform=Platform.CUDA, parameters=parameters
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
return benchmark


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TensorRT inference benchmark')
parser.add_argument(
'--model_source',
type=str,
default='in-house',
choices=['in-house', 'huggingface'],
help='Source of the model: in-house (default) or huggingface'
)
parser.add_argument(
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
)
parser.add_argument('--precision', type=str, default='fp16', choices=['fp32', 'fp16', 'int8'])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--seq_length', type=int, default=512)
parser.add_argument('--iterations', type=int, default=2048)
args = parser.parse_args()

if args.model_source == 'huggingface':
run_huggingface_benchmark(
args.model_identifier, args.precision, args.batch_size, args.seq_length, args.iterations
)
else:
run_inhouse_benchmark()
166 changes: 158 additions & 8 deletions superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,23 @@
import torch.hub
import torch.onnx
import torchvision.models
from transformers import BertConfig, GPT2Config, LlamaConfig

from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel
import traceback

if MixtralBenchmarkModel is not None:
from transformers import MixtralConfig
from superbench.common.utils import logger


class torch2onnxExporter():
"""PyTorch model to ONNX exporter."""
def __init__(self):
"""Constructor."""
from transformers import BertConfig, GPT2Config, LlamaConfig
from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel

self.num_classes = 100
self.lstm_input_size = 256
self.benchmark_models = {
Expand Down Expand Up @@ -129,6 +130,7 @@ def __init__(self):

# Only include Mixtral models if MixtralBenchmarkModel is available
if MixtralBenchmarkModel is not None:
from transformers import MixtralConfig
self.benchmark_models.update(
{
'mixtral-8x7b':
Expand Down Expand Up @@ -270,3 +272,151 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512):
del dummy_input
torch.cuda.empty_cache()
return file_name

Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new export_huggingface_model() introduces multiple branches (vision vs NLP, dynamic axes behavior, and external-data conversion for >2GB) but there’s no targeted unit test coverage shown for this method. Consider adding mocked/unit tests that validate: (1) correct input/output names for vision and NLP paths, and (2) external-data conversion is invoked when the size threshold is exceeded (can be done by mocking parameter sizing and ONNX helpers).

Suggested change
_ONNX_EXTERNAL_DATA_THRESHOLD_BYTES = 2 * 1024 * 1024 * 1024
def _get_model_parameter_size_bytes(self, model):
"""Return the total serialized parameter size in bytes for a model.
This helper is intentionally isolated so unit tests can mock parameter
shapes/sizes without performing a real ONNX export.
Args:
model: Model instance exposing ``parameters()``.
Returns:
int: Total parameter size in bytes.
"""
total_size = 0
for parameter in model.parameters():
total_size += parameter.nelement() * parameter.element_size()
return total_size
def _should_use_external_data_format(self, model):
"""Return whether ONNX external-data format should be used.
Args:
model: Model instance exposing ``parameters()``.
Returns:
bool: True when the model size exceeds the ONNX 2GB threshold.
"""
return self._get_model_parameter_size_bytes(model) > self._ONNX_EXTERNAL_DATA_THRESHOLD_BYTES
def _build_huggingface_export_config(self, model, batch_size=1, seq_length=512):
"""Build dummy input and ONNX I/O metadata for HuggingFace export.
This helper extracts the vision-vs-NLP branch logic into a directly
testable unit so mocked tests can validate input/output names and
dynamic axes without requiring a full export.
Args:
model: HuggingFace model instance to export.
batch_size (int): Batch size of input. Defaults to 1.
seq_length (int): Sequence length of input. Defaults to 512.
Returns:
tuple: (dummy_input, input_names, output_names, dynamic_axes)
"""
config = getattr(model, 'config', None)
model_type = getattr(config, 'model_type', '')
# Vision models typically consume pixel_values with NCHW layout.
if model_type in ('vit', 'swin', 'convnext', 'beit', 'deit', 'resnet', 'detr'):
dummy_input = torch.randn((batch_size, 3, 224, 224), device='cuda')
input_names = ['pixel_values']
output_names = ['output']
dynamic_axes = {
'pixel_values': {
0: 'batch_size',
},
'output': {
0: 'batch_size',
}
}
return dummy_input, input_names, output_names, dynamic_axes
# Default HuggingFace NLP-style export.
dummy_input = torch.ones((batch_size, seq_length), dtype=torch.int64, device='cuda')
input_names = ['input_ids']
output_names = ['output']
dynamic_axes = {
'input_ids': {
0: 'batch_size',
1: 'seq_length',
},
'output': {
0: 'batch_size',
}
}
return dummy_input, input_names, output_names, dynamic_axes

Copilot uses AI. Check for mistakes.
def export_huggingface_model(self, model, model_name, batch_size=1, seq_length=512, output_dir=None):
"""Export a HuggingFace model to ONNX format.

Args:
model: HuggingFace model instance to export.
model_name (str): Name for the exported ONNX model file.
batch_size (int): Batch size of input. Defaults to 1.
seq_length (int): Sequence length of input. Defaults to 512.
output_dir (str): Output directory path. If None, uses default path.

Returns:
str: Exported ONNX model file path, or empty string if export fails.
"""
try:
# Use custom output directory if provided
output_path = Path(output_dir) if output_dir else self._onnx_model_path
file_name = str(output_path / (model_name + '.onnx'))

# Put model in eval mode and move to CUDA if available
model.eval()

# Disable cache to avoid DynamicCache issues with ONNX export
if hasattr(model.config, 'use_cache'):
model.config.use_cache = False

if torch.cuda.is_available():
model = model.cuda()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Get model's dtype for inputs
model_dtype = next(model.parameters()).dtype
Comment on lines +294 to +307
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This exporter forces the model onto CUDA whenever available, even if the caller intentionally loaded the model on CPU (both ORT/TensorRT preprocessing comments indicate CPU loading to avoid device-map/dispatch issues). This can reintroduce GPU memory pressure and device mismatch problems during export. Consider respecting the model’s current device (or adding a parameter to control export device) rather than always moving to CUDA.

Suggested change
# Put model in eval mode and move to CUDA if available
model.eval()
# Disable cache to avoid DynamicCache issues with ONNX export
if hasattr(model.config, 'use_cache'):
model.config.use_cache = False
if torch.cuda.is_available():
model = model.cuda()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Get model's dtype for inputs
model_dtype = next(model.parameters()).dtype
# Put model in eval mode and respect its current device placement
model.eval()
# Disable cache to avoid DynamicCache issues with ONNX export
if hasattr(model.config, 'use_cache'):
model.config.use_cache = False
try:
first_param = next(model.parameters())
device = first_param.device
model_dtype = first_param.dtype
except StopIteration:
try:
first_buffer = next(model.buffers())
device = first_buffer.device
model_dtype = first_buffer.dtype
except StopIteration:
device = torch.device('cpu')
model_dtype = torch.float32
# Get model's dtype for inputs

Copilot uses AI. Check for mistakes.

# Detect model type and create appropriate inputs
# Vision models use pixel_values, NLP models use input_ids
# Use HuggingFace's main_input_name property for automatic detection
main_input = getattr(model, 'main_input_name', 'input_ids')
is_vision_model = main_input == 'pixel_values'

if is_vision_model:
# Vision models: use pixel_values (batch_size, channels, height, width)
# Standard ImageNet size is 224x224, 3 channels
# Match the dtype of the model
dummy_input = torch.randn(batch_size, 3, 224, 224, dtype=model_dtype, device=device)
input_names = ['pixel_values']
dynamic_axes = {'pixel_values': {0: 'batch_size'}, 'output': {0: 'batch_size'}}

# Wrapper for vision models
class VisionModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, pixel_values):
outputs = self.model(pixel_values=pixel_values)
if hasattr(outputs, 'logits'):
return outputs.logits
elif hasattr(outputs, 'last_hidden_state'):
return outputs.last_hidden_state
else:
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs

wrapped_model = VisionModelWrapper(model)
export_args = (dummy_input, )
else:
# NLP models: use input_ids and attention_mask
dummy_input = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
input_names = ['input_ids', 'attention_mask']
dynamic_axes = {
'input_ids': {
0: 'batch_size',
1: 'seq_length'
},
'attention_mask': {
0: 'batch_size',
1: 'seq_length'
},
'output': {
0: 'batch_size'
},
}
Comment on lines +345 to +357
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For many NLP models the exported output shape is sequence-dependent (e.g., logits/hidden states often have a seq_length dimension). Currently only the batch dimension is marked dynamic for output, which can lock the exported ONNX to a fixed seq_length and break dynamic-shape inference/engine building. Consider adding the sequence dimension to output’s dynamic_axes when the model output is 3D (batch, seq, hidden/vocab).

Copilot uses AI. Check for mistakes.

# Wrapper for NLP models
class NLPModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
if hasattr(outputs, 'logits'):
return outputs.logits
elif hasattr(outputs, 'last_hidden_state'):
return outputs.last_hidden_state
else:
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs

wrapped_model = NLPModelWrapper(model)
export_args = (dummy_input, attention_mask)

# Export to ONNX for large models (>2GB), use external data format
model_size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3)
use_external_data = model_size_gb > 2.0

if use_external_data:
logger.info(f'Model size is {model_size_gb:.2f}GB, using external data format for ONNX export')

torch.onnx.export(
wrapped_model,
export_args,
file_name,
opset_version=14,
do_constant_folding=True,
input_names=input_names,
output_names=['output'],
dynamic_axes=dynamic_axes,
Copy link

Copilot AI Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For models >2GB, exporting without enabling external-data at export time can fail due to protobuf size limits (the subsequent convert_model_to_external_data() may never run). Pass the appropriate export-time option (e.g., use_external_data_format=use_external_data) so large-model exports succeed reliably.

Suggested change
dynamic_axes=dynamic_axes,
dynamic_axes=dynamic_axes,
use_external_data_format=use_external_data,

Copilot uses AI. Check for mistakes.
)
Comment on lines +377 to +393
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For models larger than ~2GB, torch.onnx.export(...) may fail before the later convert_model_to_external_data(...) step due to protobuf size limits, because the export itself still attempts to serialize initializers into the main ONNX file. For large-model support to be reliable, enable PyTorch's large/external-data export mode at export time (e.g., using the appropriate large_model / use_external_data_format option supported by your PyTorch version) rather than only converting after the fact.

Copilot uses AI. Check for mistakes.

# If using external data, convert to external data format
if use_external_data:
import onnx
from onnx.external_data_helper import convert_model_to_external_data

onnx_model = onnx.load(file_name)
external_data_path = model_name + '_data.bin'
convert_model_to_external_data(
onnx_model,
all_tensors_to_one_file=True,
location=external_data_path,
size_threshold=1024,
convert_attribute=False
)
onnx.save(onnx_model, file_name)
logger.info(f'Converted ONNX model to external data format: {external_data_path}')

# Clean up
del dummy_input
if torch.cuda.is_available():
torch.cuda.empty_cache()

return file_name

except Exception as e:
logger.error(f'Failed to export HuggingFace model to ONNX: {str(e)}')
logger.error(traceback.format_exc())
return ''
Loading
Loading