adressing feedback
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 16s
Build Project / Build Project (3.10) (pull_request) Successful in 57s
Build Project / Build Project (3.11) (pull_request) Successful in 54s
Build Project / Build Project (3.12) (pull_request) Successful in 2m17s
Test with tox / Test with tox (3.10) (pull_request) Successful in 2m51s
Test with tox / Test with tox (3.11) (pull_request) Successful in 4m44s
Test with tox / Test with tox (3.12) (pull_request) Successful in 8m39s

This commit is contained in:
ben 2026-06-04 15:49:44 -04:00
parent 5a67b39d22
commit 0b75013653
6 changed files with 161 additions and 230 deletions

View File

@ -56,18 +56,9 @@ _REGISTER_TIMEOUT_S = 15
REGISTRATION_REASON_MESSAGES = { REGISTRATION_REASON_MESSAGES = {
"invalid_key": ( "invalid_key": ("Registration key not recognized. Generate a fresh key from " "Settings → RIA Agents on the hub."),
"Registration key not recognized. Generate a fresh key from " "expired": ("This registration key has expired. Generate a new one from " "Settings → RIA Agents on the hub."),
"Settings → RIA Agents on the hub." "revoked": ("This registration key was revoked. Generate a new one from " "Settings → RIA Agents on the hub."),
),
"expired": (
"This registration key has expired. Generate a new one from "
"Settings → RIA Agents on the hub."
),
"revoked": (
"This registration key was revoked. Generate a new one from "
"Settings → RIA Agents on the hub."
),
"already_consumed": ( "already_consumed": (
"This single-use registration key has already been used. " "This single-use registration key has already been used. "
"Generate a new one, or mint a reusable key instead." "Generate a new one, or mint a reusable key instead."

View File

@ -3,6 +3,7 @@ This module contains the main group for the ria toolkit oss CLI.
""" """
import subprocess import subprocess
import sys
import warnings import warnings
warnings.filterwarnings( warnings.filterwarnings(
@ -49,7 +50,7 @@ def cli(ctx, verbose):
err=True, err=True,
) )
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
if lfs_missing: if lfs_missing and sys.stdin.isatty():
click.pause(info="\nPress Enter to continue...", err=True) click.pause(info="\nPress Enter to continue...", err=True)
click.echo(ctx.get_help()) click.echo(ctx.get_help())

View File

@ -0,0 +1,97 @@
"""Shared authentication and security helpers for RIA Hub API calls."""
import base64
import subprocess
import urllib.error
import urllib.parse
import urllib.request
import click
DEFAULT_HUB = "https://riahub.ai"
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
"""Block redirects on authenticated requests to prevent credential exfiltration.
urllib re-sends the Authorization header on same-host redirects by default.
A malicious server could redirect a POST to a different host to harvest
credentials. We refuse all redirects API clients should not encounter them
in normal operation.
"""
def redirect_request(self, req, fp, code, msg, headers, newurl):
raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials")
def hub_opener() -> urllib.request.OpenerDirector:
"""Return a urllib opener that blocks redirects."""
return urllib.request.build_opener(_NoRedirectHandler)
def warn_if_insecure(hub: str) -> None:
"""Warn when credentials would be sent over plain HTTP to a non-localhost host."""
parsed = urllib.parse.urlparse(hub)
if parsed.scheme == "http":
host = parsed.hostname or ""
if host not in ("localhost", "127.0.0.1", "::1"):
click.echo(
f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.",
err=True,
)
def basic_auth(username: str, password: str) -> str:
return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode()
def get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
"""Ask git credential fill for stored creds. Returns (username, password) or (None, None)."""
parsed = urllib.parse.urlparse(hub_url)
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
try:
result = subprocess.run(
["git", "credential", "fill"],
input=payload,
capture_output=True,
text=True,
timeout=5,
)
creds = {}
for line in result.stdout.splitlines():
# Partition on the FIRST '=' only so passwords containing '=' are preserved.
k, sep, v = line.partition("=")
if sep:
creds[k.strip()] = v # keep value verbatim
return creds.get("username"), creds.get("password")
except Exception:
return None, None
def store_credentials(hub_url: str, username: str, password: str) -> None:
"""Cache credentials via git credential approve (uses the system keychain/store)."""
parsed = urllib.parse.urlparse(hub_url)
payload = (
f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n"
)
try:
subprocess.run(
["git", "credential", "approve"],
input=payload,
capture_output=True,
text=True,
timeout=5,
)
except Exception:
pass # non-fatal — next push just prompts again
def resolve_credentials(hub: str) -> tuple[str, str]:
"""Return (username, password), prompting interactively if not cached."""
username, password = get_stored_credentials(hub)
if username and password:
return username, password
click.echo(f"No stored credentials found for {hub}.")
username = click.prompt("RIA Hub username")
password = click.prompt("Password / personal access token", hide_input=True)
return username, password

View File

@ -1,6 +1,5 @@
"""ria setup_repo — create and configure a RIA Hub Project repo.""" """ria setup_repo — create and configure a RIA Hub Project repo."""
import base64
import json import json
import os import os
import re import re
@ -12,6 +11,15 @@ import urllib.request
import click import click
from ._hub_auth import (
DEFAULT_HUB,
_NoRedirectHandler,
basic_auth,
resolve_credentials,
store_credentials,
warn_if_insecure,
)
RIA_LFS_RULES = [ RIA_LFS_RULES = [
("*.pt", "filter=lfs diff=lfs merge=lfs -text"), ("*.pt", "filter=lfs diff=lfs merge=lfs -text"),
("*.pth", "filter=lfs diff=lfs merge=lfs -text"), ("*.pth", "filter=lfs diff=lfs merge=lfs -text"),
@ -27,89 +35,15 @@ RIA_LFS_RULES = [
("*.pkl", "filter=lfs diff=lfs merge=lfs -text"), ("*.pkl", "filter=lfs diff=lfs merge=lfs -text"),
] ]
DEFAULT_HUB = "https://riahub.ai"
# Repo names must be safe directory names and valid git remote path components. # Repo names must be safe directory names and valid git remote path components.
_SAFE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,100}$") _SAFE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,100}$")
# ---------------------------------------------------------------------------
# Credential helpers — all credential I/O goes through git's own store
# ---------------------------------------------------------------------------
def _get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
"""Ask git credential fill for stored creds. Returns (username, password) or (None, None)."""
parsed = urllib.parse.urlparse(hub_url)
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
try:
result = subprocess.run(
["git", "credential", "fill"],
input=payload,
capture_output=True,
text=True,
timeout=5,
)
creds = {}
for line in result.stdout.splitlines():
# partition on the FIRST '=' only so passwords containing '=' are preserved.
# Only strip whitespace from the key, not the value.
k, sep, v = line.partition("=")
if sep:
creds[k.strip()] = v # keep password value verbatim
return creds.get("username"), creds.get("password")
except Exception:
return None, None
def _store_credentials(hub_url: str, username: str, password: str) -> None:
"""Cache credentials via git credential approve (uses the system keychain/store)."""
parsed = urllib.parse.urlparse(hub_url)
payload = (
f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n"
)
try:
subprocess.run(
["git", "credential", "approve"],
input=payload,
capture_output=True,
text=True,
timeout=5,
)
except Exception:
pass # non-fatal — next push just prompts again
def _resolve_credentials(hub: str) -> tuple[str, str]:
"""Return (username, password), prompting interactively if not cached."""
username, password = _get_stored_credentials(hub)
if username and password:
return username, password
click.echo(f"No stored credentials found for {hub}.")
username = click.prompt("RIA Hub username")
password = click.prompt("Password / personal access token", hide_input=True)
return username, password
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# API helpers # API helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class _NoRedirectHandler(urllib.request.HTTPRedirectHandler):
"""Block redirects on API requests to prevent credential exfiltration.
urllib follows redirects by default and re-sends the Authorization header
on same-host redirects. A malicious server could redirect a POST to a
different host to harvest credentials. We refuse all redirects API
clients should not encounter them in normal operation.
"""
def redirect_request(self, req, fp, code, msg, headers, newurl):
raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials")
def _api_request( def _api_request(
hub: str, hub: str,
path: str, path: str,
@ -129,10 +63,7 @@ def _api_request(
data = json.dumps(body).encode() if body is not None else None data = json.dumps(body).encode() if body is not None else None
req = urllib.request.Request(url, data=data, method=method) req = urllib.request.Request(url, data=data, method=method)
req.add_header("Content-Type", "application/json") req.add_header("Content-Type", "application/json")
# Credentials are base64-encoded (not encrypted). Callers must ensure req.add_header("Authorization", basic_auth(username, password))
# they only call this over HTTPS or localhost — enforced in _warn_if_insecure.
cred = base64.b64encode(f"{username}:{password}".encode()).decode()
req.add_header("Authorization", f"Basic {cred}")
opener = urllib.request.build_opener(_NoRedirectHandler) opener = urllib.request.build_opener(_NoRedirectHandler)
try: try:
@ -148,18 +79,6 @@ def _api_request(
return {"message": str(e.reason)}, 0 return {"message": str(e.reason)}, 0
def _warn_if_insecure(hub: str) -> None:
"""Warn when sending credentials over plain HTTP to a non-localhost host."""
parsed = urllib.parse.urlparse(hub)
if parsed.scheme == "http":
host = parsed.hostname or ""
if host not in ("localhost", "127.0.0.1", "::1"):
click.echo(
f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.",
err=True,
)
def _get_authenticated_username(hub: str, username: str, password: str) -> str | None: def _get_authenticated_username(hub: str, username: str, password: str) -> str | None:
"""Return the login name of the authenticated user from GET /api/v1/user. """Return the login name of the authenticated user from GET /api/v1/user.
@ -349,7 +268,7 @@ def _configure_remote(
click.echo( click.echo(
f"Skipped remote setup. Add it manually:\n" f"Skipped remote setup. Add it manually:\n"
f" git -C {repo_path} remote add origin " f" git -C {repo_path} remote add origin "
f"{hub.rstrip('/')}/<owner>/{repo_name}.git" f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
) )
return return
remote_url = f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git" remote_url = f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git"
@ -422,9 +341,9 @@ def setup_repo(
sys.exit(1) sys.exit(1)
if not no_remote: if not no_remote:
_warn_if_insecure(hub) warn_if_insecure(hub)
username, password = (None, None) if no_remote else _resolve_credentials(hub) username, password = (None, None) if no_remote else resolve_credentials(hub)
resolved_owner = _resolve_owner(hub, username, password, owner) resolved_owner = _resolve_owner(hub, username, password, owner)
# newly_created=True means the server ran auto_init+is_ria and seeded # newly_created=True means the server ran auto_init+is_ria and seeded
@ -436,7 +355,7 @@ def setup_repo(
click.echo(f"Repository '{resolved_owner}/{repo_name}' already exists on RIA Hub.") click.echo(f"Repository '{resolved_owner}/{repo_name}' already exists on RIA Hub.")
else: else:
newly_created = _create_repo_on_hub(hub, repo_name, username, password, private) newly_created = _create_repo_on_hub(hub, repo_name, username, password, private)
_store_credentials(hub, username, password) store_credentials(hub, username, password)
if not os.path.exists(repo_path): if not os.path.exists(repo_path):
os.makedirs(repo_path) os.makedirs(repo_path)

View File

@ -27,6 +27,14 @@ import urllib.request
import click import click
from ._hub_auth import (
DEFAULT_HUB,
basic_auth,
hub_opener,
resolve_credentials,
warn_if_insecure,
)
# Read buffer for hashing and streaming — 8 MB keeps memory use flat # Read buffer for hashing and streaming — 8 MB keeps memory use flat
# for arbitrarily large files. # for arbitrarily large files.
_CHUNK = 8 * 1024 * 1024 _CHUNK = 8 * 1024 * 1024
@ -34,48 +42,6 @@ _CHUNK = 8 * 1024 * 1024
LFS_MEDIA_TYPE = "application/vnd.git-lfs+json" LFS_MEDIA_TYPE = "application/vnd.git-lfs+json"
# ---------------------------------------------------------------------------
# Credential helpers (reused from setup_repo pattern)
# ---------------------------------------------------------------------------
def _get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]:
import subprocess
parsed = urllib.parse.urlparse(hub_url)
payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n"
try:
result = subprocess.run(
["git", "credential", "fill"],
input=payload,
capture_output=True,
text=True,
timeout=5,
)
creds = {}
for line in result.stdout.splitlines():
k, sep, v = line.partition("=")
if sep:
creds[k.strip()] = v
return creds.get("username"), creds.get("password")
except Exception:
return None, None
def _resolve_credentials(hub: str) -> tuple[str, str]:
username, password = _get_stored_credentials(hub)
if username and password:
return username, password
click.echo(f"No stored credentials found for {hub}.")
username = click.prompt("RIA Hub username")
password = click.prompt("Password / personal access token", hide_input=True)
return username, password
def _basic_auth(username: str, password: str) -> str:
return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# File helpers # File helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -134,64 +100,17 @@ def _lfs_batch(
req = urllib.request.Request(url, data=body, method="POST") req = urllib.request.Request(url, data=body, method="POST")
req.add_header("Content-Type", LFS_MEDIA_TYPE) req.add_header("Content-Type", LFS_MEDIA_TYPE)
req.add_header("Accept", LFS_MEDIA_TYPE) req.add_header("Accept", LFS_MEDIA_TYPE)
req.add_header("Authorization", _basic_auth(username, password)) req.add_header("Authorization", basic_auth(username, password))
opener = hub_opener()
try: try:
with urllib.request.urlopen(req, timeout=30) as resp: with opener.open(req, timeout=30) as resp:
return json.loads(resp.read()) return json.loads(resp.read())
except urllib.error.HTTPError as e: except urllib.error.HTTPError as e:
body_text = e.read().decode(errors="replace") body_text = e.read().decode(errors="replace")
raise RuntimeError(f"LFS batch request failed (HTTP {e.code}): {body_text}") from e raise RuntimeError(f"LFS batch request failed (HTTP {e.code}): {body_text}") from e
# ---------------------------------------------------------------------------
# Streaming PUT upload
# ---------------------------------------------------------------------------
def _stream_upload(href: str, headers: dict, file_path: str, size: int) -> None:
"""
PUT file_path content to href, streaming in _CHUNK-sized pieces.
Uses http.client directly so Content-Length is set without buffering
the whole file in memory. Works for files of any size.
"""
parsed = urllib.parse.urlparse(href)
host = parsed.netloc
path = parsed.path
if parsed.query:
path += "?" + parsed.query
if parsed.scheme == "https":
conn = http.client.HTTPSConnection(host, timeout=300)
else:
conn = http.client.HTTPConnection(host, timeout=300)
all_headers = dict(headers or {})
all_headers.setdefault("Content-Type", "application/octet-stream")
all_headers["Content-Length"] = str(size)
try:
conn.connect()
conn.putrequest("PUT", path)
for k, v in all_headers.items():
conn.putheader(k, v)
conn.endheaders()
with open(file_path, "rb") as f:
while True:
chunk = f.read(_CHUNK)
if not chunk:
break
conn.send(chunk)
resp = conn.getresponse()
resp.read() # drain
if resp.status not in (200, 201):
raise RuntimeError(f"LFS object upload failed: HTTP {resp.status}")
finally:
conn.close()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Gitea contents API — create / update a file to record the LFS pointer # Gitea contents API — create / update a file to record the LFS pointer
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -214,9 +133,9 @@ def _get_file_sha(
f"?ref={urllib.parse.quote(branch)}" f"?ref={urllib.parse.quote(branch)}"
) )
req = urllib.request.Request(url) req = urllib.request.Request(url)
req.add_header("Authorization", _basic_auth(username, password)) req.add_header("Authorization", basic_auth(username, password))
try: try:
with urllib.request.urlopen(req, timeout=15) as resp: with hub_opener().open(req, timeout=15) as resp:
return json.loads(resp.read()).get("sha") return json.loads(resp.read()).get("sha")
except urllib.error.HTTPError as e: except urllib.error.HTTPError as e:
if e.code == 404: if e.code == 404:
@ -255,10 +174,10 @@ def _commit_lfs_pointer(
method = "PUT" if existing_sha else "POST" method = "PUT" if existing_sha else "POST"
req = urllib.request.Request(url, data=json.dumps(body).encode(), method=method) req = urllib.request.Request(url, data=json.dumps(body).encode(), method=method)
req.add_header("Content-Type", "application/json") req.add_header("Content-Type", "application/json")
req.add_header("Authorization", _basic_auth(username, password)) req.add_header("Authorization", basic_auth(username, password))
try: try:
with urllib.request.urlopen(req, timeout=30) as resp: with hub_opener().open(req, timeout=30) as resp:
resp.read() resp.read()
except urllib.error.HTTPError as e: except urllib.error.HTTPError as e:
body_text = e.read().decode(errors="replace") body_text = e.read().decode(errors="replace")
@ -375,36 +294,37 @@ def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int)
all_headers.setdefault("Content-Type", "application/octet-stream") all_headers.setdefault("Content-Type", "application/octet-stream")
all_headers["Content-Length"] = str(size) all_headers["Content-Length"] = str(size)
with click.progressbar( try:
length=size,
label=" ",
width=40,
show_eta=True,
show_percent=True,
fill_char="",
empty_char="",
) as bar:
conn.connect() conn.connect()
conn.putrequest("PUT", path_q) conn.putrequest("PUT", path_q)
for k, v in all_headers.items(): for k, v in all_headers.items():
conn.putheader(k, v) conn.putheader(k, v)
conn.endheaders() conn.endheaders()
with open(file_path, "rb") as f: with click.progressbar(
while True: length=size,
chunk = f.read(_CHUNK) label=" ",
if not chunk: width=40,
break show_eta=True,
conn.send(chunk) show_percent=True,
bar.update(len(chunk)) fill_char="",
empty_char="",
) as bar:
with open(file_path, "rb") as f:
while True:
chunk = f.read(_CHUNK)
if not chunk:
break
conn.send(chunk)
bar.update(len(chunk))
resp = conn.getresponse() resp = conn.getresponse()
resp.read() resp.read()
if resp.status not in (200, 201):
raise RuntimeError(f"HTTP {resp.status}")
finally:
conn.close() conn.close()
if resp.status not in (200, 201):
raise RuntimeError(f"HTTP {resp.status}")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Command # Command
@ -416,7 +336,7 @@ def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int)
@click.option( @click.option(
"--repo", required=True, metavar="OWNER/NAME", help="Target repository on RIA Hub (e.g. benchinnery/my-dataset)." "--repo", required=True, metavar="OWNER/NAME", help="Target repository on RIA Hub (e.g. benchinnery/my-dataset)."
) )
@click.option("--hub", default="https://riahub.ai", show_default=True, metavar="URL", help="RIA Hub base URL.") @click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.")
@click.option("--branch", default="main", show_default=True, help="Branch to commit the files to.") @click.option("--branch", default="main", show_default=True, help="Branch to commit the files to.")
@click.option( @click.option(
"--path", "--path",
@ -460,9 +380,9 @@ def upload(
sys.exit(1) sys.exit(1)
resolved.append(os.path.abspath(pattern)) resolved.append(os.path.abspath(pattern))
# Credentials
hub = hub.rstrip("/") hub = hub.rstrip("/")
username, password = _resolve_credentials(hub) warn_if_insecure(hub)
username, password = resolve_credentials(hub)
click.echo(f"Uploading {len(resolved)} file(s) to {owner}/{repo_name} on {hub}...") click.echo(f"Uploading {len(resolved)} file(s) to {owner}/{repo_name} on {hub}...")

View File

@ -79,9 +79,9 @@ def test_user_agent_is_set_and_not_python_default():
""" """
ua = agent_cli._user_agent() ua = agent_cli._user_agent()
assert ua, "User-Agent must not be empty" assert ua, "User-Agent must not be empty"
assert not ua.lower().startswith("python-urllib"), ( assert not ua.lower().startswith(
f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it" "python-urllib"
) ), f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it"
assert ua.startswith("ria-agent/") assert ua.startswith("ria-agent/")
@ -96,7 +96,10 @@ def test_register_request_carries_explicit_user_agent(tmp_path):
captured["api_key"] = req.get_header("X-api-key") captured["api_key"] = req.get_header("X-api-key")
captured["timeout"] = kwargs.get("timeout") captured["timeout"] = kwargs.get("timeout")
raise urllib.error.HTTPError( raise urllib.error.HTTPError(
url=req.full_url, code=403, msg="", hdrs=None, # type: ignore[arg-type] url=req.full_url,
code=403,
msg="",
hdrs=None, # type: ignore[arg-type]
fp=BytesIO(_structured("invalid_key")), fp=BytesIO(_structured("invalid_key")),
) )