forked from qoherent/modrec-workflow
commented functions, added in type defintions
This commit is contained in:
parent
06bd4d4001
commit
bb4f739535
|
@ -42,12 +42,18 @@ jobs:
|
|||
mkdir -p data/recordings
|
||||
PYTHONPATH=. python scripts/dataset_building/data_gen.py --output-dir data/recordings
|
||||
echo "recordings produced successfully"
|
||||
|
||||
- name: Upload Recordings
|
||||
|
||||
- name: 📦 Zip and Upload Recordings
|
||||
run: |
|
||||
echo "📦 Zipping recordings..."
|
||||
zip -qr recordings.zip data/recordings
|
||||
shell: bash
|
||||
|
||||
- name: ⬆️ Upload zipped recordings
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: recordings
|
||||
path: data/recordings/**
|
||||
path: recordings.zip
|
||||
|
||||
- name: 2. Build HDF5 Dataset
|
||||
run: |
|
||||
|
@ -56,12 +62,14 @@ jobs:
|
|||
echo "datasets produced successfully"
|
||||
shell: bash
|
||||
|
||||
- name: Upload Dataset Artifacts
|
||||
- name: 📦 Zip Dataset
|
||||
run: zip -qr dataset.zip data/dataset
|
||||
|
||||
- name: 📤 Upload Zipped Dataset
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: dataset
|
||||
path: data/dataset/**
|
||||
|
||||
path: dataset.zip
|
||||
|
||||
- name: 3. Train Model
|
||||
env:
|
||||
|
@ -80,6 +88,9 @@ jobs:
|
|||
|
||||
|
||||
- name: 4. Convert to ONNX file
|
||||
env:
|
||||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
run: |
|
||||
mkdir -p onnx_files
|
||||
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/onnx/convert_to_onnx.py
|
||||
|
|
|
@ -33,6 +33,24 @@ dataset:
|
|||
# Number of samples per generated recording
|
||||
recording_length: 1024
|
||||
|
||||
# Settings for each modulation scheme
|
||||
# Keys must match entries in `modulation_types`
|
||||
# - `num_bits_per_symbol`: how many bits each symbol encodes (e.g., 1 for BPSK, 4 for 16-QAM)
|
||||
# - `constellation_type`: type of modulation (e.g., "psk", "qam", "fsk", "ofdm")
|
||||
modulation_settings:
|
||||
bpsk:
|
||||
num_bits_per_symbol: 1
|
||||
constellation_type: psk
|
||||
qpsk:
|
||||
num_bits_per_symbol: 2
|
||||
constellation_type: psk
|
||||
qam16:
|
||||
num_bits_per_symbol: 4
|
||||
constellation_type: qam
|
||||
qam64:
|
||||
num_bits_per_symbol: 6
|
||||
constellation_type: qam
|
||||
|
||||
|
||||
|
||||
training:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -24,6 +25,7 @@ class DataSetConfig:
|
|||
snr_step: int
|
||||
num_iterations: int
|
||||
recording_length: int
|
||||
modulation_settings: Dict[str, Dict[str, str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -35,7 +37,6 @@ class TrainingConfig:
|
|||
drop_rate: float
|
||||
drop_path_rate: float
|
||||
wd: int
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -2,25 +2,39 @@ from utils.data import Recording
|
|||
import numpy as np
|
||||
from utils.signal import block_generator
|
||||
import argparse
|
||||
import os
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
mods = {
|
||||
"bpsk": {"num_bits_per_symbol": 1, "constellation_type": "psk"},
|
||||
"qpsk": {"num_bits_per_symbol": 2, "constellation_type": "psk"},
|
||||
"qam16": {"num_bits_per_symbol": 4, "constellation_type": "qam"},
|
||||
"qam64": {"num_bits_per_symbol": 6, "constellation_type": "qam"},
|
||||
}
|
||||
settings = get_app_settings().dataset
|
||||
|
||||
mods = settings.modulation_settings
|
||||
|
||||
|
||||
def generate_modulated_signals(output_dir):
|
||||
settings = get_app_settings().dataset
|
||||
def generate_modulated_signals(output_dir: str) -> None:
|
||||
"""
|
||||
Generates modulated IQ signal recordings across specified modulation types and SNR values,
|
||||
adds AWGN noise, and saves the resulting samples as .npy files to the given output directory.
|
||||
|
||||
The function uses modulation parameters defined in app.yaml and supports modulation types like
|
||||
PSK and QAM through configurable constellation settings. The generated recordings are tagged
|
||||
with metadata such as modulation type, SNR, roll-off factor (beta), and samples-per-symbol (sps).
|
||||
|
||||
Parameters:
|
||||
output_dir (str): Path to the directory where .npy files will be saved.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
for modulation in settings.modulation_types:
|
||||
for snr in np.arange(settings.snr_start, settings.snr_end, settings.snr_step):
|
||||
for i in range(3):
|
||||
recording_length = settings.recording_length
|
||||
beta = settings.beta # the rolloff factor, can be changed to add variety
|
||||
sps = settings.sps # samples per symbol, or the relative bandwidth of the digital signal. Can also be changed.
|
||||
beta = (
|
||||
settings.beta
|
||||
) # the rolloff factor, can be changed to add variety
|
||||
sps = (
|
||||
settings.sps
|
||||
) # samples per symbol, or the relative bandwidth of the digital signal. Can also be changed.
|
||||
|
||||
# blocks don't directly take the string 'qpsk' so we use the dict 'mods' to get parameters
|
||||
constellation_type = mods[modulation]["constellation_type"]
|
||||
|
@ -64,17 +78,17 @@ def generate_modulated_signals(output_dir):
|
|||
|
||||
# view if you want
|
||||
# output_recording.view()
|
||||
|
||||
|
||||
# save to file
|
||||
output_recording.to_npy(path = output_dir) # optionally add path and filename parameters
|
||||
output_recording.to_npy(
|
||||
path=output_dir
|
||||
) # optionally add path and filename parameters
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser(description="Generate modulated signal .npy files")
|
||||
p.add_argument(
|
||||
"--output-dir",
|
||||
default=".",
|
||||
help="Folder where .npy files will be saved"
|
||||
"--output-dir", default=".", help="Folder where .npy files will be saved"
|
||||
)
|
||||
args = p.parse_args()
|
||||
generate_modulated_signals(args.output_dir)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import os, h5py, numpy as np
|
||||
from typing import List
|
||||
from utils.io import from_npy
|
||||
from split_dataset import split, split_recording
|
||||
from helpers.app_settings import get_app_settings
|
||||
from helpers.app_settings import DataSetConfig, get_app_settings
|
||||
|
||||
meta_dtype = np.dtype(
|
||||
[
|
||||
|
@ -23,7 +24,7 @@ info_dtype = np.dtype(
|
|||
)
|
||||
|
||||
|
||||
def write_hdf5_file(records, output_path, dataset_name="data"):
|
||||
def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data") -> str:
|
||||
"""
|
||||
Writes a list of records to an HDF5 file.
|
||||
Parameters:
|
||||
|
@ -52,7 +53,6 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
|
|||
data_arr = np.stack([rec[0] for rec in records])
|
||||
dset = hf.create_dataset(dataset_name, data=data_arr, compression="gzip")
|
||||
|
||||
|
||||
mg = hf.create_group("metadata")
|
||||
mg.create_dataset("metadata", data=meta_arr, compression="gzip")
|
||||
|
||||
|
@ -74,9 +74,15 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
|
|||
return output_path
|
||||
|
||||
|
||||
def complex_to_channel(data):
|
||||
def complex_to_channel(data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert complex-valued IQ data of shape (1, N) to 2-channel real array of shape (2, N).
|
||||
Converts complex-valued IQ data of shape (1, N) to a 2-channel real array of shape (2, N).
|
||||
|
||||
Parameters:
|
||||
data (np.ndarray): Complex-valued array of shape (1, N)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Real-valued array of shape (2, N) with separate real and imaginary channels
|
||||
"""
|
||||
assert np.iscomplexobj(data) # check if the data is in the form a+bi
|
||||
real = np.real(data[0]) # (N,)
|
||||
|
@ -85,14 +91,12 @@ def complex_to_channel(data):
|
|||
return stacked.astype(np.float32)
|
||||
|
||||
|
||||
def generate_datasets(cfg):
|
||||
def generate_datasets(cfg: DataSetConfig) -> tuple:
|
||||
"""
|
||||
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
|
||||
|
||||
Parameters:
|
||||
path_to_recordings (str): Path to the folder containing .npy files
|
||||
output_path (str): Path to the output HDF5 file
|
||||
dataset_name (str): Name of the dataset in the HDF5 file (default: "data")
|
||||
cfg (DataSetConfig): Dataset configuration loaded from app.yaml
|
||||
|
||||
Returns:
|
||||
dset (h5py.Dataset): The created dataset object
|
||||
|
@ -137,8 +141,24 @@ def generate_datasets(cfg):
|
|||
def main():
|
||||
settings = get_app_settings()
|
||||
dataset_cfg = settings.dataset
|
||||
|
||||
print("📦 Generating training and validation datasets...")
|
||||
print(f" ➤ Slicing each recording into {dataset_cfg.num_slices} snippets")
|
||||
print(f" ➤ Train/Val split: {int(dataset_cfg.train_split * 100)}% / {int((1 - dataset_cfg.train_split) * 100)}%")
|
||||
print(f" ➤ Output directory: data/dataset\n")
|
||||
|
||||
train_path, val_path = generate_datasets(dataset_cfg)
|
||||
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
|
||||
|
||||
# Count number of samples in each file
|
||||
with h5py.File(train_path, "r") as f:
|
||||
num_train = f["training_data"].shape[0]
|
||||
with h5py.File(val_path, "r") as f:
|
||||
num_val = f["validation_data"].shape[0]
|
||||
|
||||
print("✅ Dataset generation complete!")
|
||||
print(f" 🔹 Training samples saved to: {train_path} ({num_train} samples)")
|
||||
print(f" 🔸 Validation samples saved to: {val_path} ({num_val} samples)")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,17 +1,27 @@
|
|||
import random
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple, Dict
|
||||
import numpy as np
|
||||
|
||||
|
||||
def split(dataset, train_frac=0.8, seed=42, label_key="modulation"):
|
||||
def split(
|
||||
dataset: List[Tuple[np.ndarray, Dict[str, any]]],
|
||||
train_frac: float,
|
||||
seed: int,
|
||||
label_key: str = "modulation"
|
||||
) -> Tuple[List[Tuple[np.ndarray, Dict[str, any]]], List[Tuple[np.ndarray, Dict[str, any]]]]:
|
||||
"""
|
||||
Splits a dataset into smaller datasets based on the specified lengths.
|
||||
Splits a dataset of modulated IQ signal recordings into training and validation subsets.
|
||||
|
||||
Parameters:
|
||||
dataset (list): The dataset to be split.
|
||||
lengths (list): A list of lengths for each split.
|
||||
dataset (list): List of tuples where each tuple contains:
|
||||
- np.ndarray: 2xN real array (channels x samples)
|
||||
- dict: Metadata for the sample
|
||||
train_frac (float): Fraction of the dataset to use for training (default: 0.8)
|
||||
seed (int): Random seed for reproducibility (default: 42)
|
||||
label_key (str): Metadata key to group by during splitting (default: "modulation")
|
||||
|
||||
Returns:
|
||||
list: A list of split datasets.
|
||||
tuple: Two lists of (np.ndarray, dict) pairs — (train_records, val_records)
|
||||
"""
|
||||
rec_buckets = defaultdict(list)
|
||||
for data, md in dataset:
|
||||
|
@ -50,15 +60,29 @@ def split(dataset, train_frac=0.8, seed=42, label_key="modulation"):
|
|||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
def split_recording(recording_list, num_snippets):
|
||||
def split_recording(
|
||||
recording_list: List[Tuple[np.ndarray, Dict[str, any]]],
|
||||
num_snippets: int
|
||||
) -> List[Tuple[np.ndarray, Dict[str, any]]]:
|
||||
"""
|
||||
Splits a list of recordings into smaller chunks.
|
||||
Splits each full recording into a specified number of smaller snippets.
|
||||
|
||||
Each recording is a tuple of:
|
||||
- data (np.ndarray): A 2xN real-valued array representing I/Q signal data.
|
||||
- metadata (dict): Metadata describing the recording (e.g., modulation, SNR, etc.)
|
||||
|
||||
The split is typically done along the time axis (axis=1), dividing each (2, N)
|
||||
array into `num_snippets` contiguous chunks of shape (2, N // num_snippets).
|
||||
|
||||
Parameters:
|
||||
recording_list (list): List of recordings to be split
|
||||
recording_list (List[Tuple[np.ndarray, dict]]):
|
||||
List of (data, metadata) tuples to be split.
|
||||
num_snippets (int):
|
||||
Number of equal-length segments to divide each recording into.
|
||||
|
||||
Returns: yeah yeah
|
||||
list: List of split recordings
|
||||
Returns:
|
||||
List[Tuple[np.ndarray, dict]]:
|
||||
A flat list containing all resulting (snippet, metadata) pairs.
|
||||
"""
|
||||
snippet_list = []
|
||||
|
||||
|
|
|
@ -7,17 +7,20 @@ from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
|||
from helpers.app_settings import get_app_settings
|
||||
|
||||
|
||||
def convert_to_onnx(ckpt_path, fp16=False):
|
||||
|
||||
def convert_to_onnx(
|
||||
ckpt_path: str,
|
||||
fp16: bool=False
|
||||
) -> None :
|
||||
"""
|
||||
Convert a PyTorch model to ONNX format.
|
||||
|
||||
Parameters:
|
||||
model (torch.nn.Module): The PyTorch model to convert.
|
||||
input_shape (tuple): The shape of the input tensor.
|
||||
output_path (str): The path to save the converted ONNX model.
|
||||
fp16 (bool): 16 float point percision
|
||||
"""
|
||||
settings = get_app_settings()
|
||||
|
||||
|
||||
dataset_cfg = settings.dataset
|
||||
|
||||
in_channels = 2
|
||||
|
@ -32,9 +35,10 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
|||
in_chans=in_channels,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
checkpoint = torch.load(
|
||||
ckpt_path, weights_only=True, map_location=torch.device("cpu")
|
||||
ckpt_path, weights_only=True, map_location=device
|
||||
)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
|
|
|
@ -2,38 +2,79 @@ import onnxruntime as ort
|
|||
import numpy as np
|
||||
from helpers.app_settings import get_app_settings
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
|
||||
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
|
||||
# Set up session options
|
||||
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100, warmup_runs: int = 5) -> None:
|
||||
"""
|
||||
Profiles an ONNX model by running inference multiple times and collecting performance data.
|
||||
|
||||
Prints session initialization time, provider used, average inference time (excluding warm-up),
|
||||
and parses the ONNX Runtime JSON trace to show the most expensive operation.
|
||||
|
||||
Parameters:
|
||||
path_to_onnx (str): Path to the ONNX model file.
|
||||
num_runs (int): Number of inference runs (including warm-ups).
|
||||
warmup_runs (int): Number of warm-up runs to skip from timing.
|
||||
"""
|
||||
# Session setup
|
||||
options = ort.SessionOptions()
|
||||
options.enable_profiling = True
|
||||
|
||||
# Enables cleanup of QuantizeLinear/DequantizeLinear node pairs (optional optimization)
|
||||
options.add_session_config_entry("session.enable_quant_qdq_cleanup", "1")
|
||||
|
||||
# Set workload type for efficiency (low scheduling priority)
|
||||
options.add_session_config_entry("ep.dynamic.workload_type", "Efficient")
|
||||
|
||||
# Create inference session on CPU
|
||||
session = ort.InferenceSession(path_to_onnx, sess_options=options, providers=["CPUExecutionProvider"])
|
||||
# Try GPU, then fallback to CPU
|
||||
try:
|
||||
start_time = time.time()
|
||||
session = ort.InferenceSession(
|
||||
path_to_onnx, sess_options=options, providers=["CUDAExecutionProvider"]
|
||||
)
|
||||
print("Running on the GPU")
|
||||
except Exception as e:
|
||||
session = ort.InferenceSession(
|
||||
path_to_onnx, sess_options=options, providers=["CPUExecutionProvider"]
|
||||
)
|
||||
print("Could not find GPU, running on CPU")
|
||||
|
||||
end_time = time.time()
|
||||
print(f"[Timing] Model load + session init time: {end_time - start_time:.4f} sec")
|
||||
print("Session providers:", session.get_providers())
|
||||
|
||||
# Get model input details
|
||||
# Prepare dummy input
|
||||
input_name = session.get_inputs()[0].name
|
||||
input_shape = session.get_inputs()[0].shape
|
||||
|
||||
# Generate dummy input data
|
||||
# If model expects dynamic shape (None), replace with fixed size (e.g. batch 1)
|
||||
input_shape = [dim if isinstance(dim, int) and dim > 0 else 1 for dim in input_shape]
|
||||
input_shape = [
|
||||
dim if isinstance(dim, int) and dim > 0 else 1
|
||||
for dim in session.get_inputs()[0].shape
|
||||
]
|
||||
input_data = np.random.randn(*input_shape).astype(np.float32)
|
||||
|
||||
# Run inference multiple times to collect profiling data
|
||||
for _ in range(num_runs):
|
||||
# Time multiple inferences (skip warm-up)
|
||||
times = []
|
||||
for i in range(num_runs):
|
||||
t0 = time.time()
|
||||
session.run(None, {input_name: input_data})
|
||||
t1 = time.time()
|
||||
if i >= warmup_runs:
|
||||
times.append(t1 - t0)
|
||||
|
||||
# End profiling and get profile file path
|
||||
avg_time = sum(times) / len(times)
|
||||
print(f"[Timing] Avg inference time (excluding {warmup_runs} warm-ups): {avg_time:.6f} sec")
|
||||
|
||||
# End profiling & parse JSON
|
||||
profile_file = session.end_profiling()
|
||||
print(f"Profiling saved to: {profile_file}")
|
||||
print(f"[Output] Profiling trace saved to: {profile_file}")
|
||||
|
||||
try:
|
||||
with open(profile_file, "r") as f:
|
||||
trace = json.load(f)
|
||||
nodes = [e for e in trace if e.get("cat") == "Node"]
|
||||
print(f"[Profile] Number of nodes executed: {len(nodes)}")
|
||||
if nodes:
|
||||
top = max(nodes, key=lambda x: x.get("dur", 0))
|
||||
print(f"[Profile] Most expensive op: {top['name']} — {top['dur'] / 1e6:.3f} ms")
|
||||
except Exception as e:
|
||||
print(f"[Warning] Failed to parse profiling JSON: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
|
|
|
@ -15,9 +15,12 @@ command = [
|
|||
"-m",
|
||||
"onnxruntime.tools.convert_onnx_models_to_ort",
|
||||
input_path,
|
||||
"--output_dir", "ort_files",
|
||||
"--optimization_style", optimization_style,
|
||||
"--target_platform", target_platform
|
||||
"--output_dir",
|
||||
"ort_files",
|
||||
"--optimization_style",
|
||||
optimization_style,
|
||||
"--target_platform",
|
||||
target_platform,
|
||||
]
|
||||
|
||||
# Run it
|
||||
|
|
92
scripts/training/plot_data.py
Normal file
92
scripts/training/plot_data.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import h5py
|
||||
from sklearn.metrics import classification_report
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
||||
from helpers.app_settings import get_app_settings
|
||||
from cm_plotter import plot_confusion_matrix
|
||||
|
||||
|
||||
def load_validation_data(h5_path:str ="data/datasets/validation.h5"):
|
||||
"""
|
||||
Loads validation data from an HDF5 file.
|
||||
|
||||
Returns:
|
||||
X_val: np.ndarray of shape (N, C, L)
|
||||
y_val: np.ndarray of shape (N,)
|
||||
class_names: list of class names
|
||||
"""
|
||||
with h5py.File(h5_path, "r") as f:
|
||||
X = f["X"][:] # shape: (N, C, L)
|
||||
y = f["y"][:] # shape: (N,)
|
||||
if "class_names" in f:
|
||||
class_names = [s.decode("utf-8") for s in f["class_names"][:]]
|
||||
else:
|
||||
class_names = [str(i) for i in np.unique(y)]
|
||||
return X, y, class_names
|
||||
|
||||
|
||||
def build_model_from_ckpt(ckpt_path: str, in_channels: int, num_classes: int) -> torch.nn.Module:
|
||||
"""
|
||||
Build and return a PyTorch model loaded from a checkpoint.
|
||||
"""
|
||||
model = RFClassifier(
|
||||
model=mobilenetv3(
|
||||
model_size="mobilenetv3_small_050",
|
||||
num_classes=num_classes,
|
||||
in_chans=in_channels
|
||||
)
|
||||
)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
checkpoint = torch.load(
|
||||
ckpt_path, weights_only=True, map_location=device
|
||||
)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def evaluate_checkpoint(ckpt_path: str):
|
||||
"""
|
||||
Loads the model from checkpoint and evaluates it on a validation set.
|
||||
Prints classification metrics and plots a confusion matrix.
|
||||
"""
|
||||
# Load validation data
|
||||
X_val, y_true, class_names = load_validation_data()
|
||||
num_classes = len(class_names)
|
||||
in_channels = X_val.shape[1]
|
||||
|
||||
# Load model
|
||||
model = build_model_from_ckpt(ckpt_path, in_channels=in_channels, num_classes=num_classes)
|
||||
|
||||
# Inference
|
||||
y_pred = []
|
||||
with torch.no_grad():
|
||||
for x in X_val:
|
||||
x_tensor = torch.tensor(x[np.newaxis, ...], dtype=torch.float32)
|
||||
logits = model(x_tensor)
|
||||
pred = torch.argmax(logits, dim=1).item()
|
||||
y_pred.append(pred)
|
||||
|
||||
# Print classification report
|
||||
print("\nClassification Report:")
|
||||
print(classification_report(y_true, y_pred, target_names=class_names))
|
||||
|
||||
# Plot confusion matrix
|
||||
plot_confusion_matrix(
|
||||
y_true=np.array(y_true),
|
||||
y_pred=np.array(y_pred),
|
||||
classes=class_names,
|
||||
normalize=True,
|
||||
title="Normalized Confusion Matrix"
|
||||
)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
ckpt_path = os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
||||
evaluate_checkpoint(ckpt_path)
|
|
@ -1,4 +1,5 @@
|
|||
import sys, os
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
import lightning as L
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
@ -8,6 +9,7 @@ import torchmetrics
|
|||
from helpers.app_settings import get_app_settings
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
import mobilenetv3
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||
|
@ -15,6 +17,7 @@ if project_root not in sys.path:
|
|||
sys.path.insert(0, project_root)
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
|
||||
class CustomProgressBar(TQDMProgressBar):
|
||||
def __init__(self):
|
||||
super().__init__(refresh_rate=128) # update every batch
|
||||
|
@ -27,7 +30,6 @@ def train_model():
|
|||
batch_size = training_cfg.batch_size
|
||||
epochs = training_cfg.epochs
|
||||
|
||||
|
||||
train_data = "data/dataset/train.h5"
|
||||
val_data = "data/dataset/val.h5"
|
||||
|
||||
|
@ -115,7 +117,6 @@ def train_model():
|
|||
drop_path_rate=hparams["drop_path_rate"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath="checkpoint_files",
|
||||
|
|
Loading…
Reference in New Issue
Block a user