Compare commits
No commits in common. "3e7ccf506d95e5f06cd13015a233a891bec908d3" and "dd305aabeb59fed1cc38798668d545202117d0f6" have entirely different histories.
3e7ccf506d
...
dd305aabeb
1220
poetry.lock
generated
1220
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
|
@ -28,8 +28,6 @@ from .hardware import available_devices
|
||||||
from .legacy_executor import main as _legacy_main
|
from .legacy_executor import main as _legacy_main
|
||||||
from .namegen import generate_agent_name
|
from .namegen import generate_agent_name
|
||||||
|
|
||||||
DEFAULT_HUB_URL = "https://riahub.ai"
|
|
||||||
|
|
||||||
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -56,9 +54,18 @@ _REGISTER_TIMEOUT_S = 15
|
||||||
|
|
||||||
|
|
||||||
REGISTRATION_REASON_MESSAGES = {
|
REGISTRATION_REASON_MESSAGES = {
|
||||||
"invalid_key": ("Registration key not recognized. Generate a fresh key from " "Settings → RIA Agents on the hub."),
|
"invalid_key": (
|
||||||
"expired": ("This registration key has expired. Generate a new one from " "Settings → RIA Agents on the hub."),
|
"Registration key not recognized. Generate a fresh key from "
|
||||||
"revoked": ("This registration key was revoked. Generate a new one from " "Settings → RIA Agents on the hub."),
|
"Settings → RIA Agents on the hub."
|
||||||
|
),
|
||||||
|
"expired": (
|
||||||
|
"This registration key has expired. Generate a new one from "
|
||||||
|
"Settings → RIA Agents on the hub."
|
||||||
|
),
|
||||||
|
"revoked": (
|
||||||
|
"This registration key was revoked. Generate a new one from "
|
||||||
|
"Settings → RIA Agents on the hub."
|
||||||
|
),
|
||||||
"already_consumed": (
|
"already_consumed": (
|
||||||
"This single-use registration key has already been used. "
|
"This single-use registration key has already been used. "
|
||||||
"Generate a new one, or mint a reusable key instead."
|
"Generate a new one, or mint a reusable key instead."
|
||||||
|
|
@ -219,7 +226,7 @@ def main() -> None:
|
||||||
sub.add_parser("detect", help="List available SDR drivers")
|
sub.add_parser("detect", help="List available SDR drivers")
|
||||||
|
|
||||||
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
||||||
p_reg.add_argument("--hub", default=DEFAULT_HUB_URL, help=f"RIA Hub URL (default: {DEFAULT_HUB_URL})")
|
p_reg.add_argument("--hub", required=True, help="RIA Hub URL (e.g. http://whitehorse:3005)")
|
||||||
p_reg.add_argument(
|
p_reg.add_argument(
|
||||||
"--api-key",
|
"--api-key",
|
||||||
dest="api_key",
|
dest="api_key",
|
||||||
|
|
|
||||||
|
|
@ -45,14 +45,7 @@ def is_annotation_contained(inner: Annotation, outer: Annotation) -> bool:
|
||||||
outer_sample_stop = outer.sample_start + outer.sample_count
|
outer_sample_stop = outer.sample_start + outer.sample_count
|
||||||
|
|
||||||
if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop:
|
if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop:
|
||||||
if (
|
if inner.freq_lower_edge > outer.freq_lower_edge and inner.freq_upper_edge < outer.freq_upper_edge:
|
||||||
inner.freq_lower_edge is not None
|
|
||||||
and inner.freq_upper_edge is not None
|
|
||||||
and outer.freq_lower_edge is not None
|
|
||||||
and outer.freq_upper_edge is not None
|
|
||||||
and inner.freq_lower_edge > outer.freq_lower_edge
|
|
||||||
and inner.freq_upper_edge < outer.freq_upper_edge
|
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,10 @@ class Annotation:
|
||||||
:type sample_start: int
|
:type sample_start: int
|
||||||
:param sample_count: The index of the ending sample of the annotation, inclusive.
|
:param sample_count: The index of the ending sample of the annotation, inclusive.
|
||||||
:type sample_count: int
|
:type sample_count: int
|
||||||
:param freq_lower_edge: The lower frequency of the annotation. Optional; None if not specified in source.
|
:param freq_lower_edge: The lower frequency of the annotation.
|
||||||
:type freq_lower_edge: float, optional
|
:type freq_lower_edge: float
|
||||||
:param freq_upper_edge: The upper frequency of the annotation. Optional; None if not specified in source.
|
:param freq_upper_edge: The upper frequency of the annotation.
|
||||||
:type freq_upper_edge: float, optional
|
:type freq_upper_edge: float
|
||||||
:param label: The label that will be displayed with the bounding box in compatible viewers including IQEngine.
|
:param label: The label that will be displayed with the bounding box in compatible viewers including IQEngine.
|
||||||
Defaults to an emtpy string.
|
Defaults to an emtpy string.
|
||||||
:type label: str, optional
|
:type label: str, optional
|
||||||
|
|
@ -34,8 +34,8 @@ class Annotation:
|
||||||
self,
|
self,
|
||||||
sample_start: int,
|
sample_start: int,
|
||||||
sample_count: int,
|
sample_count: int,
|
||||||
freq_lower_edge: Optional[float] = None,
|
freq_lower_edge: float,
|
||||||
freq_upper_edge: Optional[float] = None,
|
freq_upper_edge: float,
|
||||||
label: Optional[str] = "",
|
label: Optional[str] = "",
|
||||||
comment: Optional[str] = "",
|
comment: Optional[str] = "",
|
||||||
detail: Optional[dict] = None,
|
detail: Optional[dict] = None,
|
||||||
|
|
@ -43,8 +43,8 @@ class Annotation:
|
||||||
"""Initialize a new Annotation instance."""
|
"""Initialize a new Annotation instance."""
|
||||||
self.sample_start = int(sample_start)
|
self.sample_start = int(sample_start)
|
||||||
self.sample_count = int(sample_count)
|
self.sample_count = int(sample_count)
|
||||||
self.freq_lower_edge = float(freq_lower_edge) if freq_lower_edge is not None else None
|
self.freq_lower_edge = float(freq_lower_edge)
|
||||||
self.freq_upper_edge = float(freq_upper_edge) if freq_upper_edge is not None else None
|
self.freq_upper_edge = float(freq_upper_edge)
|
||||||
self.label = str(label)
|
self.label = str(label)
|
||||||
self.comment = str(comment)
|
self.comment = str(comment)
|
||||||
|
|
||||||
|
|
@ -62,8 +62,6 @@ class Annotation:
|
||||||
:returns: True if valid, False if not.
|
:returns: True if valid, False if not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.freq_lower_edge is None or self.freq_upper_edge is None:
|
|
||||||
return self.sample_count > 0
|
|
||||||
return self.sample_count > 0 and self.freq_lower_edge < self.freq_upper_edge
|
return self.sample_count > 0 and self.freq_lower_edge < self.freq_upper_edge
|
||||||
|
|
||||||
def overlap(self, other):
|
def overlap(self, other):
|
||||||
|
|
@ -75,14 +73,6 @@ class Annotation:
|
||||||
|
|
||||||
:returns: The area of the overlap in samples*frequency, or 0 if they do not overlap."""
|
:returns: The area of the overlap in samples*frequency, or 0 if they do not overlap."""
|
||||||
|
|
||||||
if (
|
|
||||||
self.freq_lower_edge is None
|
|
||||||
or self.freq_upper_edge is None
|
|
||||||
or other.freq_lower_edge is None
|
|
||||||
or other.freq_upper_edge is None
|
|
||||||
):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
sample_overlap_start = max(self.sample_start, other.sample_start)
|
sample_overlap_start = max(self.sample_start, other.sample_start)
|
||||||
sample_overlap_end = min(self.sample_start + self.sample_count, other.sample_start + other.sample_count)
|
sample_overlap_end = min(self.sample_start + self.sample_count, other.sample_start + other.sample_count)
|
||||||
|
|
||||||
|
|
@ -101,8 +91,6 @@ class Annotation:
|
||||||
|
|
||||||
:returns: sample length multiplied by bandwidth."""
|
:returns: sample length multiplied by bandwidth."""
|
||||||
|
|
||||||
if self.freq_lower_edge is None or self.freq_upper_edge is None:
|
|
||||||
return 0
|
|
||||||
return self.sample_count * (self.freq_upper_edge - self.freq_lower_edge)
|
return self.sample_count * (self.freq_upper_edge - self.freq_lower_edge)
|
||||||
|
|
||||||
def __eq__(self, other: Annotation) -> bool:
|
def __eq__(self, other: Annotation) -> bool:
|
||||||
|
|
@ -115,16 +103,13 @@ class Annotation:
|
||||||
|
|
||||||
annotation_dict = {SigMFFile.START_INDEX_KEY: self.sample_start, SigMFFile.LENGTH_INDEX_KEY: self.sample_count}
|
annotation_dict = {SigMFFile.START_INDEX_KEY: self.sample_start, SigMFFile.LENGTH_INDEX_KEY: self.sample_count}
|
||||||
|
|
||||||
metadata = {
|
annotation_dict["metadata"] = {
|
||||||
SigMFFile.LABEL_KEY: self.label,
|
SigMFFile.LABEL_KEY: self.label,
|
||||||
SigMFFile.COMMENT_KEY: self.comment,
|
SigMFFile.COMMENT_KEY: self.comment,
|
||||||
|
SigMFFile.FHI_KEY: self.freq_upper_edge,
|
||||||
|
SigMFFile.FLO_KEY: self.freq_lower_edge,
|
||||||
"ria:detail": self.detail,
|
"ria:detail": self.detail,
|
||||||
}
|
}
|
||||||
if self.freq_upper_edge is not None:
|
|
||||||
metadata[SigMFFile.FHI_KEY] = self.freq_upper_edge
|
|
||||||
if self.freq_lower_edge is not None:
|
|
||||||
metadata[SigMFFile.FLO_KEY] = self.freq_lower_edge
|
|
||||||
annotation_dict["metadata"] = metadata
|
|
||||||
|
|
||||||
if _is_jsonable(annotation_dict):
|
if _is_jsonable(annotation_dict):
|
||||||
return annotation_dict
|
return annotation_dict
|
||||||
|
|
|
||||||
|
|
@ -81,8 +81,6 @@ def view_annotations(
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
for annotation in sorted(annotations, key=_threshold_sort_key, reverse=True):
|
for annotation in sorted(annotations, key=_threshold_sort_key, reverse=True):
|
||||||
if annotation.freq_lower_edge is None or annotation.freq_upper_edge is None:
|
|
||||||
continue
|
|
||||||
t_start = annotation.sample_start / sample_rate
|
t_start = annotation.sample_start / sample_rate
|
||||||
t_width = annotation.sample_count / sample_rate
|
t_width = annotation.sample_count / sample_rate
|
||||||
f_start = annotation.freq_lower_edge
|
f_start = annotation.freq_lower_edge
|
||||||
|
|
|
||||||
|
|
@ -2,57 +2,15 @@
|
||||||
This module contains the main group for the ria toolkit oss CLI.
|
This module contains the main group for the ria toolkit oss CLI.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import subprocess
|
import click
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
warnings.filterwarnings(
|
from ria_toolkit_oss_cli.ria_toolkit_oss import commands
|
||||||
"ignore",
|
|
||||||
message="Unable to import Axes3D",
|
|
||||||
category=UserWarning,
|
|
||||||
module="matplotlib",
|
|
||||||
)
|
|
||||||
|
|
||||||
import click # noqa: E402
|
|
||||||
|
|
||||||
from ria_toolkit_oss_cli.ria_toolkit_oss import commands # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
def _git_lfs_installed() -> bool:
|
@click.group()
|
||||||
"""Return True if git-lfs is available on PATH."""
|
|
||||||
try:
|
|
||||||
return (
|
|
||||||
subprocess.run(
|
|
||||||
["git", "lfs", "version"],
|
|
||||||
capture_output=True,
|
|
||||||
).returncode
|
|
||||||
== 0
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@click.group(invoke_without_command=True)
|
|
||||||
@click.option("-v", "--verbose", is_flag=True, type=bool, help="Increase verbosity, especially useful for debugging.")
|
@click.option("-v", "--verbose", is_flag=True, type=bool, help="Increase verbosity, especially useful for debugging.")
|
||||||
@click.pass_context
|
def cli(verbose):
|
||||||
def cli(ctx, verbose):
|
pass
|
||||||
lfs_missing = not _git_lfs_installed()
|
|
||||||
if lfs_missing:
|
|
||||||
click.echo(
|
|
||||||
"Warning: git-lfs is not installed. RIA Hub projects require git-lfs to\n"
|
|
||||||
"track large binary files (models, recordings, datasets).\n"
|
|
||||||
"\n"
|
|
||||||
" Linux: sudo apt-get install git-lfs\n"
|
|
||||||
" macOS: brew install git-lfs\n"
|
|
||||||
" Other platforms: https://git-lfs.com\n"
|
|
||||||
"\n"
|
|
||||||
"After installing, run: git lfs install",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
if ctx.invoked_subcommand is None:
|
|
||||||
if lfs_missing and sys.stdin.isatty():
|
|
||||||
click.pause(info="\nPress Enter to continue...", err=True)
|
|
||||||
click.echo(ctx.get_help())
|
|
||||||
|
|
||||||
|
|
||||||
# Loop through project commands, binding them all to the CLI.
|
# Loop through project commands, binding them all to the CLI.
|
||||||
|
|
|
||||||
|
|
@ -1,97 +0,0 @@
|
||||||
"""Shared authentication and security helpers for RIA Hub API calls."""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import subprocess
|
|
||||||
import urllib.error
|
|
||||||
import urllib.parse
|
|
||||||
import urllib.request
|
|
||||||
|
|
||||||
import click
|
|
||||||
|
|
||||||
DEFAULT_HUB = "https://riahub.ai"
|
|
||||||
|
|
||||||
|
|
||||||
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
|
||||||
"""Block redirects on authenticated requests to prevent credential exfiltration.
|
|
||||||
|
|
||||||
urllib re-sends the Authorization header on same-host redirects by default.
|
|
||||||
A malicious server could redirect a POST to a different host to harvest
|
|
||||||
credentials. We refuse all redirects — API clients should not encounter them
|
|
||||||
in normal operation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
|
||||||
raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials")
|
|
||||||
|
|
||||||
|
|
||||||
def hub_opener() -> urllib.request.OpenerDirector:
|
|
||||||
"""Return a urllib opener that blocks redirects."""
|
|
||||||
return urllib.request.build_opener(_NoRedirectHandler)
|
|
||||||
|
|
||||||
|
|
||||||
def warn_if_insecure(hub: str) -> None:
|
|
||||||
"""Warn when credentials would be sent over plain HTTP to a non-localhost host."""
|
|
||||||
parsed = urllib.parse.urlparse(hub)
|
|
||||||
if parsed.scheme == "http":
|
|
||||||
host = parsed.hostname or ""
|
|
||||||
if host not in ("localhost", "127.0.0.1", "::1"):
|
|
||||||
click.echo(
|
|
||||||
f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def basic_auth(username: str, password: str) -> str:
|
|
||||||
return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
|
|
||||||
"""Ask git credential fill for stored creds. Returns (username, password) or (None, None)."""
|
|
||||||
parsed = urllib.parse.urlparse(hub_url)
|
|
||||||
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
["git", "credential", "fill"],
|
|
||||||
input=payload,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
creds = {}
|
|
||||||
for line in result.stdout.splitlines():
|
|
||||||
# Partition on the FIRST '=' only so passwords containing '=' are preserved.
|
|
||||||
k, sep, v = line.partition("=")
|
|
||||||
if sep:
|
|
||||||
creds[k.strip()] = v # keep value verbatim
|
|
||||||
return creds.get("username"), creds.get("password")
|
|
||||||
except Exception:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def store_credentials(hub_url: str, username: str, password: str) -> None:
|
|
||||||
"""Cache credentials via git credential approve (uses the system keychain/store)."""
|
|
||||||
parsed = urllib.parse.urlparse(hub_url)
|
|
||||||
payload = (
|
|
||||||
f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
subprocess.run(
|
|
||||||
["git", "credential", "approve"],
|
|
||||||
input=payload,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # non-fatal — next push just prompts again
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_credentials(hub: str) -> tuple[str, str]:
|
|
||||||
"""Return (username, password), prompting interactively if not cached."""
|
|
||||||
username, password = get_stored_credentials(hub)
|
|
||||||
if username and password:
|
|
||||||
return username, password
|
|
||||||
click.echo(f"No stored credentials found for {hub}.")
|
|
||||||
username = click.prompt("RIA Hub username")
|
|
||||||
password = click.prompt("Password / personal access token", hide_input=True)
|
|
||||||
return username, password
|
|
||||||
|
|
@ -86,7 +86,9 @@ def save_recording_auto(recording, output_path, input_path, quiet=False, overwri
|
||||||
input_path = Path(input_path)
|
input_path = Path(input_path)
|
||||||
fmt = detect_input_format(input_path)
|
fmt = detect_input_format(input_path)
|
||||||
|
|
||||||
output_path = determine_output_path(input_path=input_path, output_path=output_path, fmt=fmt, overwrite=overwrite)
|
output_path = determine_output_path(
|
||||||
|
input_path=input_path, output_path=output_path, fmt=fmt, overwrite=overwrite
|
||||||
|
)
|
||||||
|
|
||||||
if not quiet:
|
if not quiet:
|
||||||
if fmt == "sigmf":
|
if fmt == "sigmf":
|
||||||
|
|
@ -256,11 +258,7 @@ def list(input, verbose):
|
||||||
user_comment = ann.comment or ""
|
user_comment = ann.comment or ""
|
||||||
|
|
||||||
# Basic info
|
# Basic info
|
||||||
freq_range = (
|
freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
||||||
f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
|
||||||
if ann.freq_lower_edge is not None and ann.freq_upper_edge is not None
|
|
||||||
else "N/A"
|
|
||||||
)
|
|
||||||
click.echo(
|
click.echo(
|
||||||
f" [{i}] Samples {format_sample_count(ann.sample_start)}-"
|
f" [{i}] Samples {format_sample_count(ann.sample_start)}-"
|
||||||
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {ann.label}"
|
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {ann.label}"
|
||||||
|
|
@ -504,7 +502,8 @@ def clear(input, output, overwrite, force, quiet):
|
||||||
help="Annotation type",
|
help="Annotation type",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
"--sample-rate", type=float, default=None,
|
||||||
|
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||||
)
|
)
|
||||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||||
|
|
@ -618,7 +617,8 @@ def energy(
|
||||||
help="Annotation type",
|
help="Annotation type",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
"--sample-rate", type=float, default=None,
|
||||||
|
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||||
)
|
)
|
||||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||||
|
|
@ -707,7 +707,8 @@ def cusum(input, label, min_duration, window_size, tolerance, annotation_type, s
|
||||||
)
|
)
|
||||||
@click.option("--channel", type=int, default=0, help="Channel index to annotate (default: 0)")
|
@click.option("--channel", type=int, default=0, help="Channel index to annotate (default: 0)")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
"--sample-rate", type=float, default=None,
|
||||||
|
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||||
)
|
)
|
||||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||||
|
|
@ -786,7 +787,8 @@ def threshold(input, threshold, label, window_size, annotation_type, channel, sa
|
||||||
@click.option("--noise-threshold-db", type=float, help="Noise floor threshold in dB (auto-estimated if not specified)")
|
@click.option("--noise-threshold-db", type=float, help="Noise floor threshold in dB (auto-estimated if not specified)")
|
||||||
@click.option("--min-component-bw", type=float, default=50e3, help="Min component bandwidth in Hz")
|
@click.option("--min-component-bw", type=float, default=50e3, help="Min component bandwidth in Hz")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
"--sample-rate", type=float, default=None,
|
||||||
|
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||||
)
|
)
|
||||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||||
|
|
@ -807,8 +809,7 @@ def _log_separate_start(quiet, recording, indices_list, nfft, noise_threshold_db
|
||||||
|
|
||||||
|
|
||||||
def separate(
|
def separate(
|
||||||
input, indices, nfft, noise_threshold_db, min_component_bw, sample_rate, output, overwrite, quiet, verbose
|
input, indices, nfft, noise_threshold_db, min_component_bw, sample_rate, output, overwrite, quiet, verbose):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Auto-detect parallel frequency-offset signals and split into sub-bands.
|
Auto-detect parallel frequency-offset signals and split into sub-bands.
|
||||||
|
|
||||||
|
|
@ -882,11 +883,7 @@ def separate(
|
||||||
click.echo("\n Details:")
|
click.echo("\n Details:")
|
||||||
for i in range(initial_count, final_count):
|
for i in range(initial_count, final_count):
|
||||||
ann = recording.annotations[i]
|
ann = recording.annotations[i]
|
||||||
freq_range = (
|
freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
||||||
f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
|
||||||
if ann.freq_lower_edge is not None and ann.freq_upper_edge is not None
|
|
||||||
else "N/A"
|
|
||||||
)
|
|
||||||
click.echo(
|
click.echo(
|
||||||
f" [{i}] samples {format_sample_count(ann.sample_start)}-"
|
f" [{i}] samples {format_sample_count(ann.sample_start)}-"
|
||||||
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {freq_range}"
|
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {freq_range}"
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,9 @@ from .generate import generate
|
||||||
# from .generate import generate
|
# from .generate import generate
|
||||||
from .init import init
|
from .init import init
|
||||||
from .serve import serve
|
from .serve import serve
|
||||||
from .setup_repo import setup_repo
|
|
||||||
from .split import split
|
from .split import split
|
||||||
from .transform import transform
|
from .transform import transform
|
||||||
from .transmit import transmit
|
from .transmit import transmit
|
||||||
from .upload import upload
|
|
||||||
from .view import view
|
from .view import view
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
|
|
|
||||||
|
|
@ -1,401 +0,0 @@
|
||||||
"""ria setup_repo — create and configure a RIA Hub Project repo."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import urllib.error
|
|
||||||
import urllib.parse
|
|
||||||
import urllib.request
|
|
||||||
|
|
||||||
import click
|
|
||||||
|
|
||||||
from ._hub_auth import (
|
|
||||||
DEFAULT_HUB,
|
|
||||||
_NoRedirectHandler,
|
|
||||||
basic_auth,
|
|
||||||
resolve_credentials,
|
|
||||||
store_credentials,
|
|
||||||
warn_if_insecure,
|
|
||||||
)
|
|
||||||
|
|
||||||
RIA_LFS_RULES = [
|
|
||||||
("*.pt", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.pth", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.onnx", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.sigmf", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.sigmf-data", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.sigmf-meta", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.npy", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.npz", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.h5", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.hdf5", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.bin", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
("*.pkl", "filter=lfs diff=lfs merge=lfs -text"),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Repo names must be safe directory names and valid git remote path components.
|
|
||||||
_SAFE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,100}$")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# API helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _api_request(
|
|
||||||
hub: str,
|
|
||||||
path: str,
|
|
||||||
method: str,
|
|
||||||
username: str,
|
|
||||||
password: str,
|
|
||||||
body: dict | None = None,
|
|
||||||
) -> tuple[dict, int]:
|
|
||||||
"""
|
|
||||||
Make an authenticated request to the RIA Hub API.
|
|
||||||
Returns (parsed_response_body, http_status_code).
|
|
||||||
Status 0 means a network/connection error.
|
|
||||||
Credentials are sent as HTTP Basic auth — safe over HTTPS and localhost HTTP.
|
|
||||||
Redirects are blocked to prevent credential exfiltration.
|
|
||||||
"""
|
|
||||||
url = f"{hub.rstrip('/')}/api/v1{path}"
|
|
||||||
data = json.dumps(body).encode() if body is not None else None
|
|
||||||
req = urllib.request.Request(url, data=data, method=method)
|
|
||||||
req.add_header("Content-Type", "application/json")
|
|
||||||
req.add_header("Authorization", basic_auth(username, password))
|
|
||||||
|
|
||||||
opener = urllib.request.build_opener(_NoRedirectHandler)
|
|
||||||
try:
|
|
||||||
with opener.open(req, timeout=15) as resp:
|
|
||||||
return json.loads(resp.read() or b"{}"), resp.status
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
try:
|
|
||||||
resp_body = json.loads(e.read() or b"{}")
|
|
||||||
except Exception:
|
|
||||||
resp_body = {}
|
|
||||||
return resp_body, e.code
|
|
||||||
except urllib.error.URLError as e:
|
|
||||||
return {"message": str(e.reason)}, 0
|
|
||||||
|
|
||||||
|
|
||||||
def _get_authenticated_username(hub: str, username: str, password: str) -> str | None:
|
|
||||||
"""Return the login name of the authenticated user from GET /api/v1/user.
|
|
||||||
|
|
||||||
This is the canonical username for URL construction — it may differ from
|
|
||||||
git config user.name which is a display name, not a login.
|
|
||||||
"""
|
|
||||||
body, status = _api_request(hub, "/user", "GET", username, password)
|
|
||||||
if status == 200:
|
|
||||||
return body.get("login")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _repo_exists(hub: str, owner: str, name: str, username: str, password: str) -> bool:
|
|
||||||
body, status = _api_request(
|
|
||||||
hub,
|
|
||||||
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(name, safe='')}",
|
|
||||||
"GET",
|
|
||||||
username,
|
|
||||||
password,
|
|
||||||
)
|
|
||||||
return status == 200
|
|
||||||
|
|
||||||
|
|
||||||
def _create_repo_on_hub(hub: str, name: str, username: str, password: str, private: bool) -> bool:
|
|
||||||
"""Create an RIA Hub Project repo via API.
|
|
||||||
|
|
||||||
Returns True if the repo was freshly created (server seeded README.md and
|
|
||||||
.gitattributes via auto_init + is_ria), False if the hub was unreachable
|
|
||||||
(local fallback needed). Exits on fatal errors (auth, quota, name taken).
|
|
||||||
"""
|
|
||||||
body, status = _api_request(
|
|
||||||
hub,
|
|
||||||
"/user/repos",
|
|
||||||
"POST",
|
|
||||||
username,
|
|
||||||
password,
|
|
||||||
{
|
|
||||||
"name": name,
|
|
||||||
"auto_init": True,
|
|
||||||
"is_ria": True,
|
|
||||||
"private": private,
|
|
||||||
"default_branch": "main",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if status == 201:
|
|
||||||
click.echo(f"Repository '{name}' created on RIA Hub.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if status == 0:
|
|
||||||
click.echo(
|
|
||||||
f"Warning: could not reach RIA Hub at {hub}: {body.get('message', 'connection failed')}",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
click.echo("Continuing with local setup only — create the repo manually on RIA Hub.", err=True)
|
|
||||||
return False
|
|
||||||
|
|
||||||
msg = body.get("message", "")
|
|
||||||
|
|
||||||
if status == 401:
|
|
||||||
click.echo("Error: authentication failed — check your username/password.", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if status in (403, 413) or "quota" in msg.lower() or "limit" in msg.lower():
|
|
||||||
click.echo("Error: cannot create repository — storage quota or account limit reached.", err=True)
|
|
||||||
if msg:
|
|
||||||
click.echo(f" Server message: {msg}", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if status == 422 or "already exist" in msg.lower():
|
|
||||||
click.echo(f"Repository '{name}' already exists on RIA Hub.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
click.echo(f"Error creating repository (HTTP {status}): {msg}", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Local git helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _tracked_patterns(ga_path: str) -> set:
|
|
||||||
if not os.path.exists(ga_path):
|
|
||||||
return set()
|
|
||||||
patterns = set()
|
|
||||||
with open(ga_path, encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line or line.startswith("#"):
|
|
||||||
continue
|
|
||||||
m = re.match(r"^(\S+)\s+", line)
|
|
||||||
if m:
|
|
||||||
patterns.add(m.group(1))
|
|
||||||
return patterns
|
|
||||||
|
|
||||||
|
|
||||||
def _write_local_ria_files(repo_path: str, repo_name: str) -> None:
|
|
||||||
"""Seed README.md and .gitattributes locally (used when hub is unreachable or --no-remote)."""
|
|
||||||
# README
|
|
||||||
for candidate in ("README.md", "README.rst", "README.txt", "README"):
|
|
||||||
if os.path.exists(os.path.join(repo_path, candidate)):
|
|
||||||
click.echo(f"README: {candidate} already exists, skipping")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
with open(os.path.join(repo_path, "README.md"), "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
f"# {repo_name}\n"
|
|
||||||
"\n"
|
|
||||||
"A RIA Hub project.\n"
|
|
||||||
"\n"
|
|
||||||
"## Description\n"
|
|
||||||
"\n"
|
|
||||||
"<!-- Add your project description here -->\n"
|
|
||||||
"\n"
|
|
||||||
"## Contents\n"
|
|
||||||
"\n"
|
|
||||||
"<!-- Describe the signals, models, or datasets in this repository -->\n"
|
|
||||||
)
|
|
||||||
click.echo("README.md: created")
|
|
||||||
|
|
||||||
# .gitattributes
|
|
||||||
ga_path = os.path.join(repo_path, ".gitattributes")
|
|
||||||
existing = _tracked_patterns(ga_path)
|
|
||||||
new_rules = [(p, a) for p, a in RIA_LFS_RULES if p not in existing]
|
|
||||||
|
|
||||||
if new_rules:
|
|
||||||
existing_content = ""
|
|
||||||
if os.path.exists(ga_path):
|
|
||||||
with open(ga_path, encoding="utf-8") as f:
|
|
||||||
existing_content = f.read()
|
|
||||||
|
|
||||||
separator = "" if (not existing_content or existing_content.endswith("\n")) else "\n"
|
|
||||||
addition = separator + "".join(f"{pattern} {attrs}\n" for pattern, attrs in new_rules)
|
|
||||||
|
|
||||||
with open(ga_path, "a", encoding="utf-8") as f:
|
|
||||||
f.write(addition)
|
|
||||||
click.echo(f".gitattributes: {len(new_rules)} rule(s) added")
|
|
||||||
else:
|
|
||||||
click.echo(".gitattributes: all RIA Hub rules are already present")
|
|
||||||
|
|
||||||
|
|
||||||
def _git(repo_path: str, *args: str, check: bool = True) -> subprocess.CompletedProcess:
|
|
||||||
return subprocess.run(
|
|
||||||
["git", "-C", repo_path, *args],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=check,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path_and_name(name: str | None, local_path: str | None) -> tuple[str, str]:
|
|
||||||
if local_path:
|
|
||||||
repo_path = os.path.abspath(local_path)
|
|
||||||
repo_name = name or os.path.basename(repo_path)
|
|
||||||
elif name:
|
|
||||||
repo_path = os.path.abspath(name)
|
|
||||||
repo_name = name
|
|
||||||
else:
|
|
||||||
repo_path = os.path.abspath(".")
|
|
||||||
repo_name = os.path.basename(repo_path)
|
|
||||||
return repo_path, repo_name
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_owner(hub: str, username: str | None, password: str | None, owner: str | None) -> str:
|
|
||||||
if not owner and username and password:
|
|
||||||
api_login = _get_authenticated_username(hub, username, password)
|
|
||||||
owner = api_login or username
|
|
||||||
return owner or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def _git_init(repo_path: str) -> None:
|
|
||||||
if os.path.isdir(os.path.join(repo_path, ".git")):
|
|
||||||
return
|
|
||||||
result = _git(repo_path, "init", "-b", "main", check=False)
|
|
||||||
if result.returncode != 0:
|
|
||||||
# Older git (< 2.28) doesn't support -b; fall back and rename.
|
|
||||||
_git(repo_path, "init")
|
|
||||||
_git(repo_path, "symbolic-ref", "HEAD", "refs/heads/main")
|
|
||||||
click.echo("git init: done (branch: main)")
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_remote(
|
|
||||||
repo_path: str, hub: str, resolved_owner: str, repo_name: str, username: str | None, no_remote: bool
|
|
||||||
) -> None:
|
|
||||||
if no_remote or not username:
|
|
||||||
click.echo(
|
|
||||||
f"Skipped remote setup. Add it manually:\n"
|
|
||||||
f" git -C {repo_path} remote add origin "
|
|
||||||
f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
remote_url = f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
|
||||||
existing = _git(repo_path, "remote", "get-url", "origin", check=False)
|
|
||||||
if existing.returncode == 0:
|
|
||||||
existing_url = existing.stdout.strip()
|
|
||||||
if existing_url == remote_url:
|
|
||||||
click.echo(f"remote origin: {remote_url} (already set)")
|
|
||||||
else:
|
|
||||||
click.echo(
|
|
||||||
f"remote 'origin' already points to {existing_url}.\n"
|
|
||||||
f" To update: git remote set-url origin {remote_url}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_git(repo_path, "remote", "add", "origin", remote_url)
|
|
||||||
click.echo(f"remote origin: {remote_url}")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Command
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@click.command("setup_repo")
|
|
||||||
@click.argument("name", required=False)
|
|
||||||
@click.option(
|
|
||||||
"--path", "local_path", default=None, help="Local directory (default: current dir, or created from NAME)."
|
|
||||||
)
|
|
||||||
@click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.")
|
|
||||||
@click.option(
|
|
||||||
"--owner",
|
|
||||||
default=None,
|
|
||||||
metavar="USER",
|
|
||||||
help="RIA Hub login username (default: looked up from the API using your credentials).",
|
|
||||||
)
|
|
||||||
@click.option("--private", is_flag=True, default=False, help="Create the repository as private.")
|
|
||||||
@click.option(
|
|
||||||
"--no-remote", is_flag=True, default=False, help="Skip creating the repository on RIA Hub (local setup only)."
|
|
||||||
)
|
|
||||||
def setup_repo(
|
|
||||||
name: str | None,
|
|
||||||
local_path: str | None,
|
|
||||||
hub: str,
|
|
||||||
owner: str | None,
|
|
||||||
private: bool,
|
|
||||||
no_remote: bool,
|
|
||||||
) -> None:
|
|
||||||
"""Create and configure a RIA Hub Project repo.
|
|
||||||
|
|
||||||
NAME is the repository name. If the local directory does not exist or is
|
|
||||||
not a git repo, it will be initialised automatically. Credentials are
|
|
||||||
retrieved from git's credential store — no token setup required if you
|
|
||||||
have used RIA Hub with git before.
|
|
||||||
|
|
||||||
\b
|
|
||||||
Examples:
|
|
||||||
ria setup_repo my-dataset
|
|
||||||
ria setup_repo my-dataset --hub https://riahub.example.com
|
|
||||||
ria setup_repo --path ./existing-dir
|
|
||||||
ria setup_repo my-dataset --private
|
|
||||||
"""
|
|
||||||
repo_path, repo_name = _resolve_path_and_name(name, local_path)
|
|
||||||
|
|
||||||
if not _SAFE_NAME_RE.match(repo_name):
|
|
||||||
click.echo(
|
|
||||||
f"Error: '{repo_name}' is not a valid repository name.\n"
|
|
||||||
"Use only letters, numbers, hyphens, underscores, and dots (max 100 chars).",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if not no_remote:
|
|
||||||
warn_if_insecure(hub)
|
|
||||||
|
|
||||||
username, password = (None, None) if no_remote else resolve_credentials(hub)
|
|
||||||
resolved_owner = _resolve_owner(hub, username, password, owner)
|
|
||||||
|
|
||||||
# newly_created=True means the server ran auto_init+is_ria and seeded
|
|
||||||
# README.md + .gitattributes in the initial commit; local setup pulls
|
|
||||||
# those files via fetch rather than writing them from scratch.
|
|
||||||
newly_created = False
|
|
||||||
if not no_remote and username and password:
|
|
||||||
if _repo_exists(hub, resolved_owner, repo_name, username, password):
|
|
||||||
click.echo(f"Repository '{resolved_owner}/{repo_name}' already exists on RIA Hub.")
|
|
||||||
else:
|
|
||||||
newly_created = _create_repo_on_hub(hub, repo_name, username, password, private)
|
|
||||||
store_credentials(hub, username, password)
|
|
||||||
|
|
||||||
if not os.path.exists(repo_path):
|
|
||||||
os.makedirs(repo_path)
|
|
||||||
click.echo(f"Created directory: {repo_path}")
|
|
||||||
|
|
||||||
_git_init(repo_path)
|
|
||||||
|
|
||||||
if subprocess.run(["git", "lfs", "version"], capture_output=True).returncode != 0:
|
|
||||||
click.echo(
|
|
||||||
"Error: git-lfs is not installed.\n"
|
|
||||||
" Linux: sudo apt-get install git-lfs\n"
|
|
||||||
" macOS: brew install git-lfs\n"
|
|
||||||
" Other platforms: https://git-lfs.com",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
_git(repo_path, "lfs", "install", "--local")
|
|
||||||
click.echo("git lfs install --local: done")
|
|
||||||
|
|
||||||
_configure_remote(repo_path, hub, resolved_owner, repo_name, username, no_remote)
|
|
||||||
|
|
||||||
if newly_created:
|
|
||||||
fetch = _git(repo_path, "fetch", "origin", check=False)
|
|
||||||
if fetch.returncode == 0:
|
|
||||||
_git(repo_path, "reset", "--hard", "origin/main")
|
|
||||||
click.echo("Pulled initial commit from RIA Hub (README.md + .gitattributes)")
|
|
||||||
else:
|
|
||||||
click.echo("Warning: fetch failed — falling back to local file setup.", err=True)
|
|
||||||
_write_local_ria_files(repo_path, repo_name)
|
|
||||||
else:
|
|
||||||
_write_local_ria_files(repo_path, repo_name)
|
|
||||||
|
|
||||||
if newly_created:
|
|
||||||
click.echo(f"\nRepo is ready. Push your work:\n cd {repo_path}\n git push -u origin main")
|
|
||||||
else:
|
|
||||||
click.echo(
|
|
||||||
f"\nRepo is ready. Commit and push:\n"
|
|
||||||
f" cd {repo_path}\n"
|
|
||||||
f" git add README.md .gitattributes\n"
|
|
||||||
f" git commit -m 'chore: initialise RIA Hub project'\n"
|
|
||||||
f" git push -u origin main"
|
|
||||||
)
|
|
||||||
|
|
@ -1,392 +0,0 @@
|
||||||
"""ria upload — stream large files to a RIA Hub Project via the LFS API.
|
|
||||||
|
|
||||||
How it works
|
|
||||||
------------
|
|
||||||
1. The file is hashed locally (SHA-256 + size) — this is the LFS object ID.
|
|
||||||
2. A single POST to the repo's LFS batch endpoint returns an upload URL
|
|
||||||
(and headers) for any object the server does not already have.
|
|
||||||
3. The file is streamed to that URL in fixed-size chunks — nothing is ever
|
|
||||||
fully loaded into memory, so files of any size work.
|
|
||||||
4. A commit is created via the Gitea contents API that records the LFS
|
|
||||||
pointer (a small text file) so the file appears in the repo tree.
|
|
||||||
|
|
||||||
No server-side changes are required — this uses the same authenticated LFS
|
|
||||||
protocol that `git lfs push` uses internally.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import http.client
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import urllib.error
|
|
||||||
import urllib.parse
|
|
||||||
import urllib.request
|
|
||||||
|
|
||||||
import click
|
|
||||||
|
|
||||||
from ._hub_auth import (
|
|
||||||
DEFAULT_HUB,
|
|
||||||
basic_auth,
|
|
||||||
hub_opener,
|
|
||||||
resolve_credentials,
|
|
||||||
warn_if_insecure,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Read buffer for hashing and streaming — 8 MB keeps memory use flat
|
|
||||||
# for arbitrarily large files.
|
|
||||||
_CHUNK = 8 * 1024 * 1024
|
|
||||||
|
|
||||||
LFS_MEDIA_TYPE = "application/vnd.git-lfs+json"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# File helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_file(path: str) -> tuple[str, int]:
|
|
||||||
"""Return (sha256_hex, byte_size) by streaming the file."""
|
|
||||||
h = hashlib.sha256()
|
|
||||||
size = 0
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
while True:
|
|
||||||
chunk = f.read(_CHUNK)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
h.update(chunk)
|
|
||||||
size += len(chunk)
|
|
||||||
return h.hexdigest(), size
|
|
||||||
|
|
||||||
|
|
||||||
def _lfs_pointer_text(oid: str, size: int) -> str:
|
|
||||||
return f"version https://git-lfs.github.com/spec/v1\noid sha256:{oid}\nsize {size}\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# LFS batch API
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _lfs_batch(
|
|
||||||
hub: str,
|
|
||||||
owner: str,
|
|
||||||
repo: str,
|
|
||||||
objects: list[dict],
|
|
||||||
username: str,
|
|
||||||
password: str,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
POST to /{owner}/{repo}.git/info/lfs/objects/batch.
|
|
||||||
Returns the parsed JSON response.
|
|
||||||
Raises on HTTP error or JSON decode failure.
|
|
||||||
"""
|
|
||||||
url = (
|
|
||||||
f"{hub.rstrip('/')}"
|
|
||||||
f"/{urllib.parse.quote(owner, safe='')}"
|
|
||||||
f"/{urllib.parse.quote(repo, safe='')}"
|
|
||||||
f".git/info/lfs/objects/batch"
|
|
||||||
)
|
|
||||||
body = json.dumps(
|
|
||||||
{
|
|
||||||
"operation": "upload",
|
|
||||||
"transfers": ["basic"],
|
|
||||||
"objects": objects,
|
|
||||||
}
|
|
||||||
).encode()
|
|
||||||
|
|
||||||
req = urllib.request.Request(url, data=body, method="POST")
|
|
||||||
req.add_header("Content-Type", LFS_MEDIA_TYPE)
|
|
||||||
req.add_header("Accept", LFS_MEDIA_TYPE)
|
|
||||||
req.add_header("Authorization", basic_auth(username, password))
|
|
||||||
|
|
||||||
opener = hub_opener()
|
|
||||||
try:
|
|
||||||
with opener.open(req, timeout=30) as resp:
|
|
||||||
return json.loads(resp.read())
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
body_text = e.read().decode(errors="replace")
|
|
||||||
raise RuntimeError(f"LFS batch request failed (HTTP {e.code}): {body_text}") from e
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Gitea contents API — create / update a file to record the LFS pointer
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _get_file_sha(
|
|
||||||
hub: str,
|
|
||||||
owner: str,
|
|
||||||
repo: str,
|
|
||||||
path: str,
|
|
||||||
branch: str,
|
|
||||||
username: str,
|
|
||||||
password: str,
|
|
||||||
) -> str | None:
|
|
||||||
"""Return the blob SHA of an existing file, or None if it doesn't exist."""
|
|
||||||
url = (
|
|
||||||
f"{hub.rstrip('/')}/api/v1"
|
|
||||||
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(repo, safe='')}"
|
|
||||||
f"/contents/{urllib.parse.quote(path)}"
|
|
||||||
f"?ref={urllib.parse.quote(branch)}"
|
|
||||||
)
|
|
||||||
req = urllib.request.Request(url)
|
|
||||||
req.add_header("Authorization", basic_auth(username, password))
|
|
||||||
try:
|
|
||||||
with hub_opener().open(req, timeout=15) as resp:
|
|
||||||
return json.loads(resp.read()).get("sha")
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
if e.code == 404:
|
|
||||||
return None
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def _commit_lfs_pointer(
|
|
||||||
hub: str,
|
|
||||||
owner: str,
|
|
||||||
repo: str,
|
|
||||||
remote_path: str,
|
|
||||||
pointer_text: str,
|
|
||||||
branch: str,
|
|
||||||
message: str,
|
|
||||||
username: str,
|
|
||||||
password: str,
|
|
||||||
) -> None:
|
|
||||||
"""Create or update a file in the repo containing the LFS pointer."""
|
|
||||||
url = (
|
|
||||||
f"{hub.rstrip('/')}/api/v1"
|
|
||||||
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(repo, safe='')}"
|
|
||||||
f"/contents/{urllib.parse.quote(remote_path)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_sha = _get_file_sha(hub, owner, repo, remote_path, branch, username, password)
|
|
||||||
|
|
||||||
body: dict = {
|
|
||||||
"message": message,
|
|
||||||
"content": base64.b64encode(pointer_text.encode()).decode(),
|
|
||||||
"branch": branch,
|
|
||||||
}
|
|
||||||
if existing_sha:
|
|
||||||
body["sha"] = existing_sha
|
|
||||||
|
|
||||||
method = "PUT" if existing_sha else "POST"
|
|
||||||
req = urllib.request.Request(url, data=json.dumps(body).encode(), method=method)
|
|
||||||
req.add_header("Content-Type", "application/json")
|
|
||||||
req.add_header("Authorization", basic_auth(username, password))
|
|
||||||
|
|
||||||
try:
|
|
||||||
with hub_opener().open(req, timeout=30) as resp:
|
|
||||||
resp.read()
|
|
||||||
except urllib.error.HTTPError as e:
|
|
||||||
body_text = e.read().decode(errors="replace")
|
|
||||||
raise RuntimeError(f"Failed to commit LFS pointer for '{remote_path}' (HTTP {e.code}): {body_text}") from e
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Per-file upload logic
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _upload_single_file(
|
|
||||||
hub: str,
|
|
||||||
owner: str,
|
|
||||||
repo_name: str,
|
|
||||||
username: str,
|
|
||||||
password: str,
|
|
||||||
file_path: str,
|
|
||||||
remote_dir: str,
|
|
||||||
message: str | None,
|
|
||||||
branch: str,
|
|
||||||
) -> None:
|
|
||||||
"""Hash, upload (if needed), and commit the LFS pointer for one file."""
|
|
||||||
filename = os.path.basename(file_path)
|
|
||||||
file_size = os.path.getsize(file_path)
|
|
||||||
size_mb = file_size / (1024 * 1024)
|
|
||||||
|
|
||||||
click.echo(f"\n {filename} ({size_mb:.1f} MB)")
|
|
||||||
|
|
||||||
click.echo(" Hashing...", nl=False)
|
|
||||||
oid, size = _hash_file(file_path)
|
|
||||||
click.echo(f" sha256:{oid[:12]}...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
batch = _lfs_batch(hub, owner, repo_name, [{"oid": oid, "size": size}], username, password)
|
|
||||||
except RuntimeError as e:
|
|
||||||
click.echo(f"\n Error: {e}", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
objects = batch.get("objects", [])
|
|
||||||
if not objects:
|
|
||||||
click.echo(" Already in LFS — skipping upload.")
|
|
||||||
else:
|
|
||||||
obj = objects[0]
|
|
||||||
if "error" in obj:
|
|
||||||
err_msg = obj["error"].get("message", "unknown error")
|
|
||||||
err_code = obj["error"].get("code", 0)
|
|
||||||
if err_code == 413 or "quota" in err_msg.lower() or "limit" in err_msg.lower():
|
|
||||||
click.echo(
|
|
||||||
f"\n Error: storage quota exceeded for this repo.\n Server: {err_msg}",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
click.echo(f"\n Error from server: {err_msg}", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
upload_action = obj.get("actions", {}).get("upload")
|
|
||||||
if not upload_action:
|
|
||||||
click.echo(" Already in LFS — skipping upload.")
|
|
||||||
else:
|
|
||||||
href = upload_action["href"]
|
|
||||||
up_headers = upload_action.get("header", {})
|
|
||||||
chunks = math.ceil(size / _CHUNK)
|
|
||||||
click.echo(f" Uploading ({size_mb:.1f} MB, {chunks} chunk{'s' if chunks != 1 else ''})...")
|
|
||||||
try:
|
|
||||||
_stream_upload_progress(href, up_headers, file_path, size)
|
|
||||||
except RuntimeError as e:
|
|
||||||
click.echo(f"\n Upload failed: {e}", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
click.echo(" Upload complete.")
|
|
||||||
|
|
||||||
verify_action = obj.get("actions", {}).get("verify")
|
|
||||||
if verify_action:
|
|
||||||
try:
|
|
||||||
vreq = urllib.request.Request(
|
|
||||||
verify_action["href"],
|
|
||||||
data=json.dumps({"oid": oid, "size": size}).encode(),
|
|
||||||
method="POST",
|
|
||||||
)
|
|
||||||
vreq.add_header("Content-Type", LFS_MEDIA_TYPE)
|
|
||||||
vreq.add_header("Accept", LFS_MEDIA_TYPE)
|
|
||||||
for k, v in verify_action.get("header", {}).items():
|
|
||||||
vreq.add_header(k, v)
|
|
||||||
with urllib.request.urlopen(vreq, timeout=15):
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
pass # verify is optional; non-fatal on failure
|
|
||||||
|
|
||||||
pointer = _lfs_pointer_text(oid, size)
|
|
||||||
remote_path = (f"{remote_dir.rstrip('/')}/{filename}").lstrip("/") if remote_dir else filename
|
|
||||||
commit_msg = message or f"chore: upload {filename} via ria"
|
|
||||||
|
|
||||||
click.echo(f" Committing pointer → {remote_path}...", nl=False)
|
|
||||||
try:
|
|
||||||
_commit_lfs_pointer(hub, owner, repo_name, remote_path, pointer, branch, commit_msg, username, password)
|
|
||||||
click.echo(" done.")
|
|
||||||
except RuntimeError as e:
|
|
||||||
click.echo(f"\n Error: {e}", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int) -> None:
|
|
||||||
"""Stream file_path to href with a click progress bar."""
|
|
||||||
parsed = urllib.parse.urlparse(href)
|
|
||||||
host = parsed.netloc
|
|
||||||
path_q = parsed.path + (f"?{parsed.query}" if parsed.query else "")
|
|
||||||
|
|
||||||
if parsed.scheme == "https":
|
|
||||||
conn = http.client.HTTPSConnection(host, timeout=300)
|
|
||||||
else:
|
|
||||||
conn = http.client.HTTPConnection(host, timeout=300)
|
|
||||||
|
|
||||||
all_headers = dict(headers)
|
|
||||||
all_headers.setdefault("Content-Type", "application/octet-stream")
|
|
||||||
all_headers["Content-Length"] = str(size)
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn.connect()
|
|
||||||
conn.putrequest("PUT", path_q)
|
|
||||||
for k, v in all_headers.items():
|
|
||||||
conn.putheader(k, v)
|
|
||||||
conn.endheaders()
|
|
||||||
|
|
||||||
with click.progressbar(
|
|
||||||
length=size,
|
|
||||||
label=" ",
|
|
||||||
width=40,
|
|
||||||
show_eta=True,
|
|
||||||
show_percent=True,
|
|
||||||
fill_char="█",
|
|
||||||
empty_char="░",
|
|
||||||
) as bar:
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
while True:
|
|
||||||
chunk = f.read(_CHUNK)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
conn.send(chunk)
|
|
||||||
bar.update(len(chunk))
|
|
||||||
|
|
||||||
resp = conn.getresponse()
|
|
||||||
resp.read()
|
|
||||||
if resp.status not in (200, 201):
|
|
||||||
raise RuntimeError(f"HTTP {resp.status}")
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Command
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@click.command("upload")
|
|
||||||
@click.argument("files", nargs=-1, required=True)
|
|
||||||
@click.option(
|
|
||||||
"--repo", required=True, metavar="OWNER/NAME", help="Target repository on RIA Hub (e.g. benchinnery/my-dataset)."
|
|
||||||
)
|
|
||||||
@click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.")
|
|
||||||
@click.option("--branch", default="main", show_default=True, help="Branch to commit the files to.")
|
|
||||||
@click.option(
|
|
||||||
"--path",
|
|
||||||
"remote_dir",
|
|
||||||
default="",
|
|
||||||
metavar="DIR",
|
|
||||||
help="Remote directory path inside the repo (default: repo root).",
|
|
||||||
)
|
|
||||||
@click.option("--message", "-m", default=None, help="Commit message (default: 'chore: upload <filename> via ria').")
|
|
||||||
def upload(
|
|
||||||
files: tuple[str],
|
|
||||||
repo: str,
|
|
||||||
hub: str,
|
|
||||||
branch: str,
|
|
||||||
remote_dir: str,
|
|
||||||
message: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Upload large files to a RIA Hub Project via Git LFS.
|
|
||||||
|
|
||||||
Files are streamed directly to the repo's LFS object store — nothing is
|
|
||||||
buffered into memory, so files of any size work. Each file creates one
|
|
||||||
commit recording the LFS pointer.
|
|
||||||
|
|
||||||
\b
|
|
||||||
Examples:
|
|
||||||
ria upload recording.sigmf-data --repo benchinnery/my-recordings
|
|
||||||
ria upload *.npy --repo benchinnery/my-recordings --branch main
|
|
||||||
ria upload big.pt --repo benchinnery/models --path weights/
|
|
||||||
"""
|
|
||||||
# Validate repo argument
|
|
||||||
if "/" not in repo:
|
|
||||||
click.echo("Error: --repo must be in the form OWNER/NAME.", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
owner, repo_name = repo.split("/", 1)
|
|
||||||
|
|
||||||
# Expand and validate files
|
|
||||||
resolved = []
|
|
||||||
for pattern in files:
|
|
||||||
if not os.path.isfile(pattern):
|
|
||||||
click.echo(f"Error: '{pattern}' is not a file or does not exist.", err=True)
|
|
||||||
sys.exit(1)
|
|
||||||
resolved.append(os.path.abspath(pattern))
|
|
||||||
|
|
||||||
hub = hub.rstrip("/")
|
|
||||||
warn_if_insecure(hub)
|
|
||||||
username, password = resolve_credentials(hub)
|
|
||||||
|
|
||||||
click.echo(f"Uploading {len(resolved)} file(s) to {owner}/{repo_name} on {hub}...")
|
|
||||||
|
|
||||||
for file_path in resolved:
|
|
||||||
_upload_single_file(hub, owner, repo_name, username, password, file_path, remote_dir, message, branch)
|
|
||||||
|
|
||||||
click.echo(f"\nAll done. {len(resolved)} file(s) uploaded to {owner}/{repo_name}.")
|
|
||||||
|
|
@ -79,9 +79,9 @@ def test_user_agent_is_set_and_not_python_default():
|
||||||
"""
|
"""
|
||||||
ua = agent_cli._user_agent()
|
ua = agent_cli._user_agent()
|
||||||
assert ua, "User-Agent must not be empty"
|
assert ua, "User-Agent must not be empty"
|
||||||
assert not ua.lower().startswith(
|
assert not ua.lower().startswith("python-urllib"), (
|
||||||
"python-urllib"
|
f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it"
|
||||||
), f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it"
|
)
|
||||||
assert ua.startswith("ria-agent/")
|
assert ua.startswith("ria-agent/")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -96,10 +96,7 @@ def test_register_request_carries_explicit_user_agent(tmp_path):
|
||||||
captured["api_key"] = req.get_header("X-api-key")
|
captured["api_key"] = req.get_header("X-api-key")
|
||||||
captured["timeout"] = kwargs.get("timeout")
|
captured["timeout"] = kwargs.get("timeout")
|
||||||
raise urllib.error.HTTPError(
|
raise urllib.error.HTTPError(
|
||||||
url=req.full_url,
|
url=req.full_url, code=403, msg="", hdrs=None, # type: ignore[arg-type]
|
||||||
code=403,
|
|
||||||
msg="",
|
|
||||||
hdrs=None, # type: ignore[arg-type]
|
|
||||||
fp=BytesIO(_structured("invalid_key")),
|
fp=BytesIO(_structured("invalid_key")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -199,44 +199,3 @@ def test_annotation_to_sigmf_format_values():
|
||||||
values = list(result.values())
|
values = list(result.values())
|
||||||
assert 50 in values or ann.sample_start in values
|
assert 50 in values or ann.sample_start in values
|
||||||
assert 100 in values or ann.sample_count in values
|
assert 100 in values or ann.sample_count in values
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# None freq-edge regression tests (SigMF optional fields)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_annotation_no_freq_edges():
|
|
||||||
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
|
||||||
assert ann.freq_lower_edge is None
|
|
||||||
assert ann.freq_upper_edge is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_annotation_is_valid_no_freq_edges():
|
|
||||||
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
|
||||||
assert ann.is_valid() is True
|
|
||||||
|
|
||||||
ann_zero = Annotation(sample_start=0, sample_count=0, label="burst")
|
|
||||||
assert ann_zero.is_valid() is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_annotation_overlap_none_edges_returns_zero():
|
|
||||||
ann1 = Annotation(sample_start=0, sample_count=10)
|
|
||||||
ann2 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=0, freq_upper_edge=100)
|
|
||||||
assert ann1.overlap(ann2) == 0
|
|
||||||
assert ann2.overlap(ann1) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_annotation_area_none_edges_returns_zero():
|
|
||||||
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
|
||||||
assert ann.area() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_annotation_to_sigmf_omits_freq_keys_when_none():
|
|
||||||
from sigmf import SigMFFile
|
|
||||||
|
|
||||||
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
|
||||||
result = ann.to_sigmf_format()
|
|
||||||
metadata = result["metadata"]
|
|
||||||
assert SigMFFile.FLO_KEY not in metadata
|
|
||||||
assert SigMFFile.FHI_KEY not in metadata
|
|
||||||
|
|
|
||||||
|
|
@ -189,21 +189,3 @@ def test_sigmf_3(tmp_path):
|
||||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
|
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
assert str(e) == "File already exists"
|
assert str(e) == "File already exists"
|
||||||
|
|
||||||
|
|
||||||
def test_sigmf_annotation_without_freq_edges(tmp_path):
|
|
||||||
# Regression: annotations that omit the optional SigMF freq edge fields must
|
|
||||||
# load without error; edges should be None and the annotation still valid.
|
|
||||||
ann = Annotation(sample_start=0, sample_count=5, label="burst")
|
|
||||||
recording1 = Recording(data=complex_data_1, metadata=sample_metadata, annotations=[ann])
|
|
||||||
|
|
||||||
filename = tmp_path / "test"
|
|
||||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name, overwrite=True)
|
|
||||||
recording2 = from_sigmf(filename)
|
|
||||||
|
|
||||||
assert len(recording2.annotations) == 1
|
|
||||||
loaded = recording2.annotations[0]
|
|
||||||
assert loaded.freq_lower_edge is None
|
|
||||||
assert loaded.freq_upper_edge is None
|
|
||||||
assert loaded.is_valid() is True
|
|
||||||
assert loaded.label == "burst"
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user