ria-toolkit-oss/src/ria_toolkit_oss/data/datasets/radio_dataset.py
M madrigal 8a66860d33
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 15m51s
Build Project / Build Project (3.10) (pull_request) Successful in 16m14s
Build Project / Build Project (3.11) (pull_request) Successful in 17m9s
Build Project / Build Project (3.12) (pull_request) Successful in 2m29s
Test with tox / Test with tox (3.12) (pull_request) Successful in 21m28s
Test with tox / Test with tox (3.10) (pull_request) Successful in 22m50s
Test with tox / Test with tox (3.11) (pull_request) Successful in 23m18s
Moved all contents of to , refactored accordingly
2026-04-21 14:38:06 -04:00

1069 lines
44 KiB
Python

from __future__ import annotations
import os
import pathlib
import re
from abc import ABC, abstractmethod
from collections import Counter
from typing import Iterator, Optional
import h5py
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from ria_toolkit_oss.data.datasets.h5helpers import (
append_entry_inplace,
copy_file,
copy_over_example,
delete_example_inplace,
duplicate_entry_inplace,
make_empty_clone,
)
class RadioDataset(ABC):
"""A radio dataset is an iterable dataset designed for machine learning applications in radio signal
processing and analysis. They are a structured collections of examples in a machine learning-ready format,
with associated metadata.
This is an abstract interface defining common properties and behavior of radio datasets. Therefore, this class
should not be instantiated directly. Instead, it should be subclassed to define specific interfaces for different
types of radio datasets. For example, see ria_toolkit_oss.data.datasets.IQDataset, which is a radio dataset
subclass tailored for tasks involving the processing of radio signals represented as IQ (In-phase and Quadrature)
samples.
:param source: Path to the dataset source file. For more information on dataset source files
and their format, see :doc:`radio_datasets`.
:type source: str or os.PathLike
"""
def __init__(self, source: str | os.PathLike):
"""Create a new RadioDataset."""
if not h5py.is_hdf5(source):
raise ValueError(f"Dataset source must be HDF5, {source} is not.")
# TODO: Check to see if source is RIA dataset, and let them know otherwise they should use dataset builder
# utilities to generate a dataset compatible with the RIA framework.
self._source = pathlib.Path(source)
self._index = 0
@property
def source(self) -> pathlib.Path:
"""
:return: Path to the dataset source file.
:type: pathlib.Path
"""
return self._source
@property
def shape(self) -> tuple[int]:
"""
:return: The shape of the dataset. The elements of the shape tuple give the lengths of the corresponding
dataset dimensions.
:type: tuple of ints
"""
with h5py.File(self.source, "r") as f:
return f["data"].shape
@property
def data(self) -> np.ndarray:
"""Retrieve the data from the source file.
.. note::
Accessing this property reads all the data from the source file into memory as a NumPy array, which can
consume significant amounts of memory and potentially degrade performance. Instead, use the
``RadioDataset`` class methods to process and manipulate the dataset source file. You can read individual
examples into memory as NumPy arrays by indexing the dataset: ``RadioDataset[idx]``.
:return: The dataset examples as a single NumPy array.
:type: np.ndarray
"""
with h5py.File(self.source, "r") as f:
return f["data"][:]
@property
def metadata(self) -> pd.DataFrame:
"""Retrieve the metadata from the source file.
.. note::
Accessing this property reads all the metadata from the source file into memory as a Pandas DataFrame.
:return: The dataset metadata as a Pandas DataFrame.
:type: pd.DataFrame
"""
with h5py.File(self.source, "r") as f:
return pd.DataFrame(f["metadata/metadata"][:]).map(decode_bytes)
@property
def labels(self) -> list[str]:
"""Retrieves the metadata labels from the dataset file.
:return: A list of metadata column headers.
:rtype: list of strings
**Examples:**
>>> awgn_builder = AWGN_Builder()
>>> awgn_builder.download_and_prepare()
>>> ds = awgn_builder.as_dataset(backend="pytorch")
>>> print(ds.labels)
['rec_id', 'modulation', 'snr']
"""
with h5py.File(self.source, "r") as f:
return [name for name, _ in f["metadata/metadata"].dtype.fields.items()]
@abstractmethod
def inspect(self):
"""
.. todo:: This method is not yet fully conceptualized. Likely, it will wrap some of the functionality in the
Dataset Inspector package (dataset_manager.inspector) to produce an image or visualization. However,
the Dataset Inspector package is not yet implemented.
"""
# TODO: Implement in subclass based on https://github.com/qoherent/QDM/blob/main/inspection_utils/inspector.py
# Consider removing moving into the dataset builder.
pass
@abstractmethod
def default_augmentations(self) -> list[callable]:
"""Returns a list of default augmentations.
:return: A list of default augmentations.
:rtype: list of callable
"""
pass
def augment( # noqa: C901 # TODO: Simplify function
self,
class_key: str,
augmentations: Optional[callable | list[callable]] = None,
level: Optional[float | list[float]] = 1.0,
target_size: Optional[int | list[int]] = None,
classes_to_augment: Optional[str | list[str]] = None,
inplace: Optional[bool] = False,
) -> RadioDataset | None:
"""
Supplement the dataset with new examples by applying various transformations
to the pre-existing examples in the dataset.
.. todo::
This method is currently under construction, and may produce unexpected results.
The process of supplementing a dataset to artificially increase the diversity
of examples is called augmentation. Training on augmented data can enhance
the generalization and robustness of deep machine learning models. For more
information, see `A Complete Guide to Data Augmentation
<https://www.datacamp.com/tutorial/complete-guide-data-augmentation>`_.
Metadata for each new example will be identical to the metadata of the
pre-existing example from which it was generated. The metadata will be
extended to include an 'augmentation' column, populated with the string
representation of the transform used.
Augmented data should only be used for model training, not for testing or
validation.
Unless specified, augmentations are applied equally across classes, maintaining
the original class distribution.
If target_size does not match the sum of the original class sizes scaled by
an integer multiple, the class distribution is slightly adjusted to satisfy
target_size.
:param class_key: Class name used to augment from and calculate class distribution.
:type class_key: str
:param augmentations: A function or list of functions that take an example
and return a transformed version. Defaults to ``default_augmentations()``.
:type augmentations: callable or list of callables, optional
:param level: The extent of augmentation from 0.0 (none) to 1.0 (full). If
``classes_to_augment`` is specified, can be either:
* A single float: All classes augmented evenly to this level.
* A list of floats: Each element corresponds to the augmentation level
target for the corresponding class.
:type level: float or list of floats, optional
:param target_size: Target size of the augmented dataset. Overrides ``level``
if specified. If ``classes_to_augment`` is specified, can be either:
* A single float: All classes are augmented proportional to their
relative frequency until the dataset reaches target_size.
* A list of floats: Each element corresponds to the target size for the
corresponding class.
:type target_size: int or list of ints, optional
:param classes_to_augment: List of metadata keys of classes to augment.
:type classes_to_augment: string or list of strings, optional
:param inplace: If True, the augmentation is performed inplace and ``None`` is returned.
:type inplace: bool, optional
:raises ValueError: If level has any values not in the range (0,1].
:raises ValueError: If target_size of dataset is already sufficed.
:raises ValueError: If a class in classes_to_augment does not exist in class_key.
:return: The augmented dataset or None if ``inplace=True``.
:rtype: RadioDataset or None
**Examples:**
>>> from ria.dataset_manager.builders import AWGN_Builder
>>> builder = AWGN_Builder()
>>> builder.download_and_prepare()
>>> ds = builder.as_dataset()
>>> ds.get_class_sizes(class_key='col')
{'a': 100, 'b': 500, 'c': 300}
>>> new_ds = ds.augment(class_key='col', classes_to_augment=['a', 'b'], target_size=1200)
>>> new_ds.get_class_sizes(class_key='col')
{'a': 150, 'b': 750, 'c': 300}
"""
if augmentations is None:
augmentations = self.default_augmentations()
if not isinstance(augmentations, list):
augmentations = [augmentations]
if isinstance(level, list):
for i in level:
if i <= 0 or i > 1:
raise ValueError("level must be in this range: (0,1]")
else:
if level <= 0 or level > 1:
raise ValueError("level must be in this range: (0,1]")
class_sizes = self.get_class_sizes(class_key=class_key)
if isinstance(target_size, int) and target_size <= sum(class_sizes.values()):
raise ValueError("target_size must be greater than the total sum of the current class sizes.")
# Encode class names to byte strings and check if all class names exist in class key
if classes_to_augment is not None:
if isinstance(classes_to_augment, list):
classes_to_augment = [cls_name.encode("utf-8") for cls_name in classes_to_augment]
for i in classes_to_augment:
if i not in class_sizes:
raise ValueError(f"class name of {i} does not belong to the class key of {class_key}")
else:
classes_to_augment = classes_to_augment.encode("utf-8")
if classes_to_augment not in class_sizes:
raise ValueError(
f"class name of {classes_to_augment} does not belong to the class key of {class_key}"
)
result_sizes = get_result_sizes(
level=level, target_size=target_size, classes_to_augment=classes_to_augment, class_sizes=class_sizes
)
if "augmentations" not in self.metadata.columns:
# Add metadata column to metadata
raise NotImplementedError
# Create new dataset object in not inplace
if not inplace:
new_source = self._get_next_file_name()
copy_file(original_source=self.source, new_source=new_source)
ds = self.__class__(source=new_source)
else:
ds = self
# Create a dict where each pair is the class name and a list of all indices of the examples of that class
indices_to_add = dict()
with h5py.File(ds.source, "a") as f:
class_labels = f["metadata/metadata"][class_key]
for i in range(len(class_labels)):
current_class = class_labels[i]
if class_sizes[current_class] < result_sizes[current_class] and current_class not in indices_to_add:
indices_to_add[current_class] = []
if class_sizes[current_class] < result_sizes[current_class] and current_class in indices_to_add:
indices_to_add[current_class].append(i)
for key in class_sizes:
if class_sizes[key] < result_sizes[key]:
# Generate a sublist which holds the indices of examples to be augmented
rand_idxs = np.random.choice(indices_to_add[key], result_sizes[key] - class_sizes[key], replace=True)
aug_idx = 0
with h5py.File(ds.source, "a") as f:
data = f["data"]
metadata = f["metadata/metadata"]
for idx in rand_idxs:
rand_example = data[idx]
augmented_example = augmentations[aug_idx](rand_example)
# Update corresponding metadata entry to contain the augmentation that was applied
original_metadata_entry = metadata[idx]
augmented_metadata_entry = original_metadata_entry.copy()
augmented_metadata_entry["augmentations"] = augmentations[aug_idx].__name__
# Update augmentation index after adding name of augmentation to metadata column
if aug_idx < len(augmentations) - 1:
aug_idx += 1
else:
aug_idx = 0
append_entry_inplace(source=ds.source, dataset_path="data", entry=augmented_example)
append_entry_inplace(
source=ds.source, dataset_path="metadata/metadata", entry=augmented_metadata_entry
)
if not inplace:
return ds
def subsample(self, class_key: str, percentage: float, inplace: Optional[bool] = False) -> RadioDataset | None:
"""Reduces the number of examples in all classes of a dataset by randomly subsampling each class according
to a specified percentage. This function reduces the number of examples per class to the specified
percentage without affecting the overall class distribution.
:param class_key: The name of the class to subsample.
:type class_key: str
:param percentage: The percentage of the original class sizes to keep.
:type percentage: float
:param inplace: If True, the operation modifies the existing source file directly and returns None.
If False, the operation creates a new dataset object and corresponding source file, leaving the original
dataset unchanged. Default is False.
:type inplace: bool, optional
:raises ValueError: If the target size of the class with the lowest frequency goes to 0.
:return: The subsampled dataset.
:rtype: RadioDataset or None
**Examples:**
>>> from ria.dataset_manager.builders import AWGN_Builder()
>>> builder = AWGN_Builder()
>>> builder.download_and_prepare()
>>> ds = builder.as_dataset()
>>> ds.get_class_sizes(class_key="col")
{a:100, b:200, c:300}
>>> new_ds = ds.subsample(percentage=0.80, class_key="col")
>>> new_ds.get_class_sizes(class_key="col")
{a:80, b:160, c:240}
"""
class_sizes = self.get_class_sizes(class_key=class_key)
channels, example_length = np.shape(self[0])
target_sizes = dict()
for key in class_sizes:
target_sizes[key] = target_sizes.get(key, int(class_sizes[key] * percentage))
if min(target_sizes.values()) <= 0:
raise ValueError("Subsampling can not be performed on dataset because class size will equal 0")
if not inplace:
ds = self._create_next_dataset(example_length=example_length)
masks = dict()
for key in class_sizes:
masks[key] = masks.get(
key, np.array([1] * target_sizes[key] + [0] * (class_sizes[key] - target_sizes[key]))
)
np.random.shuffle(masks[key])
counters = dict()
for key in class_sizes:
counters[key] = counters.get(key, 0)
idx = 0
with h5py.File(self.source, "r") as f:
while idx < len(self):
labels = f["metadata/metadata"][class_key]
current_class = labels[idx]
current_mask = masks[current_class]
current_mask_value = current_mask[counters[current_class]]
counters[current_class] += 1
if not inplace and current_mask_value == 1:
copy_over_example(self.source, ds.source, idx)
elif inplace and current_mask_value == 0:
delete_example_inplace(self.source, idx)
continue
idx += 1
if not inplace:
return ds
def resample(self, quantity_target: int, class_key: str, inplace: Optional[bool] = False) -> RadioDataset | None:
"""Adjusts an unsampled dataset by changing the number of examples per class to a user-specified quantity.
For each class:
- If there are excess examples, it randomly subsamples the class to the quantity target.
- If there are less examples, it randomly duplicates examples to reach the quantity target.
:param quantity_target: The number of examples each class should have.
:type quantity_target: int
:param class_key: The label of the class to resample.
:type class_key: str
:param inplace: If True, the operation modifies the existing source file directly and returns None.
If False, the operation creates a new dataset object and corresponding source file, leaving the original
dataset unchanged. Default is False.
:type inplace: bool, optional
:return: The resampled dataset.
:rtype: RadioDataset or None
**Examples:**
>>> from ria.dataset_manager.builders import AWGN_Builder()
>>> builder = AWGN_Builder()
>>> builder.download_and_prepare()
>>> ds = builder.as_dataset()
>>> ds.get_class_sizes(class_key="col")
{a:100, b:200, c:300}
>>> new_ds = ds.resample(quantity_target=250, class_key="col")
>>> new_ds.get_class_sizes(class_key="col")
{a:250, b:250, c:250}
"""
if not inplace:
ds = self.homogenize(class_key=class_key, example_limit=quantity_target)
else:
self.homogenize(class_key=class_key, example_limit=quantity_target, inplace=True)
ds = self
class_sizes = ds.get_class_sizes(class_key=class_key)
indices_to_add = dict()
with h5py.File(ds.source, "a") as f:
labels = f["metadata/metadata"][class_key]
for i in range(len(labels)):
current_class = labels[i]
if class_sizes[current_class] < quantity_target and current_class not in indices_to_add:
indices_to_add[current_class] = []
if class_sizes[current_class] < quantity_target and current_class in indices_to_add:
indices_to_add[current_class].append(i)
for key in indices_to_add.keys():
rand_idxs = np.random.choice(indices_to_add[key], quantity_target - class_sizes[key], replace=True)
for idx in rand_idxs:
duplicate_entry_inplace(ds.source, "data", idx)
duplicate_entry_inplace(ds.source, "metadata/metadata", idx)
if not inplace:
return ds
def homogenize(
self, class_key: str, example_limit: Optional[int] = None, inplace: Optional[bool] = False
) -> RadioDataset | None:
"""Discards excess samples by randomly subsampling all classes within a dataset that have more than a
user-specified limit of examples. If the user doesn't specify a limit, the class the with the
fewest examples is selected as the limit.
:param class_key: The label of the class to homogenize.
:type class_key: str
:param example_limit: The class size limit to which all classes are subsampled. If not specified,
the class with the fewest examples is used as the limit. Default is None.
:type example_limit: int, optional
:param inplace: If True, the operation modifies the existing source file directly and returns None.
If False, the operation creates a new dataset cbject and corresponding source file, leaving the original
dataset unchanged. Default is False.
:type inplace: bool, optional
:return: The homogenized dataset.
:rtype: RadioDataset or None
**Examples:**
>>> from ria.dataset_manager.builders import AWGN_Builder()
>>> builder = AWGN_Builder()
>>> builder.download_and_prepare()
>>> ds = builder.as_dataset()
>>> ds.get_class_sizes(class_key="col")
{a:1000, b:5000, c:1500, d:900}
>>> new_ds = ds.homogenize(example_limit=1000, class_key="col")
>>> new_ds.get_class_sizes(class_key="col")
{a:1000, b:1000, c:1000, d:900}
>>> from ria.dataset_manager.builders import AWGN_Builder()
>>> builder = AWGN_Builder()
>>> builder.download_and_prepare()
>>> ds = builder.as_dataset()
>>> ds.get_class_sizes(class_key="col")
{a:1000, b:5000, c:1500, d:900}
>>> new_ds = ds.homogenize(class_key="col")
>>> new_ds.get_class_sizes(class_key="col")
{a:900, b:900, c:900, d:900}
"""
class_sizes = self.get_class_sizes(class_key=class_key)
if example_limit is None:
example_limit = min(class_sizes.values())
channels, example_length = np.shape(self[0])
if not inplace:
ds = self._create_next_dataset(example_length=example_length)
masks, counters = get_masks_and_counters(class_sizes, example_limit)
idx = 0
with h5py.File(self.source, "r") as f:
while idx < len(self):
labels = f["metadata/metadata"][class_key]
current_class = labels[idx]
current_mask = masks[current_class]
if current_mask is None and not inplace:
copy_over_example(self.source, ds.source, idx)
if current_mask is not None:
current_mask_value = current_mask[counters[current_class]]
counters[current_class] += 1
if not inplace and current_mask_value == 1:
copy_over_example(self.source, ds.source, idx)
elif inplace and current_mask_value == 0:
delete_example_inplace(self.source, idx)
continue
idx += 1
if not inplace:
return ds
def drop_class(self, class_key: str, class_value: str, inplace: Optional[bool] = False) -> RadioDataset | None:
"""Removes an entire class from the dataset.
:param class_key: Class that will have a value dropped from it. Example: 'signal_type'
:type class_key: str
:param class_value: Value of the class to be dropped. Example: 'LTE', 'NR'
:type class_value: str
:param inplace: If True, the operation modifies the existing source file directly and returns None.
If False, the operation creates a new dataset cbject and corresponding source file, leaving the original
dataset unchanged. Defaults to False.
:type inplace: bool, optional
:raises ValueError: If the entered class name does not exist in the dataset.
:return: The dataset without the removed class.
:rtype: RadioDataset or None
**Examples:**
>>> from ria.dataset_manager.builders import AWGN_Builder()
>>> builder = AWGN_Builder()
>>> builder.download_and_prepare()
>>> ds = builder.as_dataset()
>>> ds.get_class_sizes()
{a:100, b:500, c:300}
>>> new_ds = ds.drop_class('a')
>>> new_ds.get_class_sizes()
{b:500, c:300}
"""
class_sizes = self.get_class_sizes(class_key=class_key)
if class_value.encode("utf-8") not in class_sizes.keys():
raise ValueError(f"{class_value} is not a class of this dataset.")
channels, example_length = np.shape(self[0])
if not inplace:
ds = self._create_next_dataset(example_length=example_length)
idx = 0
with h5py.File(self.source, "a") as f:
while idx < len(self):
labels = f["metadata/metadata"][class_key]
current_label = labels[idx].decode("utf-8")
if current_label == class_value and inplace:
delete_example_inplace(self.source, idx)
continue
elif current_label != class_value and not inplace:
copy_over_example(self.source, ds.source, idx)
idx += 1
if not inplace:
return ds
def add_label(self, column_name: str, data: ArrayLike, inplace: Optional[bool] = False) -> RadioDataset | None:
"""Add a new metadata label to the dataset.
.. todo:: This method is not yet implemented.
:param column_name: Name of the new metadata column header.
:type inplace: str
:param data: The contents of the new metadata column.
:type inplace: np.typing.ArrayLike
:param inplace: If True, the label is added inplace and ``None`` is returned. Defaults to False.
:type inplace: bool, optional
:raises ValueError: If the length of ``data`` is not equal to the length of the dataset.
:return: The augmented dataset or None if ``inplace=True``.
:rtype: RadioDataset or None
**Examples:**
.. todo:: Usage examples coming soon.
"""
raise NotImplementedError
def get_class_sizes(self, class_key: str) -> dict[str, int]:
"""Returns a dictionary containing the sizes of each class in the dataset at the provided key.
:param class_key: The class label.
:type class_key: str
:raises ValueError: If the specified key is not found in the dataset labels.
:return: A dictionary where each key is a distinct class label, and it's value is the class size.
:rtype: A dictionary where the keys are strings and the values are integers
**Examples:**
>>> from ria.dataset_manager.builders import AWGN_Builder()
>>> spectrogram_sensing_builder = AWGN_Builder()
>>> spectrogram_sensing_builder.download_and_prepare()
>>> ds = spectrogram_sensing_builder.as_dataset(backend="pytorch")
>>> ds.get_class_sizes(class_key='signal_type')
{'LTE': 900, 'NR': 900, 'LTE_NR': 900}
"""
with h5py.File(self.source, "r") as f:
labels = f["metadata/metadata"][class_key]
return dict(Counter(labels))
def delete_example(self, idx: int, inplace: Optional[bool] = False) -> RadioDataset | None:
"""Deletes an example and it's corresponding metadata from the dataset.
:param idx: The index of the example to be deleted.
:type idx: int
:param inplace: If True, the deletion is performed inplace and ``None`` is returned. Defaults to False.
:type inplace: bool, optional
:return: The new dataset or None if ``inplace=True``.
:rtype: RadioDataset or None
**Examples:**
>>> from ria.dataset_manager.builders import AWGN_Builder()
>>> spectrogram_sensing_builder = AWGN_Builder()
>>> spectrogram_sensing_builder.download_and_prepare()
>>> ds = spectrogram_sensing_builder.as_dataset(backend="pytorch")
>>> len(ds)
2700
>>> ds = ds.delete_example(idx=34)
>>> len(ds)
2699
"""
if inplace:
delete_example_inplace(source=self.source, idx=idx)
return None
else:
# The deletion is performed by 1. creating a new source file, 2. copying all contents to the new source
# file, and 3. deleting the example at idx inplace.
new_source = self._get_next_file_name()
copy_file(original_source=self.source, new_source=new_source)
delete_example_inplace(source=new_source, idx=idx)
return self.__class__(source=new_source)
def append(self, example: ArrayLike, metadata: dict) -> None:
"""Append a single example to the end of the dataset. This operation is performed inplace.
.. todo:: This method is not yet implemented.
:param example: The example to append.
:type example: np.typing.ArrayLike
:param metadata: The corresponding metadata dictionary.
:type metadata: dict
:raises ValueError: If example does not the same shape and type as rest of the examples in the dataset.
:return: None.
**Examples:**
.. todo:: Usage examples coming soon.
"""
raise NotImplementedError
def join(self, ds: RadioDataset) -> RadioDataset:
"""Join or merge together two radio datasets.
.. todo:: This method is not yet implemented.
- Duplicate entries are not removed; they are included.
- The examples are not shuffled; examples from ``ds`` are appended at the end.
- Metadata will be expanded to contain all columns.
:param ds: The dataset to merge together with self. Examples from both datasets must have the same shape.
:type ds: bool
:return: The combined dataset.
:rtype: RadioDataset
**Examples:**
.. todo:: Usage examples coming soon.
"""
raise NotImplementedError
def filter(self, mask: ArrayLike, inplace: Optional[bool] = False) -> RadioDataset:
"""Filter the dataset using the provided mask.
.. todo:: This method is not yet implemented.
:param mask: A boolean mask. Where True, keep the corresponding examples. Where False, discard keep the
corresponding examples. The filtering mask is often the result of applying a condition across the elements
of the dataset.
:type mask: array_like
:param inplace: If True, the filter operation is performed inplace and ``None`` is returned. Defaults to False.
:type inplace: bool, optional
:return: The filtered dataset or None if ``inplace=True``.
:rtype: RadioDataset or None
Examples:
.. todo:: Usage examples coming soon!
"""
raise NotImplementedError
def _get_next_file_name(self) -> str:
"""As we manipulate a dataset, we create new source files. That is, unless inplace==True. Each new source
file needs a new name, and so we count up. This function computes and returns the next file name.
If the file has not been manipulated before, it will add `.001` to the end of the file name before the
extension. If there is already a number at the end of the file name, it will update the current number to the
next consecutive number.
For example:
>>> from ria.dataset_manager.builders import AWGN_Builder()
>>> builder = AWGN_Builder()
>>> builder.download_and_prepare()
>>> ds = builder.as_dataset() # my_dataset.hdf5
>>> ds = ds.subsample() # my_dataset.001.hdf5
>>> ds = ds.augment() # my_dataset.002.hdf5
:raises ValueError: If the number at the end of the file name exceeds 999.
:return: The name of the next file, including the file extension
:rtype: str
"""
name, ext = os.path.splitext(str(self.source))
end_of_name = name[-3:]
if re.match(r"^\d+$", end_of_name):
operation_number = int(end_of_name)
operation_number += 1
if operation_number < 10:
operation_number_as_string = f"00{operation_number}" # 1 digits
elif operation_number < 100:
operation_number_as_string = f"0{operation_number}" # 2 digits
elif operation_number < 1000:
operation_number_as_string = f"{operation_number}" # 3 digits
else:
# We assume the maximum number of dataset manipulations will not exceed 999.
raise ValueError("The maximum allowed number of dataset manipulations is 999.")
return f"{name[:-3]}{operation_number_as_string}{ext}"
else:
return f"{name}.001{ext}"
def _create_next_dataset(self, example_length: int) -> RadioDataset:
"""Creates a new empty dataset with a new source file, but with the same file structure as self.source.
:param example_length: The length of the examples in the new dataset.
:type example_length: int
:return: A new dataset with empty data and labels.
:rtype: RadioDataset
"""
new_source = self._get_next_file_name()
make_empty_clone(self.source, new_source, example_length=example_length)
return self.__class__(source=new_source)
def __iter__(self) -> Iterator:
self._index = 0
return self
def __next__(self) -> np.ndarray:
if self._index < len(self):
with h5py.File(self.source, "r") as f:
dataset = f["data"]
result = dataset[self._index]
self._index += 1
return result
else:
raise StopIteration
def __eq__(self, other: RadioDataset) -> bool:
"""Two RadioDatasets are equal iff they share the same source file."""
return self._source == other._source
def __len__(self) -> int:
"""
:return: The number of examples in a dataset.
:rtype: int
"""
return self.shape[0]
def __getitem__(self, key: int | slice | ArrayLike) -> np.ndarray | RadioDataset:
"""If key is an integer read in and return the example at key.
If key is a slice, a new dataset instance is returned, initialized with the data and metadata corresponding
to that slice. However, if key is `[:]`, the data is read and returned as a NumPy array.
If key is array_like, it is interpreted as a boolean mask and used to filter the dataset. In this case, we
return a new instance of the dataset, initialized from a new source file with the filtered data/metadata.
"""
if isinstance(key, int):
with h5py.File(self.source, "r") as file:
return file["data"][key]
elif isinstance(key, slice):
if key == slice(None):
return self.data
else:
# Create and return a new dataset, initialized from a new source file, with the data/metadata at slice.
raise NotImplementedError("Dataset slicing not yet implemented.")
else:
try:
key = np.asarray(key)
if key.dtype == bool:
return self.filter(mask=key)
else:
raise ValueError("Array-like mask must be of boolean type.")
except (TypeError, ValueError):
raise ValueError(f"Indexing with key of type {key} is not supported.")
def __setitem__(self, *args, **kwargs) -> None:
"""Raise an error if an attempt is made to assign to the dataset."""
raise ValueError("Assignment to dataset is not allowed.")
def decode_bytes(cell: any) -> any:
"""If cell is of type bytes, returns the decoded UTF-8 string. Otherwise, returns the input value unchanged."""
if isinstance(cell, bytes):
return cell.decode("utf-8")
return cell
def get_result_sizes( # noqa: C901 # TODO: Simplify function
level: float | list[float],
target_size: int | list[int] | None,
classes_to_augment: str | list[str] | None,
class_sizes: dict,
) -> dict:
"""Returns the desired sizes of each class in the metadata. This is a helper function specifically
used by the augment method.
:param level: The level or extent of data augmentation to apply, ranging from 0.0 (no augmentation) to
1.0 (full augmentation, where each augmentation is applied to each pre-existing example).
:type level: float or list of floats
:param target_size: Target size of the augmented dataset. If specified, ``level`` is ignored, and augmentations
are applied to expand the dataset to contain the specified number of examples.
:type target_size: int or list of ints or None
:param classes_to_augment: List of the classes to augment.
:type classes_to_augment: string or list of strings or None
:param class_sizes: A dictionary where each key-value pair is the class label and the class size.
:type class_sizes: dict
:raises ValueError: If level is a list when classes_to_augment is None.
:raises ValueError: If classes_to_augment and level are lists, but they have different sizes.
:raises ValueError: If target_size is a list when classes_to_augment is None.
:raises ValueError: If classes_to_augment and target_size are lists, but they have different sizes.
:raises ValueError: If classes_to_augment and target_size are lists, but the target_size of a class is already met.
:return: A dictionary where each key is a distinct class label, and it's value is the desired class size.
:rtype: A dictionary where the keys are strings and the values are integers
"""
result_sizes = dict(class_sizes)
if target_size is None:
# Calculate off of level
if classes_to_augment is None:
# Apply to entire dataset, if classes_to_augment is None
if isinstance(level, list):
raise ValueError("Since classes_to_augment is None, level must be a single float value.")
for key in result_sizes:
result_sizes[key] = round(result_sizes[key] + class_sizes[key] * level)
else:
if not isinstance(classes_to_augment, list):
classes_to_augment = [classes_to_augment]
if isinstance(level, list):
if len(level) != len(classes_to_augment):
raise ValueError("If level is a list, there must be one value for each class you wish to augment.")
for index, class_name in enumerate(classes_to_augment):
result_sizes[class_name] = round(result_sizes[class_name] + class_sizes[class_name] * level[index])
else:
for class_name in classes_to_augment:
result_sizes[class_name] = round(result_sizes[class_name] + class_sizes[class_name] * level)
else:
# Calculate off of target_size
if classes_to_augment is None:
# apply to entire dataset, if classes_to_augment is None
if isinstance(target_size, list):
raise ValueError("Since classes_to_augment is None, target_size must be a single int value.")
result_sizes = calculate_size_with_original_distribution(
class_sizes=class_sizes, target_size=target_size, classes_to_augment=classes_to_augment
)
else:
# user specified classes to augment
# if user provides only 1 class convert it to a list
if not isinstance(classes_to_augment, list):
classes_to_augment = [classes_to_augment]
if isinstance(target_size, list):
if len(target_size) != len(classes_to_augment):
raise ValueError(
"If target_size is a list, there must be one value for each class you wish to augment."
)
# Check that each class that will be augmented does not already suffice target_size
for cls_name, target_size_value in zip(classes_to_augment, target_size):
if class_sizes[cls_name] >= target_size_value:
raise ValueError(f"""target_size of {target_size_value} is already sufficed for current size of
{class_sizes[cls_name]} for class: {cls_name}""")
for index, class_name in enumerate(classes_to_augment):
result_sizes[class_name] = target_size[index]
else:
result_sizes = calculate_size_with_original_distribution(
class_sizes=class_sizes, target_size=target_size, classes_to_augment=classes_to_augment
)
return result_sizes
def calculate_size_with_original_distribution( # noqa: C901 # TODO: Simplify function
class_sizes: dict, target_size: int, classes_to_augment: list[str] | None
) -> dict:
"""Returns the desired sizes of each class when target_size is used to calculate the resultant class sizes.
Specifically used as a helper by the get result sizes method.
:param class_sizes: A dictionary where each key-value pair is the class label and the class size.
:type class_sizes: dict
:param target_size: Target size of the augmented dataset.
:type target_size: int
:param classes_to_augment: List of the classes to augment.
:type classes_to_augment: list of strings or None
:return: A dictionary where each key is a distinct class label, and it's value is the desired class size.
:rtype: dict
"""
total_size = sum(class_sizes.values())
if classes_to_augment is None:
scaled_sizes = {cls: (size / total_size) * target_size for cls, size in class_sizes.items()}
rounded_sizes = {cls: round(size) for cls, size in scaled_sizes.items()}
difference = target_size - sum(rounded_sizes.values())
else:
partial_class_size_total = sum(size for cls, size in class_sizes.items() if cls in classes_to_augment)
partial_target_size = target_size - sum(
size for cls, size in class_sizes.items() if cls not in classes_to_augment
)
scaled_sizes = {
cls: (size / partial_class_size_total) * partial_target_size
for cls, size in class_sizes.items()
if cls in classes_to_augment
}
rounded_sizes = {cls: round(size) for cls, size in scaled_sizes.items()}
difference = partial_target_size - sum(rounded_sizes.values())
if difference != 0:
decimals = {cls: scaled_sizes[cls] % 1 for cls in scaled_sizes}
if difference > 0:
sorted_classes = sorted(rounded_sizes, key=decimals.get, reverse=True)
for cls in sorted_classes:
rounded_sizes[cls] += 1
difference -= 1
if difference == 0:
break
else:
sorted_classes = sorted(rounded_sizes, key=decimals.get)
for cls in sorted_classes:
rounded_sizes[cls] -= 1
difference += 1
if difference == 0:
break
# Put back classes that were not chosen to be augmented back into resultant size dictionary
if classes_to_augment is not None:
for cls, count in class_sizes.items():
if cls not in rounded_sizes:
rounded_sizes[cls] = count
return rounded_sizes
def get_masks_and_counters(class_sizes: dict, example_limit: int) -> tuple:
"""
Returns the masks and counters based on the class sizes and example limit of a dataset.
Specifically used for the homogenize method.
:param class_sizes: Dictionary containing the sizes of each class in the dataset.
:type class_sizes: dict
:param example_limit: The class size limit to which all classes are subsampled. If not specified,
the class with the fewest examples is used as the limit.
:type example_limit: int
:return: The mask and counter dictionaries.
:rtype: tuple
"""
masks, counters = dict(), dict()
for key in class_sizes:
if class_sizes[key] <= example_limit:
masks[key] = None
else:
masks[key] = np.array([1] * example_limit + [0] * (class_sizes[key] - example_limit))
np.random.shuffle(masks[key])
counters[key] = 0
return masks, counters