From 015f6d4db5263ba150f4da477190f71433606f94 Mon Sep 17 00:00:00 2001 From: Liyu Xiao Date: Mon, 26 May 2025 09:43:42 -0400 Subject: [PATCH] changed the training to only do 5 epochs, added in a convert to ORT at the end of worklflow --- convert_to_onnx.py | 4 ++-- data/training/train.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/convert_to_onnx.py b/convert_to_onnx.py index b7d2e9f..59b1aff 100644 --- a/convert_to_onnx.py +++ b/convert_to_onnx.py @@ -79,7 +79,7 @@ if __name__ == "__main__": convert_to_onnx( ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False ) - + output_file = "inference_recognition_model.onnx" - + print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file)) diff --git a/data/training/train.py b/data/training/train.py index 6f7c1ce..ef8b863 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -27,17 +27,18 @@ from modulation_dataset import ModulationH5Dataset import mobilenetv3 + class CleanProgressCallback(Callback): """Clean progress callback that only shows epoch summaries""" - + def on_train_epoch_end(self, trainer, pl_module): epoch = trainer.current_epoch + 1 - + # Get metrics train_loss = trainer.callback_metrics.get("train_loss") val_loss = trainer.callback_metrics.get("val_loss") val_acc = trainer.callback_metrics.get("val_acc") - + # Print clean output print(f"Epoch {epoch}:") if train_loss is not None: @@ -160,9 +161,9 @@ def train_model(): mode="max", enable_version_counter=False, ) - + clean_progress = CleanProgressCallback() - + trainer = L.Trainer( max_epochs=epochs, callbacks=[checkpoint_callback, clean_progress], @@ -180,4 +181,4 @@ def train_model(): if __name__ == "__main__": - train_model() \ No newline at end of file + train_model()