forked from qoherent/modrec-workflow
fixed paths
This commit is contained in:
parent
9326505fca
commit
1dc5383162
19
.gitignore
vendored
19
.gitignore
vendored
|
@ -1,9 +1,22 @@
|
|||
.venv/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.ckpt
|
||||
*.ipynb
|
||||
*.onnx
|
||||
*.json
|
||||
*.h5
|
||||
*.h5
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# Visual Studio Code
|
||||
.vscode/
|
|
@ -68,6 +68,7 @@ jobs:
|
|||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
run: |
|
||||
mkdir -p checkpoint_files
|
||||
PYTHONPATH=. python scripts/training/train.py
|
||||
echo "training model"
|
||||
|
||||
|
@ -79,6 +80,8 @@ jobs:
|
|||
|
||||
- name: 4. Convert to ONNX file
|
||||
run: |
|
||||
mkdir -p onnx_scripts
|
||||
mkdir -p onnx_files
|
||||
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python onnx_scripts/convert_to_onnx.py
|
||||
echo "building inference app"
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
from helpers.app_settings import get_app_settings
|
||||
from onnx_files import ONNX_DIR
|
||||
import os
|
||||
|
||||
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
|
||||
|
@ -38,5 +37,5 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
|
|||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
output_path = os.path.join(ONNX_DIR, f"{settings.inference.onnx_model_filename}.onnx")
|
||||
output_path = os.path.join("onnx_files", "convert_to_onnx.onnx")
|
||||
profile_onnx_model(output_path)
|
||||
|
|
|
@ -15,7 +15,7 @@ mods = {
|
|||
def generate_modulated_signals(output_dir):
|
||||
for modulation in ["bpsk", "qpsk", "qam16", "qam64"]:
|
||||
for snr in np.arange(-6, 13, 3):
|
||||
for i in range(100):
|
||||
for i in range(3):
|
||||
recording_length = 1024
|
||||
beta = 0.3 # the rolloff factor, can be changed to add variety
|
||||
sps = 4 # samples per symbol, or the relative bandwidth of the digital signal. Can also be changed.
|
||||
|
|
|
@ -29,11 +29,9 @@ def train_model():
|
|||
batch_size = training_cfg.batch_size
|
||||
epochs = training_cfg.epochs
|
||||
|
||||
checkpoint_dir = training_cfg.checkpoint_dir
|
||||
checkpoint_filename = training_cfg.checkpoint_filename
|
||||
|
||||
train_data = f"{dataset_cfg.output_dir}/train.h5"
|
||||
val_data = f"{dataset_cfg.output_dir}/val.h5"
|
||||
train_data = "data/dataset/train.h5"
|
||||
val_data = "data/dataset/val.h5"
|
||||
|
||||
label = "modulation"
|
||||
|
||||
|
@ -122,8 +120,8 @@ def train_model():
|
|||
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=checkpoint_dir,
|
||||
filename=checkpoint_filename,
|
||||
dirpath="checkpoint_files",
|
||||
filename="inference_recognition_model.ckpt",
|
||||
save_top_k=1,
|
||||
verbose=True, # Disable checkpoint verbose output
|
||||
monitor="val_acc",
|
||||
|
|
Loading…
Reference in New Issue
Block a user