Adding unit tests

This commit is contained in:
Michael Luciuk 2025-09-04 14:40:24 -04:00
parent b4b8d27bfd
commit 42211a7453
10 changed files with 528 additions and 9 deletions

View File

@ -7,9 +7,9 @@ from typing import Any, Optional
from packaging.version import Version from packaging.version import Version
from ria_toolkit_oss.utils.abstract_attribute import abstract_attribute
from ria_toolkit_oss.datatypes.datasets.license.dataset_license import DatasetLicense from ria_toolkit_oss.datatypes.datasets.license.dataset_license import DatasetLicense
from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset from ria_toolkit_oss.datatypes.datasets.radio_dataset import RadioDataset
from ria_toolkit_oss.utils.abstract_attribute import abstract_attribute
class DatasetBuilder(ABC): class DatasetBuilder(ABC):

View File

@ -7,7 +7,10 @@ import numpy as np
from numpy.random import Generator from numpy.random import Generator
from ria_toolkit_oss.datatypes.datasets import RadioDataset from ria_toolkit_oss.datatypes.datasets import RadioDataset
from ria_toolkit_oss.datatypes.datasets.h5helpers import copy_over_example, make_empty_clone from ria_toolkit_oss.datatypes.datasets.h5helpers import (
copy_over_example,
make_empty_clone,
)
def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDataset]: def split(dataset: RadioDataset, lengths: list[int | float]) -> list[RadioDataset]:
@ -123,7 +126,8 @@ def random_split(
training and test datasets. training and test datasets.
This restriction makes it unlikely that a random split will produce datasets with the exact lengths specified. This restriction makes it unlikely that a random split will produce datasets with the exact lengths specified.
If it is important to ensure the closest possible split, consider using ria_toolkit_oss.datatypes.datasets.split instead. If it is important to ensure the closest possible split, consider using ria_toolkit_oss.datatypes.datasets.split
instead.
:param dataset: Dataset to be split. :param dataset: Dataset to be split.
:type dataset: RadioDataset :type dataset: RadioDataset

View File

@ -233,7 +233,7 @@ class Recording:
:return: Data-type of the data array's elements. :return: Data-type of the data array's elements.
:type: numpy dtype object :type: numpy dtype object
""" """
return self.datatypes.dtype return self.data.dtype
@property @property
def timestamp(self) -> float | int: def timestamp(self) -> float | int:
@ -282,7 +282,7 @@ class Recording:
# cross-platform support where the types are aliased across platforms. # cross-platform support where the types are aliased across platforms.
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") # Casting may generate user warnings. E.g., complex -> real warnings.simplefilter("ignore") # Casting may generate user warnings. E.g., complex -> real
data = self.datatypes.astype(dtype) data = self.data.astype(dtype)
if np.iscomplexobj(data): if np.iscomplexobj(data):
return Recording(data=data, metadata=self.metadata, annotations=self.annotations) return Recording(data=data, metadata=self.metadata, annotations=self.annotations)

View File

@ -188,7 +188,7 @@ def to_sigmf(recording: Recording, filename: Optional[str] = None, path: Optiona
meta_dict = sigMF_metafile.ordered_metadata() meta_dict = sigMF_metafile.ordered_metadata()
meta_dict["ria"] = metadata meta_dict["ria"] = metadata
sigMF_metafile.tofile(f"{os.path.join(path,filename)}.sigmf-meta") sigMF_metafile.tofile(f"{os.path.join(path, filename)}.sigmf-meta")
def from_sigmf(file: os.PathLike | str) -> Recording: def from_sigmf(file: os.PathLike | str) -> Recording:
@ -205,6 +205,7 @@ def from_sigmf(file: os.PathLike | str) -> Recording:
:rtype: ria_toolkit_oss.datatypes.Recording :rtype: ria_toolkit_oss.datatypes.Recording
""" """
file = str(file)
if len(file) > 11: if len(file) > 11:
if file[-11:-5] != ".sigmf": if file[-11:-5] != ".sigmf":
file = file + ".sigmf-data" file = file + ".sigmf-data"

View File

@ -679,7 +679,7 @@ def patch_shuffle(signal: ArrayLike | Recording, max_patch_size: Optional[int] =
array([[2+5j, 1+8j, 3+4j, 6+9j, 4+7j]]) array([[2+5j, 1+8j, 3+4j, 6+9j, 4+7j]])
""" """
if isinstance(signal, Recording): if isinstance(signal, Recording):
data = signal.datatypes.copy() # Cannot shuffle read-only array. data = signal.data.copy() # Cannot shuffle read-only array.
else: else:
data = np.asarray(signal) data = np.asarray(signal)

View File

@ -0,0 +1,69 @@
from ria_toolkit_oss.datatypes import Annotation
def test_annotation_creation():
# Test creating an Annotation instance
sample_start = 100
sample_count = 200
freq_upper_edge = 1000.0
freq_lower_edge = 500.0
label = "Event"
comment = "This is a test annotation"
annotation = Annotation(
sample_start=sample_start,
sample_count=sample_count,
freq_lower_edge=freq_lower_edge,
freq_upper_edge=freq_upper_edge,
label=label,
comment=comment,
)
assert annotation.sample_start == sample_start
assert annotation.sample_count == sample_count
assert annotation.freq_lower_edge == freq_lower_edge
assert annotation.freq_upper_edge == freq_upper_edge
assert annotation.label == label
assert annotation.comment == comment
def test_annotation_overlap():
annotation_1 = Annotation(sample_start=0, sample_count=2, freq_lower_edge=0, freq_upper_edge=2)
annotation_2 = Annotation(sample_start=1, sample_count=2, freq_lower_edge=1, freq_upper_edge=3)
assert annotation_1.overlap(annotation_2) == 1
def test_annotation_equality():
# Test equality of two Annotation instances with the same attributes.
annotation1 = Annotation(100, 200, 1000.0, 500.0, "Event", "Comment 1")
annotation2 = Annotation(100, 200, 1000.0, 500.0, "Event", "Comment 1")
assert annotation1 == annotation2
def test_annotation_inequality():
# Test inequality of two Annotation instances with the different attributes.
annotation1 = Annotation(100, 300, 1000.0, 500.0, "Event", "Comment 1")
annotation2 = Annotation(100, 200, 1000.0, 500.0, "Event", "Comment 1")
assert annotation1 != annotation2
def test_annotation_validity():
# Test annotations' validity by checking illegal inputs (sample count and frequency edges)
annotation1 = Annotation(100, 0, 1000.0, 3000.0, "Event", "Comment 1")
annotation2 = Annotation(100, 300, 1000.0, 500.0, "Event", "Comment 2")
annotation3 = Annotation(100, 300, 1000.0, 3000.0, "Event", "Comment 3")
assert annotation1.is_valid() is False
assert annotation2.is_valid() is False
assert annotation3.is_valid() is True
def test_annotation_area():
# Test annotation area
sample_annotation = Annotation(100, 300, 1000.0, 3000.0, "Event", "Comment")
annotation_area = sample_annotation.area()
assert annotation_area == 600000

View File

@ -0,0 +1,220 @@
from typing import Iterable
import numpy as np
import pytest
from ria_toolkit_oss.datatypes import Annotation, Recording
from ria_toolkit_oss.datatypes.recording import generate_recording_id
COMPLEX_DATA_1 = [[0.5 + 0.5j, 0.1 + 0.1j, 0.3 + 0.3j, 0.4 + 0.4j, 0.5 + 0.5j]]
COMPLEX_DATA_2 = [
[0.5 + 0.5j, 0.1 + 0.1j, 0.3 + 0.3j, 0.4 + 0.4j, 0.5 + 0.5j],
[0.5 + 0.5j, 0.1 + 0.1j, 0.3 + 0.3j, 0.4 + 0.4j, 0.5 + 0.5j],
[0.5 + 0.5j, 0.1 + 0.1j, 0.3 + 0.3j, 0.45 + 0.45j, 0.5 + 0.5j],
]
REAL_DATA_2 = [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]
SAMPLE_METADATA = {"source": "test", "timestamp": 1723472227.698788}
COMPLEX_DATA_OUT_OF_RANGE = [[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j]]
def test_without_metadata():
# Verify we can create a new Recording object without specifying metadata
rec = Recording(data=COMPLEX_DATA_1)
assert np.array_equal(rec.data, np.asarray(COMPLEX_DATA_1))
# The following class attributes should be initialized automatically.
assert "rec_id" in rec.metadata
assert "timestamp" in rec.metadata
assert len(rec.rec_id) == 64
def test_1d_input():
# Verify the recording works with 1D complex array or as input.
x = [0.5 + 0.5j, 0.1 + 0.1j, 0.3 + 0.3j, 0.4 + 0.4j, 0.5 + 0.5j]
rec = Recording(data=x)
assert np.array_equal(rec.data, np.asarray([x]))
def test_with_sample_metadata():
# Test creating a new Recording without specifying metadata
rec = Recording(data=COMPLEX_DATA_1, metadata=SAMPLE_METADATA)
expected_metadata = SAMPLE_METADATA.copy()
expected_metadata["rec_id"] = generate_recording_id(data=rec.data, timestamp=SAMPLE_METADATA["timestamp"])
assert np.array_equal(rec.data, np.array(COMPLEX_DATA_1))
assert rec.metadata == expected_metadata
sample_rate = 10e5
rec.add_to_metadata(key="sample_rate", value=sample_rate)
assert rec.metadata["sample_rate"] == sample_rate
with pytest.raises(ValueError):
rec.add_to_metadata(key="SampleRate", value=sample_rate) # Invalid key
with pytest.raises(ValueError):
rec.add_to_metadata(key="rec", value=Recording) # Invalid value
# with pytest.raises(ValueError):
# rec.update_metadata(key="rec_id", value=45) # protected key
rec.update_metadata(key="source", value="foo")
assert rec.metadata["source"] == "foo"
rec.metadata["source"] = "boo" # Expect statement to have no effect
assert rec.metadata["source"] == "foo"
def test_property_assignment():
# Verify protected properties cannot be set.
rec = Recording(data=COMPLEX_DATA_1)
with pytest.raises(AttributeError):
rec.data = COMPLEX_DATA_1
with pytest.raises(AttributeError):
rec.metadata = SAMPLE_METADATA
def test_sample_rate():
# Test Recording.sample_rate property.
recording = Recording(data=COMPLEX_DATA_1)
sample_rate = 100
recording.sample_rate = sample_rate
assert recording.sample_rate == sample_rate
def test_equality():
# Test recording equality
# We expect these two recordings to be equal because there were generated with the same data and timestamps.
recording1 = Recording(data=COMPLEX_DATA_1, metadata=SAMPLE_METADATA)
recording2 = Recording(data=COMPLEX_DATA_1, metadata=SAMPLE_METADATA)
assert recording1 == recording2
meta_w_rec_id = {"rec_id": "e08603ebcd4c481be8e0204992170386e72623baa1a91ca80a714de8ffda3452"}
recording1 = Recording(data=COMPLEX_DATA_1, metadata=meta_w_rec_id)
recording2 = Recording(data=COMPLEX_DATA_1, metadata=meta_w_rec_id)
assert recording1 == recording2
recording1 = Recording(data=COMPLEX_DATA_1)
recording2 = Recording(data=COMPLEX_DATA_1, timestamp=recording1.timestamp + 0.001)
assert recording1 != recording2
def test_shape_len():
# Verify that the shape parameter and calling len() on Recording objects works as expected.
rec = Recording(data=COMPLEX_DATA_1)
assert len(rec) == np.asarray(COMPLEX_DATA_1).shape[1]
assert rec.shape == np.shape(np.asarray(COMPLEX_DATA_1))
def test_iterator():
# Test the iterator returned by __iter__.
rec = Recording(data=COMPLEX_DATA_2)
assert isinstance(rec, Iterable)
iterator = rec.__iter__()
values = np.asarray([next(iterator) for _ in range(len(COMPLEX_DATA_2))])
assert np.array_equal(values, np.asarray(COMPLEX_DATA_2))
# Confirm we can iterate over the recording object, which works just the same as iterating over the data.
arr = np.asarray(COMPLEX_DATA_2)
x = np.full(shape=np.shape(arr), fill_value=np.nan, dtype=np.asarray(arr).dtype)
for c, channel in enumerate(rec):
for n, sample in enumerate(channel):
x[c, n] = sample
assert np.array_equal(arr, x)
def test_normalize():
# Check that the max of normalized data is 1
rec = Recording(data=COMPLEX_DATA_OUT_OF_RANGE)
normalized_rec = rec.normalize()
assert np.isclose(np.max(abs(normalized_rec.data)), 1)
assert rec.metadata == normalized_rec.metadata
# Check that the normalized recording is a scaled version of original data.
ratios = normalized_rec.data / rec.data
unique_ratios = np.unique(ratios)
assert len(unique_ratios) == 1
def test_asdtype():
# Verify we can cast to other complex scalar types, but not to any other type.
rec_64 = Recording(data=COMPLEX_DATA_1, dtype=np.complex64)
assert rec_64.dtype == np.complex64
rec_128 = rec_64.astype(dtype=np.complex128)
assert rec_128.dtype == np.complex128
with pytest.raises(ValueError):
rec_128.astype(np.float64)
rec_128.astype(np.bool_)
rec_64.astype(str)
def test_indexing():
# Verify recording indexing, slicing, and filtering works as expected using the [] syntax.
rec = Recording(data=COMPLEX_DATA_2)
for i in range(rec.n_chan):
assert np.array_equal(rec[i], COMPLEX_DATA_2[i])
assert rec[2, 3] == 0.45 + 0.45j
assert np.array_equal(rec[:], rec.data)
assert np.array_equal(rec[:, 0], np.asarray([0.5 + 0.5j, 0.5 + 0.5j, 0.5 + 0.5j]))
def test_trim_1():
# ensure trimming works as expected including shifting annotations
anno1 = Annotation(sample_start=15, sample_count=10, freq_lower_edge=-1, freq_upper_edge=1, label="anno")
anno2 = Annotation(sample_start=5, sample_count=10, freq_lower_edge=-1, freq_upper_edge=1, label="anno")
anno3 = Annotation(sample_start=12, sample_count=2, freq_lower_edge=-1, freq_upper_edge=1, label="anno")
annotations = [anno1, anno2, anno3]
data = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
dtype=np.complex64,
)
orig_rec = Recording(data=data, metadata=SAMPLE_METADATA, annotations=annotations)
trimmed_rec = orig_rec.trim(start_sample=10, num_samples=10)
assert len(trimmed_rec) == 10
assert np.array_equal(trimmed_rec.data[0], np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.complex64))
shifted_anno1 = trimmed_rec.annotations[0]
shifted_anno2 = trimmed_rec.annotations[1]
shifted_anno3 = trimmed_rec.annotations[2]
assert shifted_anno1.sample_start == 5
assert shifted_anno1.sample_count == 5
assert shifted_anno2.sample_start == 0
assert shifted_anno2.sample_count == 5
assert shifted_anno3.sample_start == 2
assert shifted_anno3.sample_count == 2
def test_remove_from_metadata_1():
data = COMPLEX_DATA_2
metadata = {"source": "test", "timestamp": 1723472227.698788}
recording = Recording(data=data, metadata=metadata)
recording.remove_from_metadata("source")
with pytest.raises(ValueError):
recording.remove_from_metadata("timestamp")

View File

@ -0,0 +1,173 @@
import numpy as np
from ria_toolkit_oss.datatypes import Annotation, Recording
from ria_toolkit_oss.io.recording import (
from_npy,
from_sigmf,
load_rec,
to_npy,
to_sigmf,
)
complex_data_1 = np.array([0.5 + 0.5j, 0.1 + 0.1j, 0.3 + 0.3j, 0.4 + 0.4j, 0.5 + 0.5j], dtype=np.complex64)
real_data_1 = np.array([[0.5, 0.1, 0.3, 0.4, 0.5], [0.5, 0.1, 0.3, 0.4, 0.5]])
sample_metadata = {"source": "test", "timestamp": 1723472227.698788}
nd_complex_data_1 = np.array(
[
[0.5 + 0.5j, 0.1 + 0.1j, 0.3 + 0.3j, 0.4 + 0.4j, 0.5 + 0.5j],
[0.5 + 0.5j, 0.1 + 0.1j, 0.3 + 0.3j, 0.4 + 0.4j, 0.5 + 0.5j],
]
)
nd_real_data_1 = np.array([[0.1, 0.2, 0.3], [0.1, 0.2, 0.3], [0.1, 0.2, 0.3]])
complex_data_out_of_range_1 = np.array([1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j])
def test_npy_save_1(tmp_path):
# Create test recording
recording1 = Recording(data=complex_data_1, metadata=sample_metadata)
# Save to tmp_path
filename = tmp_path / "test"
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
# Reload
recording2 = from_npy(filename)
# Verify
assert np.array_equal(recording1.data, recording2.data)
assert recording1.metadata == recording2.metadata
def test_npy_save_2(tmp_path):
# Create test recording
recording1 = Recording(data=nd_complex_data_1, metadata=sample_metadata)
# Save to tmp_path
filename = tmp_path / "test"
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
# Reload
recording2 = from_npy(filename)
# Verify
assert np.array_equal(recording1.data, recording2.data)
assert recording1.metadata is not None
assert recording1.metadata == recording2.metadata
# Check that metadata is loaded properly as a dict
assert recording1.metadata.get("source") == recording2.metadata.get("source")
def test_npy_save_3(tmp_path):
# Create test recording without metadata
recording1 = Recording(data=nd_complex_data_1)
# Save to tmp_path
filename = tmp_path / "test"
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
# Reload
recording2 = from_npy(filename)
# Verify
assert np.array_equal(recording1.data, recording2.data)
assert recording1.metadata == recording2.metadata
def test_npy_annotations(tmp_path):
# Create annotations
annotation1 = Annotation(sample_start=0, sample_count=100, freq_lower_edge=0, freq_upper_edge=100)
annotation2 = Annotation(sample_start=1, sample_count=101, freq_lower_edge=1, freq_upper_edge=101)
annotations = [annotation1, annotation2]
# Create test recording with annotations
recording1 = Recording(data=nd_complex_data_1, metadata=sample_metadata, annotations=annotations)
# Save to tmp_path
filename = tmp_path / "test"
to_npy(filename=filename.name, path=tmp_path, recording=recording1)
# Reload
recording2 = from_npy(filename)
# Verify annotations
assert recording1.annotations == recording2.annotations
def test_load_recording_npy(tmp_path):
# test to_npy and load_recording methods with npy
annotation1 = Annotation(sample_start=0, sample_count=1, freq_lower_edge=0, freq_upper_edge=1)
annotation2 = Annotation(sample_start=1, sample_count=2, freq_lower_edge=1, freq_upper_edge=2)
annotations = [annotation1, annotation2]
recording1 = Recording(data=complex_data_1, metadata=sample_metadata, annotations=annotations)
# Save to tmp_path
filename = tmp_path / "test.npy"
recording1.to_npy(path=tmp_path, filename=filename.name)
# Load from tmp_path
recording2 = load_rec(filename)
assert recording1.annotations == recording2.annotations
# Check that original metadata was preserved
assert all(
key in recording2.metadata and recording2.metadata[key] == value for key, value in recording1.metadata.items()
)
assert np.array_equal(recording1.data, recording2.data)
def test_sigmf_1(tmp_path):
# Create annotations
annotation1 = Annotation(sample_start=0, sample_count=1, freq_lower_edge=0, freq_upper_edge=1)
annotation2 = Annotation(sample_start=1, sample_count=2, freq_lower_edge=1, freq_upper_edge=2)
annotations = [annotation1, annotation2]
# Create test recording with annotations
recording1 = Recording(data=complex_data_1, metadata=sample_metadata, annotations=annotations)
# Save to tmp_path in SigMF format
filename = tmp_path / "test"
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
# Reload
recording2 = from_sigmf(filename)
# Verify annotations
assert recording1.annotations == recording2.annotations
# Verify metadata (original keys preserved)
assert all(
key in recording2.metadata and recording2.metadata[key] == value for key, value in recording1.metadata.items()
)
# Verify data
assert np.array_equal(recording1.data, recording2.data)
def test_sigmf_2(tmp_path):
# checks that recording can be saved to sigmf and then retrieved without data loss
annotation1 = Annotation(sample_start=0, sample_count=1, freq_lower_edge=0, freq_upper_edge=1)
annotation2 = Annotation(sample_start=1, sample_count=2, freq_lower_edge=1, freq_upper_edge=2)
annotations = [annotation1, annotation2]
recording1 = Recording(data=complex_data_1, metadata=sample_metadata, annotations=annotations)
# Save to tmp_path using the base name
filename = tmp_path / "test"
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
# Load from tmp_path; from_sigmf expects the base name
recording2 = from_sigmf(filename)
assert recording1.annotations == recording2.annotations
# checks that the original metadata was preserved (although some sigmf specific metadata may have been added)
assert all(
key in recording2.metadata and recording2.metadata[key] == value for key, value in recording1.metadata.items()
)
assert np.array_equal(recording1.data, recording2.data)

View File

@ -0,0 +1,52 @@
import pytest
from ria_toolkit_oss.utils.abstract_attribute import ABCMeta2, abstract_attribute
class InterfaceWithAbstractClassAttributes(metaclass=ABCMeta2):
_url = abstract_attribute()
_name = abstract_attribute()
def __init__(self):
pass
@property
def name(self):
return self._name
class ClassWithNeitherAbstractAttributeImplemented(InterfaceWithAbstractClassAttributes):
def __init__(self):
super().__init__()
class ClassWithOnlyOneAbstractAttributeImplemented(InterfaceWithAbstractClassAttributes):
_url = "https://www.google.com/"
def __init__(self):
super().__init__()
class ClassWithAllAbstractAttributesImplemented(InterfaceWithAbstractClassAttributes):
_url = "https://www.google.com/"
_name = "Michael Luciuk"
def __init__(self):
super().__init__()
def test_with_neither_attribute_implemented():
with pytest.raises(NotImplementedError):
ClassWithNeitherAbstractAttributeImplemented()
def test_with_one_attribute_missing():
with pytest.raises(NotImplementedError):
ClassWithOnlyOneAbstractAttributeImplemented()
def test_with_both_attributes_implemented():
my_class = ClassWithAllAbstractAttributesImplemented()
assert my_class.name == "Michael Luciuk"