signal-view-fix #35
1220
poetry.lock
generated
1220
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
|
|
@ -28,6 +28,8 @@ from .hardware import available_devices
|
|||
from .legacy_executor import main as _legacy_main
|
||||
from .namegen import generate_agent_name
|
||||
|
||||
DEFAULT_HUB_URL = "https://riahub.ai"
|
||||
|
||||
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
||||
|
||||
|
||||
|
|
@ -54,18 +56,9 @@ _REGISTER_TIMEOUT_S = 15
|
|||
|
||||
|
||||
REGISTRATION_REASON_MESSAGES = {
|
||||
"invalid_key": (
|
||||
"Registration key not recognized. Generate a fresh key from "
|
||||
"Settings → RIA Agents on the hub."
|
||||
),
|
||||
"expired": (
|
||||
"This registration key has expired. Generate a new one from "
|
||||
"Settings → RIA Agents on the hub."
|
||||
),
|
||||
"revoked": (
|
||||
"This registration key was revoked. Generate a new one from "
|
||||
"Settings → RIA Agents on the hub."
|
||||
),
|
||||
"invalid_key": ("Registration key not recognized. Generate a fresh key from " "Settings → RIA Agents on the hub."),
|
||||
"expired": ("This registration key has expired. Generate a new one from " "Settings → RIA Agents on the hub."),
|
||||
"revoked": ("This registration key was revoked. Generate a new one from " "Settings → RIA Agents on the hub."),
|
||||
"already_consumed": (
|
||||
"This single-use registration key has already been used. "
|
||||
"Generate a new one, or mint a reusable key instead."
|
||||
|
|
@ -226,7 +219,7 @@ def main() -> None:
|
|||
sub.add_parser("detect", help="List available SDR drivers")
|
||||
|
||||
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
||||
p_reg.add_argument("--hub", required=True, help="RIA Hub URL (e.g. http://whitehorse:3005)")
|
||||
p_reg.add_argument("--hub", default=DEFAULT_HUB_URL, help=f"RIA Hub URL (default: {DEFAULT_HUB_URL})")
|
||||
p_reg.add_argument(
|
||||
"--api-key",
|
||||
dest="api_key",
|
||||
|
|
|
|||
|
|
@ -45,7 +45,14 @@ def is_annotation_contained(inner: Annotation, outer: Annotation) -> bool:
|
|||
outer_sample_stop = outer.sample_start + outer.sample_count
|
||||
|
||||
if inner.sample_start > outer.sample_start and inner_sample_stop < outer_sample_stop:
|
||||
if inner.freq_lower_edge > outer.freq_lower_edge and inner.freq_upper_edge < outer.freq_upper_edge:
|
||||
if (
|
||||
inner.freq_lower_edge is not None
|
||||
and inner.freq_upper_edge is not None
|
||||
and outer.freq_lower_edge is not None
|
||||
and outer.freq_upper_edge is not None
|
||||
and inner.freq_lower_edge > outer.freq_lower_edge
|
||||
and inner.freq_upper_edge < outer.freq_upper_edge
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -17,10 +17,10 @@ class Annotation:
|
|||
:type sample_start: int
|
||||
:param sample_count: The index of the ending sample of the annotation, inclusive.
|
||||
:type sample_count: int
|
||||
:param freq_lower_edge: The lower frequency of the annotation.
|
||||
:type freq_lower_edge: float
|
||||
:param freq_upper_edge: The upper frequency of the annotation.
|
||||
:type freq_upper_edge: float
|
||||
:param freq_lower_edge: The lower frequency of the annotation. Optional; None if not specified in source.
|
||||
:type freq_lower_edge: float, optional
|
||||
:param freq_upper_edge: The upper frequency of the annotation. Optional; None if not specified in source.
|
||||
:type freq_upper_edge: float, optional
|
||||
:param label: The label that will be displayed with the bounding box in compatible viewers including IQEngine.
|
||||
Defaults to an emtpy string.
|
||||
:type label: str, optional
|
||||
|
|
@ -34,8 +34,8 @@ class Annotation:
|
|||
self,
|
||||
sample_start: int,
|
||||
sample_count: int,
|
||||
freq_lower_edge: float,
|
||||
freq_upper_edge: float,
|
||||
freq_lower_edge: Optional[float] = None,
|
||||
freq_upper_edge: Optional[float] = None,
|
||||
label: Optional[str] = "",
|
||||
comment: Optional[str] = "",
|
||||
detail: Optional[dict] = None,
|
||||
|
|
@ -43,8 +43,8 @@ class Annotation:
|
|||
"""Initialize a new Annotation instance."""
|
||||
self.sample_start = int(sample_start)
|
||||
self.sample_count = int(sample_count)
|
||||
self.freq_lower_edge = float(freq_lower_edge)
|
||||
self.freq_upper_edge = float(freq_upper_edge)
|
||||
self.freq_lower_edge = float(freq_lower_edge) if freq_lower_edge is not None else None
|
||||
self.freq_upper_edge = float(freq_upper_edge) if freq_upper_edge is not None else None
|
||||
self.label = str(label)
|
||||
self.comment = str(comment)
|
||||
|
||||
|
|
@ -62,6 +62,8 @@ class Annotation:
|
|||
:returns: True if valid, False if not.
|
||||
"""
|
||||
|
||||
if self.freq_lower_edge is None or self.freq_upper_edge is None:
|
||||
return self.sample_count > 0
|
||||
return self.sample_count > 0 and self.freq_lower_edge < self.freq_upper_edge
|
||||
|
||||
def overlap(self, other):
|
||||
|
|
@ -73,6 +75,14 @@ class Annotation:
|
|||
|
||||
:returns: The area of the overlap in samples*frequency, or 0 if they do not overlap."""
|
||||
|
||||
if (
|
||||
self.freq_lower_edge is None
|
||||
or self.freq_upper_edge is None
|
||||
or other.freq_lower_edge is None
|
||||
or other.freq_upper_edge is None
|
||||
):
|
||||
return 0
|
||||
|
||||
sample_overlap_start = max(self.sample_start, other.sample_start)
|
||||
sample_overlap_end = min(self.sample_start + self.sample_count, other.sample_start + other.sample_count)
|
||||
|
||||
|
|
@ -91,6 +101,8 @@ class Annotation:
|
|||
|
||||
:returns: sample length multiplied by bandwidth."""
|
||||
|
||||
if self.freq_lower_edge is None or self.freq_upper_edge is None:
|
||||
return 0
|
||||
return self.sample_count * (self.freq_upper_edge - self.freq_lower_edge)
|
||||
|
||||
def __eq__(self, other: Annotation) -> bool:
|
||||
|
|
@ -103,13 +115,16 @@ class Annotation:
|
|||
|
||||
annotation_dict = {SigMFFile.START_INDEX_KEY: self.sample_start, SigMFFile.LENGTH_INDEX_KEY: self.sample_count}
|
||||
|
||||
annotation_dict["metadata"] = {
|
||||
metadata = {
|
||||
SigMFFile.LABEL_KEY: self.label,
|
||||
SigMFFile.COMMENT_KEY: self.comment,
|
||||
SigMFFile.FHI_KEY: self.freq_upper_edge,
|
||||
SigMFFile.FLO_KEY: self.freq_lower_edge,
|
||||
"ria:detail": self.detail,
|
||||
}
|
||||
if self.freq_upper_edge is not None:
|
||||
metadata[SigMFFile.FHI_KEY] = self.freq_upper_edge
|
||||
if self.freq_lower_edge is not None:
|
||||
metadata[SigMFFile.FLO_KEY] = self.freq_lower_edge
|
||||
annotation_dict["metadata"] = metadata
|
||||
|
||||
if _is_jsonable(annotation_dict):
|
||||
return annotation_dict
|
||||
|
|
|
|||
|
|
@ -81,6 +81,8 @@ def view_annotations(
|
|||
return 0
|
||||
|
||||
for annotation in sorted(annotations, key=_threshold_sort_key, reverse=True):
|
||||
if annotation.freq_lower_edge is None or annotation.freq_upper_edge is None:
|
||||
continue
|
||||
t_start = annotation.sample_start / sample_rate
|
||||
t_width = annotation.sample_count / sample_rate
|
||||
f_start = annotation.freq_lower_edge
|
||||
|
|
|
|||
|
|
@ -2,15 +2,57 @@
|
|||
This module contains the main group for the ria toolkit oss CLI.
|
||||
"""
|
||||
|
||||
import click
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss import commands
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="Unable to import Axes3D",
|
||||
category=UserWarning,
|
||||
module="matplotlib",
|
||||
)
|
||||
|
||||
import click # noqa: E402
|
||||
|
||||
from ria_toolkit_oss_cli.ria_toolkit_oss import commands # noqa: E402
|
||||
|
||||
|
||||
@click.group()
|
||||
def _git_lfs_installed() -> bool:
|
||||
"""Return True if git-lfs is available on PATH."""
|
||||
try:
|
||||
return (
|
||||
subprocess.run(
|
||||
["git", "lfs", "version"],
|
||||
capture_output=True,
|
||||
).returncode
|
||||
== 0
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.option("-v", "--verbose", is_flag=True, type=bool, help="Increase verbosity, especially useful for debugging.")
|
||||
def cli(verbose):
|
||||
pass
|
||||
@click.pass_context
|
||||
def cli(ctx, verbose):
|
||||
lfs_missing = not _git_lfs_installed()
|
||||
if lfs_missing:
|
||||
click.echo(
|
||||
"Warning: git-lfs is not installed. RIA Hub projects require git-lfs to\n"
|
||||
"track large binary files (models, recordings, datasets).\n"
|
||||
"\n"
|
||||
" Linux: sudo apt-get install git-lfs\n"
|
||||
" macOS: brew install git-lfs\n"
|
||||
" Other platforms: https://git-lfs.com\n"
|
||||
"\n"
|
||||
"After installing, run: git lfs install",
|
||||
err=True,
|
||||
)
|
||||
if ctx.invoked_subcommand is None:
|
||||
if lfs_missing and sys.stdin.isatty():
|
||||
click.pause(info="\nPress Enter to continue...", err=True)
|
||||
click.echo(ctx.get_help())
|
||||
|
||||
|
||||
# Loop through project commands, binding them all to the CLI.
|
||||
|
|
|
|||
97
src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py
Normal file
97
src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Shared authentication and security helpers for RIA Hub API calls."""
|
||||
|
||||
import base64
|
||||
import subprocess
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
import click
|
||||
|
||||
DEFAULT_HUB = "https://riahub.ai"
|
||||
|
||||
|
||||
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
||||
"""Block redirects on authenticated requests to prevent credential exfiltration.
|
||||
|
||||
urllib re-sends the Authorization header on same-host redirects by default.
|
||||
A malicious server could redirect a POST to a different host to harvest
|
||||
credentials. We refuse all redirects — API clients should not encounter them
|
||||
in normal operation.
|
||||
"""
|
||||
|
||||
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
||||
raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials")
|
||||
|
||||
|
||||
def hub_opener() -> urllib.request.OpenerDirector:
|
||||
"""Return a urllib opener that blocks redirects."""
|
||||
return urllib.request.build_opener(_NoRedirectHandler)
|
||||
|
||||
|
||||
def warn_if_insecure(hub: str) -> None:
|
||||
"""Warn when credentials would be sent over plain HTTP to a non-localhost host."""
|
||||
parsed = urllib.parse.urlparse(hub)
|
||||
if parsed.scheme == "http":
|
||||
host = parsed.hostname or ""
|
||||
if host not in ("localhost", "127.0.0.1", "::1"):
|
||||
click.echo(
|
||||
f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
|
||||
def basic_auth(username: str, password: str) -> str:
|
||||
return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
|
||||
|
||||
def get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
|
||||
"""Ask git credential fill for stored creds. Returns (username, password) or (None, None)."""
|
||||
parsed = urllib.parse.urlparse(hub_url)
|
||||
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "credential", "fill"],
|
||||
input=payload,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
creds = {}
|
||||
for line in result.stdout.splitlines():
|
||||
# Partition on the FIRST '=' only so passwords containing '=' are preserved.
|
||||
k, sep, v = line.partition("=")
|
||||
if sep:
|
||||
creds[k.strip()] = v # keep value verbatim
|
||||
return creds.get("username"), creds.get("password")
|
||||
except Exception:
|
||||
return None, None
|
||||
|
||||
|
||||
def store_credentials(hub_url: str, username: str, password: str) -> None:
|
||||
"""Cache credentials via git credential approve (uses the system keychain/store)."""
|
||||
parsed = urllib.parse.urlparse(hub_url)
|
||||
payload = (
|
||||
f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n"
|
||||
)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "credential", "approve"],
|
||||
input=payload,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except Exception:
|
||||
pass # non-fatal — next push just prompts again
|
||||
|
||||
|
||||
def resolve_credentials(hub: str) -> tuple[str, str]:
|
||||
"""Return (username, password), prompting interactively if not cached."""
|
||||
username, password = get_stored_credentials(hub)
|
||||
if username and password:
|
||||
return username, password
|
||||
click.echo(f"No stored credentials found for {hub}.")
|
||||
username = click.prompt("RIA Hub username")
|
||||
password = click.prompt("Password / personal access token", hide_input=True)
|
||||
return username, password
|
||||
|
|
@ -86,9 +86,7 @@ def save_recording_auto(recording, output_path, input_path, quiet=False, overwri
|
|||
input_path = Path(input_path)
|
||||
fmt = detect_input_format(input_path)
|
||||
|
||||
output_path = determine_output_path(
|
||||
input_path=input_path, output_path=output_path, fmt=fmt, overwrite=overwrite
|
||||
)
|
||||
output_path = determine_output_path(input_path=input_path, output_path=output_path, fmt=fmt, overwrite=overwrite)
|
||||
|
||||
if not quiet:
|
||||
if fmt == "sigmf":
|
||||
|
|
@ -258,7 +256,11 @@ def list(input, verbose):
|
|||
user_comment = ann.comment or ""
|
||||
|
||||
# Basic info
|
||||
freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
||||
freq_range = (
|
||||
f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
||||
if ann.freq_lower_edge is not None and ann.freq_upper_edge is not None
|
||||
else "N/A"
|
||||
)
|
||||
click.echo(
|
||||
f" [{i}] Samples {format_sample_count(ann.sample_start)}-"
|
||||
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {ann.label}"
|
||||
|
|
@ -502,8 +504,7 @@ def clear(input, output, overwrite, force, quiet):
|
|||
help="Annotation type",
|
||||
)
|
||||
@click.option(
|
||||
"--sample-rate", type=float, default=None,
|
||||
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||
)
|
||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||
|
|
@ -617,8 +618,7 @@ def energy(
|
|||
help="Annotation type",
|
||||
)
|
||||
@click.option(
|
||||
"--sample-rate", type=float, default=None,
|
||||
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||
)
|
||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||
|
|
@ -707,8 +707,7 @@ def cusum(input, label, min_duration, window_size, tolerance, annotation_type, s
|
|||
)
|
||||
@click.option("--channel", type=int, default=0, help="Channel index to annotate (default: 0)")
|
||||
@click.option(
|
||||
"--sample-rate", type=float, default=None,
|
||||
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||
)
|
||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||
|
|
@ -787,8 +786,7 @@ def threshold(input, threshold, label, window_size, annotation_type, channel, sa
|
|||
@click.option("--noise-threshold-db", type=float, help="Noise floor threshold in dB (auto-estimated if not specified)")
|
||||
@click.option("--min-component-bw", type=float, default=50e3, help="Min component bandwidth in Hz")
|
||||
@click.option(
|
||||
"--sample-rate", type=float, default=None,
|
||||
help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||
"--sample-rate", type=float, default=None, help="Sample rate in Hz (overrides metadata; required if not in file)"
|
||||
)
|
||||
@click.option("--output", "-o", type=click.Path(), help="Output file path")
|
||||
@click.option("--overwrite", is_flag=True, help="Overwrite input file (non-SigMF only)")
|
||||
|
|
@ -809,7 +807,8 @@ def _log_separate_start(quiet, recording, indices_list, nfft, noise_threshold_db
|
|||
|
||||
|
||||
def separate(
|
||||
input, indices, nfft, noise_threshold_db, min_component_bw, sample_rate, output, overwrite, quiet, verbose):
|
||||
input, indices, nfft, noise_threshold_db, min_component_bw, sample_rate, output, overwrite, quiet, verbose
|
||||
):
|
||||
"""
|
||||
Auto-detect parallel frequency-offset signals and split into sub-bands.
|
||||
|
||||
|
|
@ -883,7 +882,11 @@ def separate(
|
|||
click.echo("\n Details:")
|
||||
for i in range(initial_count, final_count):
|
||||
ann = recording.annotations[i]
|
||||
freq_range = f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
||||
freq_range = (
|
||||
f"{format_frequency(ann.freq_lower_edge)} - {format_frequency(ann.freq_upper_edge)}"
|
||||
if ann.freq_lower_edge is not None and ann.freq_upper_edge is not None
|
||||
else "N/A"
|
||||
)
|
||||
click.echo(
|
||||
f" [{i}] samples {format_sample_count(ann.sample_start)}-"
|
||||
f"{format_sample_count(ann.sample_start + ann.sample_count)}: {freq_range}"
|
||||
|
|
|
|||
|
|
@ -16,9 +16,11 @@ from .generate import generate
|
|||
# from .generate import generate
|
||||
from .init import init
|
||||
from .serve import serve
|
||||
from .setup_repo import setup_repo
|
||||
from .split import split
|
||||
from .transform import transform
|
||||
from .transmit import transmit
|
||||
from .upload import upload
|
||||
from .view import view
|
||||
|
||||
# Aliases
|
||||
|
|
|
|||
401
src/ria_toolkit_oss_cli/ria_toolkit_oss/setup_repo.py
Normal file
401
src/ria_toolkit_oss_cli/ria_toolkit_oss/setup_repo.py
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
"""ria setup_repo — create and configure a RIA Hub Project repo."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
import click
|
||||
|
||||
from ._hub_auth import (
|
||||
DEFAULT_HUB,
|
||||
_NoRedirectHandler,
|
||||
basic_auth,
|
||||
resolve_credentials,
|
||||
store_credentials,
|
||||
warn_if_insecure,
|
||||
)
|
||||
|
||||
RIA_LFS_RULES = [
|
||||
("*.pt", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.pth", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.onnx", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.sigmf", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.sigmf-data", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.sigmf-meta", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.npy", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.npz", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.h5", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.hdf5", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.bin", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
("*.pkl", "filter=lfs diff=lfs merge=lfs -text"),
|
||||
]
|
||||
|
||||
# Repo names must be safe directory names and valid git remote path components.
|
||||
_SAFE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,100}$")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _api_request(
|
||||
hub: str,
|
||||
path: str,
|
||||
method: str,
|
||||
username: str,
|
||||
password: str,
|
||||
body: dict | None = None,
|
||||
) -> tuple[dict, int]:
|
||||
"""
|
||||
Make an authenticated request to the RIA Hub API.
|
||||
Returns (parsed_response_body, http_status_code).
|
||||
Status 0 means a network/connection error.
|
||||
Credentials are sent as HTTP Basic auth — safe over HTTPS and localhost HTTP.
|
||||
Redirects are blocked to prevent credential exfiltration.
|
||||
"""
|
||||
url = f"{hub.rstrip('/')}/api/v1{path}"
|
||||
data = json.dumps(body).encode() if body is not None else None
|
||||
req = urllib.request.Request(url, data=data, method=method)
|
||||
req.add_header("Content-Type", "application/json")
|
||||
req.add_header("Authorization", basic_auth(username, password))
|
||||
|
||||
opener = urllib.request.build_opener(_NoRedirectHandler)
|
||||
try:
|
||||
with opener.open(req, timeout=15) as resp:
|
||||
return json.loads(resp.read() or b"{}"), resp.status
|
||||
except urllib.error.HTTPError as e:
|
||||
try:
|
||||
resp_body = json.loads(e.read() or b"{}")
|
||||
except Exception:
|
||||
resp_body = {}
|
||||
return resp_body, e.code
|
||||
except urllib.error.URLError as e:
|
||||
return {"message": str(e.reason)}, 0
|
||||
|
||||
|
||||
def _get_authenticated_username(hub: str, username: str, password: str) -> str | None:
|
||||
"""Return the login name of the authenticated user from GET /api/v1/user.
|
||||
|
||||
This is the canonical username for URL construction — it may differ from
|
||||
git config user.name which is a display name, not a login.
|
||||
"""
|
||||
body, status = _api_request(hub, "/user", "GET", username, password)
|
||||
if status == 200:
|
||||
return body.get("login")
|
||||
return None
|
||||
|
||||
|
||||
def _repo_exists(hub: str, owner: str, name: str, username: str, password: str) -> bool:
|
||||
body, status = _api_request(
|
||||
hub,
|
||||
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(name, safe='')}",
|
||||
"GET",
|
||||
username,
|
||||
password,
|
||||
)
|
||||
return status == 200
|
||||
|
||||
|
||||
def _create_repo_on_hub(hub: str, name: str, username: str, password: str, private: bool) -> bool:
|
||||
"""Create an RIA Hub Project repo via API.
|
||||
|
||||
Returns True if the repo was freshly created (server seeded README.md and
|
||||
.gitattributes via auto_init + is_ria), False if the hub was unreachable
|
||||
(local fallback needed). Exits on fatal errors (auth, quota, name taken).
|
||||
"""
|
||||
body, status = _api_request(
|
||||
hub,
|
||||
"/user/repos",
|
||||
"POST",
|
||||
username,
|
||||
password,
|
||||
{
|
||||
"name": name,
|
||||
"auto_init": True,
|
||||
"is_ria": True,
|
||||
"private": private,
|
||||
"default_branch": "main",
|
||||
},
|
||||
)
|
||||
|
||||
if status == 201:
|
||||
click.echo(f"Repository '{name}' created on RIA Hub.")
|
||||
return True
|
||||
|
||||
if status == 0:
|
||||
click.echo(
|
||||
f"Warning: could not reach RIA Hub at {hub}: {body.get('message', 'connection failed')}",
|
||||
err=True,
|
||||
)
|
||||
click.echo("Continuing with local setup only — create the repo manually on RIA Hub.", err=True)
|
||||
return False
|
||||
|
||||
msg = body.get("message", "")
|
||||
|
||||
if status == 401:
|
||||
click.echo("Error: authentication failed — check your username/password.", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
if status in (403, 413) or "quota" in msg.lower() or "limit" in msg.lower():
|
||||
click.echo("Error: cannot create repository — storage quota or account limit reached.", err=True)
|
||||
if msg:
|
||||
click.echo(f" Server message: {msg}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
if status == 422 or "already exist" in msg.lower():
|
||||
click.echo(f"Repository '{name}' already exists on RIA Hub.")
|
||||
return False
|
||||
|
||||
click.echo(f"Error creating repository (HTTP {status}): {msg}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Local git helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _tracked_patterns(ga_path: str) -> set:
|
||||
if not os.path.exists(ga_path):
|
||||
return set()
|
||||
patterns = set()
|
||||
with open(ga_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
m = re.match(r"^(\S+)\s+", line)
|
||||
if m:
|
||||
patterns.add(m.group(1))
|
||||
return patterns
|
||||
|
||||
|
||||
def _write_local_ria_files(repo_path: str, repo_name: str) -> None:
|
||||
"""Seed README.md and .gitattributes locally (used when hub is unreachable or --no-remote)."""
|
||||
# README
|
||||
for candidate in ("README.md", "README.rst", "README.txt", "README"):
|
||||
if os.path.exists(os.path.join(repo_path, candidate)):
|
||||
click.echo(f"README: {candidate} already exists, skipping")
|
||||
break
|
||||
else:
|
||||
with open(os.path.join(repo_path, "README.md"), "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
f"# {repo_name}\n"
|
||||
"\n"
|
||||
"A RIA Hub project.\n"
|
||||
"\n"
|
||||
"## Description\n"
|
||||
"\n"
|
||||
"<!-- Add your project description here -->\n"
|
||||
"\n"
|
||||
"## Contents\n"
|
||||
"\n"
|
||||
"<!-- Describe the signals, models, or datasets in this repository -->\n"
|
||||
)
|
||||
click.echo("README.md: created")
|
||||
|
||||
# .gitattributes
|
||||
ga_path = os.path.join(repo_path, ".gitattributes")
|
||||
existing = _tracked_patterns(ga_path)
|
||||
new_rules = [(p, a) for p, a in RIA_LFS_RULES if p not in existing]
|
||||
|
||||
if new_rules:
|
||||
existing_content = ""
|
||||
if os.path.exists(ga_path):
|
||||
with open(ga_path, encoding="utf-8") as f:
|
||||
existing_content = f.read()
|
||||
|
||||
separator = "" if (not existing_content or existing_content.endswith("\n")) else "\n"
|
||||
addition = separator + "".join(f"{pattern} {attrs}\n" for pattern, attrs in new_rules)
|
||||
|
||||
with open(ga_path, "a", encoding="utf-8") as f:
|
||||
f.write(addition)
|
||||
click.echo(f".gitattributes: {len(new_rules)} rule(s) added")
|
||||
else:
|
||||
click.echo(".gitattributes: all RIA Hub rules are already present")
|
||||
|
||||
|
||||
def _git(repo_path: str, *args: str, check: bool = True) -> subprocess.CompletedProcess:
|
||||
return subprocess.run(
|
||||
["git", "-C", repo_path, *args],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=check,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_path_and_name(name: str | None, local_path: str | None) -> tuple[str, str]:
|
||||
if local_path:
|
||||
repo_path = os.path.abspath(local_path)
|
||||
repo_name = name or os.path.basename(repo_path)
|
||||
elif name:
|
||||
repo_path = os.path.abspath(name)
|
||||
repo_name = name
|
||||
else:
|
||||
repo_path = os.path.abspath(".")
|
||||
repo_name = os.path.basename(repo_path)
|
||||
return repo_path, repo_name
|
||||
|
||||
|
||||
def _resolve_owner(hub: str, username: str | None, password: str | None, owner: str | None) -> str:
|
||||
if not owner and username and password:
|
||||
api_login = _get_authenticated_username(hub, username, password)
|
||||
owner = api_login or username
|
||||
return owner or "unknown"
|
||||
|
||||
|
||||
def _git_init(repo_path: str) -> None:
|
||||
if os.path.isdir(os.path.join(repo_path, ".git")):
|
||||
return
|
||||
result = _git(repo_path, "init", "-b", "main", check=False)
|
||||
if result.returncode != 0:
|
||||
# Older git (< 2.28) doesn't support -b; fall back and rename.
|
||||
_git(repo_path, "init")
|
||||
_git(repo_path, "symbolic-ref", "HEAD", "refs/heads/main")
|
||||
click.echo("git init: done (branch: main)")
|
||||
|
||||
|
||||
def _configure_remote(
|
||||
repo_path: str, hub: str, resolved_owner: str, repo_name: str, username: str | None, no_remote: bool
|
||||
) -> None:
|
||||
if no_remote or not username:
|
||||
click.echo(
|
||||
f"Skipped remote setup. Add it manually:\n"
|
||||
f" git -C {repo_path} remote add origin "
|
||||
f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
||||
)
|
||||
return
|
||||
remote_url = f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
||||
existing = _git(repo_path, "remote", "get-url", "origin", check=False)
|
||||
if existing.returncode == 0:
|
||||
existing_url = existing.stdout.strip()
|
||||
if existing_url == remote_url:
|
||||
click.echo(f"remote origin: {remote_url} (already set)")
|
||||
else:
|
||||
click.echo(
|
||||
f"remote 'origin' already points to {existing_url}.\n"
|
||||
f" To update: git remote set-url origin {remote_url}"
|
||||
)
|
||||
else:
|
||||
_git(repo_path, "remote", "add", "origin", remote_url)
|
||||
click.echo(f"remote origin: {remote_url}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@click.command("setup_repo")
|
||||
@click.argument("name", required=False)
|
||||
@click.option(
|
||||
"--path", "local_path", default=None, help="Local directory (default: current dir, or created from NAME)."
|
||||
)
|
||||
@click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.")
|
||||
@click.option(
|
||||
"--owner",
|
||||
default=None,
|
||||
metavar="USER",
|
||||
help="RIA Hub login username (default: looked up from the API using your credentials).",
|
||||
)
|
||||
@click.option("--private", is_flag=True, default=False, help="Create the repository as private.")
|
||||
@click.option(
|
||||
"--no-remote", is_flag=True, default=False, help="Skip creating the repository on RIA Hub (local setup only)."
|
||||
)
|
||||
def setup_repo(
|
||||
name: str | None,
|
||||
local_path: str | None,
|
||||
hub: str,
|
||||
owner: str | None,
|
||||
private: bool,
|
||||
no_remote: bool,
|
||||
) -> None:
|
||||
"""Create and configure a RIA Hub Project repo.
|
||||
|
||||
NAME is the repository name. If the local directory does not exist or is
|
||||
not a git repo, it will be initialised automatically. Credentials are
|
||||
retrieved from git's credential store — no token setup required if you
|
||||
have used RIA Hub with git before.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
ria setup_repo my-dataset
|
||||
ria setup_repo my-dataset --hub https://riahub.example.com
|
||||
ria setup_repo --path ./existing-dir
|
||||
ria setup_repo my-dataset --private
|
||||
"""
|
||||
repo_path, repo_name = _resolve_path_and_name(name, local_path)
|
||||
|
||||
if not _SAFE_NAME_RE.match(repo_name):
|
||||
click.echo(
|
||||
f"Error: '{repo_name}' is not a valid repository name.\n"
|
||||
"Use only letters, numbers, hyphens, underscores, and dots (max 100 chars).",
|
||||
err=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not no_remote:
|
||||
warn_if_insecure(hub)
|
||||
|
||||
username, password = (None, None) if no_remote else resolve_credentials(hub)
|
||||
resolved_owner = _resolve_owner(hub, username, password, owner)
|
||||
|
||||
# newly_created=True means the server ran auto_init+is_ria and seeded
|
||||
# README.md + .gitattributes in the initial commit; local setup pulls
|
||||
# those files via fetch rather than writing them from scratch.
|
||||
newly_created = False
|
||||
if not no_remote and username and password:
|
||||
if _repo_exists(hub, resolved_owner, repo_name, username, password):
|
||||
click.echo(f"Repository '{resolved_owner}/{repo_name}' already exists on RIA Hub.")
|
||||
else:
|
||||
newly_created = _create_repo_on_hub(hub, repo_name, username, password, private)
|
||||
store_credentials(hub, username, password)
|
||||
|
||||
if not os.path.exists(repo_path):
|
||||
os.makedirs(repo_path)
|
||||
click.echo(f"Created directory: {repo_path}")
|
||||
|
||||
_git_init(repo_path)
|
||||
|
||||
if subprocess.run(["git", "lfs", "version"], capture_output=True).returncode != 0:
|
||||
click.echo(
|
||||
"Error: git-lfs is not installed.\n"
|
||||
" Linux: sudo apt-get install git-lfs\n"
|
||||
" macOS: brew install git-lfs\n"
|
||||
" Other platforms: https://git-lfs.com",
|
||||
err=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
_git(repo_path, "lfs", "install", "--local")
|
||||
click.echo("git lfs install --local: done")
|
||||
|
||||
_configure_remote(repo_path, hub, resolved_owner, repo_name, username, no_remote)
|
||||
|
||||
if newly_created:
|
||||
fetch = _git(repo_path, "fetch", "origin", check=False)
|
||||
if fetch.returncode == 0:
|
||||
_git(repo_path, "reset", "--hard", "origin/main")
|
||||
click.echo("Pulled initial commit from RIA Hub (README.md + .gitattributes)")
|
||||
else:
|
||||
click.echo("Warning: fetch failed — falling back to local file setup.", err=True)
|
||||
_write_local_ria_files(repo_path, repo_name)
|
||||
else:
|
||||
_write_local_ria_files(repo_path, repo_name)
|
||||
|
||||
if newly_created:
|
||||
click.echo(f"\nRepo is ready. Push your work:\n cd {repo_path}\n git push -u origin main")
|
||||
else:
|
||||
click.echo(
|
||||
f"\nRepo is ready. Commit and push:\n"
|
||||
f" cd {repo_path}\n"
|
||||
f" git add README.md .gitattributes\n"
|
||||
f" git commit -m 'chore: initialise RIA Hub project'\n"
|
||||
f" git push -u origin main"
|
||||
)
|
||||
392
src/ria_toolkit_oss_cli/ria_toolkit_oss/upload.py
Normal file
392
src/ria_toolkit_oss_cli/ria_toolkit_oss/upload.py
Normal file
|
|
@ -0,0 +1,392 @@
|
|||
"""ria upload — stream large files to a RIA Hub Project via the LFS API.
|
||||
|
||||
How it works
|
||||
------------
|
||||
1. The file is hashed locally (SHA-256 + size) — this is the LFS object ID.
|
||||
2. A single POST to the repo's LFS batch endpoint returns an upload URL
|
||||
(and headers) for any object the server does not already have.
|
||||
3. The file is streamed to that URL in fixed-size chunks — nothing is ever
|
||||
fully loaded into memory, so files of any size work.
|
||||
4. A commit is created via the Gitea contents API that records the LFS
|
||||
pointer (a small text file) so the file appears in the repo tree.
|
||||
|
||||
No server-side changes are required — this uses the same authenticated LFS
|
||||
protocol that `git lfs push` uses internally.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import http.client
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
import click
|
||||
|
||||
from ._hub_auth import (
|
||||
DEFAULT_HUB,
|
||||
basic_auth,
|
||||
hub_opener,
|
||||
resolve_credentials,
|
||||
warn_if_insecure,
|
||||
)
|
||||
|
||||
# Read buffer for hashing and streaming — 8 MB keeps memory use flat
|
||||
# for arbitrarily large files.
|
||||
_CHUNK = 8 * 1024 * 1024
|
||||
|
||||
LFS_MEDIA_TYPE = "application/vnd.git-lfs+json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hash_file(path: str) -> tuple[str, int]:
|
||||
"""Return (sha256_hex, byte_size) by streaming the file."""
|
||||
h = hashlib.sha256()
|
||||
size = 0
|
||||
with open(path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
size += len(chunk)
|
||||
return h.hexdigest(), size
|
||||
|
||||
|
||||
def _lfs_pointer_text(oid: str, size: int) -> str:
|
||||
return f"version https://git-lfs.github.com/spec/v1\noid sha256:{oid}\nsize {size}\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LFS batch API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _lfs_batch(
|
||||
hub: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
objects: list[dict],
|
||||
username: str,
|
||||
password: str,
|
||||
) -> dict:
|
||||
"""
|
||||
POST to /{owner}/{repo}.git/info/lfs/objects/batch.
|
||||
Returns the parsed JSON response.
|
||||
Raises on HTTP error or JSON decode failure.
|
||||
"""
|
||||
url = (
|
||||
f"{hub.rstrip('/')}"
|
||||
f"/{urllib.parse.quote(owner, safe='')}"
|
||||
f"/{urllib.parse.quote(repo, safe='')}"
|
||||
f".git/info/lfs/objects/batch"
|
||||
)
|
||||
body = json.dumps(
|
||||
{
|
||||
"operation": "upload",
|
||||
"transfers": ["basic"],
|
||||
"objects": objects,
|
||||
}
|
||||
).encode()
|
||||
|
||||
req = urllib.request.Request(url, data=body, method="POST")
|
||||
req.add_header("Content-Type", LFS_MEDIA_TYPE)
|
||||
req.add_header("Accept", LFS_MEDIA_TYPE)
|
||||
req.add_header("Authorization", basic_auth(username, password))
|
||||
|
||||
opener = hub_opener()
|
||||
try:
|
||||
with opener.open(req, timeout=30) as resp:
|
||||
return json.loads(resp.read())
|
||||
except urllib.error.HTTPError as e:
|
||||
body_text = e.read().decode(errors="replace")
|
||||
raise RuntimeError(f"LFS batch request failed (HTTP {e.code}): {body_text}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gitea contents API — create / update a file to record the LFS pointer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_file_sha(
|
||||
hub: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
path: str,
|
||||
branch: str,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> str | None:
|
||||
"""Return the blob SHA of an existing file, or None if it doesn't exist."""
|
||||
url = (
|
||||
f"{hub.rstrip('/')}/api/v1"
|
||||
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(repo, safe='')}"
|
||||
f"/contents/{urllib.parse.quote(path)}"
|
||||
f"?ref={urllib.parse.quote(branch)}"
|
||||
)
|
||||
req = urllib.request.Request(url)
|
||||
req.add_header("Authorization", basic_auth(username, password))
|
||||
try:
|
||||
with hub_opener().open(req, timeout=15) as resp:
|
||||
return json.loads(resp.read()).get("sha")
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
return None
|
||||
raise
|
||||
|
||||
|
||||
def _commit_lfs_pointer(
|
||||
hub: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
remote_path: str,
|
||||
pointer_text: str,
|
||||
branch: str,
|
||||
message: str,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> None:
|
||||
"""Create or update a file in the repo containing the LFS pointer."""
|
||||
url = (
|
||||
f"{hub.rstrip('/')}/api/v1"
|
||||
f"/repos/{urllib.parse.quote(owner, safe='')}/{urllib.parse.quote(repo, safe='')}"
|
||||
f"/contents/{urllib.parse.quote(remote_path)}"
|
||||
)
|
||||
|
||||
existing_sha = _get_file_sha(hub, owner, repo, remote_path, branch, username, password)
|
||||
|
||||
body: dict = {
|
||||
"message": message,
|
||||
"content": base64.b64encode(pointer_text.encode()).decode(),
|
||||
"branch": branch,
|
||||
}
|
||||
if existing_sha:
|
||||
body["sha"] = existing_sha
|
||||
|
||||
method = "PUT" if existing_sha else "POST"
|
||||
req = urllib.request.Request(url, data=json.dumps(body).encode(), method=method)
|
||||
req.add_header("Content-Type", "application/json")
|
||||
req.add_header("Authorization", basic_auth(username, password))
|
||||
|
||||
try:
|
||||
with hub_opener().open(req, timeout=30) as resp:
|
||||
resp.read()
|
||||
except urllib.error.HTTPError as e:
|
||||
body_text = e.read().decode(errors="replace")
|
||||
raise RuntimeError(f"Failed to commit LFS pointer for '{remote_path}' (HTTP {e.code}): {body_text}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-file upload logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _upload_single_file(
|
||||
hub: str,
|
||||
owner: str,
|
||||
repo_name: str,
|
||||
username: str,
|
||||
password: str,
|
||||
file_path: str,
|
||||
remote_dir: str,
|
||||
message: str | None,
|
||||
branch: str,
|
||||
) -> None:
|
||||
"""Hash, upload (if needed), and commit the LFS pointer for one file."""
|
||||
filename = os.path.basename(file_path)
|
||||
file_size = os.path.getsize(file_path)
|
||||
size_mb = file_size / (1024 * 1024)
|
||||
|
||||
click.echo(f"\n {filename} ({size_mb:.1f} MB)")
|
||||
|
||||
click.echo(" Hashing...", nl=False)
|
||||
oid, size = _hash_file(file_path)
|
||||
click.echo(f" sha256:{oid[:12]}...")
|
||||
|
||||
try:
|
||||
batch = _lfs_batch(hub, owner, repo_name, [{"oid": oid, "size": size}], username, password)
|
||||
except RuntimeError as e:
|
||||
click.echo(f"\n Error: {e}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
objects = batch.get("objects", [])
|
||||
if not objects:
|
||||
click.echo(" Already in LFS — skipping upload.")
|
||||
else:
|
||||
obj = objects[0]
|
||||
if "error" in obj:
|
||||
err_msg = obj["error"].get("message", "unknown error")
|
||||
err_code = obj["error"].get("code", 0)
|
||||
if err_code == 413 or "quota" in err_msg.lower() or "limit" in err_msg.lower():
|
||||
click.echo(
|
||||
f"\n Error: storage quota exceeded for this repo.\n Server: {err_msg}",
|
||||
err=True,
|
||||
)
|
||||
else:
|
||||
click.echo(f"\n Error from server: {err_msg}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
upload_action = obj.get("actions", {}).get("upload")
|
||||
if not upload_action:
|
||||
click.echo(" Already in LFS — skipping upload.")
|
||||
else:
|
||||
href = upload_action["href"]
|
||||
up_headers = upload_action.get("header", {})
|
||||
chunks = math.ceil(size / _CHUNK)
|
||||
click.echo(f" Uploading ({size_mb:.1f} MB, {chunks} chunk{'s' if chunks != 1 else ''})...")
|
||||
try:
|
||||
_stream_upload_progress(href, up_headers, file_path, size)
|
||||
except RuntimeError as e:
|
||||
click.echo(f"\n Upload failed: {e}", err=True)
|
||||
sys.exit(1)
|
||||
click.echo(" Upload complete.")
|
||||
|
||||
verify_action = obj.get("actions", {}).get("verify")
|
||||
if verify_action:
|
||||
try:
|
||||
vreq = urllib.request.Request(
|
||||
verify_action["href"],
|
||||
data=json.dumps({"oid": oid, "size": size}).encode(),
|
||||
method="POST",
|
||||
)
|
||||
vreq.add_header("Content-Type", LFS_MEDIA_TYPE)
|
||||
vreq.add_header("Accept", LFS_MEDIA_TYPE)
|
||||
for k, v in verify_action.get("header", {}).items():
|
||||
vreq.add_header(k, v)
|
||||
with urllib.request.urlopen(vreq, timeout=15):
|
||||
pass
|
||||
except Exception:
|
||||
pass # verify is optional; non-fatal on failure
|
||||
|
||||
pointer = _lfs_pointer_text(oid, size)
|
||||
remote_path = (f"{remote_dir.rstrip('/')}/{filename}").lstrip("/") if remote_dir else filename
|
||||
commit_msg = message or f"chore: upload {filename} via ria"
|
||||
|
||||
click.echo(f" Committing pointer → {remote_path}...", nl=False)
|
||||
try:
|
||||
_commit_lfs_pointer(hub, owner, repo_name, remote_path, pointer, branch, commit_msg, username, password)
|
||||
click.echo(" done.")
|
||||
except RuntimeError as e:
|
||||
click.echo(f"\n Error: {e}", err=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int) -> None:
|
||||
"""Stream file_path to href with a click progress bar."""
|
||||
parsed = urllib.parse.urlparse(href)
|
||||
host = parsed.netloc
|
||||
path_q = parsed.path + (f"?{parsed.query}" if parsed.query else "")
|
||||
|
||||
if parsed.scheme == "https":
|
||||
conn = http.client.HTTPSConnection(host, timeout=300)
|
||||
else:
|
||||
conn = http.client.HTTPConnection(host, timeout=300)
|
||||
|
||||
all_headers = dict(headers)
|
||||
all_headers.setdefault("Content-Type", "application/octet-stream")
|
||||
all_headers["Content-Length"] = str(size)
|
||||
|
||||
try:
|
||||
conn.connect()
|
||||
conn.putrequest("PUT", path_q)
|
||||
for k, v in all_headers.items():
|
||||
conn.putheader(k, v)
|
||||
conn.endheaders()
|
||||
|
||||
with click.progressbar(
|
||||
length=size,
|
||||
label=" ",
|
||||
width=40,
|
||||
show_eta=True,
|
||||
show_percent=True,
|
||||
fill_char="█",
|
||||
empty_char="░",
|
||||
) as bar:
|
||||
with open(file_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(_CHUNK)
|
||||
if not chunk:
|
||||
break
|
||||
conn.send(chunk)
|
||||
bar.update(len(chunk))
|
||||
|
||||
resp = conn.getresponse()
|
||||
resp.read()
|
||||
if resp.status not in (200, 201):
|
||||
raise RuntimeError(f"HTTP {resp.status}")
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@click.command("upload")
|
||||
@click.argument("files", nargs=-1, required=True)
|
||||
@click.option(
|
||||
"--repo", required=True, metavar="OWNER/NAME", help="Target repository on RIA Hub (e.g. benchinnery/my-dataset)."
|
||||
)
|
||||
@click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.")
|
||||
@click.option("--branch", default="main", show_default=True, help="Branch to commit the files to.")
|
||||
@click.option(
|
||||
"--path",
|
||||
"remote_dir",
|
||||
default="",
|
||||
metavar="DIR",
|
||||
help="Remote directory path inside the repo (default: repo root).",
|
||||
)
|
||||
@click.option("--message", "-m", default=None, help="Commit message (default: 'chore: upload <filename> via ria').")
|
||||
def upload(
|
||||
files: tuple[str],
|
||||
repo: str,
|
||||
hub: str,
|
||||
branch: str,
|
||||
remote_dir: str,
|
||||
message: str | None,
|
||||
) -> None:
|
||||
"""Upload large files to a RIA Hub Project via Git LFS.
|
||||
|
||||
Files are streamed directly to the repo's LFS object store — nothing is
|
||||
buffered into memory, so files of any size work. Each file creates one
|
||||
commit recording the LFS pointer.
|
||||
|
||||
\b
|
||||
Examples:
|
||||
ria upload recording.sigmf-data --repo benchinnery/my-recordings
|
||||
ria upload *.npy --repo benchinnery/my-recordings --branch main
|
||||
ria upload big.pt --repo benchinnery/models --path weights/
|
||||
"""
|
||||
# Validate repo argument
|
||||
if "/" not in repo:
|
||||
click.echo("Error: --repo must be in the form OWNER/NAME.", err=True)
|
||||
sys.exit(1)
|
||||
owner, repo_name = repo.split("/", 1)
|
||||
|
||||
# Expand and validate files
|
||||
resolved = []
|
||||
for pattern in files:
|
||||
if not os.path.isfile(pattern):
|
||||
click.echo(f"Error: '{pattern}' is not a file or does not exist.", err=True)
|
||||
sys.exit(1)
|
||||
resolved.append(os.path.abspath(pattern))
|
||||
|
||||
hub = hub.rstrip("/")
|
||||
warn_if_insecure(hub)
|
||||
username, password = resolve_credentials(hub)
|
||||
|
||||
click.echo(f"Uploading {len(resolved)} file(s) to {owner}/{repo_name} on {hub}...")
|
||||
|
||||
for file_path in resolved:
|
||||
_upload_single_file(hub, owner, repo_name, username, password, file_path, remote_dir, message, branch)
|
||||
|
||||
click.echo(f"\nAll done. {len(resolved)} file(s) uploaded to {owner}/{repo_name}.")
|
||||
|
|
@ -79,9 +79,9 @@ def test_user_agent_is_set_and_not_python_default():
|
|||
"""
|
||||
ua = agent_cli._user_agent()
|
||||
assert ua, "User-Agent must not be empty"
|
||||
assert not ua.lower().startswith("python-urllib"), (
|
||||
f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it"
|
||||
)
|
||||
assert not ua.lower().startswith(
|
||||
"python-urllib"
|
||||
), f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it"
|
||||
assert ua.startswith("ria-agent/")
|
||||
|
||||
|
||||
|
|
@ -96,7 +96,10 @@ def test_register_request_carries_explicit_user_agent(tmp_path):
|
|||
captured["api_key"] = req.get_header("X-api-key")
|
||||
captured["timeout"] = kwargs.get("timeout")
|
||||
raise urllib.error.HTTPError(
|
||||
url=req.full_url, code=403, msg="", hdrs=None, # type: ignore[arg-type]
|
||||
url=req.full_url,
|
||||
code=403,
|
||||
msg="",
|
||||
hdrs=None, # type: ignore[arg-type]
|
||||
fp=BytesIO(_structured("invalid_key")),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -199,3 +199,44 @@ def test_annotation_to_sigmf_format_values():
|
|||
values = list(result.values())
|
||||
assert 50 in values or ann.sample_start in values
|
||||
assert 100 in values or ann.sample_count in values
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# None freq-edge regression tests (SigMF optional fields)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_annotation_no_freq_edges():
|
||||
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
||||
assert ann.freq_lower_edge is None
|
||||
assert ann.freq_upper_edge is None
|
||||
|
||||
|
||||
def test_annotation_is_valid_no_freq_edges():
|
||||
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
||||
assert ann.is_valid() is True
|
||||
|
||||
ann_zero = Annotation(sample_start=0, sample_count=0, label="burst")
|
||||
assert ann_zero.is_valid() is False
|
||||
|
||||
|
||||
def test_annotation_overlap_none_edges_returns_zero():
|
||||
ann1 = Annotation(sample_start=0, sample_count=10)
|
||||
ann2 = Annotation(sample_start=0, sample_count=10, freq_lower_edge=0, freq_upper_edge=100)
|
||||
assert ann1.overlap(ann2) == 0
|
||||
assert ann2.overlap(ann1) == 0
|
||||
|
||||
|
||||
def test_annotation_area_none_edges_returns_zero():
|
||||
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
||||
assert ann.area() == 0
|
||||
|
||||
|
||||
def test_annotation_to_sigmf_omits_freq_keys_when_none():
|
||||
from sigmf import SigMFFile
|
||||
|
||||
ann = Annotation(sample_start=0, sample_count=10, label="burst")
|
||||
result = ann.to_sigmf_format()
|
||||
metadata = result["metadata"]
|
||||
assert SigMFFile.FLO_KEY not in metadata
|
||||
assert SigMFFile.FHI_KEY not in metadata
|
||||
|
|
|
|||
|
|
@ -189,3 +189,21 @@ def test_sigmf_3(tmp_path):
|
|||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name)
|
||||
except IOError as e:
|
||||
assert str(e) == "File already exists"
|
||||
|
||||
|
||||
def test_sigmf_annotation_without_freq_edges(tmp_path):
|
||||
# Regression: annotations that omit the optional SigMF freq edge fields must
|
||||
# load without error; edges should be None and the annotation still valid.
|
||||
ann = Annotation(sample_start=0, sample_count=5, label="burst")
|
||||
recording1 = Recording(data=complex_data_1, metadata=sample_metadata, annotations=[ann])
|
||||
|
||||
filename = tmp_path / "test"
|
||||
to_sigmf(recording=recording1, path=tmp_path, filename=filename.name, overwrite=True)
|
||||
recording2 = from_sigmf(filename)
|
||||
|
||||
assert len(recording2.annotations) == 1
|
||||
loaded = recording2.annotations[0]
|
||||
assert loaded.freq_lower_edge is None
|
||||
assert loaded.freq_upper_edge is None
|
||||
assert loaded.is_valid() is True
|
||||
assert loaded.label == "burst"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user