From 7e255a704dbc18ed864d47cc6ee8336d036507e2 Mon Sep 17 00:00:00 2001 From: Liyu Xiao Date: Fri, 23 May 2025 10:47:27 -0400 Subject: [PATCH] fixed the training bar, so it only shows final epoch results. --- data/training/train.py | 60 +++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/data/training/train.py b/data/training/train.py index ac790a8..6f7c1ce 100644 --- a/data/training/train.py +++ b/data/training/train.py @@ -2,7 +2,6 @@ import sys, os os.environ["NNPACK"] = "0" - script_dir = os.path.dirname(os.path.abspath(__file__)) data_dir = os.path.abspath(os.path.join(script_dir, "..")) project_root = os.path.abspath(os.path.join(script_dir, "../..")) @@ -17,7 +16,7 @@ if project_root not in sys.path: sys.path.insert(0, project_root) import lightning as L -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import ModelCheckpoint, Callback import torch import torch.nn.functional as F @@ -28,27 +27,25 @@ from modulation_dataset import ModulationH5Dataset import mobilenetv3 -from lightning.pytorch.callbacks import TQDMProgressBar - -class EpochOnlyProgressBar(TQDMProgressBar): - def __init__(self): - super().__init__() - self.enable = True - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - # Skip batch-level updates - pass - +class CleanProgressCallback(Callback): + """Clean progress callback that only shows epoch summaries""" + def on_train_epoch_end(self, trainer, pl_module): epoch = trainer.current_epoch + 1 + + # Get metrics train_loss = trainer.callback_metrics.get("train_loss") - val_loss = trainer.callback_metrics.get("val_loss") + val_loss = trainer.callback_metrics.get("val_loss") val_acc = trainer.callback_metrics.get("val_acc") - print(f"\nEpoch {epoch}:") - if train_loss: print(f" Train Loss: {train_loss:.4f}") - if val_loss: print(f" Val Loss: {val_loss:.4f}") - if val_acc: print(f" Val Acc: {val_acc:.4f}") + # Print clean output + print(f"Epoch {epoch}:") + if train_loss is not None: + print(f" Train Loss: {train_loss:.4f}") + if val_loss is not None: + print(f" Val Loss: {val_loss:.4f}") + if val_acc is not None: + print(f" Val Acc: {val_acc:.4f}") print("-" * 30) @@ -64,12 +61,8 @@ def train_model(): checkpoint_dir = training_cfg.checkpoint_dir checkpoint_filename = training_cfg.checkpoint_filename - train_data = ( - f"{dataset_cfg.output_dir}/train.h5" - ) - val_data = ( - f"{dataset_cfg.output_dir}/val.h5" - ) + train_data = f"{dataset_cfg.output_dir}/train.h5" + val_data = f"{dataset_cfg.output_dir}/val.h5" dataset_name = "Modulation Inference - Initial Model" metadata_names = "Modulation" @@ -138,7 +131,7 @@ def train_model(): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) - self.log("train_loss", loss, on_epoch=True, prog_bar=True) + self.log("train_loss", loss, on_epoch=True, prog_bar=False) return loss def validation_step(self, batch, batch_idx): @@ -146,8 +139,8 @@ def train_model(): y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.accuracy(y_hat, y) - self.log("val_loss", loss, prog_bar=True) - self.log("val_acc", self.accuracy, prog_bar=True) + self.log("val_loss", loss, prog_bar=False) + self.log("val_acc", self.accuracy, prog_bar=False) model = RFClassifier( mobilenetv3.mobilenetv3( @@ -161,24 +154,25 @@ def train_model(): checkpoint_callback = ModelCheckpoint( dirpath=checkpoint_dir, filename=checkpoint_filename, - save_top_k=True, - verbose=True, + save_top_k=1, + verbose=False, # Disable checkpoint verbose output monitor="val_acc", mode="max", enable_version_counter=False, ) - progress_bar = EpochOnlyProgressBar() + clean_progress = CleanProgressCallback() trainer = L.Trainer( max_epochs=epochs, - callbacks=[checkpoint_callback,progress_bar], + callbacks=[checkpoint_callback, clean_progress], accelerator="gpu", devices=1, benchmark=True, precision="bf16-mixed", logger=False, - enable_progress_bar=True, + enable_progress_bar=False, # Disable all progress bars + enable_model_summary=False, # Disable model summary ) if train_flag: @@ -186,4 +180,4 @@ def train_model(): if __name__ == "__main__": - train_model() + train_model() \ No newline at end of file