changed the training to only do 5 epochs, added in a convert to ORT at the end of worklflow

This commit is contained in:
Liyu Xiao 2025-05-26 09:43:42 -04:00
parent 7e255a704d
commit 015f6d4db5
2 changed files with 9 additions and 8 deletions

View File

@ -79,7 +79,7 @@ if __name__ == "__main__":
convert_to_onnx( convert_to_onnx(
ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False
) )
output_file = "inference_recognition_model.onnx" output_file = "inference_recognition_model.onnx"
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file)) print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file))

View File

@ -27,17 +27,18 @@ from modulation_dataset import ModulationH5Dataset
import mobilenetv3 import mobilenetv3
class CleanProgressCallback(Callback): class CleanProgressCallback(Callback):
"""Clean progress callback that only shows epoch summaries""" """Clean progress callback that only shows epoch summaries"""
def on_train_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch + 1 epoch = trainer.current_epoch + 1
# Get metrics # Get metrics
train_loss = trainer.callback_metrics.get("train_loss") train_loss = trainer.callback_metrics.get("train_loss")
val_loss = trainer.callback_metrics.get("val_loss") val_loss = trainer.callback_metrics.get("val_loss")
val_acc = trainer.callback_metrics.get("val_acc") val_acc = trainer.callback_metrics.get("val_acc")
# Print clean output # Print clean output
print(f"Epoch {epoch}:") print(f"Epoch {epoch}:")
if train_loss is not None: if train_loss is not None:
@ -160,9 +161,9 @@ def train_model():
mode="max", mode="max",
enable_version_counter=False, enable_version_counter=False,
) )
clean_progress = CleanProgressCallback() clean_progress = CleanProgressCallback()
trainer = L.Trainer( trainer = L.Trainer(
max_epochs=epochs, max_epochs=epochs,
callbacks=[checkpoint_callback, clean_progress], callbacks=[checkpoint_callback, clean_progress],
@ -180,4 +181,4 @@ def train_model():
if __name__ == "__main__": if __name__ == "__main__":
train_model() train_model()