added updates to confusion matrix/riahub secrets
All checks were successful
RIA Hub Workflow Demo / ria-demo (push) Successful in 2m45s
All checks were successful
RIA Hub Workflow Demo / ria-demo (push) Successful in 2m45s
This commit is contained in:
parent
3437512d7c
commit
ea6beda81b
|
@ -34,10 +34,14 @@ jobs:
|
|||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
- name: Install dependencies (incl. RIA Hub utils)
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install \
|
||||
--index-url "https://${{ secrets.RIAHUB_USER }}:${{ secrets.RIAHUB_TOKEN }}@git.riahub.ai/api/packages/qoherent/pypi/simple/" \
|
||||
utils \
|
||||
-r requirements.txt
|
||||
|
||||
|
||||
|
||||
- name: 1. Generate Recordings
|
||||
|
@ -89,7 +93,7 @@ jobs:
|
|||
path: checkpoint_files/*
|
||||
|
||||
|
||||
- name: 4. Convert to ONNX file
|
||||
- name: 5. Convert to ONNX file
|
||||
env:
|
||||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
|
@ -104,11 +108,7 @@ jobs:
|
|||
name: onnx-file
|
||||
path: onnx_files/inference_recognition_model.onnx
|
||||
|
||||
- name: List checkpoint directory
|
||||
run: ls -lh onnx_files
|
||||
|
||||
|
||||
- name: 5. Profile ONNX model
|
||||
- name: 6. Profile ONNX model
|
||||
run: |
|
||||
PYTHONPATH=. python scripts/onnx/profile_onnx.py
|
||||
|
||||
|
@ -118,7 +118,7 @@ jobs:
|
|||
name: profile-data
|
||||
path: '**/onnxruntime_profile_*.json'
|
||||
|
||||
- name: 6. Convert to ORT file
|
||||
- name: 7. Convert to ORT file
|
||||
run: |
|
||||
PYTHONPATH=. python scripts/ort/convert_to_ort.py
|
||||
|
||||
|
|
11
README.md
11
README.md
|
@ -23,6 +23,17 @@ dataset:
|
|||
val_split : 0.2
|
||||
```
|
||||
|
||||
### Configure GitHub Secrets
|
||||
|
||||
Before running the pipeline, add the following repository secrets in GitHub (Settings → Secrets and variables → Actions):
|
||||
|
||||
- **RIAHUB_USER**: Your RIA Hub username.
|
||||
- **RIAHUB_TOKEN**: RIA Hub access token with `read:packages` scope (from your RIA Hub account **Settings → Access Tokens**).
|
||||
- **CLONER_TOKEN**: Personal access token for `stark_cloner_bot` with `read_repository` scope (from your on-prem Git server user settings).
|
||||
|
||||
Once secrets are configured, you can run the pipeline:
|
||||
|
||||
|
||||
3. Run the Pipeline
|
||||
Once you update the changes to app.yaml, you can make any push or pull to your repo to start running the workflow
|
||||
|
||||
|
|
|
@ -13,9 +13,12 @@ command = [
|
|||
"-m",
|
||||
"onnxruntime.tools.convert_onnx_models_to_ort",
|
||||
input_path,
|
||||
"--output_dir", "ort_files",
|
||||
"--optimization_style", optimization_style,
|
||||
"--target_platform", target_platform,
|
||||
"--output_dir",
|
||||
"ort_files",
|
||||
"--optimization_style",
|
||||
optimization_style,
|
||||
"--target_platform",
|
||||
target_platform,
|
||||
]
|
||||
|
||||
# Run the command
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
import torch
|
||||
import numpy as np
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
@ -13,13 +14,11 @@ from scripts.training.modulation_dataset import ModulationH5Dataset
|
|||
|
||||
def load_validation_data():
|
||||
val_dataset = ModulationH5Dataset(
|
||||
"data/dataset/val.h5",
|
||||
label_name="modulation",
|
||||
data_key="validation_data"
|
||||
"data/dataset/val.h5", label_name="modulation", data_key="validation_data"
|
||||
)
|
||||
|
||||
X = np.stack([x.numpy() for x, _ in val_dataset]) # shape: (N, C, L)
|
||||
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
||||
y = np.array([y.item() for _, y in val_dataset]) # shape: (N,)
|
||||
class_names = list(val_dataset.label_encoder.classes_)
|
||||
|
||||
return X, y, class_names
|
||||
|
@ -72,16 +71,71 @@ def evaluate_checkpoint(ckpt_path: str):
|
|||
|
||||
# Print classification report
|
||||
print("\nClassification Report:")
|
||||
print(classification_report(y_true, y_pred, target_names=class_names, zero_division=0))
|
||||
print(
|
||||
classification_report(y_true, y_pred, target_names=class_names, zero_division=0)
|
||||
)
|
||||
|
||||
# Plot confusion matrix
|
||||
plot_confusion_matrix(
|
||||
plot_confusion_matrix_with_counts(
|
||||
y_true=np.array(y_true),
|
||||
y_pred=np.array(y_pred),
|
||||
classes=class_names,
|
||||
normalize=True,
|
||||
title="Normalized Confusion Matrix",
|
||||
)
|
||||
|
||||
|
||||
def plot_confusion_matrix_with_counts(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
classes: list[str],
|
||||
normalize: bool = True,
|
||||
title: str = "Confusion Matrix (counts and normalized)",
|
||||
) -> None:
|
||||
"""
|
||||
Plot a confusion matrix showing both raw counts and (optionally) normalized values.
|
||||
|
||||
Args:
|
||||
y_true: true labels (integers 0..C-1)
|
||||
y_pred: predicted labels (same shape as y_true)
|
||||
classes: list of class‐name strings in index order
|
||||
normalize: if True, each row is normalized to sum=1
|
||||
title: title for the plot
|
||||
"""
|
||||
# 1) build raw CM
|
||||
C = len(classes)
|
||||
cm = np.zeros((C, C), dtype=int)
|
||||
for t, p in zip(y_true, y_pred):
|
||||
cm[t, p] += 1
|
||||
|
||||
# 2) normalize if requested
|
||||
if normalize:
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
cm_norm = cm.astype(float) / cm.sum(axis=1)[:, None]
|
||||
cm_norm = np.nan_to_num(cm_norm)
|
||||
else:
|
||||
cm_norm = cm
|
||||
|
||||
# 3) plot
|
||||
fig, ax = plt.subplots(figsize=(8, 8))
|
||||
im = ax.imshow(cm_norm, interpolation="nearest")
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Predicted label")
|
||||
ax.set_ylabel("True label")
|
||||
ax.set_xticks(np.arange(C))
|
||||
ax.set_yticks(np.arange(C))
|
||||
ax.set_xticklabels(classes, rotation=45, ha="right")
|
||||
ax.set_yticklabels(classes)
|
||||
|
||||
# 4) annotate
|
||||
for i in range(C):
|
||||
for j in range(C):
|
||||
count = cm[i, j]
|
||||
val = cm_norm[i, j]
|
||||
txt = f"{count}\n{val:.2f}"
|
||||
ax.text(j, i, txt, ha="center", va="center")
|
||||
|
||||
fig.colorbar(im, ax=ax, label="Normalized value" if normalize else "Count")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user