diff --git a/.riahub/workflows/workflow.yaml b/.riahub/workflows/workflow.yaml index bc0f304..2304f4f 100644 --- a/.riahub/workflows/workflow.yaml +++ b/.riahub/workflows/workflow.yaml @@ -37,12 +37,19 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt - # - name: 1. Build HDF5 Dataset - # run: | - # mkdir -p data/dataset - # PYTHONPATH=. python data/scripts/produce_dataset.py - # echo "datasets produced successfully" - # shell: bash + - name: 1. Build HDF5 Dataset + run: | + mkdir -p data/dataset + PYTHONPATH=. python data/scripts/produce_dataset.py + echo "datasets produced successfully" + shell: bash + + - name: Upload Dataset Artifacts + uses: actions/upload-artifact@v3 + with: + name: ria-dataset + path: data/dataset/** + - name: 2. Train Model env: @@ -52,26 +59,32 @@ jobs: PYTHONPATH=. python data/training/train.py 2>/dev/null echo "training model" - - name: 3. Build inference app - run: | - PYTHONPATH=. python convert_to_onnx.py - echo "building inference app" - - - name: Upload Dataset Artifacts - uses: actions/upload-artifact@v3 - with: - name: ria-dataset - path: data/dataset/** - - name: Upload Checkpoints uses: actions/upload-artifact@v3 with: name: ria-checkpoints path: checkpoint_files/inference_recognition_model.ckpt - - - name: Upload Inference App + - name: 3. Convert to ONNX file + run: | + PYTHONPATH=. python convert_to_onnx.py + echo "building inference app" + + - name: Upload ONNX file uses: actions/upload-artifact@v3 with: - name: ria-demo-app - path: onnx_files/inference_recognition_model.onnx \ No newline at end of file + name: ria-demo-onnx + path: onnx_files/inference_recognition_model.onnx + + - name: 4. Convert to ORT file + run: | + python -m onnxruntime.tools.convert_onnx_models_to_ort \ + --input /onnx_files/inference_recognition_model.onnx \ + --output /ort_files/inference_recognition_model.ort \ + + - name: Upload ORT file + uses: actions/upload-artifact@v3 + with: + name: ria-demo-ort + path: ort_files/inference_recognition_model.ort + diff --git a/conf/app.yaml b/conf/app.yaml index 01f14df..224c76e 100644 --- a/conf/app.yaml +++ b/conf/app.yaml @@ -12,15 +12,16 @@ dataset: output_dir: data/dataset training: - batch_size: 64 - epochs: 50 - learning_rate: 0.001 + batch_size: 256 + epochs: 5 + learning_rate: 1e-4 checkpoint_dir: checkpoint_files checkpoint_filename: inference_recognition_model use_gpu: true inference: num_classes: 4 + onnx_model_filename: inference_recognition_model app: build_dir: dist \ No newline at end of file diff --git a/convert_to_onnx.py b/convert_to_onnx.py index 59b1aff..4dfa4f5 100644 --- a/convert_to_onnx.py +++ b/convert_to_onnx.py @@ -72,7 +72,9 @@ def convert_to_onnx(ckpt_path, fp16=False): if __name__ == "__main__": from checkpoint_files import CHECKPOINTS_DIR - model_checkpoint = "inference_recognition_model.ckpt" + settings = get_app_settings() + + model_checkpoint = settings.training.checkpoint_filename + ".ckpt" print("Converting to ONNX...") @@ -80,6 +82,6 @@ if __name__ == "__main__": ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False ) - output_file = "inference_recognition_model.onnx" + output_file = settings.inference.onnx_model_filename + ".onnx" print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file)) diff --git a/data/training/train.py b/data/training/train.py index ef8b863..f2106ab 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -1,31 +1,18 @@ import sys, os - os.environ["NNPACK"] = "0" - -script_dir = os.path.dirname(os.path.abspath(__file__)) -data_dir = os.path.abspath(os.path.join(script_dir, "..")) -project_root = os.path.abspath(os.path.join(script_dir, "../..")) - -if project_root not in sys.path: - sys.path.insert(0, project_root) - -from helpers.app_settings import get_app_settings - -project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) -if project_root not in sys.path: - sys.path.insert(0, project_root) - import lightning as L from lightning.pytorch.callbacks import ModelCheckpoint, Callback - import torch import torch.nn.functional as F import torchmetrics - from helpers.app_settings import get_app_settings from modulation_dataset import ModulationH5Dataset - import mobilenetv3 +script_dir = os.path.dirname(os.path.abspath(__file__)) +data_dir = os.path.abspath(os.path.join(script_dir, "..")) +project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) class CleanProgressCallback(Callback): @@ -56,8 +43,8 @@ def train_model(): dataset_cfg = settings.dataset train_flag = True - batch_size = 128 - epochs = 50 + batch_size = training_cfg.batch_size + epochs = training_cfg.epochs checkpoint_dir = training_cfg.checkpoint_dir checkpoint_filename = training_cfg.checkpoint_filename @@ -98,7 +85,7 @@ def train_model(): hparams = { "drop_path_rate": 0.2, "drop_rate": 0.5, - "learning_rate": 1e-4, + "learning_rate": training_cfg.learning_rate, "wd": 0.01, } diff --git a/helpers/app_settings.py b/helpers/app_settings.py index fa5b842..a53a9d2 100644 --- a/helpers/app_settings.py +++ b/helpers/app_settings.py @@ -34,6 +34,7 @@ class TrainingConfig: @dataclass class InferenceConfig: num_classes: int + onnx_model_filename: str @dataclass