forked from qoherent/modrec-workflow
added profiling for the onnx model
This commit is contained in:
parent
b3d17f804c
commit
3a32a83c34
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -5,3 +5,4 @@ __pycache__/
|
||||||
*.ckpt
|
*.ckpt
|
||||||
*.ipynb
|
*.ipynb
|
||||||
*.onnx
|
*.onnx
|
||||||
|
*.json
|
|
@ -67,7 +67,7 @@ jobs:
|
||||||
|
|
||||||
- name: 3. Convert to ONNX file
|
- name: 3. Convert to ONNX file
|
||||||
run: |
|
run: |
|
||||||
PYTHONPATH=. python convert_to_onnx.py
|
PYTHONPATH=. python onnx_scripts/convert_to_onnx.py
|
||||||
echo "building inference app"
|
echo "building inference app"
|
||||||
|
|
||||||
- name: Upload ONNX file
|
- name: Upload ONNX file
|
||||||
|
@ -76,11 +76,23 @@ jobs:
|
||||||
name: ria-demo-onnx
|
name: ria-demo-onnx
|
||||||
path: onnx_files/inference_recognition_model.onnx
|
path: onnx_files/inference_recognition_model.onnx
|
||||||
|
|
||||||
|
- name: 4. Profile ONNX model
|
||||||
|
run: |
|
||||||
|
PYTHONPATH=. python onnx_scripts/profile_onnx.py
|
||||||
|
|
||||||
|
- name: Upload JSON profiling data
|
||||||
|
uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
name: profile-data
|
||||||
|
path: '**/onnxruntime_profile_*.json'
|
||||||
|
|
||||||
- name: 4. Convert to ORT file
|
- name: 4. Convert to ORT file
|
||||||
run: |
|
run: |
|
||||||
python -m onnxruntime.tools.convert_onnx_models_to_ort \
|
python -m onnxruntime.tools.convert_onnx_models_to_ort \
|
||||||
--input /onnx_files/inference_recognition_model.onnx \
|
--input /onnx_files/inference_recognition_model.onnx \
|
||||||
--output /ort_files/inference_recognition_model.ort \
|
--output /ort_files/inference_recognition_model.ort \
|
||||||
|
--optimization_style Fixed \
|
||||||
|
--target_platform amd64
|
||||||
|
|
||||||
- name: Upload ORT file
|
- name: Upload ORT file
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
|
|
42
onnx_scripts/profile_onnx.py
Normal file
42
onnx_scripts/profile_onnx.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import onnxruntime as ort
|
||||||
|
import numpy as np
|
||||||
|
from helpers.app_settings import get_app_settings
|
||||||
|
from onnx_files import ONNX_DIR
|
||||||
|
import os
|
||||||
|
|
||||||
|
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
|
||||||
|
# Set up session options
|
||||||
|
options = ort.SessionOptions()
|
||||||
|
options.enable_profiling = True
|
||||||
|
|
||||||
|
# Enables cleanup of QuantizeLinear/DequantizeLinear node pairs (optional optimization)
|
||||||
|
options.add_session_config_entry("session.enable_quant_qdq_cleanup", "1")
|
||||||
|
|
||||||
|
# Set workload type for efficiency (low scheduling priority)
|
||||||
|
options.add_session_config_entry("ep.dynamic.workload_type", "Efficient")
|
||||||
|
|
||||||
|
# Create inference session on CPU
|
||||||
|
session = ort.InferenceSession(path_to_onnx, sess_options=options, providers=["CPUExecutionProvider"])
|
||||||
|
print("Session providers:", session.get_providers())
|
||||||
|
|
||||||
|
# Get model input details
|
||||||
|
input_name = session.get_inputs()[0].name
|
||||||
|
input_shape = session.get_inputs()[0].shape
|
||||||
|
|
||||||
|
# Generate dummy input data
|
||||||
|
# If model expects dynamic shape (None), replace with fixed size (e.g. batch 1)
|
||||||
|
input_shape = [dim if isinstance(dim, int) and dim > 0 else 1 for dim in input_shape]
|
||||||
|
input_data = np.random.randn(*input_shape).astype(np.float32)
|
||||||
|
|
||||||
|
# Run inference multiple times to collect profiling data
|
||||||
|
for _ in range(num_runs):
|
||||||
|
session.run(None, {input_name: input_data})
|
||||||
|
|
||||||
|
# End profiling and get profile file path
|
||||||
|
profile_file = session.end_profiling()
|
||||||
|
print(f"Profiling saved to: {profile_file}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
settings = get_app_settings()
|
||||||
|
output_path = os.path.join(ONNX_DIR, f"{settings.inference.onnx_model_filename}.onnx")
|
||||||
|
profile_onnx_model(output_path)
|
|
@ -8,4 +8,5 @@ scikit_learn
|
||||||
timm
|
timm
|
||||||
torch
|
torch
|
||||||
onnx
|
onnx
|
||||||
|
onnxruntime
|
||||||
./wheel/utils-0.1.2.dev0-py3-none-any.whl
|
./wheel/utils-0.1.2.dev0-py3-none-any.whl
|
Loading…
Reference in New Issue
Block a user