reorganized file structure
This commit is contained in:
parent
4198e1b929
commit
ba796961a3
File diff suppressed because it is too large
Load Diff
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||
from matplotlib import pyplot as plt
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
|
||||
def plot_confusion_matrix(
|
||||
y_true: np.array,
|
||||
y_pred: np.array,
|
||||
|
@ -11,11 +12,11 @@ def plot_confusion_matrix(
|
|||
title: Optional[str] = None,
|
||||
text: bool = True,
|
||||
rotate_x_text: int = 90,
|
||||
figsize: tuple = (16,9),
|
||||
figsize: tuple = (16, 9),
|
||||
cmap: plt.cm = plt.cm.Blues,
|
||||
):
|
||||
"""Function to help plot confusion matrices
|
||||
|
||||
|
||||
https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
|
||||
"""
|
||||
if not title:
|
||||
|
@ -32,13 +33,13 @@ def plot_confusion_matrix(
|
|||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(cm, interpolation="none", cmap=cmap)
|
||||
ax.figure.colorbar(im, ax=ax)
|
||||
ax.set(
|
||||
xticks=np.arange(cm.shape[1]),
|
||||
yticks=np.arange(cm.shape[0]),
|
||||
xticklabels=classes,
|
||||
yticklabels=classes,
|
||||
title=title,
|
||||
ylabel="True label",
|
||||
ax.set(
|
||||
xticks=np.arange(cm.shape[1]),
|
||||
yticks=np.arange(cm.shape[0]),
|
||||
xticklabels=classes,
|
||||
yticklabels=classes,
|
||||
title=title,
|
||||
ylabel="True label",
|
||||
xlabel="Predicted label",
|
||||
)
|
||||
ax.set_xticklabels(classes, rotation=rotate_x_text)
|
||||
|
@ -50,9 +51,16 @@ def plot_confusion_matrix(
|
|||
for i in range(cm.shape[0]):
|
||||
for j in range(cm.shape[1]):
|
||||
if text:
|
||||
ax.text(j, i, format(cm[i,j], fmt), ha="center", va="center", color="white" if cm[i,j] > thresh else "black")
|
||||
ax.text(
|
||||
j,
|
||||
i,
|
||||
format(cm[i, j], fmt),
|
||||
ha="center",
|
||||
va="center",
|
||||
color="white" if cm[i, j] > thresh else "black",
|
||||
)
|
||||
if len(classes) == 2:
|
||||
plt.axis([-0.5, 1.5, 1.5, -0.5])
|
||||
fig.tight_layout()
|
||||
|
||||
return ax
|
||||
return ax
|
|
@ -2,21 +2,23 @@ import numpy as np
|
|||
import torch
|
||||
import timm
|
||||
from torch import nn
|
||||
import lightning as L
|
||||
|
||||
sizes = [
|
||||
'mobilenetv3_large_075',
|
||||
'mobilenetv3_large_100',
|
||||
'mobilenetv3_rw',
|
||||
'mobilenetv3_small_050',
|
||||
'mobilenetv3_small_075',
|
||||
'mobilenetv3_small_100',
|
||||
'tf_mobilenetv3_large_075',
|
||||
'tf_mobilenetv3_large_100',
|
||||
'tf_mobilenetv3_large_minimal_100',
|
||||
'tf_mobilenetv3_small_075',
|
||||
'tf_mobilenetv3_small_100',
|
||||
'tf_mobilenetv3_small_minimal_100'
|
||||
]
|
||||
"mobilenetv3_large_075",
|
||||
"mobilenetv3_large_100",
|
||||
"mobilenetv3_rw",
|
||||
"mobilenetv3_small_050",
|
||||
"mobilenetv3_small_075",
|
||||
"mobilenetv3_small_100",
|
||||
"tf_mobilenetv3_large_075",
|
||||
"tf_mobilenetv3_large_100",
|
||||
"tf_mobilenetv3_large_minimal_100",
|
||||
"tf_mobilenetv3_small_075",
|
||||
"tf_mobilenetv3_small_100",
|
||||
"tf_mobilenetv3_small_minimal_100",
|
||||
]
|
||||
|
||||
|
||||
class SqueezeExcite(nn.Module):
|
||||
def __init__(
|
||||
|
@ -54,10 +56,11 @@ class FastGlobalAvgPool1d(nn.Module):
|
|||
in_size = x.size()
|
||||
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
||||
else:
|
||||
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1)
|
||||
return (
|
||||
x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1)
|
||||
)
|
||||
|
||||
|
||||
|
||||
class GBN(torch.nn.Module):
|
||||
"""
|
||||
Ghost Batch Normalization
|
||||
|
@ -87,7 +90,7 @@ class GBN(torch.nn.Module):
|
|||
def replace_bn(parent):
|
||||
for n, m in parent.named_children():
|
||||
if type(m) is timm.layers.norm_act.BatchNormAct2d:
|
||||
# if type(m) is nn.BatchNorm2d:
|
||||
# if type(m) is nn.BatchNorm2d:
|
||||
# print(type(m))
|
||||
setattr(
|
||||
parent,
|
||||
|
@ -97,6 +100,7 @@ def replace_bn(parent):
|
|||
else:
|
||||
replace_bn(m)
|
||||
|
||||
|
||||
def replace_se(parent):
|
||||
for n, m in parent.named_children():
|
||||
if type(m) is timm.models._efficientnet_blocks.SqueezeExcite:
|
||||
|
@ -111,6 +115,7 @@ def replace_se(parent):
|
|||
else:
|
||||
replace_se(m)
|
||||
|
||||
|
||||
def replace_conv(parent, ds_rate):
|
||||
for n, m in parent.named_children():
|
||||
if type(m) is nn.Conv2d:
|
||||
|
@ -145,6 +150,7 @@ def replace_conv(parent, ds_rate):
|
|||
else:
|
||||
replace_conv(m, ds_rate)
|
||||
|
||||
|
||||
def create_mobilenetv3(network, ds_rate=2, in_chans=2):
|
||||
replace_se(network)
|
||||
replace_bn(network)
|
||||
|
@ -152,19 +158,20 @@ def create_mobilenetv3(network, ds_rate=2, in_chans=2):
|
|||
network.global_pool = FastGlobalAvgPool1d()
|
||||
|
||||
network.conv_stem = nn.Conv1d(
|
||||
in_channels=in_chans,
|
||||
out_channels=network.conv_stem.out_channels,
|
||||
kernel_size=network.conv_stem.kernel_size,
|
||||
stride=network.conv_stem.stride,
|
||||
padding=network.conv_stem.padding,
|
||||
bias=network.conv_stem.kernel_size,
|
||||
groups=network.conv_stem.groups,
|
||||
)
|
||||
in_channels=in_chans,
|
||||
out_channels=network.conv_stem.out_channels,
|
||||
kernel_size=network.conv_stem.kernel_size,
|
||||
stride=network.conv_stem.stride,
|
||||
padding=network.conv_stem.padding,
|
||||
bias=network.conv_stem.kernel_size,
|
||||
groups=network.conv_stem.groups,
|
||||
)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def mobilenetv3(
|
||||
model_size = 'mobilenetv3_small_050',
|
||||
model_size="mobilenetv3_small_050",
|
||||
num_classes: int = 10,
|
||||
drop_rate: float = 0,
|
||||
drop_path_rate: float = 0,
|
||||
|
@ -183,24 +190,11 @@ def mobilenetv3(
|
|||
)
|
||||
return mdl
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
class Simple1DCNN(nn.Module):
|
||||
def __init__(self, in_chans=2, num_classes=4):
|
||||
class RFClassifier(L.LightningModule):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(in_chans, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool1d(2),
|
||||
nn.Conv1d(32, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Flatten(),
|
||||
nn.Linear(64, num_classes)
|
||||
)
|
||||
self.model = model
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) # x shape: [B, 2, 128]
|
||||
|
||||
def simple_cnn(in_chans=2, num_classes=4):
|
||||
return Simple1DCNN(in_chans, num_classes)
|
||||
return self.model(x)
|
|
@ -1,4 +1,5 @@
|
|||
import sys, os
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -11,39 +12,43 @@ dataset = settings.dataset.modulation_types
|
|||
|
||||
|
||||
class ModulationH5Dataset(Dataset):
|
||||
def __init__(self, hdf5_path, label_name, data_key="training_data", label_encoder=None, transform=None):
|
||||
self.hdf5_path = hdf5_path
|
||||
self.data_key = data_key
|
||||
def __init__(
|
||||
self,
|
||||
hdf5_path,
|
||||
label_name,
|
||||
data_key="training_data",
|
||||
label_encoder=None,
|
||||
transform=None,
|
||||
):
|
||||
self.hdf5_path = hdf5_path
|
||||
self.data_key = data_key
|
||||
self.label_name = label_name
|
||||
self.label_encoder = label_encoder
|
||||
self.transform = transform
|
||||
|
||||
|
||||
with h5py.File(hdf5_path, 'r') as f:
|
||||
with h5py.File(hdf5_path, "r") as f:
|
||||
self.length = f[data_key].shape[0]
|
||||
self.metadata = f["metadata"]["metadata"][:]
|
||||
|
||||
|
||||
|
||||
settings = get_app_settings()
|
||||
dataset_cfg = settings.dataset
|
||||
all_labels = dataset_cfg.modulation_types
|
||||
|
||||
|
||||
if self.label_encoder is None:
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
self.label_encoder = LabelEncoder()
|
||||
self.label_encoder.fit(all_labels)
|
||||
|
||||
|
||||
# Get per-sample labels from metadata
|
||||
raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata]
|
||||
self.encoded_labels = self.label_encoder.transform(raw_labels)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with h5py.File(self.hdf5_path, 'r') as f:
|
||||
with h5py.File(self.hdf5_path, "r") as f:
|
||||
x = f[self.data_key][idx] # shape (1, 128) or similar
|
||||
|
||||
# Normalize
|
||||
|
@ -54,4 +59,3 @@ class ModulationH5Dataset(Dataset):
|
|||
|
||||
label = torch.tensor(self.encoded_labels[idx], dtype=torch.long)
|
||||
return x, label
|
||||
|
153
data/training/train.py
Normal file
153
data/training/train.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
import sys, os
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
|
||||
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
|
||||
import mobilenetv3
|
||||
|
||||
|
||||
|
||||
def train_model():
|
||||
settings = get_app_settings()
|
||||
dataset = settings.dataset.modulation_types
|
||||
|
||||
train_flag = True
|
||||
batch_size = 128
|
||||
epochs = 1
|
||||
|
||||
checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'
|
||||
|
||||
train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'
|
||||
val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'
|
||||
|
||||
dataset_name = 'Modulation Inference - Initial Model'
|
||||
metadata_names = 'Modulation'
|
||||
label = 'modulation'
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
|
||||
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset=ds_train,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=8,
|
||||
)
|
||||
val_loader = torch.utils.data.DataLoader(
|
||||
dataset=ds_val,
|
||||
batch_size=2048,
|
||||
shuffle=False,
|
||||
num_workers=8,
|
||||
)
|
||||
|
||||
for x, y in train_loader:
|
||||
print("X shape:", x.shape)
|
||||
print("Y values:", y[:10])
|
||||
break
|
||||
|
||||
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
|
||||
num_classes = len(ds_train.label_encoder.classes_)
|
||||
|
||||
hparams = {
|
||||
'drop_path_rate': 0.2,
|
||||
'drop_rate': 0.5,
|
||||
'learning_rate': 3e-4,
|
||||
'wd': 0.2
|
||||
}
|
||||
|
||||
class RFClassifier(L.LightningModule):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
lr=hparams['learning_rate'],
|
||||
weight_decay=hparams['wd'],
|
||||
)
|
||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer,
|
||||
T_0=len(train_loader),
|
||||
)
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': lr_scheduler,
|
||||
'interval': 'step'
|
||||
}
|
||||
}
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.accuracy(y_hat, y)
|
||||
self.log('val_loss', loss, prog_bar=True)
|
||||
self.log('val_acc', self.accuracy, prog_bar=True)
|
||||
|
||||
model = RFClassifier(
|
||||
mobilenetv3.mobilenetv3(
|
||||
model_size='mobilenetv3_small_050',
|
||||
num_classes=num_classes,
|
||||
drop_rate=hparams['drop_rate'],
|
||||
drop_path_rate=hparams['drop_path_rate']
|
||||
)
|
||||
)
|
||||
|
||||
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
|
||||
filename=checkpoint_filename,
|
||||
save_top_k=True,
|
||||
verbose=True,
|
||||
monitor='val_acc',
|
||||
mode='max',
|
||||
enable_version_counter=False,
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=epochs,
|
||||
callbacks=[checkpoint_callback],
|
||||
accelerator='gpu',
|
||||
devices=1,
|
||||
benchmark=True,
|
||||
precision='bf16-mixed',
|
||||
logger=False
|
||||
)
|
||||
|
||||
if train_flag:
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_model()
|
Loading…
Reference in New Issue
Block a user