62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
import sys, os
|
|
|
|
sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import h5py
|
|
from helpers.app_settings import get_app_settings
|
|
|
|
settings = get_app_settings()
|
|
dataset = settings.dataset.modulation_types
|
|
|
|
|
|
class ModulationH5Dataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
hdf5_path,
|
|
label_name,
|
|
data_key="training_data",
|
|
label_encoder=None,
|
|
transform=None,
|
|
):
|
|
self.hdf5_path = hdf5_path
|
|
self.data_key = data_key
|
|
self.label_name = label_name
|
|
self.label_encoder = label_encoder
|
|
self.transform = transform
|
|
|
|
with h5py.File(hdf5_path, "r") as f:
|
|
self.length = f[data_key].shape[0]
|
|
self.metadata = f["metadata"]["metadata"][:]
|
|
|
|
settings = get_app_settings()
|
|
dataset_cfg = settings.dataset
|
|
all_labels = dataset_cfg.modulation_types
|
|
|
|
if self.label_encoder is None:
|
|
from sklearn.preprocessing import LabelEncoder
|
|
|
|
self.label_encoder = LabelEncoder()
|
|
self.label_encoder.fit(all_labels)
|
|
|
|
# Get per-sample labels from metadata
|
|
raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata]
|
|
self.encoded_labels = self.label_encoder.transform(raw_labels)
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
def __getitem__(self, idx):
|
|
with h5py.File(self.hdf5_path, "r") as f:
|
|
x = f[self.data_key][idx] # shape (1, 128) or similar
|
|
|
|
# Normalize
|
|
mean = np.mean(x, axis=-1, keepdims=True)
|
|
std = np.std(x, axis=-1, keepdims=True)
|
|
x = (x - mean) / (std + 1e-6)
|
|
x = torch.tensor(x, dtype=torch.float32)
|
|
|
|
label = torch.tensor(self.encoded_labels[idx], dtype=torch.long)
|
|
return x, label
|