From 5f0ab7ac71cdb542e665e86c3da8841b77e3d25a Mon Sep 17 00:00:00 2001 From: madrigal Date: Thu, 11 Dec 2025 16:53:26 -0500 Subject: [PATCH] Fixed merging errors and import errors --- .../ria_toolkit_oss/combine.py | 8 +- .../ria_toolkit_oss/commands.py | 2 +- .../ria_toolkit_oss/convert.py | 14 ++-- .../ria_toolkit_oss/generate.py | 80 ++++++++++--------- .../ria_toolkit_oss/split.py | 6 +- .../ria_toolkit_oss/transform.py | 8 +- .../block_generator/pulse_shaping/__init__.py | 4 +- .../block_generator/source/binary_source.py | 7 +- tests/ria_toolkit_oss_cli/test.combine.py | 2 +- tests/ria_toolkit_oss_cli/test_capture.py | 10 ++- tests/ria_toolkit_oss_cli/test_transmit.py | 54 ++++++++++--- .../test_transmit_generate.py | 2 +- 12 files changed, 120 insertions(+), 77 deletions(-) diff --git a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/combine.py b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/combine.py index 3b8e915..8fb917b 100644 --- a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/combine.py +++ b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/combine.py @@ -6,16 +6,16 @@ from pathlib import Path import click import numpy as np -from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( + +from ria_toolkit_oss.datatypes import Recording +from ria_toolkit_oss.io import from_npy_legacy, load_recording +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import ( echo_progress, echo_verbose, format_sample_count, save_recording, ) -from ria_toolkit_oss.datatypes import Recording -from ria_toolkit_oss.io import from_npy_legacy, load_recording - def load_recording_list(inputs, legacy, verbose, quiet): recordings = [] diff --git a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/commands.py b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/commands.py index 3f76388..60ddba9 100644 --- a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/commands.py +++ b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/commands.py @@ -6,10 +6,10 @@ This module contains all the CLI bindings for the ria package. from .capture import capture from .combine import combine from .convert import convert -from .generate import generate # Import all command functions from .discover import discover +from .generate import generate # from .generate import generate from .init import init diff --git a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/convert.py b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/convert.py index 3350117..f245a69 100644 --- a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/convert.py +++ b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/convert.py @@ -4,13 +4,6 @@ import os from pathlib import Path import click -from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( - check_for_overwriting, - detect_file_format, - echo_progress, - echo_verbose, - format_sample_count, -) from ria_toolkit_oss.io.recording import ( from_npy, @@ -20,6 +13,13 @@ from ria_toolkit_oss.io.recording import ( to_sigmf, to_wav, ) +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import ( + check_for_overwriting, + detect_file_format, + echo_progress, + echo_verbose, + format_sample_count, +) from .config import load_user_config diff --git a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py index 365135a..b4d183b 100644 --- a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py +++ b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/generate.py @@ -5,42 +5,11 @@ from typing import Optional import click import numpy as np -import ria_toolkit_oss.signal.basic_signal_generator as basic_gen import yaml + +import ria_toolkit_oss.signal.basic_signal_generator as basic_gen from ria_toolkit_oss.datatypes import Recording -from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import FSKModulator -from ria_toolkit_oss.signal.block_generator.basic import FrequencyShift -from ria_toolkit_oss.signal.block_generator.data_types import DataType -from ria_toolkit_oss.signal.block_generator.mapping.apsk_mapper import _APSKMapper -from ria_toolkit_oss.signal.block_generator.mapping.cross_qam_mapper import _CrossQAMMapper -from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper -from ria_toolkit_oss.signal.block_generator.symbol_modulation import ( - GMSKModulator, - OOKModulator, - OQPSKModulator, -) -from ria_toolkit_oss.signal.block_generator.pulse_shaping import ( - RaisedCosineFilter, - RootRaisedCosineFilter, - Upsampling, -) -from ria_toolkit_oss.signal.block_generator.source import ( - LFMChirpSource, - BinarySource, - RecordingSource, - SawtoothSource, - SquareSource, -) - -# Block Generator Imports -from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock - -from ria_toolkit_oss.transforms.iq_impairments import ( - iq_imbalance, -) - - -from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import ( echo_progress, echo_verbose, format_frequency, @@ -48,7 +17,40 @@ from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( parse_metadata_args, save_recording, ) -from ria_toolkit_oss_cli.ria_toolkit_oss.config import load_user_config +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.config import load_user_config +from ria_toolkit_oss.signal.block_generator.basic import FrequencyShift +from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import ( + FSKModulator, +) +from ria_toolkit_oss.signal.block_generator.data_types import DataType +from ria_toolkit_oss.signal.block_generator.mapping.apsk_mapper import _APSKMapper +from ria_toolkit_oss.signal.block_generator.mapping.cross_qam_mapper import ( + _CrossQAMMapper, +) +from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper +from ria_toolkit_oss.signal.block_generator.pulse_shaping import ( + RaisedCosineFilter, + RootRaisedCosineFilter, + Upsampling, +) +from ria_toolkit_oss.signal.block_generator.source import ( + BinarySource, + LFMChirpSource, + RecordingSource, + SawtoothSource, + SquareSource, +) + +# Block Generator Imports +from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock +from ria_toolkit_oss.signal.block_generator.symbol_modulation import ( + GMSKModulator, + OOKModulator, + OQPSKModulator, +) +from ria_toolkit_oss.transforms.iq_impairments import ( + iq_imbalance, +) # Extend Mapper to support new types @@ -149,6 +151,10 @@ class FileSourceBlock(SourceBlock): self.bits = bits.astype(np.float32) # SourceBlock expects float32 bits (0.0, 1.0) self.idx = 0 + @property + def input_type(self) -> DataType: + return [DataType.NONE] + @property def output_type(self) -> DataType: return DataType.BITS @@ -662,7 +668,7 @@ def sawtooth( def load_source(message_source, message_content, num_bits=None): if num_bits is not None: if message_source == "random": - return RandomBinarySource()((1, num_bits)) + return BinarySource()((1, num_bits)) elif message_source == "string": if not message_content: raise click.BadParameter("Message content required for string source") @@ -679,7 +685,7 @@ def load_source(message_source, message_content, num_bits=None): return FileSourceBlock(p.read_bytes(), repeat=True)(num_bits).reshape(1, -1) else: if message_source == "random": - return RandomBinarySource() # Infinite source + return BinarySource() # Infinite source elif message_source == "string": if not message_content: diff --git a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/split.py b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/split.py index 6fd4c23..30c0c95 100644 --- a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/split.py +++ b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/split.py @@ -4,7 +4,9 @@ from pathlib import Path import click import numpy as np -from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( + +from ria_toolkit_oss.io import from_npy_legacy, load_recording +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import ( detect_file_format, echo_progress, echo_verbose, @@ -12,8 +14,6 @@ from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( save_recording, ) -from ria_toolkit_oss.io import from_npy_legacy, load_recording - def get_output_extension(format_name): """Get file extension for format name.""" diff --git a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/transform.py b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/transform.py index 71ea6ab..361d67c 100644 --- a/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/transform.py +++ b/src/ria_toolkit_oss/ria_toolkit_oss_cli/ria_toolkit_oss/transform.py @@ -7,15 +7,15 @@ import os from pathlib import Path import click -from ria_toolkit_oss_cli.ria_toolkit_oss.common import ( + +from ria_toolkit_oss.datatypes.recording import Recording +from ria_toolkit_oss.io.recording import load_recording +from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import ( echo_progress, echo_verbose, format_sample_count, save_recording, ) - -from ria_toolkit_oss.datatypes.recording import Recording -from ria_toolkit_oss.io.recording import load_recording from ria_toolkit_oss.transforms import iq_augmentations, iq_impairments diff --git a/src/ria_toolkit_oss/signal/block_generator/pulse_shaping/__init__.py b/src/ria_toolkit_oss/signal/block_generator/pulse_shaping/__init__.py index fc80459..3cc9231 100644 --- a/src/ria_toolkit_oss/signal/block_generator/pulse_shaping/__init__.py +++ b/src/ria_toolkit_oss/signal/block_generator/pulse_shaping/__init__.py @@ -17,10 +17,10 @@ upsampling factor. Example Usage: - >>> from ria_toolkit_oss.signal.block_generator import RandomBinarySource, Mapper, Upsampling, RaisedCosineFilter + >>> from ria_toolkit_oss.signal.block_generator import BinarySource, Mapper, Upsampling, RaisedCosineFilter >>> # create digital modulaiton symbols - >>> source = RandomBinarySource() + >>> source = BinarySource() >>> mapper = Mapper(constellation_type='psk', num_bits_per_symbol=2) >>> mapper.connect_input([source]) diff --git a/src/ria_toolkit_oss/signal/block_generator/source/binary_source.py b/src/ria_toolkit_oss/signal/block_generator/source/binary_source.py index f743f93..47f229f 100644 --- a/src/ria_toolkit_oss/signal/block_generator/source/binary_source.py +++ b/src/ria_toolkit_oss/signal/block_generator/source/binary_source.py @@ -35,7 +35,7 @@ class BinarySource(SourceBlock): def __call__( self, num_samples: int = 1, - num_bits: int = 1024, + num_bits: Optional[int] = None, file_path: Optional[Union[str, Path]] = None, *, cycle: bool = True, @@ -56,7 +56,10 @@ class BinarySource(SourceBlock): """ if file_path is None: # Random mode: 0 with prob p, 1 with prob (1-p) - return (self.rng.random((num_samples, num_bits)) > self.p).astype(np.float32) + if num_bits: + return (self.rng.random((num_samples, num_bits)) > self.p).astype(np.float32) + else: + return (self.rng.random((num_samples)) > self.p).astype(np.float32) # File mode: read raw bytes and unpack to bits path = Path(file_path) diff --git a/tests/ria_toolkit_oss_cli/test.combine.py b/tests/ria_toolkit_oss_cli/test.combine.py index b6f7d8b..423400f 100644 --- a/tests/ria_toolkit_oss_cli/test.combine.py +++ b/tests/ria_toolkit_oss_cli/test.combine.py @@ -6,10 +6,10 @@ from pathlib import Path import numpy as np import pytest from click.testing import CliRunner +from ria_toolkit_oss_cli.cli import cli from ria_toolkit_oss.datatypes import Annotation, Recording from ria_toolkit_oss.io import load_recording, to_npy, to_sigmf -from ria_toolkit_oss_cli.cli import cli class TestCombineHelp: diff --git a/tests/ria_toolkit_oss_cli/test_capture.py b/tests/ria_toolkit_oss_cli/test_capture.py index 81749d6..7e83c1b 100644 --- a/tests/ria_toolkit_oss_cli/test_capture.py +++ b/tests/ria_toolkit_oss_cli/test_capture.py @@ -140,7 +140,10 @@ class TestSaveVisualization: mock_recording = MagicMock() with ( - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=ImportError("Module not found")), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", + side_effect=ImportError("Module not found"), + ), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo, ): @@ -155,7 +158,10 @@ class TestSaveVisualization: mock_recording = MagicMock() with ( - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=Exception("Failed to plot")), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", + side_effect=Exception("Failed to plot"), + ), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo, ): diff --git a/tests/ria_toolkit_oss_cli/test_transmit.py b/tests/ria_toolkit_oss_cli/test_transmit.py index d2eaa71..39325d9 100644 --- a/tests/ria_toolkit_oss_cli/test_transmit.py +++ b/tests/ria_toolkit_oss_cli/test_transmit.py @@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch import numpy as np import pytest -import yaml from click.testing import CliRunner from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device @@ -64,7 +63,9 @@ class TestAutoSelectTxDevice: patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]), - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[] + ), ): with pytest.raises(ClickException) as exc_info: @@ -82,7 +83,9 @@ class TestAutoSelectTxDevice: "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[{"type": "HackRF One", "serial": "123456"}], ), - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[] + ), ): device_type = auto_select_tx_device(quiet=True) @@ -103,7 +106,9 @@ class TestAutoSelectTxDevice: "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[{"type": "HackRF One", "serial": "123456"}], ), - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[] + ), ): with pytest.raises(ClickException) as exc_info: @@ -124,10 +129,19 @@ class TestAutoSelectTxDevice: for device_name, expected_type in test_cases: with ( patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"), - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]), - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]), - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]), - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[{"type": device_name}]), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[] + ), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[] + ), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[] + ), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", + return_value=[{"type": device_name}], + ), ): device_type = auto_select_tx_device(quiet=True) @@ -154,7 +168,10 @@ class TestLoadInputFile: try: mock_recording = MagicMock() - with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", return_value=mock_recording): + with patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", + return_value=mock_recording, + ): recording = load_input_file(test_file, legacy=False) assert recording == mock_recording @@ -169,7 +186,10 @@ class TestLoadInputFile: try: mock_recording = MagicMock() - with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy", return_value=mock_recording): + with patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy", + return_value=mock_recording, + ): recording = load_input_file(test_file, legacy=True) assert recording == mock_recording @@ -184,7 +204,10 @@ class TestLoadInputFile: test_file = f.name try: - with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", side_effect=Exception("Unsupported format")): + with patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", + side_effect=Exception("Unsupported format"), + ): with pytest.raises(ClickException) as exc_info: load_input_file(test_file) @@ -319,8 +342,13 @@ class TestTransmitCommand: mock_recording.metadata = {} with ( - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr), - patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file", return_value=mock_recording), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr + ), + patch( + "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file", + return_value=mock_recording, + ), ): result = self.runner.invoke( diff --git a/tests/ria_toolkit_oss_cli/test_transmit_generate.py b/tests/ria_toolkit_oss_cli/test_transmit_generate.py index 493d37e..f4a3e0e 100644 --- a/tests/ria_toolkit_oss_cli/test_transmit_generate.py +++ b/tests/ria_toolkit_oss_cli/test_transmit_generate.py @@ -2,7 +2,7 @@ from click.testing import CliRunner -from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli +from src.ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli class TestTransmitGenerate: