fixed paths

This commit is contained in:
Liyu Xiao 2025-06-13 15:03:42 -04:00
parent 9326505fca
commit 1dc5383162
5 changed files with 25 additions and 12 deletions

17
.gitignore vendored
View File

@ -1,5 +1,3 @@
.venv/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.ckpt *.ckpt
@ -7,3 +5,18 @@ __pycache__/
*.onnx *.onnx
*.json *.json
*.h5 *.h5
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# PyCharm
.idea/
# Visual Studio Code
.vscode/

View File

@ -68,6 +68,7 @@ jobs:
NO_NNPACK: 1 NO_NNPACK: 1
PYTORCH_NO_NNPACK: 1 PYTORCH_NO_NNPACK: 1
run: | run: |
mkdir -p checkpoint_files
PYTHONPATH=. python scripts/training/train.py PYTHONPATH=. python scripts/training/train.py
echo "training model" echo "training model"
@ -79,6 +80,8 @@ jobs:
- name: 4. Convert to ONNX file - name: 4. Convert to ONNX file
run: | run: |
mkdir -p onnx_scripts
mkdir -p onnx_files
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python onnx_scripts/convert_to_onnx.py MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python onnx_scripts/convert_to_onnx.py
echo "building inference app" echo "building inference app"

View File

@ -1,7 +1,6 @@
import onnxruntime as ort import onnxruntime as ort
import numpy as np import numpy as np
from helpers.app_settings import get_app_settings from helpers.app_settings import get_app_settings
from onnx_files import ONNX_DIR
import os import os
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100): def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
@ -38,5 +37,5 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
if __name__ == "__main__": if __name__ == "__main__":
settings = get_app_settings() settings = get_app_settings()
output_path = os.path.join(ONNX_DIR, f"{settings.inference.onnx_model_filename}.onnx") output_path = os.path.join("onnx_files", "convert_to_onnx.onnx")
profile_onnx_model(output_path) profile_onnx_model(output_path)

View File

@ -15,7 +15,7 @@ mods = {
def generate_modulated_signals(output_dir): def generate_modulated_signals(output_dir):
for modulation in ["bpsk", "qpsk", "qam16", "qam64"]: for modulation in ["bpsk", "qpsk", "qam16", "qam64"]:
for snr in np.arange(-6, 13, 3): for snr in np.arange(-6, 13, 3):
for i in range(100): for i in range(3):
recording_length = 1024 recording_length = 1024
beta = 0.3 # the rolloff factor, can be changed to add variety beta = 0.3 # the rolloff factor, can be changed to add variety
sps = 4 # samples per symbol, or the relative bandwidth of the digital signal. Can also be changed. sps = 4 # samples per symbol, or the relative bandwidth of the digital signal. Can also be changed.

View File

@ -29,11 +29,9 @@ def train_model():
batch_size = training_cfg.batch_size batch_size = training_cfg.batch_size
epochs = training_cfg.epochs epochs = training_cfg.epochs
checkpoint_dir = training_cfg.checkpoint_dir
checkpoint_filename = training_cfg.checkpoint_filename
train_data = f"{dataset_cfg.output_dir}/train.h5" train_data = "data/dataset/train.h5"
val_data = f"{dataset_cfg.output_dir}/val.h5" val_data = "data/dataset/val.h5"
label = "modulation" label = "modulation"
@ -122,8 +120,8 @@ def train_model():
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
dirpath=checkpoint_dir, dirpath="checkpoint_files",
filename=checkpoint_filename, filename="inference_recognition_model.ckpt",
save_top_k=1, save_top_k=1,
verbose=True, # Disable checkpoint verbose output verbose=True, # Disable checkpoint verbose output
monitor="val_acc", monitor="val_acc",