forked from qoherent/modrec-workflow
fixed paths
This commit is contained in:
parent
9326505fca
commit
1dc5383162
19
.gitignore
vendored
19
.gitignore
vendored
|
@ -1,9 +1,22 @@
|
||||||
.venv/
|
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.ckpt
|
*.ckpt
|
||||||
*.ipynb
|
*.ipynb
|
||||||
*.onnx
|
*.onnx
|
||||||
*.json
|
*.json
|
||||||
*.h5
|
*.h5
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Visual Studio Code
|
||||||
|
.vscode/
|
|
@ -68,6 +68,7 @@ jobs:
|
||||||
NO_NNPACK: 1
|
NO_NNPACK: 1
|
||||||
PYTORCH_NO_NNPACK: 1
|
PYTORCH_NO_NNPACK: 1
|
||||||
run: |
|
run: |
|
||||||
|
mkdir -p checkpoint_files
|
||||||
PYTHONPATH=. python scripts/training/train.py
|
PYTHONPATH=. python scripts/training/train.py
|
||||||
echo "training model"
|
echo "training model"
|
||||||
|
|
||||||
|
@ -79,6 +80,8 @@ jobs:
|
||||||
|
|
||||||
- name: 4. Convert to ONNX file
|
- name: 4. Convert to ONNX file
|
||||||
run: |
|
run: |
|
||||||
|
mkdir -p onnx_scripts
|
||||||
|
mkdir -p onnx_files
|
||||||
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python onnx_scripts/convert_to_onnx.py
|
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python onnx_scripts/convert_to_onnx.py
|
||||||
echo "building inference app"
|
echo "building inference app"
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
from onnx_files import ONNX_DIR
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
|
def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
|
||||||
|
@ -38,5 +37,5 @@ def profile_onnx_model(path_to_onnx: str, num_runs: int = 100):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
output_path = os.path.join(ONNX_DIR, f"{settings.inference.onnx_model_filename}.onnx")
|
output_path = os.path.join("onnx_files", "convert_to_onnx.onnx")
|
||||||
profile_onnx_model(output_path)
|
profile_onnx_model(output_path)
|
||||||
|
|
|
@ -15,7 +15,7 @@ mods = {
|
||||||
def generate_modulated_signals(output_dir):
|
def generate_modulated_signals(output_dir):
|
||||||
for modulation in ["bpsk", "qpsk", "qam16", "qam64"]:
|
for modulation in ["bpsk", "qpsk", "qam16", "qam64"]:
|
||||||
for snr in np.arange(-6, 13, 3):
|
for snr in np.arange(-6, 13, 3):
|
||||||
for i in range(100):
|
for i in range(3):
|
||||||
recording_length = 1024
|
recording_length = 1024
|
||||||
beta = 0.3 # the rolloff factor, can be changed to add variety
|
beta = 0.3 # the rolloff factor, can be changed to add variety
|
||||||
sps = 4 # samples per symbol, or the relative bandwidth of the digital signal. Can also be changed.
|
sps = 4 # samples per symbol, or the relative bandwidth of the digital signal. Can also be changed.
|
||||||
|
|
|
@ -29,11 +29,9 @@ def train_model():
|
||||||
batch_size = training_cfg.batch_size
|
batch_size = training_cfg.batch_size
|
||||||
epochs = training_cfg.epochs
|
epochs = training_cfg.epochs
|
||||||
|
|
||||||
checkpoint_dir = training_cfg.checkpoint_dir
|
|
||||||
checkpoint_filename = training_cfg.checkpoint_filename
|
|
||||||
|
|
||||||
train_data = f"{dataset_cfg.output_dir}/train.h5"
|
train_data = "data/dataset/train.h5"
|
||||||
val_data = f"{dataset_cfg.output_dir}/val.h5"
|
val_data = "data/dataset/val.h5"
|
||||||
|
|
||||||
label = "modulation"
|
label = "modulation"
|
||||||
|
|
||||||
|
@ -122,8 +120,8 @@ def train_model():
|
||||||
|
|
||||||
|
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
dirpath=checkpoint_dir,
|
dirpath="checkpoint_files",
|
||||||
filename=checkpoint_filename,
|
filename="inference_recognition_model.ckpt",
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
verbose=True, # Disable checkpoint verbose output
|
verbose=True, # Disable checkpoint verbose output
|
||||||
monitor="val_acc",
|
monitor="val_acc",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user