modrec-workflow/scripts/onnx/profile_onnx.py
2025-06-18 13:44:29 -04:00

90 lines
3.1 KiB
Python

import onnxruntime as ort
import numpy as np
from helpers.app_settings import get_app_settings
import os
import time
import json
def profile_onnx_model(
path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5
) -> None:
"""
Profiles an ONNX model by running inference multiple times and collecting performance data.
Prints session initialization time, provider used, average inference time (excluding warm-up),
and parses the ONNX Runtime JSON trace to show the most expensive operation.
Parameters:
path_to_onnx (str): Path to the ONNX model file.
num_runs (int): Number of inference runs (including warm-ups).
warmup_runs (int): Number of warm-up runs to skip from timing.
"""
# Session setup
options = ort.SessionOptions()
options.enable_profiling = True
options.add_session_config_entry("session.enable_quant_qdq_cleanup", "1")
options.add_session_config_entry("ep.dynamic.workload_type", "Efficient")
# Try GPU, then fallback to CPU
try:
start_time = time.time()
session = ort.InferenceSession(
path_to_onnx, sess_options=options, providers=["CUDAExecutionProvider"]
)
print("Running on the GPU")
except Exception as e:
session = ort.InferenceSession(
path_to_onnx, sess_options=options, providers=["CPUExecutionProvider"]
)
print("Could not find GPU, running on CPU")
end_time = time.time()
print(f"[Timing] Model load + session init time: {end_time - start_time:.4f} sec")
print("Session providers:", session.get_providers())
# Prepare dummy input
input_name = session.get_inputs()[0].name
input_shape = [
dim if isinstance(dim, int) and dim > 0 else 1
for dim in session.get_inputs()[0].shape
]
input_data = np.random.randn(*input_shape).astype(np.float32)
# Time multiple inferences (skip warm-up)
times = []
for i in range(num_runs):
t0 = time.time()
session.run(None, {input_name: input_data})
t1 = time.time()
if i >= warmup_runs:
times.append(t1 - t0)
avg_time = sum(times) / len(times)
print(
f"[Timing] Avg inference time (excluding {warmup_runs} warm-ups): {avg_time:.6f} sec"
)
# End profiling & parse JSON
profile_file = session.end_profiling()
print(f"[Output] Profiling trace saved to: {profile_file}")
try:
with open(profile_file, "r") as f:
trace = json.load(f)
nodes = [e for e in trace if e.get("cat") == "Node"]
print(f"[Profile] Number of nodes executed: {len(nodes)}")
if nodes:
top = max(nodes, key=lambda x: x.get("dur", 0))
print(
f"[Profile] Most expensive op: {top['name']}{top['dur'] / 1e6:.3f} ms"
)
except Exception as e:
print(f"[Warning] Failed to parse profiling JSON: {e}")
if __name__ == "__main__":
settings = get_app_settings()
output_path = os.path.join("onnx_files", "inference_recognition_model.onnx")
profile_onnx_model(output_path)