Fixed merging errors and import errors
This commit is contained in:
parent
806fcf8293
commit
5f0ab7ac71
|
|
@ -6,16 +6,16 @@ from pathlib import Path
|
|||
|
||||
import click
|
||||
import numpy as np
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
|
||||
from ria_toolkit_oss.datatypes import Recording
|
||||
from ria_toolkit_oss.io import from_npy_legacy, load_recording
|
||||
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
echo_progress,
|
||||
echo_verbose,
|
||||
format_sample_count,
|
||||
save_recording,
|
||||
)
|
||||
|
||||
from ria_toolkit_oss.datatypes import Recording
|
||||
from ria_toolkit_oss.io import from_npy_legacy, load_recording
|
||||
|
||||
|
||||
def load_recording_list(inputs, legacy, verbose, quiet):
|
||||
recordings = []
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ This module contains all the CLI bindings for the ria package.
|
|||
from .capture import capture
|
||||
from .combine import combine
|
||||
from .convert import convert
|
||||
from .generate import generate
|
||||
|
||||
# Import all command functions
|
||||
from .discover import discover
|
||||
from .generate import generate
|
||||
|
||||
# from .generate import generate
|
||||
from .init import init
|
||||
|
|
|
|||
|
|
@ -4,13 +4,6 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
check_for_overwriting,
|
||||
detect_file_format,
|
||||
echo_progress,
|
||||
echo_verbose,
|
||||
format_sample_count,
|
||||
)
|
||||
|
||||
from ria_toolkit_oss.io.recording import (
|
||||
from_npy,
|
||||
|
|
@ -20,6 +13,13 @@ from ria_toolkit_oss.io.recording import (
|
|||
to_sigmf,
|
||||
to_wav,
|
||||
)
|
||||
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
check_for_overwriting,
|
||||
detect_file_format,
|
||||
echo_progress,
|
||||
echo_verbose,
|
||||
format_sample_count,
|
||||
)
|
||||
|
||||
from .config import load_user_config
|
||||
|
||||
|
|
|
|||
|
|
@ -5,42 +5,11 @@ from typing import Optional
|
|||
|
||||
import click
|
||||
import numpy as np
|
||||
import ria_toolkit_oss.signal.basic_signal_generator as basic_gen
|
||||
import yaml
|
||||
|
||||
import ria_toolkit_oss.signal.basic_signal_generator as basic_gen
|
||||
from ria_toolkit_oss.datatypes import Recording
|
||||
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import FSKModulator
|
||||
from ria_toolkit_oss.signal.block_generator.basic import FrequencyShift
|
||||
from ria_toolkit_oss.signal.block_generator.data_types import DataType
|
||||
from ria_toolkit_oss.signal.block_generator.mapping.apsk_mapper import _APSKMapper
|
||||
from ria_toolkit_oss.signal.block_generator.mapping.cross_qam_mapper import _CrossQAMMapper
|
||||
from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper
|
||||
from ria_toolkit_oss.signal.block_generator.symbol_modulation import (
|
||||
GMSKModulator,
|
||||
OOKModulator,
|
||||
OQPSKModulator,
|
||||
)
|
||||
from ria_toolkit_oss.signal.block_generator.pulse_shaping import (
|
||||
RaisedCosineFilter,
|
||||
RootRaisedCosineFilter,
|
||||
Upsampling,
|
||||
)
|
||||
from ria_toolkit_oss.signal.block_generator.source import (
|
||||
LFMChirpSource,
|
||||
BinarySource,
|
||||
RecordingSource,
|
||||
SawtoothSource,
|
||||
SquareSource,
|
||||
)
|
||||
|
||||
# Block Generator Imports
|
||||
from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock
|
||||
|
||||
from ria_toolkit_oss.transforms.iq_impairments import (
|
||||
iq_imbalance,
|
||||
)
|
||||
|
||||
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
echo_progress,
|
||||
echo_verbose,
|
||||
format_frequency,
|
||||
|
|
@ -48,7 +17,40 @@ from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
|||
parse_metadata_args,
|
||||
save_recording,
|
||||
)
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.config import load_user_config
|
||||
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.config import load_user_config
|
||||
from ria_toolkit_oss.signal.block_generator.basic import FrequencyShift
|
||||
from ria_toolkit_oss.signal.block_generator.continuous_modulation.fsk_modulator import (
|
||||
FSKModulator,
|
||||
)
|
||||
from ria_toolkit_oss.signal.block_generator.data_types import DataType
|
||||
from ria_toolkit_oss.signal.block_generator.mapping.apsk_mapper import _APSKMapper
|
||||
from ria_toolkit_oss.signal.block_generator.mapping.cross_qam_mapper import (
|
||||
_CrossQAMMapper,
|
||||
)
|
||||
from ria_toolkit_oss.signal.block_generator.mapping.mapper import Mapper
|
||||
from ria_toolkit_oss.signal.block_generator.pulse_shaping import (
|
||||
RaisedCosineFilter,
|
||||
RootRaisedCosineFilter,
|
||||
Upsampling,
|
||||
)
|
||||
from ria_toolkit_oss.signal.block_generator.source import (
|
||||
BinarySource,
|
||||
LFMChirpSource,
|
||||
RecordingSource,
|
||||
SawtoothSource,
|
||||
SquareSource,
|
||||
)
|
||||
|
||||
# Block Generator Imports
|
||||
from ria_toolkit_oss.signal.block_generator.source_block import SourceBlock
|
||||
from ria_toolkit_oss.signal.block_generator.symbol_modulation import (
|
||||
GMSKModulator,
|
||||
OOKModulator,
|
||||
OQPSKModulator,
|
||||
)
|
||||
from ria_toolkit_oss.transforms.iq_impairments import (
|
||||
iq_imbalance,
|
||||
)
|
||||
|
||||
|
||||
# Extend Mapper to support new types
|
||||
|
|
@ -149,6 +151,10 @@ class FileSourceBlock(SourceBlock):
|
|||
self.bits = bits.astype(np.float32) # SourceBlock expects float32 bits (0.0, 1.0)
|
||||
self.idx = 0
|
||||
|
||||
@property
|
||||
def input_type(self) -> DataType:
|
||||
return [DataType.NONE]
|
||||
|
||||
@property
|
||||
def output_type(self) -> DataType:
|
||||
return DataType.BITS
|
||||
|
|
@ -662,7 +668,7 @@ def sawtooth(
|
|||
def load_source(message_source, message_content, num_bits=None):
|
||||
if num_bits is not None:
|
||||
if message_source == "random":
|
||||
return RandomBinarySource()((1, num_bits))
|
||||
return BinarySource()((1, num_bits))
|
||||
elif message_source == "string":
|
||||
if not message_content:
|
||||
raise click.BadParameter("Message content required for string source")
|
||||
|
|
@ -679,7 +685,7 @@ def load_source(message_source, message_content, num_bits=None):
|
|||
return FileSourceBlock(p.read_bytes(), repeat=True)(num_bits).reshape(1, -1)
|
||||
else:
|
||||
if message_source == "random":
|
||||
return RandomBinarySource() # Infinite source
|
||||
return BinarySource() # Infinite source
|
||||
|
||||
elif message_source == "string":
|
||||
if not message_content:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ from pathlib import Path
|
|||
|
||||
import click
|
||||
import numpy as np
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
|
||||
from ria_toolkit_oss.io import from_npy_legacy, load_recording
|
||||
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
detect_file_format,
|
||||
echo_progress,
|
||||
echo_verbose,
|
||||
|
|
@ -12,8 +14,6 @@ from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
|||
save_recording,
|
||||
)
|
||||
|
||||
from ria_toolkit_oss.io import from_npy_legacy, load_recording
|
||||
|
||||
|
||||
def get_output_extension(format_name):
|
||||
"""Get file extension for format name."""
|
||||
|
|
|
|||
|
|
@ -7,15 +7,15 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.io.recording import load_recording
|
||||
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import (
|
||||
echo_progress,
|
||||
echo_verbose,
|
||||
format_sample_count,
|
||||
save_recording,
|
||||
)
|
||||
|
||||
from ria_toolkit_oss.datatypes.recording import Recording
|
||||
from ria_toolkit_oss.io.recording import load_recording
|
||||
from ria_toolkit_oss.transforms import iq_augmentations, iq_impairments
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@ upsampling factor.
|
|||
|
||||
Example Usage:
|
||||
|
||||
>>> from ria_toolkit_oss.signal.block_generator import RandomBinarySource, Mapper, Upsampling, RaisedCosineFilter
|
||||
>>> from ria_toolkit_oss.signal.block_generator import BinarySource, Mapper, Upsampling, RaisedCosineFilter
|
||||
|
||||
>>> # create digital modulaiton symbols
|
||||
>>> source = RandomBinarySource()
|
||||
>>> source = BinarySource()
|
||||
>>> mapper = Mapper(constellation_type='psk', num_bits_per_symbol=2)
|
||||
>>> mapper.connect_input([source])
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class BinarySource(SourceBlock):
|
|||
def __call__(
|
||||
self,
|
||||
num_samples: int = 1,
|
||||
num_bits: int = 1024,
|
||||
num_bits: Optional[int] = None,
|
||||
file_path: Optional[Union[str, Path]] = None,
|
||||
*,
|
||||
cycle: bool = True,
|
||||
|
|
@ -56,7 +56,10 @@ class BinarySource(SourceBlock):
|
|||
"""
|
||||
if file_path is None:
|
||||
# Random mode: 0 with prob p, 1 with prob (1-p)
|
||||
if num_bits:
|
||||
return (self.rng.random((num_samples, num_bits)) > self.p).astype(np.float32)
|
||||
else:
|
||||
return (self.rng.random((num_samples)) > self.p).astype(np.float32)
|
||||
|
||||
# File mode: read raw bytes and unpack to bits
|
||||
path = Path(file_path)
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ from pathlib import Path
|
|||
import numpy as np
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from ria_toolkit_oss_cli.cli import cli
|
||||
|
||||
from ria_toolkit_oss.datatypes import Annotation, Recording
|
||||
from ria_toolkit_oss.io import load_recording, to_npy, to_sigmf
|
||||
from ria_toolkit_oss_cli.cli import cli
|
||||
|
||||
|
||||
class TestCombineHelp:
|
||||
|
|
|
|||
|
|
@ -140,7 +140,10 @@ class TestSaveVisualization:
|
|||
mock_recording = MagicMock()
|
||||
|
||||
with (
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=ImportError("Module not found")),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig",
|
||||
side_effect=ImportError("Module not found"),
|
||||
),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo,
|
||||
):
|
||||
|
||||
|
|
@ -155,7 +158,10 @@ class TestSaveVisualization:
|
|||
mock_recording = MagicMock()
|
||||
|
||||
with (
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig", side_effect=Exception("Failed to plot")),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.view_simple_sig",
|
||||
side_effect=Exception("Failed to plot"),
|
||||
),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.capture.click.echo") as mock_echo,
|
||||
):
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import yaml
|
||||
from click.testing import CliRunner
|
||||
|
||||
from ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.common import get_sdr_device
|
||||
|
|
@ -64,7 +63,9 @@ class TestAutoSelectTxDevice:
|
|||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
|
||||
),
|
||||
):
|
||||
|
||||
with pytest.raises(ClickException) as exc_info:
|
||||
|
|
@ -82,7 +83,9 @@ class TestAutoSelectTxDevice:
|
|||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
|
||||
return_value=[{"type": "HackRF One", "serial": "123456"}],
|
||||
),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
|
||||
),
|
||||
):
|
||||
|
||||
device_type = auto_select_tx_device(quiet=True)
|
||||
|
|
@ -103,7 +106,9 @@ class TestAutoSelectTxDevice:
|
|||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices",
|
||||
return_value=[{"type": "HackRF One", "serial": "123456"}],
|
||||
),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[]
|
||||
),
|
||||
):
|
||||
|
||||
with pytest.raises(ClickException) as exc_info:
|
||||
|
|
@ -124,10 +129,19 @@ class TestAutoSelectTxDevice:
|
|||
for device_name, expected_type in test_cases:
|
||||
with (
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_sdr_drivers"),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices", return_value=[{"type": device_name}]),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_uhd_devices", return_value=[]
|
||||
),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_pluto_devices", return_value=[]
|
||||
),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_hackrf_devices", return_value=[]
|
||||
),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.find_bladerf_devices",
|
||||
return_value=[{"type": device_name}],
|
||||
),
|
||||
):
|
||||
|
||||
device_type = auto_select_tx_device(quiet=True)
|
||||
|
|
@ -154,7 +168,10 @@ class TestLoadInputFile:
|
|||
try:
|
||||
mock_recording = MagicMock()
|
||||
|
||||
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", return_value=mock_recording):
|
||||
with patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording",
|
||||
return_value=mock_recording,
|
||||
):
|
||||
recording = load_input_file(test_file, legacy=False)
|
||||
assert recording == mock_recording
|
||||
|
||||
|
|
@ -169,7 +186,10 @@ class TestLoadInputFile:
|
|||
try:
|
||||
mock_recording = MagicMock()
|
||||
|
||||
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy", return_value=mock_recording):
|
||||
with patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.from_npy_legacy",
|
||||
return_value=mock_recording,
|
||||
):
|
||||
recording = load_input_file(test_file, legacy=True)
|
||||
assert recording == mock_recording
|
||||
|
||||
|
|
@ -184,7 +204,10 @@ class TestLoadInputFile:
|
|||
test_file = f.name
|
||||
|
||||
try:
|
||||
with patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording", side_effect=Exception("Unsupported format")):
|
||||
with patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_recording",
|
||||
side_effect=Exception("Unsupported format"),
|
||||
):
|
||||
with pytest.raises(ClickException) as exc_info:
|
||||
load_input_file(test_file)
|
||||
|
||||
|
|
@ -319,8 +342,13 @@ class TestTransmitCommand:
|
|||
mock_recording.metadata = {}
|
||||
|
||||
with (
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr),
|
||||
patch("ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file", return_value=mock_recording),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.get_sdr_device", return_value=mock_sdr
|
||||
),
|
||||
patch(
|
||||
"ria_toolkit_oss.ria_toolkit_oss_cli.ria_toolkit_oss.transmit.load_input_file",
|
||||
return_value=mock_recording,
|
||||
),
|
||||
):
|
||||
|
||||
result = self.runner.invoke(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli
|
||||
from src.ria_toolkit_oss.ria_toolkit_oss_cli.cli import cli
|
||||
|
||||
|
||||
class TestTransmitGenerate:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user