Compare commits
13 Commits
0b736612ec
...
0b75013653
| Author | SHA1 | Date | |
|---|---|---|---|
| 0b75013653 | |||
| 5a67b39d22 | |||
| dd305aabeb | |||
| 816bc84f9a | |||
| b27b04dbc0 | |||
| 53f912f21a | |||
| 543517f0ca | |||
| ba1804a5f9 | |||
| febb1bd6cf | |||
| 5f68fd936d | |||
| 99447a581a | |||
| 2f6b5ced18 | |||
| eb5b4ce839 |
19
CHANGELOG.md
19
CHANGELOG.md
|
|
@ -1,5 +1,24 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [0.1.7] - 2026-05-26
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
- **Human-readable agent names** — `ria-agent register` now generates a default `adjective-colour-animal` name (e.g. `swift-teal-falcon`) via the new `namegen` module when `--name` is omitted, instead of registering with an empty string.
|
||||||
|
- **Structured registration error messages** — `ria-agent register` translates hub responses into actionable English for the known failure reasons (`invalid_key`, `expired`, `revoked`, `already_consumed`) and rate-limit (`HTTP 429`) responses, instead of surfacing raw `HTTP 4xx` text.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **`ria-agent register` `--api-key` help** — now describes the personal `ria_reg_*` registration key flow (minted from **Settings → RIA Agents** on the hub, shown once at mint time). The legacy shared `[wac] API_KEY` is still accepted by the hub for back-compat, but the CLI documents the per-user flow as preferred.
|
||||||
|
- **`ria-agent register` success output** — now prints both the hub-assigned agent ID and the chosen name: `Registered agent: <id> (<name>)`.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- **`ria-agent register` blocked by Cloudflare on hubs behind it** — set an explicit `User-Agent` (`ria-agent/<package-version> (+https://riahub.ai/qoherent/ria-toolkit-oss)`) so the request isn't rejected as `Python-urllib/<ver>` (Cloudflare Browser Integrity Check returns HTTP 403, edge error code 1010). Version is read from package metadata so it tracks releases automatically.
|
||||||
|
- **`ria-agent register` could hang indefinitely** — added a 15-second timeout to the hub request; previously `urllib`'s default of no timeout meant a stuck hub would block the CLI forever.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## [0.1.0] - 2026-02-20
|
## [0.1.0] - 2026-02-20
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
||||||
project = 'ria-toolkit-oss'
|
project = 'ria-toolkit-oss'
|
||||||
copyright = '2026, Qoherent Inc'
|
copyright = '2026, Qoherent Inc'
|
||||||
author = 'Qoherent Inc.'
|
author = 'Qoherent Inc.'
|
||||||
release = '0.1.6'
|
release = '0.1.7'
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||||
|
|
|
||||||
2
poetry.lock
generated
2
poetry.lock
generated
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alabaster"
|
name = "alabaster"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "ria-toolkit-oss"
|
name = "ria-toolkit-oss"
|
||||||
version = "0.1.6"
|
version = "0.1.7"
|
||||||
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
|
description = "An open-source version of the RIA Toolkit, including the fundamental tools to get started developing, testing, and deploying radio intelligence applications"
|
||||||
license = { text = "AGPL-3.0-only" }
|
license = { text = "AGPL-3.0-only" }
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,11 @@ Subcommands:
|
||||||
- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged).
|
- ``ria-agent run [legacy args]`` — legacy long-poll NodeAgent (unchanged).
|
||||||
- ``ria-agent stream`` — new WebSocket-based IQ streamer.
|
- ``ria-agent stream`` — new WebSocket-based IQ streamer.
|
||||||
- ``ria-agent detect`` — print SDR drivers whose modules import cleanly.
|
- ``ria-agent detect`` — print SDR drivers whose modules import cleanly.
|
||||||
- ``ria-agent register --hub URL --api-key KEY`` — register with the hub and
|
- ``ria-agent register --hub URL --api-key KEY`` — register with the hub
|
||||||
save credentials (and optional TX interlocks) to ``~/.ria/agent.json``.
|
using a personal registration key (minted from **Settings → RIA Agents**
|
||||||
|
on the hub, shown once at mint time) and save credentials (and optional
|
||||||
|
TX interlocks) to ``~/.ria/agent.json``. The hub also accepts the legacy
|
||||||
|
shared ``[wac] API_KEY`` for back-compat, but that path is deprecated.
|
||||||
|
|
||||||
Invoking ``ria-agent`` with no subcommand falls through to the legacy
|
Invoking ``ria-agent`` with no subcommand falls through to the legacy
|
||||||
long-poll behavior for back-compatibility with existing deployments.
|
long-poll behavior for back-compatibility with existing deployments.
|
||||||
|
|
@ -30,6 +33,70 @@ DEFAULT_HUB_URL = "https://riahub.ai"
|
||||||
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
_LEGACY_ALIASES = {"--hub", "--key", "--name", "--device", "--insecure", "--log-level", "--config"}
|
||||||
|
|
||||||
|
|
||||||
|
def _user_agent() -> str:
|
||||||
|
"""Build the User-Agent header for hub requests.
|
||||||
|
|
||||||
|
Set explicitly so we don't fall back to Python's default `Python-urllib/<ver>`,
|
||||||
|
which is blocked by Cloudflare's Browser Integrity Check on `riahub.ai`
|
||||||
|
(HTTP 403 edge code 1010). Version is read from package metadata so it
|
||||||
|
tracks releases instead of going stale.
|
||||||
|
"""
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
|
try:
|
||||||
|
pkg_version = version("ria-toolkit-oss")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
pkg_version = "unknown"
|
||||||
|
return f"ria-agent/{pkg_version} (+https://riahub.ai/qoherent/ria-toolkit-oss)"
|
||||||
|
|
||||||
|
|
||||||
|
# How long to wait on the hub before giving up. The register endpoint is a
|
||||||
|
# small DB lookup + insert; anything past this is a stuck hub, not a slow one.
|
||||||
|
_REGISTER_TIMEOUT_S = 15
|
||||||
|
|
||||||
|
|
||||||
|
REGISTRATION_REASON_MESSAGES = {
|
||||||
|
"invalid_key": ("Registration key not recognized. Generate a fresh key from " "Settings → RIA Agents on the hub."),
|
||||||
|
"expired": ("This registration key has expired. Generate a new one from " "Settings → RIA Agents on the hub."),
|
||||||
|
"revoked": ("This registration key was revoked. Generate a new one from " "Settings → RIA Agents on the hub."),
|
||||||
|
"already_consumed": (
|
||||||
|
"This single-use registration key has already been used. "
|
||||||
|
"Generate a new one, or mint a reusable key instead."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _explain_registration_failure(status: int, body: bytes) -> str:
|
||||||
|
"""Return a human-readable explanation for a failed register call."""
|
||||||
|
try:
|
||||||
|
parsed = json.loads(body) if body else None
|
||||||
|
except ValueError:
|
||||||
|
parsed = None
|
||||||
|
|
||||||
|
if status == 429:
|
||||||
|
# 429 carries a plain string detail, never a reason code.
|
||||||
|
if isinstance(parsed, dict) and parsed.get("detail"):
|
||||||
|
detail = parsed["detail"]
|
||||||
|
else:
|
||||||
|
detail = body.decode("utf-8", "replace") or "rate limited"
|
||||||
|
return f"Registration rate-limited by the hub: {detail}"
|
||||||
|
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
text = body.decode("utf-8", "replace")
|
||||||
|
return f"HTTP {status}: {text or 'no body'}"
|
||||||
|
|
||||||
|
detail = parsed.get("detail")
|
||||||
|
if isinstance(detail, dict):
|
||||||
|
reason = detail.get("reason")
|
||||||
|
if reason in REGISTRATION_REASON_MESSAGES:
|
||||||
|
return REGISTRATION_REASON_MESSAGES[reason]
|
||||||
|
if reason:
|
||||||
|
return f"Registration rejected ({reason})"
|
||||||
|
if isinstance(detail, str) and detail:
|
||||||
|
return f"Registration rejected: {detail}"
|
||||||
|
return f"HTTP {status}: {parsed}"
|
||||||
|
|
||||||
|
|
||||||
def _cmd_detect(_args: argparse.Namespace) -> int:
|
def _cmd_detect(_args: argparse.Namespace) -> int:
|
||||||
devices = available_devices()
|
devices = available_devices()
|
||||||
if not devices:
|
if not devices:
|
||||||
|
|
@ -41,6 +108,7 @@ def _cmd_detect(_args: argparse.Namespace) -> int:
|
||||||
|
|
||||||
|
|
||||||
def _cmd_register(args: argparse.Namespace) -> int:
|
def _cmd_register(args: argparse.Namespace) -> int:
|
||||||
|
import urllib.error
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
|
||||||
hub_url = args.hub.rstrip("/")
|
hub_url = args.hub.rstrip("/")
|
||||||
|
|
@ -53,11 +121,20 @@ def _cmd_register(args: argparse.Namespace) -> int:
|
||||||
headers={
|
headers={
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"X-API-Key": args.api_key,
|
"X-API-Key": args.api_key,
|
||||||
|
"User-Agent": _user_agent(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(req) as resp:
|
with urllib.request.urlopen(req, timeout=_REGISTER_TIMEOUT_S) as resp:
|
||||||
data = json.loads(resp.read())
|
data = json.loads(resp.read())
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
try:
|
||||||
|
err_body = e.read()
|
||||||
|
except Exception:
|
||||||
|
err_body = b""
|
||||||
|
msg = _explain_registration_failure(e.code, err_body)
|
||||||
|
print(f"error: registration failed: {msg}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"error: registration failed: {e}", file=sys.stderr)
|
print(f"error: registration failed: {e}", file=sys.stderr)
|
||||||
return 1
|
return 1
|
||||||
|
|
@ -82,7 +159,7 @@ def _cmd_register(args: argparse.Namespace) -> int:
|
||||||
cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges]
|
cfg.tx_allowed_freq_ranges = [[float(lo), float(hi)] for lo, hi in freq_ranges]
|
||||||
path = _config.save(cfg)
|
path = _config.save(cfg)
|
||||||
|
|
||||||
print(f"Registered agent: {agent_id}")
|
print(f"Registered agent: {agent_id} ({name})")
|
||||||
if cfg.tx_enabled:
|
if cfg.tx_enabled:
|
||||||
caps: list[str] = []
|
caps: list[str] = []
|
||||||
if cfg.tx_max_gain_db is not None:
|
if cfg.tx_max_gain_db is not None:
|
||||||
|
|
@ -143,7 +220,16 @@ def main() -> None:
|
||||||
|
|
||||||
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
p_reg = sub.add_parser("register", help="Register agent with RIA Hub and save credentials")
|
||||||
p_reg.add_argument("--hub", default=DEFAULT_HUB_URL, help=f"RIA Hub URL (default: {DEFAULT_HUB_URL})")
|
p_reg.add_argument("--hub", default=DEFAULT_HUB_URL, help=f"RIA Hub URL (default: {DEFAULT_HUB_URL})")
|
||||||
p_reg.add_argument("--api-key", dest="api_key", required=True, help="Hub API key")
|
p_reg.add_argument(
|
||||||
|
"--api-key",
|
||||||
|
dest="api_key",
|
||||||
|
required=True,
|
||||||
|
help=(
|
||||||
|
"Personal registration key from the RIA Agents page on the hub "
|
||||||
|
"(format: ria_reg_...). Shown once when generated; save it then. "
|
||||||
|
"The legacy shared API key is also accepted but deprecated."
|
||||||
|
),
|
||||||
|
)
|
||||||
p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
|
p_reg.add_argument("--name", default=None, help="Human-friendly agent name")
|
||||||
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
|
p_reg.add_argument("--insecure", action="store_true", help="Skip TLS verification")
|
||||||
p_reg.add_argument(
|
p_reg.add_argument(
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ This module contains the main group for the ria toolkit oss CLI.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
|
|
@ -49,7 +50,7 @@ def cli(ctx, verbose):
|
||||||
err=True,
|
err=True,
|
||||||
)
|
)
|
||||||
if ctx.invoked_subcommand is None:
|
if ctx.invoked_subcommand is None:
|
||||||
if lfs_missing:
|
if lfs_missing and sys.stdin.isatty():
|
||||||
click.pause(info="\nPress Enter to continue...", err=True)
|
click.pause(info="\nPress Enter to continue...", err=True)
|
||||||
click.echo(ctx.get_help())
|
click.echo(ctx.get_help())
|
||||||
|
|
||||||
|
|
|
||||||
97
src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py
Normal file
97
src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""Shared authentication and security helpers for RIA Hub API calls."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import subprocess
|
||||||
|
import urllib.error
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
DEFAULT_HUB = "https://riahub.ai"
|
||||||
|
|
||||||
|
|
||||||
|
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
||||||
|
"""Block redirects on authenticated requests to prevent credential exfiltration.
|
||||||
|
|
||||||
|
urllib re-sends the Authorization header on same-host redirects by default.
|
||||||
|
A malicious server could redirect a POST to a different host to harvest
|
||||||
|
credentials. We refuse all redirects — API clients should not encounter them
|
||||||
|
in normal operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
||||||
|
raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials")
|
||||||
|
|
||||||
|
|
||||||
|
def hub_opener() -> urllib.request.OpenerDirector:
|
||||||
|
"""Return a urllib opener that blocks redirects."""
|
||||||
|
return urllib.request.build_opener(_NoRedirectHandler)
|
||||||
|
|
||||||
|
|
||||||
|
def warn_if_insecure(hub: str) -> None:
|
||||||
|
"""Warn when credentials would be sent over plain HTTP to a non-localhost host."""
|
||||||
|
parsed = urllib.parse.urlparse(hub)
|
||||||
|
if parsed.scheme == "http":
|
||||||
|
host = parsed.hostname or ""
|
||||||
|
if host not in ("localhost", "127.0.0.1", "::1"):
|
||||||
|
click.echo(
|
||||||
|
f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.",
|
||||||
|
err=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def basic_auth(username: str, password: str) -> str:
|
||||||
|
return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
|
||||||
|
"""Ask git credential fill for stored creds. Returns (username, password) or (None, None)."""
|
||||||
|
parsed = urllib.parse.urlparse(hub_url)
|
||||||
|
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "credential", "fill"],
|
||||||
|
input=payload,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
creds = {}
|
||||||
|
for line in result.stdout.splitlines():
|
||||||
|
# Partition on the FIRST '=' only so passwords containing '=' are preserved.
|
||||||
|
k, sep, v = line.partition("=")
|
||||||
|
if sep:
|
||||||
|
creds[k.strip()] = v # keep value verbatim
|
||||||
|
return creds.get("username"), creds.get("password")
|
||||||
|
except Exception:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def store_credentials(hub_url: str, username: str, password: str) -> None:
|
||||||
|
"""Cache credentials via git credential approve (uses the system keychain/store)."""
|
||||||
|
parsed = urllib.parse.urlparse(hub_url)
|
||||||
|
payload = (
|
||||||
|
f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["git", "credential", "approve"],
|
||||||
|
input=payload,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # non-fatal — next push just prompts again
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_credentials(hub: str) -> tuple[str, str]:
|
||||||
|
"""Return (username, password), prompting interactively if not cached."""
|
||||||
|
username, password = get_stored_credentials(hub)
|
||||||
|
if username and password:
|
||||||
|
return username, password
|
||||||
|
click.echo(f"No stored credentials found for {hub}.")
|
||||||
|
username = click.prompt("RIA Hub username")
|
||||||
|
password = click.prompt("Password / personal access token", hide_input=True)
|
||||||
|
return username, password
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
"""ria setup_repo — create and configure a RIA Hub Project repo."""
|
"""ria setup_repo — create and configure a RIA Hub Project repo."""
|
||||||
|
|
||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
@ -12,6 +11,15 @@ import urllib.request
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from ._hub_auth import (
|
||||||
|
DEFAULT_HUB,
|
||||||
|
_NoRedirectHandler,
|
||||||
|
basic_auth,
|
||||||
|
resolve_credentials,
|
||||||
|
store_credentials,
|
||||||
|
warn_if_insecure,
|
||||||
|
)
|
||||||
|
|
||||||
RIA_LFS_RULES = [
|
RIA_LFS_RULES = [
|
||||||
("*.pt", "filter=lfs diff=lfs merge=lfs -text"),
|
("*.pt", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
("*.pth", "filter=lfs diff=lfs merge=lfs -text"),
|
("*.pth", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
|
|
@ -27,89 +35,15 @@ RIA_LFS_RULES = [
|
||||||
("*.pkl", "filter=lfs diff=lfs merge=lfs -text"),
|
("*.pkl", "filter=lfs diff=lfs merge=lfs -text"),
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_HUB = "https://riahub.ai"
|
|
||||||
|
|
||||||
# Repo names must be safe directory names and valid git remote path components.
|
# Repo names must be safe directory names and valid git remote path components.
|
||||||
_SAFE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,100}$")
|
_SAFE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,100}$")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Credential helpers — all credential I/O goes through git's own store
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
|
|
||||||
"""Ask git credential fill for stored creds. Returns (username, password) or (None, None)."""
|
|
||||||
parsed = urllib.parse.urlparse(hub_url)
|
|
||||||
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
["git", "credential", "fill"],
|
|
||||||
input=payload,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
creds = {}
|
|
||||||
for line in result.stdout.splitlines():
|
|
||||||
# partition on the FIRST '=' only so passwords containing '=' are preserved.
|
|
||||||
# Only strip whitespace from the key, not the value.
|
|
||||||
k, sep, v = line.partition("=")
|
|
||||||
if sep:
|
|
||||||
creds[k.strip()] = v # keep password value verbatim
|
|
||||||
return creds.get("username"), creds.get("password")
|
|
||||||
except Exception:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def _store_credentials(hub_url: str, username: str, password: str) -> None:
|
|
||||||
"""Cache credentials via git credential approve (uses the system keychain/store)."""
|
|
||||||
parsed = urllib.parse.urlparse(hub_url)
|
|
||||||
payload = (
|
|
||||||
f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
subprocess.run(
|
|
||||||
["git", "credential", "approve"],
|
|
||||||
input=payload,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # non-fatal — next push just prompts again
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_credentials(hub: str) -> tuple[str, str]:
|
|
||||||
"""Return (username, password), prompting interactively if not cached."""
|
|
||||||
username, password = _get_stored_credentials(hub)
|
|
||||||
if username and password:
|
|
||||||
return username, password
|
|
||||||
|
|
||||||
click.echo(f"No stored credentials found for {hub}.")
|
|
||||||
username = click.prompt("RIA Hub username")
|
|
||||||
password = click.prompt("Password / personal access token", hide_input=True)
|
|
||||||
return username, password
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# API helpers
|
# API helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
|
|
||||||
"""Block redirects on API requests to prevent credential exfiltration.
|
|
||||||
|
|
||||||
urllib follows redirects by default and re-sends the Authorization header
|
|
||||||
on same-host redirects. A malicious server could redirect a POST to a
|
|
||||||
different host to harvest credentials. We refuse all redirects — API
|
|
||||||
clients should not encounter them in normal operation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
|
||||||
raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials")
|
|
||||||
|
|
||||||
|
|
||||||
def _api_request(
|
def _api_request(
|
||||||
hub: str,
|
hub: str,
|
||||||
path: str,
|
path: str,
|
||||||
|
|
@ -129,10 +63,7 @@ def _api_request(
|
||||||
data = json.dumps(body).encode() if body is not None else None
|
data = json.dumps(body).encode() if body is not None else None
|
||||||
req = urllib.request.Request(url, data=data, method=method)
|
req = urllib.request.Request(url, data=data, method=method)
|
||||||
req.add_header("Content-Type", "application/json")
|
req.add_header("Content-Type", "application/json")
|
||||||
# Credentials are base64-encoded (not encrypted). Callers must ensure
|
req.add_header("Authorization", basic_auth(username, password))
|
||||||
# they only call this over HTTPS or localhost — enforced in _warn_if_insecure.
|
|
||||||
cred = base64.b64encode(f"{username}:{password}".encode()).decode()
|
|
||||||
req.add_header("Authorization", f"Basic {cred}")
|
|
||||||
|
|
||||||
opener = urllib.request.build_opener(_NoRedirectHandler)
|
opener = urllib.request.build_opener(_NoRedirectHandler)
|
||||||
try:
|
try:
|
||||||
|
|
@ -148,18 +79,6 @@ def _api_request(
|
||||||
return {"message": str(e.reason)}, 0
|
return {"message": str(e.reason)}, 0
|
||||||
|
|
||||||
|
|
||||||
def _warn_if_insecure(hub: str) -> None:
|
|
||||||
"""Warn when sending credentials over plain HTTP to a non-localhost host."""
|
|
||||||
parsed = urllib.parse.urlparse(hub)
|
|
||||||
if parsed.scheme == "http":
|
|
||||||
host = parsed.hostname or ""
|
|
||||||
if host not in ("localhost", "127.0.0.1", "::1"):
|
|
||||||
click.echo(
|
|
||||||
f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.",
|
|
||||||
err=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_authenticated_username(hub: str, username: str, password: str) -> str | None:
|
def _get_authenticated_username(hub: str, username: str, password: str) -> str | None:
|
||||||
"""Return the login name of the authenticated user from GET /api/v1/user.
|
"""Return the login name of the authenticated user from GET /api/v1/user.
|
||||||
|
|
||||||
|
|
@ -349,7 +268,7 @@ def _configure_remote(
|
||||||
click.echo(
|
click.echo(
|
||||||
f"Skipped remote setup. Add it manually:\n"
|
f"Skipped remote setup. Add it manually:\n"
|
||||||
f" git -C {repo_path} remote add origin "
|
f" git -C {repo_path} remote add origin "
|
||||||
f"{hub.rstrip('/')}/<owner>/{repo_name}.git"
|
f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
remote_url = f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
remote_url = f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
|
||||||
|
|
@ -422,9 +341,9 @@ def setup_repo(
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if not no_remote:
|
if not no_remote:
|
||||||
_warn_if_insecure(hub)
|
warn_if_insecure(hub)
|
||||||
|
|
||||||
username, password = (None, None) if no_remote else _resolve_credentials(hub)
|
username, password = (None, None) if no_remote else resolve_credentials(hub)
|
||||||
resolved_owner = _resolve_owner(hub, username, password, owner)
|
resolved_owner = _resolve_owner(hub, username, password, owner)
|
||||||
|
|
||||||
# newly_created=True means the server ran auto_init+is_ria and seeded
|
# newly_created=True means the server ran auto_init+is_ria and seeded
|
||||||
|
|
@ -436,7 +355,7 @@ def setup_repo(
|
||||||
click.echo(f"Repository '{resolved_owner}/{repo_name}' already exists on RIA Hub.")
|
click.echo(f"Repository '{resolved_owner}/{repo_name}' already exists on RIA Hub.")
|
||||||
else:
|
else:
|
||||||
newly_created = _create_repo_on_hub(hub, repo_name, username, password, private)
|
newly_created = _create_repo_on_hub(hub, repo_name, username, password, private)
|
||||||
_store_credentials(hub, username, password)
|
store_credentials(hub, username, password)
|
||||||
|
|
||||||
if not os.path.exists(repo_path):
|
if not os.path.exists(repo_path):
|
||||||
os.makedirs(repo_path)
|
os.makedirs(repo_path)
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,14 @@ import urllib.request
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from ._hub_auth import (
|
||||||
|
DEFAULT_HUB,
|
||||||
|
basic_auth,
|
||||||
|
hub_opener,
|
||||||
|
resolve_credentials,
|
||||||
|
warn_if_insecure,
|
||||||
|
)
|
||||||
|
|
||||||
# Read buffer for hashing and streaming — 8 MB keeps memory use flat
|
# Read buffer for hashing and streaming — 8 MB keeps memory use flat
|
||||||
# for arbitrarily large files.
|
# for arbitrarily large files.
|
||||||
_CHUNK = 8 * 1024 * 1024
|
_CHUNK = 8 * 1024 * 1024
|
||||||
|
|
@ -34,48 +42,6 @@ _CHUNK = 8 * 1024 * 1024
|
||||||
LFS_MEDIA_TYPE = "application/vnd.git-lfs+json"
|
LFS_MEDIA_TYPE = "application/vnd.git-lfs+json"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Credential helpers (reused from setup_repo pattern)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
parsed = urllib.parse.urlparse(hub_url)
|
|
||||||
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
["git", "credential", "fill"],
|
|
||||||
input=payload,
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
creds = {}
|
|
||||||
for line in result.stdout.splitlines():
|
|
||||||
k, sep, v = line.partition("=")
|
|
||||||
if sep:
|
|
||||||
creds[k.strip()] = v
|
|
||||||
return creds.get("username"), creds.get("password")
|
|
||||||
except Exception:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_credentials(hub: str) -> tuple[str, str]:
|
|
||||||
username, password = _get_stored_credentials(hub)
|
|
||||||
if username and password:
|
|
||||||
return username, password
|
|
||||||
click.echo(f"No stored credentials found for {hub}.")
|
|
||||||
username = click.prompt("RIA Hub username")
|
|
||||||
password = click.prompt("Password / personal access token", hide_input=True)
|
|
||||||
return username, password
|
|
||||||
|
|
||||||
|
|
||||||
def _basic_auth(username: str, password: str) -> str:
|
|
||||||
return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# File helpers
|
# File helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -134,64 +100,17 @@ def _lfs_batch(
|
||||||
req = urllib.request.Request(url, data=body, method="POST")
|
req = urllib.request.Request(url, data=body, method="POST")
|
||||||
req.add_header("Content-Type", LFS_MEDIA_TYPE)
|
req.add_header("Content-Type", LFS_MEDIA_TYPE)
|
||||||
req.add_header("Accept", LFS_MEDIA_TYPE)
|
req.add_header("Accept", LFS_MEDIA_TYPE)
|
||||||
req.add_header("Authorization", _basic_auth(username, password))
|
req.add_header("Authorization", basic_auth(username, password))
|
||||||
|
|
||||||
|
opener = hub_opener()
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
with opener.open(req, timeout=30) as resp:
|
||||||
return json.loads(resp.read())
|
return json.loads(resp.read())
|
||||||
except urllib.error.HTTPError as e:
|
except urllib.error.HTTPError as e:
|
||||||
body_text = e.read().decode(errors="replace")
|
body_text = e.read().decode(errors="replace")
|
||||||
raise RuntimeError(f"LFS batch request failed (HTTP {e.code}): {body_text}") from e
|
raise RuntimeError(f"LFS batch request failed (HTTP {e.code}): {body_text}") from e
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Streaming PUT upload
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _stream_upload(href: str, headers: dict, file_path: str, size: int) -> None:
|
|
||||||
"""
|
|
||||||
PUT file_path content to href, streaming in _CHUNK-sized pieces.
|
|
||||||
Uses http.client directly so Content-Length is set without buffering
|
|
||||||
the whole file in memory. Works for files of any size.
|
|
||||||
"""
|
|
||||||
parsed = urllib.parse.urlparse(href)
|
|
||||||
host = parsed.netloc
|
|
||||||
path = parsed.path
|
|
||||||
if parsed.query:
|
|
||||||
path += "?" + parsed.query
|
|
||||||
|
|
||||||
if parsed.scheme == "https":
|
|
||||||
conn = http.client.HTTPSConnection(host, timeout=300)
|
|
||||||
else:
|
|
||||||
conn = http.client.HTTPConnection(host, timeout=300)
|
|
||||||
|
|
||||||
all_headers = dict(headers or {})
|
|
||||||
all_headers.setdefault("Content-Type", "application/octet-stream")
|
|
||||||
all_headers["Content-Length"] = str(size)
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn.connect()
|
|
||||||
conn.putrequest("PUT", path)
|
|
||||||
for k, v in all_headers.items():
|
|
||||||
conn.putheader(k, v)
|
|
||||||
conn.endheaders()
|
|
||||||
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
while True:
|
|
||||||
chunk = f.read(_CHUNK)
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
conn.send(chunk)
|
|
||||||
|
|
||||||
resp = conn.getresponse()
|
|
||||||
resp.read() # drain
|
|
||||||
if resp.status not in (200, 201):
|
|
||||||
raise RuntimeError(f"LFS object upload failed: HTTP {resp.status}")
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Gitea contents API — create / update a file to record the LFS pointer
|
# Gitea contents API — create / update a file to record the LFS pointer
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -214,9 +133,9 @@ def _get_file_sha(
|
||||||
f"?ref={urllib.parse.quote(branch)}"
|
f"?ref={urllib.parse.quote(branch)}"
|
||||||
)
|
)
|
||||||
req = urllib.request.Request(url)
|
req = urllib.request.Request(url)
|
||||||
req.add_header("Authorization", _basic_auth(username, password))
|
req.add_header("Authorization", basic_auth(username, password))
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
with hub_opener().open(req, timeout=15) as resp:
|
||||||
return json.loads(resp.read()).get("sha")
|
return json.loads(resp.read()).get("sha")
|
||||||
except urllib.error.HTTPError as e:
|
except urllib.error.HTTPError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
|
|
@ -255,10 +174,10 @@ def _commit_lfs_pointer(
|
||||||
method = "PUT" if existing_sha else "POST"
|
method = "PUT" if existing_sha else "POST"
|
||||||
req = urllib.request.Request(url, data=json.dumps(body).encode(), method=method)
|
req = urllib.request.Request(url, data=json.dumps(body).encode(), method=method)
|
||||||
req.add_header("Content-Type", "application/json")
|
req.add_header("Content-Type", "application/json")
|
||||||
req.add_header("Authorization", _basic_auth(username, password))
|
req.add_header("Authorization", basic_auth(username, password))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
with hub_opener().open(req, timeout=30) as resp:
|
||||||
resp.read()
|
resp.read()
|
||||||
except urllib.error.HTTPError as e:
|
except urllib.error.HTTPError as e:
|
||||||
body_text = e.read().decode(errors="replace")
|
body_text = e.read().decode(errors="replace")
|
||||||
|
|
@ -375,6 +294,13 @@ def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int)
|
||||||
all_headers.setdefault("Content-Type", "application/octet-stream")
|
all_headers.setdefault("Content-Type", "application/octet-stream")
|
||||||
all_headers["Content-Length"] = str(size)
|
all_headers["Content-Length"] = str(size)
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn.connect()
|
||||||
|
conn.putrequest("PUT", path_q)
|
||||||
|
for k, v in all_headers.items():
|
||||||
|
conn.putheader(k, v)
|
||||||
|
conn.endheaders()
|
||||||
|
|
||||||
with click.progressbar(
|
with click.progressbar(
|
||||||
length=size,
|
length=size,
|
||||||
label=" ",
|
label=" ",
|
||||||
|
|
@ -384,12 +310,6 @@ def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int)
|
||||||
fill_char="█",
|
fill_char="█",
|
||||||
empty_char="░",
|
empty_char="░",
|
||||||
) as bar:
|
) as bar:
|
||||||
conn.connect()
|
|
||||||
conn.putrequest("PUT", path_q)
|
|
||||||
for k, v in all_headers.items():
|
|
||||||
conn.putheader(k, v)
|
|
||||||
conn.endheaders()
|
|
||||||
|
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
while True:
|
while True:
|
||||||
chunk = f.read(_CHUNK)
|
chunk = f.read(_CHUNK)
|
||||||
|
|
@ -400,10 +320,10 @@ def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int)
|
||||||
|
|
||||||
resp = conn.getresponse()
|
resp = conn.getresponse()
|
||||||
resp.read()
|
resp.read()
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if resp.status not in (200, 201):
|
if resp.status not in (200, 201):
|
||||||
raise RuntimeError(f"HTTP {resp.status}")
|
raise RuntimeError(f"HTTP {resp.status}")
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -416,7 +336,7 @@ def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--repo", required=True, metavar="OWNER/NAME", help="Target repository on RIA Hub (e.g. benchinnery/my-dataset)."
|
"--repo", required=True, metavar="OWNER/NAME", help="Target repository on RIA Hub (e.g. benchinnery/my-dataset)."
|
||||||
)
|
)
|
||||||
@click.option("--hub", default="https://riahub.ai", show_default=True, metavar="URL", help="RIA Hub base URL.")
|
@click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.")
|
||||||
@click.option("--branch", default="main", show_default=True, help="Branch to commit the files to.")
|
@click.option("--branch", default="main", show_default=True, help="Branch to commit the files to.")
|
||||||
@click.option(
|
@click.option(
|
||||||
"--path",
|
"--path",
|
||||||
|
|
@ -460,9 +380,9 @@ def upload(
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
resolved.append(os.path.abspath(pattern))
|
resolved.append(os.path.abspath(pattern))
|
||||||
|
|
||||||
# Credentials
|
|
||||||
hub = hub.rstrip("/")
|
hub = hub.rstrip("/")
|
||||||
username, password = _resolve_credentials(hub)
|
warn_if_insecure(hub)
|
||||||
|
username, password = resolve_credentials(hub)
|
||||||
|
|
||||||
click.echo(f"Uploading {len(resolved)} file(s) to {owner}/{repo_name} on {hub}...")
|
click.echo(f"Uploading {len(resolved)} file(s) to {owner}/{repo_name} on {hub}...")
|
||||||
|
|
||||||
|
|
|
||||||
145
tests/agent/test_cli_register_errors.py
Normal file
145
tests/agent/test_cli_register_errors.py
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
"""Structured error reporting for `ria-agent register` (T2)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import urllib.error
|
||||||
|
from io import BytesIO
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ria_toolkit_oss.agent import cli as agent_cli
|
||||||
|
|
||||||
|
|
||||||
|
def _structured(reason: str) -> bytes:
|
||||||
|
return json.dumps({"detail": {"reason": reason}}).encode()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"reason",
|
||||||
|
["invalid_key", "expired", "revoked", "already_consumed"],
|
||||||
|
)
|
||||||
|
def test_explain_maps_known_reasons(reason):
|
||||||
|
msg = agent_cli._explain_registration_failure(403, _structured(reason))
|
||||||
|
assert msg == agent_cli.REGISTRATION_REASON_MESSAGES[reason]
|
||||||
|
|
||||||
|
|
||||||
|
def test_explain_unknown_reason_falls_through_with_code():
|
||||||
|
msg = agent_cli._explain_registration_failure(403, _structured("brand_new_thing"))
|
||||||
|
assert "brand_new_thing" in msg
|
||||||
|
assert "rejected" in msg.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_explain_string_detail():
|
||||||
|
body = json.dumps({"detail": "Forbidden"}).encode()
|
||||||
|
msg = agent_cli._explain_registration_failure(403, body)
|
||||||
|
assert msg == "Registration rejected: Forbidden"
|
||||||
|
|
||||||
|
|
||||||
|
def test_explain_429_with_string_detail():
|
||||||
|
body = json.dumps({"detail": "Too many attempts; try again shortly"}).encode()
|
||||||
|
msg = agent_cli._explain_registration_failure(429, body)
|
||||||
|
assert "rate-limited" in msg
|
||||||
|
assert "Too many attempts" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_explain_429_with_no_body():
|
||||||
|
msg = agent_cli._explain_registration_failure(429, b"")
|
||||||
|
assert "rate-limited" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_explain_malformed_json():
|
||||||
|
msg = agent_cli._explain_registration_failure(500, b"<html>boom</html>")
|
||||||
|
assert msg.startswith("HTTP 500")
|
||||||
|
assert "boom" in msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_explain_empty_body():
|
||||||
|
msg = agent_cli._explain_registration_failure(502, b"")
|
||||||
|
assert msg == "HTTP 502: no body"
|
||||||
|
|
||||||
|
|
||||||
|
def _http_error(status: int, body: bytes) -> urllib.error.HTTPError:
|
||||||
|
return urllib.error.HTTPError(
|
||||||
|
url="http://hub/screens/agents/register",
|
||||||
|
code=status,
|
||||||
|
msg="",
|
||||||
|
hdrs=None, # type: ignore[arg-type]
|
||||||
|
fp=BytesIO(body),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_agent_is_set_and_not_python_default():
|
||||||
|
"""Cloudflare on `riahub.ai` returns 403 code 1010 to `Python-urllib/*`.
|
||||||
|
|
||||||
|
Guarding the UA explicitly is the entire point of the register-flow fix;
|
||||||
|
if this test ever breaks, the production bug is back.
|
||||||
|
"""
|
||||||
|
ua = agent_cli._user_agent()
|
||||||
|
assert ua, "User-Agent must not be empty"
|
||||||
|
assert not ua.lower().startswith(
|
||||||
|
"python-urllib"
|
||||||
|
), f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it"
|
||||||
|
assert ua.startswith("ria-agent/")
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_request_carries_explicit_user_agent(tmp_path):
|
||||||
|
"""Capture the outbound urllib Request and verify the UA header is set."""
|
||||||
|
cfg_path = tmp_path / "agent.json"
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
def _fake_urlopen(req, *args, **kwargs):
|
||||||
|
# urllib normalizes header names; get_header takes the title-cased form.
|
||||||
|
captured["ua"] = req.get_header("User-agent")
|
||||||
|
captured["api_key"] = req.get_header("X-api-key")
|
||||||
|
captured["timeout"] = kwargs.get("timeout")
|
||||||
|
raise urllib.error.HTTPError(
|
||||||
|
url=req.full_url,
|
||||||
|
code=403,
|
||||||
|
msg="",
|
||||||
|
hdrs=None, # type: ignore[arg-type]
|
||||||
|
fp=BytesIO(_structured("invalid_key")),
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False),
|
||||||
|
patch("urllib.request.urlopen", side_effect=_fake_urlopen),
|
||||||
|
patch.object(
|
||||||
|
sys,
|
||||||
|
"argv",
|
||||||
|
["ria-agent", "register", "--hub", "http://hub:3005", "--api-key", "ria_reg_x"],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
agent_cli.main()
|
||||||
|
|
||||||
|
assert captured["ua"], "User-Agent header was not sent"
|
||||||
|
assert not captured["ua"].lower().startswith("python-urllib")
|
||||||
|
assert captured["api_key"] == "ria_reg_x"
|
||||||
|
assert captured["timeout"] is not None, "register must pass a timeout to urlopen"
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_surfaces_reason_on_http_error(tmp_path, capsys):
|
||||||
|
cfg_path = tmp_path / "agent.json"
|
||||||
|
err = _http_error(403, _structured("revoked"))
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.dict("os.environ", {"RIA_AGENT_CONFIG": str(cfg_path)}, clear=False),
|
||||||
|
patch("urllib.request.urlopen", side_effect=err),
|
||||||
|
patch.object(
|
||||||
|
sys,
|
||||||
|
"argv",
|
||||||
|
["ria-agent", "register", "--hub", "http://hub:3005", "--api-key", "ria_reg_x"],
|
||||||
|
),
|
||||||
|
):
|
||||||
|
with pytest.raises(SystemExit) as exc:
|
||||||
|
agent_cli.main()
|
||||||
|
|
||||||
|
assert exc.value.code == 1
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "revoked" in captured.err.lower()
|
||||||
|
assert "Settings → RIA Agents" in captured.err
|
||||||
|
# Config must NOT be written on failure.
|
||||||
|
assert not cfg_path.exists()
|
||||||
Loading…
Reference in New Issue
Block a user