formatted all files

This commit is contained in:
liyuxiao2 2025-05-22 14:12:36 -04:00
parent 3557f854e8
commit ff3d45653d
7 changed files with 90 additions and 88 deletions

View File

@ -3,4 +3,4 @@ import os
CHECKPOINTS_DIR = os.path.dirname(os.path.abspath(__file__))
if __name__ == "__main__":
print(CHECKPOINTS_DIR)
print(CHECKPOINTS_DIR)

View File

@ -8,7 +8,6 @@ from onnx_files import ONNX_DIR
from helpers.app_settings import get_app_settings
def convert_to_onnx(ckpt_path, fp16=False):
"""
Convert a PyTorch model to ONNX format.
@ -19,16 +18,15 @@ def convert_to_onnx(ckpt_path, fp16=False):
output_path (str): The path to save the converted ONNX model.
"""
settings = get_app_settings()
inference_cfg = settings.inference
dataset_cfg = settings.dataset
in_channels = 2
batch_size = 1
slice_length = int(1024/dataset_cfg.num_slices)
slice_length = int(1024 / dataset_cfg.num_slices)
num_classes = inference_cfg.num_classes
model = RFClassifier(
model=mobilenetv3(
model_size="mobilenetv3_small_050",
@ -36,18 +34,17 @@ def convert_to_onnx(ckpt_path, fp16=False):
in_chans=in_channels,
)
)
checkpoint = torch.load(
ckpt_path, weights_only = True, map_location=torch.device("cpu")
ckpt_path, weights_only=True, map_location=torch.device("cpu")
)
model.load_state_dict(checkpoint["state_dict"])
if fp16:
model.half()
model.eval()
# Generate random sample data
base, ext = os.path.splitext(os.path.basename(ckpt_path))
if fp16:
@ -58,8 +55,7 @@ def convert_to_onnx(ckpt_path, fp16=False):
else:
output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx")
sample_input = torch.rand(batch_size, in_channels, slice_length)
torch.onnx.export(
model=model,
args=sample_input,
@ -71,14 +67,17 @@ def convert_to_onnx(ckpt_path, fp16=False):
output_names=["output"],
dynamo=False, # Requires onnxscript
)
if __name__ == "__main__":
from checkpoint_files import CHECKPOINTS_DIR
model_checkpoint = "interference_recognition_model.ckpt"
print("Converting to ONNX...")
convert_to_onnx(ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False)
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint))
convert_to_onnx(
ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False
)
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint))

View File

