Fixed merging errors and import errors

This commit is contained in:
M madrigal 2025-12-11 16:53:26 -05:00
parent 806fcf8293
commit 5f0ab7ac71
12 changed files with 120 additions and 77 deletions

View File

@ -6,16 +6,16 @@ from pathlib import Path
import click import click
import numpy as np import numpy as np
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.io import from_npy_legacy, load_recording
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
echo_progress, echo_progress,
echo_verbose, echo_verbose,
format_sample_count, format_sample_count,
save_recording, save_recording,
) )
from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.io import from_npy_legacy, load_recording
def load_recording_list(inputs, legacy, verbose, quiet): def load_recording_list(inputs, legacy, verbose, quiet):
recordings = [] recordings = []

View File

@ -6,10 +6,10 @@ This module contains all the CLI bindings for the ria package.
from .capture import capture from .capture import capture
from .combine import combine from .combine import combine
from .convert import convert from .convert import convert
from .generate import generate
# Import all command functions # Import all command functions
from .discover import discover from .discover import discover
from .generate import generate
# from .generate import generate # from .generate import generate
from .init import init from .init import init

View File

@ -4,13 +4,6 @@ import os
from pathlib import Path from pathlib import Path
import click import click
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
check_for_overwriting,
detect_file_format,
echo_progress,
echo_verbose,
format_sample_count,
)
from ria_toolkit_oss.io.recording import ( from ria_toolkit_oss.io.recording import (
from_npy, from_npy,
@ -20,6 +13,13 @@ from ria_toolkit_oss.io.recording import (
to_sigmf, to_sigmf,
to_wav, to_wav,
) )
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
check_for_overwriting,
detect_file_format,
echo_progress,
echo_verbose,
format_sample_count,
)
from .config import load_user_config from .config import load_user_config

View File

@ -5,42 +5,11 @@ from typing import Optional
import click import click
import numpy as np import numpy as np
import ria_toolkit_oss.signal.basic_signal_generator as basic_gen
import yaml import yaml
import ria_toolkit_oss.signal.basic_signal_generator as basic_gen
from ria_toolkit_oss.datatypes import Recording from ria_toolkit_oss.datatypes import Recording
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import FSKModulator from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
from ria_toolkit_oss.signal.block_generator.basic import FrequencyShift
from ria_toolkit_oss.signal.block_generator.data_types import DataType
from ria_toolkit_oss.signal.block_generator.mapping.apsk_mapper import _APSKMapper
from ria_toolkit_oss.signal.block_generator.mapping.cross_qam_mapper import _CrossQAMMapper
from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper
from ria_toolkit_oss.signal.block_generator.symbol_modulation import (
GMSKModulator,
OOKModulator,
OQPSKModulator,
)
from ria_toolkit_oss.signal.block_generator.pulse_shaping import (
RaisedCosineFilter,
RootRaisedCosineFilter,
Upsampling,
)
from ria_toolkit_oss.signal.block_generator.source import (
LFMChirpSource,
BinarySource,
RecordingSource,
SawtoothSource,
SquareSource,
)
# Block Generator Imports
from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock
from ria_toolkit_oss.transforms.iq_impairments import (
iq_imbalance,
)
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
echo_progress, echo_progress,
echo_verbose, echo_verbose,
format_frequency, format_frequency,
@ -48,7 +17,40 @@ from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
parse_metadata_args, parse_metadata_args,
save_recording, save_recording,
) )
from ria_toolkit_oss_cli.ria_toolkit_oss.config import load_user_config from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.config import load_user_config
from ria_toolkit_oss.signal.block_generator.basic import FrequencyShift
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
FSKModulator,
)
from ria_toolkit_oss.signal.block_generator.data_types import DataType
from ria_toolkit_oss.signal.block_generator.mapping.apsk_mapper import _APSKMapper
from ria_toolkit_oss.signal.block_generator.mapping.cross_qam_mapper import (
_CrossQAMMapper,
)
from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper
from ria_toolkit_oss.signal.block_generator.pulse_shaping import (
RaisedCosineFilter,
RootRaisedCosineFilter,
Upsampling,
)
from ria_toolkit_oss.signal.block_generator.source import (
BinarySource,
LFMChirpSource,
RecordingSource,
SawtoothSource,
SquareSource,
)
# Block Generator Imports
from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock
from ria_toolkit_oss.signal.block_generator.symbol_modulation import (
GMSKModulator,
OOKModulator,
OQPSKModulator,
)
from ria_toolkit_oss.transforms.iq_impairments import (
iq_imbalance,
)
# Extend Mapper to support new types # Extend Mapper to support new types
@ -149,6 +151,10 @@ class FileSourceBlock(SourceBlock):
self.bits = bits.astype(np.float32) # SourceBlock expects float32 bits (0.0, 1.0) self.bits = bits.astype(np.float32) # SourceBlock expects float32 bits (0.0, 1.0)
self.idx = 0 self.idx = 0
@property
def input_type(self) -> DataType:
return [DataType.NONE]
@property @property
def output_type(self) -> DataType: def output_type(self) -> DataType:
return DataType.BITS return DataType.BITS
@ -662,7 +668,7 @@ def sawtooth(
def load_source(message_source, message_content, num_bits=None): def load_source(message_source, message_content, num_bits=None):
if num_bits is not None: if num_bits is not None:
if message_source == "random": if message_source == "random":
return RandomBinarySource()((1, num_bits)) return BinarySource()((1, num_bits))
elif message_source == "string": elif message_source == "string":
if not message_content: if not message_content:
raise click.BadParameter("Message content required for string source") raise click.BadParameter("Message content required for string source")
@ -679,7 +685,7 @@ def load_source(message_source, message_content, num_bits=None):
return FileSourceBlock(p.read_bytes(), repeat=True)(num_bits).reshape(1, -1) return FileSourceBlock(p.read_bytes(), repeat=True)(num_bits).reshape(1, -1)
else: else:
if message_source == "random": if message_source == "random":
return RandomBinarySource() # Infinite source return BinarySource() # Infinite source
elif message_source == "string": elif message_source == "string":
if not message_content: if not message_content:

View File

@ -4,7 +4,9 @@ from pathlib import Path
import click import click
import numpy as np import numpy as np
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
from ria_toolkit_oss.io import from_npy_legacy, load_recording
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
detect_file_format, detect_file_format,
echo_progress, echo_progress,
echo_verbose, echo_verbose,
@ -12,8 +14,6 @@ from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
save_recording, save_recording,
) )
from ria_toolkit_oss.io import from_npy_legacy, load_recording
def get_output_extension(format_name): def get_output_extension(format_name):
"""Get file extension for format name.""" """Get file extension for format name."""

