reorganized file structure
This commit is contained in:
parent
4198e1b929
commit
ba796961a3
File diff suppressed because it is too large
Load Diff
|
@ -3,6 +3,7 @@ from typing import Optional
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from sklearn.metrics import confusion_matrix
|
from sklearn.metrics import confusion_matrix
|
||||||
|
|
||||||
|
|
||||||
def plot_confusion_matrix(
|
def plot_confusion_matrix(
|
||||||
y_true: np.array,
|
y_true: np.array,
|
||||||
y_pred: np.array,
|
y_pred: np.array,
|
||||||
|
@ -11,7 +12,7 @@ def plot_confusion_matrix(
|
||||||
title: Optional[str] = None,
|
title: Optional[str] = None,
|
||||||
text: bool = True,
|
text: bool = True,
|
||||||
rotate_x_text: int = 90,
|
rotate_x_text: int = 90,
|
||||||
figsize: tuple = (16,9),
|
figsize: tuple = (16, 9),
|
||||||
cmap: plt.cm = plt.cm.Blues,
|
cmap: plt.cm = plt.cm.Blues,
|
||||||
):
|
):
|
||||||
"""Function to help plot confusion matrices
|
"""Function to help plot confusion matrices
|
||||||
|
@ -50,7 +51,14 @@ def plot_confusion_matrix(
|
||||||
for i in range(cm.shape[0]):
|
for i in range(cm.shape[0]):
|
||||||
for j in range(cm.shape[1]):
|
for j in range(cm.shape[1]):
|
||||||
if text:
|
if text:
|
||||||
ax.text(j, i, format(cm[i,j], fmt), ha="center", va="center", color="white" if cm[i,j] > thresh else "black")
|
ax.text(
|
||||||
|
j,
|
||||||
|
i,
|
||||||
|
format(cm[i, j], fmt),
|
||||||
|
ha="center",
|
||||||
|
va="center",
|
||||||
|
color="white" if cm[i, j] > thresh else "black",
|
||||||
|
)
|
||||||
if len(classes) == 2:
|
if len(classes) == 2:
|
||||||
plt.axis([-0.5, 1.5, 1.5, -0.5])
|
plt.axis([-0.5, 1.5, 1.5, -0.5])
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
|
@ -2,21 +2,23 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import timm
|
import timm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import lightning as L
|
||||||
|
|
||||||
sizes = [
|
sizes = [
|
||||||
'mobilenetv3_large_075',
|
"mobilenetv3_large_075",
|
||||||
'mobilenetv3_large_100',
|
"mobilenetv3_large_100",
|
||||||
'mobilenetv3_rw',
|
"mobilenetv3_rw",
|
||||||
'mobilenetv3_small_050',
|
"mobilenetv3_small_050",
|
||||||
'mobilenetv3_small_075',
|
"mobilenetv3_small_075",
|
||||||
'mobilenetv3_small_100',
|
"mobilenetv3_small_100",
|
||||||
'tf_mobilenetv3_large_075',
|
"tf_mobilenetv3_large_075",
|
||||||
'tf_mobilenetv3_large_100',
|
"tf_mobilenetv3_large_100",
|
||||||
'tf_mobilenetv3_large_minimal_100',
|
"tf_mobilenetv3_large_minimal_100",
|
||||||
'tf_mobilenetv3_small_075',
|
"tf_mobilenetv3_small_075",
|
||||||
'tf_mobilenetv3_small_100',
|
"tf_mobilenetv3_small_100",
|
||||||
'tf_mobilenetv3_small_minimal_100'
|
"tf_mobilenetv3_small_minimal_100",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class SqueezeExcite(nn.Module):
|
class SqueezeExcite(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -54,8 +56,9 @@ class FastGlobalAvgPool1d(nn.Module):
|
||||||
in_size = x.size()
|
in_size = x.size()
|
||||||
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
||||||
else:
|
else:
|
||||||
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1)
|
return (
|
||||||
|
x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GBN(torch.nn.Module):
|
class GBN(torch.nn.Module):
|
||||||
|
@ -97,6 +100,7 @@ def replace_bn(parent):
|
||||||
else:
|
else:
|
||||||
replace_bn(m)
|
replace_bn(m)
|
||||||
|
|
||||||
|
|
||||||
def replace_se(parent):
|
def replace_se(parent):
|
||||||
for n, m in parent.named_children():
|
for n, m in parent.named_children():
|
||||||
if type(m) is timm.models._efficientnet_blocks.SqueezeExcite:
|
if type(m) is timm.models._efficientnet_blocks.SqueezeExcite:
|
||||||
|
@ -111,6 +115,7 @@ def replace_se(parent):
|
||||||
else:
|
else:
|
||||||
replace_se(m)
|
replace_se(m)
|
||||||
|
|
||||||
|
|
||||||
def replace_conv(parent, ds_rate):
|
def replace_conv(parent, ds_rate):
|
||||||
for n, m in parent.named_children():
|
for n, m in parent.named_children():
|
||||||
if type(m) is nn.Conv2d:
|
if type(m) is nn.Conv2d:
|
||||||
|
@ -145,6 +150,7 @@ def replace_conv(parent, ds_rate):
|
||||||
else:
|
else:
|
||||||
replace_conv(m, ds_rate)
|
replace_conv(m, ds_rate)
|
||||||
|
|
||||||
|
|
||||||
def create_mobilenetv3(network, ds_rate=2, in_chans=2):
|
def create_mobilenetv3(network, ds_rate=2, in_chans=2):
|
||||||
replace_se(network)
|
replace_se(network)
|
||||||
replace_bn(network)
|
replace_bn(network)
|
||||||
|
@ -163,8 +169,9 @@ def create_mobilenetv3(network, ds_rate=2, in_chans=2):
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
def mobilenetv3(
|
def mobilenetv3(
|
||||||
model_size = 'mobilenetv3_small_050',
|
model_size="mobilenetv3_small_050",
|
||||||
num_classes: int = 10,
|
num_classes: int = 10,
|
||||||
drop_rate: float = 0,
|
drop_rate: float = 0,
|
||||||
drop_path_rate: float = 0,
|
drop_path_rate: float = 0,
|
||||||
|
@ -183,24 +190,11 @@ def mobilenetv3(
|
||||||
)
|
)
|
||||||
return mdl
|
return mdl
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class Simple1DCNN(nn.Module):
|
class RFClassifier(L.LightningModule):
|
||||||
def __init__(self, in_chans=2, num_classes=4):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
self.model = model
|
||||||
nn.Conv1d(in_chans, 32, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.MaxPool1d(2),
|
|
||||||
nn.Conv1d(32, 64, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.AdaptiveAvgPool1d(1),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(64, num_classes)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x) # x shape: [B, 2, 128]
|
return self.model(x)
|
||||||
|
|
||||||
def simple_cnn(in_chans=2, num_classes=4):
|
|
||||||
return Simple1DCNN(in_chans, num_classes)
|
|
|
@ -1,4 +1,5 @@
|
||||||
import sys, os
|
import sys, os
|
||||||
|
|
||||||
sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed
|
sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -11,26 +12,31 @@ dataset = settings.dataset.modulation_types
|
||||||
|
|
||||||
|
|
||||||
class ModulationH5Dataset(Dataset):
|
class ModulationH5Dataset(Dataset):
|
||||||
def __init__(self, hdf5_path, label_name, data_key="training_data", label_encoder=None, transform=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hdf5_path,
|
||||||
|
label_name,
|
||||||
|
data_key="training_data",
|
||||||
|
label_encoder=None,
|
||||||
|
transform=None,
|
||||||
|
):
|
||||||
self.hdf5_path = hdf5_path
|
self.hdf5_path = hdf5_path
|
||||||
self.data_key = data_key
|
self.data_key = data_key
|
||||||
self.label_name = label_name
|
self.label_name = label_name
|
||||||
self.label_encoder = label_encoder
|
self.label_encoder = label_encoder
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
|
with h5py.File(hdf5_path, "r") as f:
|
||||||
with h5py.File(hdf5_path, 'r') as f:
|
|
||||||
self.length = f[data_key].shape[0]
|
self.length = f[data_key].shape[0]
|
||||||
self.metadata = f["metadata"]["metadata"][:]
|
self.metadata = f["metadata"]["metadata"][:]
|
||||||
|
|
||||||
|
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
dataset_cfg = settings.dataset
|
dataset_cfg = settings.dataset
|
||||||
all_labels = dataset_cfg.modulation_types
|
all_labels = dataset_cfg.modulation_types
|
||||||
|
|
||||||
|
|
||||||
if self.label_encoder is None:
|
if self.label_encoder is None:
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
|
||||||
self.label_encoder = LabelEncoder()
|
self.label_encoder = LabelEncoder()
|
||||||
self.label_encoder.fit(all_labels)
|
self.label_encoder.fit(all_labels)
|
||||||
|
|
||||||
|
@ -38,12 +44,11 @@ class ModulationH5Dataset(Dataset):
|
||||||
raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata]
|
raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata]
|
||||||
self.encoded_labels = self.label_encoder.transform(raw_labels)
|
self.encoded_labels = self.label_encoder.transform(raw_labels)
|
||||||
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
with h5py.File(self.hdf5_path, 'r') as f:
|
with h5py.File(self.hdf5_path, "r") as f:
|
||||||
x = f[self.data_key][idx] # shape (1, 128) or similar
|
x = f[self.data_key][idx] # shape (1, 128) or similar
|
||||||
|
|
||||||
# Normalize
|
# Normalize
|
||||||
|
@ -54,4 +59,3 @@ class ModulationH5Dataset(Dataset):
|
||||||
|
|
||||||
label = torch.tensor(self.encoded_labels[idx], dtype=torch.long)
|
label = torch.tensor(self.encoded_labels[idx], dtype=torch.long)
|
||||||
return x, label
|
return x, label
|
||||||
|
|
153
data/training/train.py
Normal file
153
data/training/train.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
import sys, os
|
||||||
|
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||||
|
project_root = os.path.abspath(os.path.join(script_dir, "../.."))
|
||||||
|
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
|
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
|
import lightning as L
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchmetrics
|
||||||
|
|
||||||
|
from helpers.app_settings import get_app_settings
|
||||||
|
from modulation_dataset import ModulationH5Dataset
|
||||||
|
|
||||||
|
import mobilenetv3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def train_model():
|
||||||
|
settings = get_app_settings()
|
||||||
|
dataset = settings.dataset.modulation_types
|
||||||
|
|
||||||
|
train_flag = True
|
||||||
|
batch_size = 128
|
||||||
|
epochs = 1
|
||||||
|
|
||||||
|
checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'
|
||||||
|
|
||||||
|
train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'
|
||||||
|
val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'
|
||||||
|
|
||||||
|
dataset_name = 'Modulation Inference - Initial Model'
|
||||||
|
metadata_names = 'Modulation'
|
||||||
|
label = 'modulation'
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision('high')
|
||||||
|
|
||||||
|
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
|
||||||
|
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset=ds_train,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=8,
|
||||||
|
)
|
||||||
|
val_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset=ds_val,
|
||||||
|
batch_size=2048,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
for x, y in train_loader:
|
||||||
|
print("X shape:", x.shape)
|
||||||
|
print("Y values:", y[:10])
|
||||||
|
break
|
||||||
|
|
||||||
|
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
|
||||||
|
num_classes = len(ds_train.label_encoder.classes_)
|
||||||
|
|
||||||
|
hparams = {
|
||||||
|
'drop_path_rate': 0.2,
|
||||||
|
'drop_rate': 0.5,
|
||||||
|
'learning_rate': 3e-4,
|
||||||
|
'wd': 0.2
|
||||||
|
}
|
||||||
|
|
||||||
|
class RFClassifier(L.LightningModule):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
self.parameters(),
|
||||||
|
lr=hparams['learning_rate'],
|
||||||
|
weight_decay=hparams['wd'],
|
||||||
|
)
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
|
optimizer,
|
||||||
|
T_0=len(train_loader),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'lr_scheduler': {
|
||||||
|
'scheduler': lr_scheduler,
|
||||||
|
'interval': 'step'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self(x)
|
||||||
|
loss = F.cross_entropy(y_hat, y)
|
||||||
|
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
x, y = batch
|
||||||
|
y_hat = self(x)
|
||||||
|
loss = F.cross_entropy(y_hat, y)
|
||||||
|
self.accuracy(y_hat, y)
|
||||||
|
self.log('val_loss', loss, prog_bar=True)
|
||||||
|
self.log('val_acc', self.accuracy, prog_bar=True)
|
||||||
|
|
||||||
|
model = RFClassifier(
|
||||||
|
mobilenetv3.mobilenetv3(
|
||||||
|
model_size='mobilenetv3_small_050',
|
||||||
|
num_classes=num_classes,
|
||||||
|
drop_rate=hparams['drop_rate'],
|
||||||
|
drop_path_rate=hparams['drop_path_rate']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
|
||||||
|
filename=checkpoint_filename,
|
||||||
|
save_top_k=True,
|
||||||
|
verbose=True,
|
||||||
|
monitor='val_acc',
|
||||||
|
mode='max',
|
||||||
|
enable_version_counter=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = L.Trainer(
|
||||||
|
max_epochs=epochs,
|
||||||
|
callbacks=[checkpoint_callback],
|
||||||
|
accelerator='gpu',
|
||||||
|
devices=1,
|
||||||
|
benchmark=True,
|
||||||
|
precision='bf16-mixed',
|
||||||
|
logger=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_flag:
|
||||||
|
trainer.fit(model, train_loader, val_loader)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_model()
|
Loading…
Reference in New Issue
Block a user