formatted all files
This commit is contained in:
parent
3557f854e8
commit
ff3d45653d
|
@ -8,7 +8,6 @@ from onnx_files import ONNX_DIR
|
|||
from helpers.app_settings import get_app_settings
|
||||
|
||||
|
||||
|
||||
def convert_to_onnx(ckpt_path, fp16=False):
|
||||
"""
|
||||
Convert a PyTorch model to ONNX format.
|
||||
|
@ -23,10 +22,9 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
|||
inference_cfg = settings.inference
|
||||
dataset_cfg = settings.dataset
|
||||
|
||||
|
||||
in_channels = 2
|
||||
batch_size = 1
|
||||
slice_length = int(1024/dataset_cfg.num_slices)
|
||||
slice_length = int(1024 / dataset_cfg.num_slices)
|
||||
num_classes = inference_cfg.num_classes
|
||||
|
||||
model = RFClassifier(
|
||||
|
@ -38,7 +36,7 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
|||
)
|
||||
|
||||
checkpoint = torch.load(
|
||||
ckpt_path, weights_only = True, map_location=torch.device("cpu")
|
||||
ckpt_path, weights_only=True, map_location=torch.device("cpu")
|
||||
)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
|
@ -47,7 +45,6 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
|||
|
||||
model.eval()
|
||||
|
||||
|
||||
# Generate random sample data
|
||||
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
||||
if fp16:
|
||||
|
@ -59,7 +56,6 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
|||
output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx")
|
||||
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
||||
|
||||
|
||||
torch.onnx.export(
|
||||
model=model,
|
||||
args=sample_input,
|
||||
|
@ -72,6 +68,7 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
|||
dynamo=False, # Requires onnxscript
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from checkpoint_files import CHECKPOINTS_DIR
|
||||
|
||||
|
@ -79,6 +76,8 @@ if __name__ == "__main__":
|
|||
|
||||
print("Converting to ONNX...")
|
||||
|
||||
convert_to_onnx(ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False)
|
||||
convert_to_onnx(
|
||||
ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False
|
||||
)
|
||||
|
||||
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint))
|
|
@ -23,7 +23,6 @@ info_dtype = np.dtype(
|
|||
)
|
||||
|
||||
|
||||
|
||||
def write_hdf5_file(records, output_path, dataset_name="data"):
|
||||
"""
|
||||
Writes a list of records to an HDF5 file.
|
||||
|
@ -77,11 +76,12 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
|
|||
|
||||
return output_path
|
||||
|
||||
|
||||
def complex_to_channel(data):
|
||||
"""
|
||||
Convert complex-valued IQ data of shape (1, N) to 2-channel real array of shape (2, N).
|
||||
"""
|
||||
assert np.iscomplexobj(data) #check if the data is in the form a+bi
|
||||
assert np.iscomplexobj(data) # check if the data is in the form a+bi
|
||||
real = np.real(data[0]) # (N,)
|
||||
imag = np.imag(data[0]) # (N,)
|
||||
stacked = np.stack([real, imag], axis=0) # shape (2, N)
|
||||
|
@ -114,22 +114,18 @@ def generate_datasets(cfg):
|
|||
for fname in files:
|
||||
rec = from_npy(os.path.join(cfg.input_dir, fname))
|
||||
|
||||
data = rec.data #here data is a numpy array with the shape (1, N)
|
||||
data = rec.data # here data is a numpy array with the shape (1, N)
|
||||
|
||||
data = complex_to_channel(data) # convert to 2-channel real array
|
||||
|
||||
|
||||
md = rec.metadata # pull metadata from the recording
|
||||
md.setdefault("recid", len(records))
|
||||
records.append((data, md))
|
||||
|
||||
# split each recording into <num_slices> snippets each
|
||||
|
||||
|
||||
records = split_recording(records, cfg.num_slices)
|
||||
|
||||
|
||||
|
||||
train_records, val_records = split(records, cfg.train_split, cfg.seed)
|
||||
|
||||
train_path = os.path.join(cfg.output_dir, "train.h5")
|
||||
|
|
|
@ -2,7 +2,7 @@ import random
|
|||
from collections import defaultdict
|
||||
|
||||
|
||||
def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
|
||||
def split(dataset, train_frac=0.8, seed=42, label_key="modulation"):
|
||||
"""
|
||||
Splits a dataset into smaller datasets based on the specified lengths.
|
||||
|
||||
|
@ -17,15 +17,14 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
|
|||
for data, md in dataset:
|
||||
rec_buckets[md["recid"]].append((data, md))
|
||||
|
||||
|
||||
rec_labels = {} #store labels for each recording
|
||||
rec_labels = {} # store labels for each recording
|
||||
for rec_id, group in rec_buckets.items():
|
||||
label = group[0][1][label_key]
|
||||
if isinstance(label, bytes): #if the label is a byte string
|
||||
if isinstance(label, bytes): # if the label is a byte string
|
||||
label = label.decode("utf-8")
|
||||
rec_labels[rec_id] = label
|
||||
|
||||
label_rec_ids = defaultdict(list) #group rec_ids by label
|
||||
label_rec_ids = defaultdict(list) # group rec_ids by label
|
||||
for rec_id, label in rec_labels.items():
|
||||
label_rec_ids[label].append(rec_id)
|
||||
|
||||
|
@ -35,12 +34,11 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
|
|||
for label, rec_ids in label_rec_ids.items():
|
||||
random.shuffle(rec_ids)
|
||||
split_idx = int(len(rec_ids) * train_frac)
|
||||
train_recs.update(rec_ids[:split_idx]) #pulls train_frac or rec_ids per label, guarantees all modulations are represented
|
||||
train_recs.update(
|
||||
rec_ids[:split_idx]
|
||||
) # pulls train_frac or rec_ids per label, guarantees all modulations are represented
|
||||
val_recs.update(rec_ids[split_idx:])
|
||||
|
||||
|
||||
|
||||
|
||||
# add the assigned recordings to the train and val datasets
|
||||
train_dataset, val_dataset = [], []
|
||||
for rec_id, group in rec_buckets.items():
|
||||
|
@ -49,9 +47,9 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
|
|||
elif rec_id in val_recs:
|
||||
val_dataset.extend(group)
|
||||
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
def split_recording(recording_list, num_snippets):
|
||||
"""
|
||||
Splits a list of recordings into smaller chunks.
|
||||
|
|
|
@ -24,7 +24,6 @@ from modulation_dataset import ModulationH5Dataset
|
|||
import mobilenetv3
|
||||
|
||||
|
||||
|
||||
def train_model():
|
||||
settings = get_app_settings()
|
||||
dataset = settings.dataset.modulation_types
|
||||
|
@ -33,16 +32,20 @@ def train_model():
|
|||
batch_size = 128
|
||||
epochs = 1
|
||||
|
||||
checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'
|
||||
checkpoint_filename = f"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model"
|
||||
|
||||
train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'
|
||||
val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'
|
||||
train_data = (
|
||||
"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5"
|
||||
)
|
||||
val_data = (
|
||||
"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5"
|
||||
)
|
||||
|
||||
dataset_name = 'Modulation Inference - Initial Model'
|
||||
metadata_names = 'Modulation'
|
||||
label = 'modulation'
|
||||
dataset_name = "Modulation Inference - Initial Model"
|
||||
metadata_names = "Modulation"
|
||||
label = "modulation"
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
|
||||
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
|
||||
|
@ -69,17 +72,19 @@ def train_model():
|
|||
num_classes = len(ds_train.label_encoder.classes_)
|
||||
|
||||
hparams = {
|
||||
'drop_path_rate': 0.2,
|
||||
'drop_rate': 0.5,
|
||||
'learning_rate': 3e-4,
|
||||
'wd': 0.2
|
||||
"drop_path_rate": 0.2,
|
||||
"drop_rate": 0.5,
|
||||
"learning_rate": 3e-4,
|
||||
"wd": 0.2,
|
||||
}
|
||||
|
||||
class RFClassifier(L.LightningModule):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
|
||||
self.accuracy = torchmetrics.Accuracy(
|
||||
task="multiclass", num_classes=num_classes
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
@ -87,26 +92,23 @@ def train_model():
|
|||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
lr=hparams['learning_rate'],
|
||||
weight_decay=hparams['wd'],
|
||||
lr=hparams["learning_rate"],
|
||||
weight_decay=hparams["wd"],
|
||||
)
|
||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
optimizer,
|
||||
T_0=len(train_loader),
|
||||
)
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': lr_scheduler,
|
||||
'interval': 'step'
|
||||
}
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
|
||||
}
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
|
||||
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
|
@ -114,15 +116,15 @@ def train_model():
|
|||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
self.accuracy(y_hat, y)
|
||||
self.log('val_loss', loss, prog_bar=True)
|
||||
self.log('val_acc', self.accuracy, prog_bar=True)
|
||||
self.log("val_loss", loss, prog_bar=True)
|
||||
self.log("val_acc", self.accuracy, prog_bar=True)
|
||||
|
||||
model = RFClassifier(
|
||||
mobilenetv3.mobilenetv3(
|
||||
model_size='mobilenetv3_small_050',
|
||||
model_size="mobilenetv3_small_050",
|
||||
num_classes=num_classes,
|
||||
drop_rate=hparams['drop_rate'],
|
||||
drop_path_rate=hparams['drop_path_rate']
|
||||
drop_rate=hparams["drop_rate"],
|
||||
drop_path_rate=hparams["drop_path_rate"],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -130,24 +132,24 @@ def train_model():
|
|||
filename=checkpoint_filename,
|
||||
save_top_k=True,
|
||||
verbose=True,
|
||||
monitor='val_acc',
|
||||
mode='max',
|
||||
monitor="val_acc",
|
||||
mode="max",
|
||||
enable_version_counter=False,
|
||||
)
|
||||
|
||||
trainer = L.Trainer(
|
||||
max_epochs=epochs,
|
||||
callbacks=[checkpoint_callback],
|
||||
accelerator='gpu',
|
||||
accelerator="gpu",
|
||||
devices=1,
|
||||
benchmark=True,
|
||||
precision='bf16-mixed',
|
||||
logger=False
|
||||
precision="bf16-mixed",
|
||||
logger=False,
|
||||
)
|
||||
|
||||
if train_flag:
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
train_model()
|
|
@ -4,10 +4,12 @@ from functools import lru_cache
|
|||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralConfig:
|
||||
run_mode: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSetConfig:
|
||||
input_dir: str
|
||||
|
@ -18,6 +20,7 @@ class DataSetConfig:
|
|||
val_split: float
|
||||
output_dir: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
batch_size: int
|
||||
|
@ -26,16 +29,19 @@ class TrainingConfig:
|
|||
checkpoint_path: str
|
||||
use_gpu: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
model_path: str
|
||||
num_classes: int
|
||||
output_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
build_dir: str
|
||||
|
||||
|
||||
class AppSettings:
|
||||
"""Application settings, to be initialized from app.yaml configuration file."""
|
||||
|
||||
|
@ -51,6 +57,7 @@ class AppSettings:
|
|||
self.inference = InferenceConfig(**config_data["inference"])
|
||||
self.app = AppConfig(**config_data["app"])
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_app_settings() -> AppSettings:
|
||||
"""Return application configuration settings."""
|
||||
|
|
Loading…
Reference in New Issue
Block a user