forked from qoherent/modrec-workflow
changed the training to only do 5 epochs, added in a convert to ORT at the end of worklflow
This commit is contained in:
parent
7e255a704d
commit
015f6d4db5
|
@ -79,7 +79,7 @@ if __name__ == "__main__":
|
||||||
convert_to_onnx(
|
convert_to_onnx(
|
||||||
ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False
|
ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False
|
||||||
)
|
)
|
||||||
|
|
||||||
output_file = "inference_recognition_model.onnx"
|
output_file = "inference_recognition_model.onnx"
|
||||||
|
|
||||||
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file))
|
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, output_file))
|
||||||
|
|
|
@ -27,17 +27,18 @@ from modulation_dataset import ModulationH5Dataset
|
||||||
|
|
||||||
import mobilenetv3
|
import mobilenetv3
|
||||||
|
|
||||||
|
|
||||||
class CleanProgressCallback(Callback):
|
class CleanProgressCallback(Callback):
|
||||||
"""Clean progress callback that only shows epoch summaries"""
|
"""Clean progress callback that only shows epoch summaries"""
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
epoch = trainer.current_epoch + 1
|
epoch = trainer.current_epoch + 1
|
||||||
|
|
||||||
# Get metrics
|
# Get metrics
|
||||||
train_loss = trainer.callback_metrics.get("train_loss")
|
train_loss = trainer.callback_metrics.get("train_loss")
|
||||||
val_loss = trainer.callback_metrics.get("val_loss")
|
val_loss = trainer.callback_metrics.get("val_loss")
|
||||||
val_acc = trainer.callback_metrics.get("val_acc")
|
val_acc = trainer.callback_metrics.get("val_acc")
|
||||||
|
|
||||||
# Print clean output
|
# Print clean output
|
||||||
print(f"Epoch {epoch}:")
|
print(f"Epoch {epoch}:")
|
||||||
if train_loss is not None:
|
if train_loss is not None:
|
||||||
|
@ -160,9 +161,9 @@ def train_model():
|
||||||
mode="max",
|
mode="max",
|
||||||
enable_version_counter=False,
|
enable_version_counter=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
clean_progress = CleanProgressCallback()
|
clean_progress = CleanProgressCallback()
|
||||||
|
|
||||||
trainer = L.Trainer(
|
trainer = L.Trainer(
|
||||||
max_epochs=epochs,
|
max_epochs=epochs,
|
||||||
callbacks=[checkpoint_callback, clean_progress],
|
callbacks=[checkpoint_callback, clean_progress],
|
||||||
|
@ -180,4 +181,4 @@ def train_model():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_model()
|
train_model()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user