modrec-workflow/scripts/training/modulation_dataset.py
Liyu Xiao b14cc2cce5
Some checks failed
RIA Hub Workflow Demo / ria-demo (push) Failing after 32s
deletdd files, updated workflow
2025-06-13 13:58:35 -04:00

62 lines
1.8 KiB
Python

import sys, os
sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed
import numpy as np
import torch
from torch.utils.data import Dataset
import h5py
from helpers.app_settings import get_app_settings
settings = get_app_settings()
dataset = settings.dataset.modulation_types
class ModulationH5Dataset(Dataset):
def __init__(
self,
hdf5_path,
label_name,
data_key="training_data",
label_encoder=None,
transform=None,
):
self.hdf5_path = hdf5_path
self.data_key = data_key
self.label_name = label_name
self.label_encoder = label_encoder
self.transform = transform
with h5py.File(hdf5_path, "r") as f:
self.length = f[data_key].shape[0]
self.metadata = f["metadata"]["metadata"][:]
settings = get_app_settings()
dataset_cfg = settings.dataset
all_labels = dataset_cfg.modulation_types
if self.label_encoder is None:
from sklearn.preprocessing import LabelEncoder
self.label_encoder = LabelEncoder()
self.label_encoder.fit(all_labels)
# Get per-sample labels from metadata
raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata]
self.encoded_labels = self.label_encoder.transform(raw_labels)
def __len__(self):
return self.length
def __getitem__(self, idx):
with h5py.File(self.hdf5_path, "r") as f:
x = f[self.data_key][idx] # shape (1, 128) or similar
# Normalize
mean = np.mean(x, axis=-1, keepdims=True)
std = np.std(x, axis=-1, keepdims=True)
x = (x - mean) / (std + 1e-6)
x = torch.tensor(x, dtype=torch.float32)
label = torch.tensor(self.encoded_labels[idx], dtype=torch.long)
return x, label