forked from qoherent/modrec-workflow
added in final changes to now do 5 epochs, and added in conversion to ORT at the end of workflow
This commit is contained in:
parent
015f6d4db5
commit
489fecf113
|
@ -37,12 +37,19 @@ jobs:
|
|||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
# - name: 1. Build HDF5 Dataset
|
||||
# run: |
|
||||
# mkdir -p data/dataset
|
||||
# PYTHONPATH=. python data/scripts/produce_dataset.py
|
||||
# echo "datasets produced successfully"
|
||||
# shell: bash
|
||||
- name: 1. Build HDF5 Dataset
|
||||
run: |
|
||||
mkdir -p data/dataset
|
||||
PYTHONPATH=. python data/scripts/produce_dataset.py
|
||||
echo "datasets produced successfully"
|
||||
shell: bash
|
||||
|
||||
- name: Upload Dataset Artifacts
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ria-dataset
|
||||
path: data/dataset/**
|
||||
|
||||
|
||||
- name: 2. Train Model
|
||||
env:
|
||||
|
@ -52,26 +59,32 @@ jobs:
|
|||
PYTHONPATH=. python data/training/train.py 2>/dev/null
|
||||
echo "training model"
|
||||
|
||||
- name: 3. Build inference app
|
||||
run: |
|
||||
PYTHONPATH=. python convert_to_onnx.py
|
||||
echo "building inference app"
|
||||
|
||||
- name: Upload Dataset Artifacts
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ria-dataset
|
||||
path: data/dataset/**
|
||||
|
||||
- name: Upload Checkpoints
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ria-checkpoints
|
||||
path: checkpoint_files/inference_recognition_model.ckpt
|
||||
|
||||
- name: 3. Convert to ONNX file
|
||||
run: |
|
||||
PYTHONPATH=. python convert_to_onnx.py
|
||||
echo "building inference app"
|
||||
|
||||
- name: Upload Inference App
|
||||
- name: Upload ONNX file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ria-demo-app
|
||||
name: ria-demo-onnx
|
||||
path: onnx_files/inference_recognition_model.onnx
|
||||
|
||||
- name: 4. Convert to ORT file
|
||||
run: |
|
||||
python -m onnxruntime.tools.convert_onnx_models_to_ort \
|
||||
--input /onnx_files/inference_recognition_model.onnx \
|
||||
--output /ort_files/inference_recognition_model.ort \
|
||||
|
||||
- name: Upload ORT file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ria-demo-ort
|
||||
path: ort_files/inference_recognition_model.ort
|
||||
|
||||
|
|
|
@ -12,15 +12,16 @@ dataset:
|
|||
output_dir: data/dataset
|
||||
|
||||
training:
|
||||
batch_size: 64
|
||||
epochs: 50
|
||||
learning_rate: 0.001
|
||||
batch_size: 256
|
||||
epochs: 5
|
||||
learning_rate: 1e-4
|
||||
checkpoint_dir: checkpoint_files
|
||||
checkpoint_filename: inference_recognition_model
|
||||
use_gpu: true
|
||||
|
||||
inference:
|
||||
num_classes: 4
|
||||
onnx_model_filename: inference_recognition_model
|
||||
|
||||
app:
|
||||
build_dir: dist
|
|
@ -72,7 +72,9 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
|||
if __name__ == "__main__":
|
||||
from checkpoint_files import CHECKPOINTS_DIR
|
||||
|
||||
model_checkpoint = "inference_recognition_model.ckpt"
|
||||
settings = get_app_settings()
|
||||
|
||||
model_checkpoint = settings.training.checkpoint_filename + ".ckpt"
|
||||
|
||||
print("Converting to ONNX...")
|
||||
|
||||
|
@ -80,6 +82,6 @@ if __name__ == "__main__":
|
|||
ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False
|
||||
)
|
||||
|
||||
output_file = "inference_recognition_model.onnx"
|
||||
output_file = settings.inference.onnx_model_filename + ".onnx"
|
||||
|
||||
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file))
|
||||
|
|
|
@ -1,31 +1,18 @@
|
|||
import sys, os
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
|
||||
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
import lightning as L
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
|
||||
import mobilenetv3
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
class CleanProgressCallback(Callback):
|
||||
|
@ -56,8 +43,8 @@ def train_model():
|
|||
dataset_cfg = settings.dataset
|
||||
|
||||
train_flag = True
|
||||
batch_size = 128
|
||||
epochs = 50
|
||||
batch_size = training_cfg.batch_size
|
||||
epochs = training_cfg.epochs
|
||||
|
||||
checkpoint_dir = training_cfg.checkpoint_dir
|
||||
checkpoint_filename = training_cfg.checkpoint_filename
|
||||
|
@ -98,7 +85,7 @@ def train_model():
|
|||
hparams = {
|
||||
"drop_path_rate": 0.2,
|
||||
"drop_rate": 0.5,
|
||||
"learning_rate": 1e-4,
|
||||
"learning_rate": training_cfg.learning_rate,
|
||||
"wd": 0.01,
|
||||
}
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ class TrainingConfig:
|
|||
@dataclass
|
||||
class InferenceConfig:
|
||||
num_classes: int
|
||||
onnx_model_filename: str
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
Loading…
Reference in New Issue
Block a user