43 lines
1.7 KiB
Python
43 lines
1.7 KiB
Python
import onnxruntime as ort
|
|
import numpy as np
|
|
from helpers.app_settings import get_app_settings
|
|
from onnx_files import ONNX_DIR
|
|
import os
|
|
|
|
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
|
|
# Set up session options
|
|
options = ort.SessionOptions()
|
|
options.enable_profiling = True
|
|
|
|
# Enables cleanup of QuantizeLinear/DequantizeLinear node pairs (optional optimization)
|
|
options.add_session_config_entry("session.enable_quant_qdq_cleanup", "1")
|
|
|
|
# Set workload type for efficiency (low scheduling priority)
|
|
options.add_session_config_entry("ep.dynamic.workload_type", "Efficient")
|
|
|
|
# Create inference session on CPU
|
|
session = ort.InferenceSession(path_to_onnx, sess_options=options, providers=["CPUExecutionProvider"])
|
|
print("Session providers:", session.get_providers())
|
|
|
|
# Get model input details
|
|
input_name = session.get_inputs()[0].name
|
|
input_shape = session.get_inputs()[0].shape
|
|
|
|
# Generate dummy input data
|
|
# If model expects dynamic shape (None), replace with fixed size (e.g. batch 1)
|
|
input_shape = [dim if isinstance(dim, int) and dim > 0 else 1 for dim in input_shape]
|
|
input_data = np.random.randn(*input_shape).astype(np.float32)
|
|
|
|
# Run inference multiple times to collect profiling data
|
|
for _ in range(num_runs):
|
|
session.run(None, {input_name: input_data})
|
|
|
|
# End profiling and get profile file path
|
|
profile_file = session.end_profiling()
|
|
print(f"Profiling saved to: {profile_file}")
|
|
|
|
if __name__ == "__main__":
|
|
settings = get_app_settings()
|
|
output_path = os.path.join(ONNX_DIR, f"{settings.inference.onnx_model_filename}.onnx")
|
|
profile_onnx_model(output_path)
|