@ -23,7 +23,6 @@ info_dtype = np.dtype(
)
def write_hdf5_file(records, output_path, dataset_name="data"):
"""
Writes a list of records to an HDF5 file.
@ -77,11 +76,12 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
return output_path
def complex_to_channel(data):
"""
Convert complex-valued IQ data of shape (1, N) to 2-channel real array of shape (2, N).
"""
assert np.iscomplexobj(data) #check if the data is in the form a+bi
assert np.iscomplexobj(data) # check if the data is in the form a+bi
real = np.real(data[0]) # (N,)
imag = np.imag(data[0]) # (N,)
stacked = np.stack([real, imag], axis=0) # shape (2, N)
@ -114,21 +114,17 @@ def generate_datasets(cfg):
for fname in files:
rec = from_npy(os.path.join(cfg.input_dir, fname))
data = rec.data #here data is a numpy array with the shape (1, N)
data = rec.data # here data is a numpy array with the shape (1, N)
data = complex_to_channel(data) # convert to 2-channel real array
md = rec.metadata # pull metadata from the recording
md.setdefault("recid", len(records))
records.append((data, md))
# split each recording into <num_slices> snippets each
records = split_recording(records, cfg.num_slices)
train_records, val_records = split(records, cfg.train_split, cfg.seed)
@ -147,6 +143,6 @@ def main():
train_path, val_path = generate_datasets(dataset_cfg)
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
if __name__ == "__main__":
main()

View File

@ -2,7 +2,7 @@ import random
from collections import defaultdict
def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
def split(dataset, train_frac=0.8, seed=42, label_key="modulation"):
"""
Splits a dataset into smaller datasets based on the specified lengths.
@ -16,31 +16,29 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
rec_buckets = defaultdict(list)
for data, md in dataset:
rec_buckets[md["recid"]].append((data, md))
rec_labels = {} #store labels for each recording
rec_labels = {} # store labels for each recording
for rec_id, group in rec_buckets.items():
label = group[0][1][label_key]
if isinstance(label, bytes): #if the label is a byte string
if isinstance(label, bytes): # if the label is a byte string
label = label.decode("utf-8")
rec_labels[rec_id] = label
label_rec_ids = defaultdict(list) #group rec_ids by label
label_rec_ids = defaultdict(list) # group rec_ids by label
for rec_id, label in rec_labels.items():
label_rec_ids[label].append(rec_id)
random.seed(seed)
train_recs, val_recs = set(), set()
for label, rec_ids in label_rec_ids.items():
random.shuffle(rec_ids)
split_idx = int(len(rec_ids) * train_frac)
train_recs.update(rec_ids[:split_idx]) #pulls train_frac or rec_ids per label, guarantees all modulations are represented
train_recs.update(
rec_ids[:split_idx]
) # pulls train_frac or rec_ids per label, guarantees all modulations are represented
val_recs.update(rec_ids[split_idx:])
# add the assigned recordings to the train and val datasets
train_dataset, val_dataset = [], []
for rec_id, group in rec_buckets.items():
@ -48,10 +46,10 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
train_dataset.extend(group)
elif rec_id in val_recs:
val_dataset.extend(group)
return train_dataset, val_dataset
def split_recording(recording_list, num_snippets):
"""
Splits a list of recordings into smaller chunks.
@ -71,7 +69,7 @@ def split_recording(recording_list, num_snippets):
start = i * L
end = (i + 1) * L
snippet = data[:, start:end]
# copy the metadata, adding a snippet index
snippet_md = md.copy()
snippet_md["snippet_idx"] = i

View File

@ -197,4 +197,4 @@ class RFClassifier(L.LightningModule):
self.model = model
def forward(self, x):
return self.model(x)
return self.model(x)

View File

@ -12,7 +12,7 @@ from helpers.app_settings import get_app_settings
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
import lightning as L
import torch
import torch.nn.functional as F
@ -24,7 +24,6 @@ from modulation_dataset import ModulationH5Dataset
import mobilenetv3
def train_model():
settings = get_app_settings()
dataset = settings.dataset.modulation_types
@ -33,16 +32,20 @@ def train_model():
batch_size = 128
epochs = 1
checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'
checkpoint_filename = f"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model"
train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'
val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'
train_data = (
"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5"
)
val_data = (
"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5"
)
dataset_name = 'Modulation Inference - Initial Model'
metadata_names = 'Modulation'
label = 'modulation'
dataset_name = "Modulation Inference - Initial Model"
metadata_names = "Modulation"
label = "modulation"
torch.set_float32_matmul_precision('high')
torch.set_float32_matmul_precision("high")
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
@ -69,17 +72,19 @@ def train_model():
num_classes = len(ds_train.label_encoder.classes_)
hparams = {
'drop_path_rate': 0.2,
'drop_rate': 0.5,
'learning_rate': 3e-4,
'wd': 0.2
"drop_path_rate": 0.2,
"drop_rate": 0.5,
"learning_rate": 3e-4,
"wd": 0.2,
}
class RFClassifier(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
self.accuracy = torchmetrics.Accuracy(
task="multiclass", num_classes=num_classes
)
def forward(self, x):
return self.model(x)
@ -87,26 +92,23 @@ def train_model():
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=hparams['learning_rate'],
weight_decay=hparams['wd'],
lr=hparams["learning_rate"],
weight_decay=hparams["wd"],
)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=len(train_loader),
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': lr_scheduler,
'interval': 'step'
}
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
}
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
@ -114,15 +116,15 @@ def train_model():
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.accuracy(y_hat, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', self.accuracy, prog_bar=True)
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", self.accuracy, prog_bar=True)
model = RFClassifier(
mobilenetv3.mobilenetv3(
model_size='mobilenetv3_small_050',
model_size="mobilenetv3_small_050",
num_classes=num_classes,
drop_rate=hparams['drop_rate'],
drop_path_rate=hparams['drop_path_rate']
drop_rate=hparams["drop_rate"],
drop_path_rate=hparams["drop_path_rate"],
)
)
@ -130,24 +132,24 @@ def train_model():
filename=checkpoint_filename,
save_top_k=True,
verbose=True,
monitor='val_acc',
mode='max',
monitor="val_acc",
mode="max",
enable_version_counter=False,
)
trainer = L.Trainer(
max_epochs=epochs,
callbacks=[checkpoint_callback],
accelerator='gpu',
accelerator="gpu",
devices=1,
benchmark=True,
precision='bf16-mixed',
logger=False
precision="bf16-mixed",
logger=False,
)
if train_flag:
trainer.fit(model, train_loader, val_loader)
if __name__ == '__main__':
train_model()
if __name__ == "__main__":
train_model()

View File

@ -4,10 +4,12 @@ from functools import lru_cache
import yaml
@dataclass
class GeneralConfig:
run_mode: str
run_mode: str
@dataclass
class DataSetConfig:
input_dir: str
@ -17,7 +19,8 @@ class DataSetConfig:
modulation_types: list
val_split: float
output_dir: str
@dataclass
class TrainingConfig:
batch_size: int
@ -26,16 +29,19 @@ class TrainingConfig:
checkpoint_path: str
use_gpu: bool
@dataclass
class InferenceConfig:
model_path: str
num_classes: int
output_path: str
@dataclass
class AppConfig:
build_dir: str
class AppSettings:
"""Application settings, to be initialized from app.yaml configuration file."""
@ -51,6 +57,7 @@ class AppSettings:
self.inference = InferenceConfig(**config_data["inference"])
self.app = AppConfig(**config_data["app"])
@lru_cache
def get_app_settings() -> AppSettings:
"""Return application configuration settings."""
@ -60,4 +67,4 @@ def get_app_settings() -> AppSettings:
if __name__ == "__main__":
s = get_app_settings()
s = get_app_settings()