reorganized file structure

This commit is contained in:
liyuxiao2 2025-05-22 14:11:18 -04:00
parent 4198e1b929
commit ba796961a3
5 changed files with 225 additions and 3494 deletions

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@ from typing import Optional
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix
def plot_confusion_matrix( def plot_confusion_matrix(
y_true: np.array, y_true: np.array,
y_pred: np.array, y_pred: np.array,
@ -11,7 +12,7 @@ def plot_confusion_matrix(
title: Optional[str] = None, title: Optional[str] = None,
text: bool = True, text: bool = True,
rotate_x_text: int = 90, rotate_x_text: int = 90,
figsize: tuple = (16,9), figsize: tuple = (16, 9),
cmap: plt.cm = plt.cm.Blues, cmap: plt.cm = plt.cm.Blues,
): ):
"""Function to help plot confusion matrices """Function to help plot confusion matrices
@ -50,7 +51,14 @@ def plot_confusion_matrix(
for i in range(cm.shape[0]): for i in range(cm.shape[0]):
for j in range(cm.shape[1]): for j in range(cm.shape[1]):
if text: if text:
ax.text(j, i, format(cm[i,j], fmt), ha="center", va="center", color="white" if cm[i,j] > thresh else "black") ax.text(
j,
i,
format(cm[i, j], fmt),
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
if len(classes) == 2: if len(classes) == 2:
plt.axis([-0.5, 1.5, 1.5, -0.5]) plt.axis([-0.5, 1.5, 1.5, -0.5])
fig.tight_layout() fig.tight_layout()

View File

@ -2,21 +2,23 @@ import numpy as np
import torch import torch
import timm import timm
from torch import nn from torch import nn
import lightning as L
sizes = [ sizes = [
'mobilenetv3_large_075', "mobilenetv3_large_075",
'mobilenetv3_large_100', "mobilenetv3_large_100",
'mobilenetv3_rw', "mobilenetv3_rw",
'mobilenetv3_small_050', "mobilenetv3_small_050",
'mobilenetv3_small_075', "mobilenetv3_small_075",
'mobilenetv3_small_100', "mobilenetv3_small_100",
'tf_mobilenetv3_large_075', "tf_mobilenetv3_large_075",
'tf_mobilenetv3_large_100', "tf_mobilenetv3_large_100",
'tf_mobilenetv3_large_minimal_100', "tf_mobilenetv3_large_minimal_100",
'tf_mobilenetv3_small_075', "tf_mobilenetv3_small_075",
'tf_mobilenetv3_small_100', "tf_mobilenetv3_small_100",
'tf_mobilenetv3_small_minimal_100' "tf_mobilenetv3_small_minimal_100",
] ]
class SqueezeExcite(nn.Module): class SqueezeExcite(nn.Module):
def __init__( def __init__(
@ -54,8 +56,9 @@ class FastGlobalAvgPool1d(nn.Module):
in_size = x.size() in_size = x.size()
return x.view((in_size[0], in_size[1], -1)).mean(dim=2) return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
else: else:
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1) return (
x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1)
)
class GBN(torch.nn.Module): class GBN(torch.nn.Module):
@ -87,7 +90,7 @@ class GBN(torch.nn.Module):
def replace_bn(parent): def replace_bn(parent):
for n, m in parent.named_children(): for n, m in parent.named_children():
if type(m) is timm.layers.norm_act.BatchNormAct2d: if type(m) is timm.layers.norm_act.BatchNormAct2d:
# if type(m) is nn.BatchNorm2d: # if type(m) is nn.BatchNorm2d:
# print(type(m)) # print(type(m))
setattr( setattr(
parent, parent,
@ -97,6 +100,7 @@ def replace_bn(parent):
else: else:
replace_bn(m) replace_bn(m)
def replace_se(parent): def replace_se(parent):
for n, m in parent.named_children(): for n, m in parent.named_children():
if type(m) is timm.models._efficientnet_blocks.SqueezeExcite: if type(m) is timm.models._efficientnet_blocks.SqueezeExcite:
@ -111,6 +115,7 @@ def replace_se(parent):
else: else:
replace_se(m) replace_se(m)
def replace_conv(parent, ds_rate): def replace_conv(parent, ds_rate):
for n, m in parent.named_children(): for n, m in parent.named_children():
if type(m) is nn.Conv2d: if type(m) is nn.Conv2d:
@ -145,6 +150,7 @@ def replace_conv(parent, ds_rate):
else: else:
replace_conv(m, ds_rate) replace_conv(m, ds_rate)
def create_mobilenetv3(network, ds_rate=2, in_chans=2): def create_mobilenetv3(network, ds_rate=2, in_chans=2):
replace_se(network) replace_se(network)
replace_bn(network) replace_bn(network)
@ -152,19 +158,20 @@ def create_mobilenetv3(network, ds_rate=2, in_chans=2):
network.global_pool = FastGlobalAvgPool1d() network.global_pool = FastGlobalAvgPool1d()
network.conv_stem = nn.Conv1d( network.conv_stem = nn.Conv1d(
in_channels=in_chans, in_channels=in_chans,
out_channels=network.conv_stem.out_channels, out_channels=network.conv_stem.out_channels,
kernel_size=network.conv_stem.kernel_size, kernel_size=network.conv_stem.kernel_size,
stride=network.conv_stem.stride, stride=network.conv_stem.stride,
padding=network.conv_stem.padding, padding=network.conv_stem.padding,
bias=network.conv_stem.kernel_size, bias=network.conv_stem.kernel_size,
groups=network.conv_stem.groups, groups=network.conv_stem.groups,
) )
return network return network
def mobilenetv3( def mobilenetv3(
model_size = 'mobilenetv3_small_050', model_size="mobilenetv3_small_050",
num_classes: int = 10, num_classes: int = 10,
drop_rate: float = 0, drop_rate: float = 0,
drop_path_rate: float = 0, drop_path_rate: float = 0,
@ -183,24 +190,11 @@ def mobilenetv3(
) )
return mdl return mdl
import torch.nn as nn
class Simple1DCNN(nn.Module): class RFClassifier(L.LightningModule):
def __init__(self, in_chans=2, num_classes=4): def __init__(self, model):
super().__init__() super().__init__()
self.net = nn.Sequential( self.model = model
nn.Conv1d(in_chans, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool1d(2),
nn.Conv1d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool1d(1),
nn.Flatten(),
nn.Linear(64, num_classes)
)
def forward(self, x): def forward(self, x):
return self.net(x) # x shape: [B, 2, 128] return self.model(x)
def simple_cnn(in_chans=2, num_classes=4):
return Simple1DCNN(in_chans, num_classes)

View File

@ -1,4 +1,5 @@
import sys, os import sys, os
sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed
import numpy as np import numpy as np
import torch import torch
@ -11,26 +12,31 @@ dataset = settings.dataset.modulation_types
class ModulationH5Dataset(Dataset): class ModulationH5Dataset(Dataset):
def __init__(self, hdf5_path, label_name, data_key="training_data", label_encoder=None, transform=None): def __init__(
self,
hdf5_path,
label_name,
data_key="training_data",
label_encoder=None,
transform=None,
):
self.hdf5_path = hdf5_path self.hdf5_path = hdf5_path
self.data_key = data_key self.data_key = data_key
self.label_name = label_name self.label_name = label_name
self.label_encoder = label_encoder self.label_encoder = label_encoder
self.transform = transform self.transform = transform
with h5py.File(hdf5_path, "r") as f:
with h5py.File(hdf5_path, 'r') as f:
self.length = f[data_key].shape[0] self.length = f[data_key].shape[0]
self.metadata = f["metadata"]["metadata"][:] self.metadata = f["metadata"]["metadata"][:]
settings = get_app_settings() settings = get_app_settings()
dataset_cfg = settings.dataset dataset_cfg = settings.dataset
all_labels = dataset_cfg.modulation_types all_labels = dataset_cfg.modulation_types
if self.label_encoder is None: if self.label_encoder is None:
from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import LabelEncoder
self.label_encoder = LabelEncoder() self.label_encoder = LabelEncoder()
self.label_encoder.fit(all_labels) self.label_encoder.fit(all_labels)
@ -38,12 +44,11 @@ class ModulationH5Dataset(Dataset):
raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata] raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata]
self.encoded_labels = self.label_encoder.transform(raw_labels) self.encoded_labels = self.label_encoder.transform(raw_labels)
def __len__(self): def __len__(self):
return self.length return self.length
def __getitem__(self, idx): def __getitem__(self, idx):
with h5py.File(self.hdf5_path, 'r') as f: with h5py.File(self.hdf5_path, "r") as f:
x = f[self.data_key][idx] # shape (1, 128) or similar x = f[self.data_key][idx] # shape (1, 128) or similar
# Normalize # Normalize
@ -54,4 +59,3 @@ class ModulationH5Dataset(Dataset):
label = torch.tensor(self.encoded_labels[idx], dtype=torch.long) label = torch.tensor(self.encoded_labels[idx], dtype=torch.long)
return x, label return x, label

153
data/training/train.py Normal file
View File

@ -0,0 +1,153 @@
import sys, os
script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from helpers.app_settings import get_app_settings
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
import lightning as L
import torch
import torch.nn.functional as F
import torchmetrics
from helpers.app_settings import get_app_settings
from modulation_dataset import ModulationH5Dataset
import mobilenetv3
def train_model():
settings = get_app_settings()
dataset = settings.dataset.modulation_types
train_flag = True
batch_size = 128
epochs = 1
checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'
train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'
val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'
dataset_name = 'Modulation Inference - Initial Model'
metadata_names = 'Modulation'
label = 'modulation'
torch.set_float32_matmul_precision('high')
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
train_loader = torch.utils.data.DataLoader(
dataset=ds_train,
batch_size=batch_size,
shuffle=True,
num_workers=8,
)
val_loader = torch.utils.data.DataLoader(
dataset=ds_val,
batch_size=2048,
shuffle=False,
num_workers=8,
)
for x, y in train_loader:
print("X shape:", x.shape)
print("Y values:", y[:10])
break
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
num_classes = len(ds_train.label_encoder.classes_)
hparams = {
'drop_path_rate': 0.2,
'drop_rate': 0.5,
'learning_rate': 3e-4,
'wd': 0.2
}
class RFClassifier(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=hparams['learning_rate'],
weight_decay=hparams['wd'],
)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=len(train_loader),
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': lr_scheduler,
'interval': 'step'
}
}
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.accuracy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', self.accuracy, prog_bar=True)
model = RFClassifier(
mobilenetv3.mobilenetv3(
model_size='mobilenetv3_small_050',
num_classes=num_classes,
drop_rate=hparams['drop_rate'],
drop_path_rate=hparams['drop_path_rate']
)
)
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
filename=checkpoint_filename,
save_top_k=True,
verbose=True,
monitor='val_acc',
mode='max',
enable_version_counter=False,
)
trainer = L.Trainer(
max_epochs=epochs,
callbacks=[checkpoint_callback],
accelerator='gpu',
devices=1,
benchmark=True,
precision='bf16-mixed',
logger=False
)
if train_flag:
trainer.fit(model, train_loader, val_loader)
if __name__ == '__main__':
train_model()