View File

@ -7,15 +7,15 @@ import os
from pathlib import Path from pathlib import Path
import click import click
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.io.recording import load_recording
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
echo_progress, echo_progress,
echo_verbose, echo_verbose,
format_sample_count, format_sample_count,
save_recording, save_recording,
) )
from ria_toolkit_oss.datatypes.recording import Recording
from ria_toolkit_oss.io.recording import load_recording
from ria_toolkit_oss.transforms import iq_augmentations, iq_impairments from ria_toolkit_oss.transforms import iq_augmentations, iq_impairments

View File

@ -17,10 +17,10 @@ upsampling factor.
Example Usage: Example Usage:
>>> from ria_toolkit_oss.signal.block_generator import RandomBinarySource, Mapper, Upsampling, RaisedCosineFilter >>> from ria_toolkit_oss.signal.block_generator import BinarySource, Mapper, Upsampling, RaisedCosineFilter
>>> # create digital modulaiton symbols >>> # create digital modulaiton symbols
>>> source = RandomBinarySource() >>> source = BinarySource()
>>> mapper = Mapper(constellation_type='psk', num_bits_per_symbol=2) >>> mapper = Mapper(constellation_type='psk', num_bits_per_symbol=2)
>>> mapper.connect_input([source]) >>> mapper.connect_input([source])

View File

