fixed the training bar, so it only shows final epoch results.
All checks were successful
RIA Hub Workflow Demo / ria-demo (push) Successful in 37m46s
All checks were successful
RIA Hub Workflow Demo / ria-demo (push) Successful in 37m46s
This commit is contained in:
parent
5f0ac995c8
commit
7e255a704d
|
@ -2,7 +2,6 @@ import sys, os
|
||||||
|
|
||||||
os.environ["NNPACK"] = "0"
|
os.environ["NNPACK"] = "0"
|
||||||
|
|
||||||
|
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||||
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
|
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
|
||||||
|
@ -17,7 +16,7 @@ if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -28,27 +27,25 @@ from modulation_dataset import ModulationH5Dataset
|
||||||
|
|
||||||
import mobilenetv3
|
import mobilenetv3
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
class CleanProgressCallback(Callback):
|
||||||
|
"""Clean progress callback that only shows epoch summaries"""
|
||||||
class EpochOnlyProgressBar(TQDMProgressBar):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.enable = True
|
|
||||||
|
|
||||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
||||||
# Skip batch-level updates
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module):
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
epoch = trainer.current_epoch + 1
|
epoch = trainer.current_epoch + 1
|
||||||
|
|
||||||
|
# Get metrics
|
||||||
train_loss = trainer.callback_metrics.get("train_loss")
|
train_loss = trainer.callback_metrics.get("train_loss")
|
||||||
val_loss = trainer.callback_metrics.get("val_loss")
|
val_loss = trainer.callback_metrics.get("val_loss")
|
||||||
val_acc = trainer.callback_metrics.get("val_acc")
|
val_acc = trainer.callback_metrics.get("val_acc")
|
||||||
|
|
||||||
print(f"\nEpoch {epoch}:")
|
# Print clean output
|
||||||
if train_loss: print(f" Train Loss: {train_loss:.4f}")
|
print(f"Epoch {epoch}:")
|
||||||
if val_loss: print(f" Val Loss: {val_loss:.4f}")
|
if train_loss is not None:
|
||||||
if val_acc: print(f" Val Acc: {val_acc:.4f}")
|
print(f" Train Loss: {train_loss:.4f}")
|
||||||
|
if val_loss is not None:
|
||||||
|
print(f" Val Loss: {val_loss:.4f}")
|
||||||
|
if val_acc is not None:
|
||||||
|
print(f" Val Acc: {val_acc:.4f}")
|
||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,12 +61,8 @@ def train_model():
|
||||||
checkpoint_dir = training_cfg.checkpoint_dir
|
checkpoint_dir = training_cfg.checkpoint_dir
|
||||||
checkpoint_filename = training_cfg.checkpoint_filename
|
checkpoint_filename = training_cfg.checkpoint_filename
|
||||||
|
|
||||||
train_data = (
|
train_data = f"{dataset_cfg.output_dir}/train.h5"
|
||||||
f"{dataset_cfg.output_dir}/train.h5"
|
val_data = f"{dataset_cfg.output_dir}/val.h5"
|
||||||
)
|
|
||||||
val_data = (
|
|
||||||
f"{dataset_cfg.output_dir}/val.h5"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_name = "Modulation Inference - Initial Model"
|
dataset_name = "Modulation Inference - Initial Model"
|
||||||
metadata_names = "Modulation"
|
metadata_names = "Modulation"
|
||||||
|
@ -138,7 +131,7 @@ def train_model():
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
loss = F.cross_entropy(y_hat, y)
|
loss = F.cross_entropy(y_hat, y)
|
||||||
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
|
self.log("train_loss", loss, on_epoch=True, prog_bar=False)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
|
@ -146,8 +139,8 @@ def train_model():
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
loss = F.cross_entropy(y_hat, y)
|
loss = F.cross_entropy(y_hat, y)
|
||||||
self.accuracy(y_hat, y)
|
self.accuracy(y_hat, y)
|
||||||
self.log("val_loss", loss, prog_bar=True)
|
self.log("val_loss", loss, prog_bar=False)
|
||||||
self.log("val_acc", self.accuracy, prog_bar=True)
|
self.log("val_acc", self.accuracy, prog_bar=False)
|
||||||
|
|
||||||
model = RFClassifier(
|
model = RFClassifier(
|
||||||
mobilenetv3.mobilenetv3(
|
mobilenetv3.mobilenetv3(
|
||||||
|
@ -161,24 +154,25 @@ def train_model():
|
||||||
checkpoint_callback = ModelCheckpoint(
|
checkpoint_callback = ModelCheckpoint(
|
||||||
dirpath=checkpoint_dir,
|
dirpath=checkpoint_dir,
|
||||||
filename=checkpoint_filename,
|
filename=checkpoint_filename,
|
||||||
save_top_k=True,
|
save_top_k=1,
|
||||||
verbose=True,
|
verbose=False, # Disable checkpoint verbose output
|
||||||
monitor="val_acc",
|
monitor="val_acc",
|
||||||
mode="max",
|
mode="max",
|
||||||
enable_version_counter=False,
|
enable_version_counter=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
progress_bar = EpochOnlyProgressBar()
|
clean_progress = CleanProgressCallback()
|
||||||
|
|
||||||
trainer = L.Trainer(
|
trainer = L.Trainer(
|
||||||
max_epochs=epochs,
|
max_epochs=epochs,
|
||||||
callbacks=[checkpoint_callback,progress_bar],
|
callbacks=[checkpoint_callback, clean_progress],
|
||||||
accelerator="gpu",
|
accelerator="gpu",
|
||||||
devices=1,
|
devices=1,
|
||||||
benchmark=True,
|
benchmark=True,
|
||||||
precision="bf16-mixed",
|
precision="bf16-mixed",
|
||||||
logger=False,
|
logger=False,
|
||||||
enable_progress_bar=True,
|
enable_progress_bar=False, # Disable all progress bars
|
||||||
|
enable_model_summary=False, # Disable model summary
|
||||||
)
|
)
|
||||||
|
|
||||||
if train_flag:
|
if train_flag:
|
||||||
|
@ -186,4 +180,4 @@ def train_model():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
train_model()
|
train_model()
|
Loading…
Reference in New Issue
Block a user