modrec-workflow/scripts/model_builder/modulation_dataset.py
Michael Luciuk 9979d84e29
All checks were successful
Modulation Recognition Demo / ria-demo (push) Successful in 2m52s
Documentation and formatting updates (#1)
Documentation and formatting updates:
- Updates to project README.
- Adding project health files (`LICENSE` and `SECURITY.md`)
- A few minor formatting changes throughout
- A few typo fixes, removal of unused code, cleanup of shadowed variables, and fixed import ordering with isort.

**Note:** These changes have not been tested.

Co-authored-by: Michael Luciuk <michael.luciuk@gmail.com>
Co-authored-by: Liyu Xiao <liyu@qoherent.ai>
Reviewed-on: https://git.riahub.ai/qoherent/modrec-workflow/pulls/1
Reviewed-by: Liyux <liyux@noreply.localhost>
Co-authored-by: Michael Luciuk <michael@qoherent.ai>
Co-committed-by: Michael Luciuk <michael@qoherent.ai>
2025-07-08 10:50:41 -04:00

64 lines
1.8 KiB
Python

import os
import sys
sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from helpers.app_settings import get_app_settings
settings = get_app_settings()
dataset = settings.dataset.modulation_types
class ModulationH5Dataset(Dataset):
def __init__(
self,
hdf5_path,
label_name,
data_key="training_data",
label_encoder=None,
transform=None,
):
self.hdf5_path = hdf5_path
self.data_key = data_key
self.label_name = label_name
self.label_encoder = label_encoder
self.transform = transform
with h5py.File(hdf5_path, "r") as f:
self.length = f[data_key].shape[0]
self.metadata = f["metadata"]["metadata"][:]
settings = get_app_settings()
dataset_cfg = settings.dataset
all_labels = dataset_cfg.modulation_types
if self.label_encoder is None:
from sklearn.preprocessing import LabelEncoder
self.label_encoder = LabelEncoder()
self.label_encoder.fit(all_labels)
# Get per-sample labels from metadata
raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata]
self.encoded_labels = self.label_encoder.transform(raw_labels)
def __len__(self):
return self.length
def __getitem__(self, idx):
with h5py.File(self.hdf5_path, "r") as f:
x = f[self.data_key][idx] # shape (1, 128) or similar
# Normalize
mean = np.mean(x, axis=-1, keepdims=True)
std = np.std(x, axis=-1, keepdims=True)
x = (x - mean) / (std + 1e-6)
x = torch.tensor(x, dtype=torch.float32)
label = torch.tensor(self.encoded_labels[idx], dtype=torch.long)
return x, label