@ -35,7 +35,7 @@ class BinarySource(SourceBlock):
def __call__( def __call__(
self, self,
num_samples: int = 1, num_samples: int = 1,
num_bits: int = 1024, num_bits: Optional[int] = None,
file_path: Optional[Union[str, Path]] = None, file_path: Optional[Union[str, Path]] = None,
*, *,
cycle: bool = True, cycle: bool = True,
@ -56,7 +56,10 @@ class BinarySource(SourceBlock):
""" """
if file_path is None: if file_path is None:
# Random mode: 0 with prob p, 1 with prob (1-p) # Random mode: 0 with prob p, 1 with prob (1-p)
return (self.rng.random((num_samples, num_bits)) > self.p).astype(np.float32) if num_bits:
return (self.rng.random((num_samples, num_bits)) > self.p).astype(np.float32)
else:
return (self.rng.random((num_samples)) > self.p).astype(np.float32)
# File mode: read raw bytes and unpack to bits # File mode: read raw bytes and unpack to bits
path = Path(file_path) path = Path(file_path)

View File

@ -6,10 +6,10 @@ from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
from click.testing import CliRunner from click.testing import CliRunner
from ria_toolkit_oss_cli.cli import cli
from ria_toolkit_oss.datatypes import Annotation, Recording from ria_toolkit_oss.datatypes import Annotation, Recording
from ria_toolkit_oss.io import load_recording, to_npy, to_sigmf from ria_toolkit_oss.io import load_recording, to_npy, to_sigmf
from ria_toolkit_oss_cli.cli import cli
class TestCombineHelp: class TestCombineHelp:

View File

@ -140,7 +140,10 @@ class TestSaveVisualization:
mock_recording = MagicMock() mock_recording = MagicMock()
with ( with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=ImportError("Module not found")), patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig",
side_effect=ImportError("Module not found"),
),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo, patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo,
): ):
@ -155,7 +158,10 @@ class TestSaveVisualization:
mock_recording = MagicMock() mock_recording = MagicMock()
with ( with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=Exception("Failed to plot")), patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig",
side_effect=Exception("Failed to plot"),
),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo, patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo,
): ):

View File

@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
import yaml
from click.testing import CliRunner from click.testing import CliRunner
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
@ -64,7 +63,9 @@ class TestAutoSelectTxDevice:
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]), patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
),
): ):
with pytest.raises(ClickException) as exc_info: with pytest.raises(ClickException) as exc_info:
@ -82,7 +83,9 @@ class TestAutoSelectTxDevice:
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
return_value=[{"type": "HackRF One", "serial": "123456"}], return_value=[{"type": "HackRF One", "serial": "123456"}],
), ),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]), patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
),
): ):
device_type = auto_select_tx_device(quiet=True) device_type = auto_select_tx_device(quiet=True)
@ -103,7 +106,9 @@ class TestAutoSelectTxDevice:
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
return_value=[{"type": "HackRF One", "serial": "123456"}], return_value=[{"type": "HackRF One", "serial": "123456"}],
), ),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]), patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
),
): ):
with pytest.raises(ClickException) as exc_info: with pytest.raises(ClickException) as exc_info:
@ -124,10 +129,19 @@ class TestAutoSelectTxDevice:
for device_name, expected_type in test_cases: for device_name, expected_type in test_cases:
with ( with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"), patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]), patch(
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]), "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]), ),
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[{"type": device_name}]), patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]
),
patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]
),
patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices",
return_value=[{"type": device_name}],
),
): ):
device_type = auto_select_tx_device(quiet=True) device_type = auto_select_tx_device(quiet=True)
@ -154,7 +168,10 @@ class TestLoadInputFile:
try: try:
mock_recording = MagicMock() mock_recording = MagicMock()
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", return_value=mock_recording): with patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording",
return_value=mock_recording,
):
recording = load_input_file(test_file, legacy=False) recording = load_input_file(test_file, legacy=False)
assert recording == mock_recording assert recording == mock_recording
@ -169,7 +186,10 @@ class TestLoadInputFile:
try: try:
mock_recording = MagicMock() mock_recording = MagicMock()
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy", return_value=mock_recording): with patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy",
return_value=mock_recording,
):
recording = load_input_file(test_file, legacy=True) recording = load_input_file(test_file, legacy=True)
assert recording == mock_recording assert recording == mock_recording
@ -184,7 +204,10 @@ class TestLoadInputFile:
test_file = f.name test_file = f.name
try: try:
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", side_effect=Exception("Unsupported format")): with patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording",
side_effect=Exception("Unsupported format"),
):
with pytest.raises(ClickException) as exc_info: with pytest.raises(ClickException) as exc_info:
load_input_file(test_file) load_input_file(test_file)
@ -319,8 +342,13 @@ class TestTransmitCommand:
mock_recording.metadata = {} mock_recording.metadata = {}
with ( with (
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr), patch(
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file", return_value=mock_recording), "ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr
),
patch(
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file",
return_value=mock_recording,
),
): ):
result = self.runner.invoke( result = self.runner.invoke(

View File

@ -2,7 +2,7 @@
from click.testing import CliRunner from click.testing import CliRunner
from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli from src.ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli
class TestTransmitGenerate: class TestTransmitGenerate: