From 0b750136539a01cb98708f6313c537565627c650 Mon Sep 17 00:00:00 2001 From: ben Date: Thu, 4 Jun 2026 15:49:44 -0400 Subject: [PATCH] adressing feedback --- src/ria_toolkit_oss/agent/cli.py | 15 +- src/ria_toolkit_oss_cli/cli.py | 3 +- .../ria_toolkit_oss/_hub_auth.py | 97 +++++++++++ .../ria_toolkit_oss/setup_repo.py | 109 ++---------- .../ria_toolkit_oss/upload.py | 156 +++++------------- tests/agent/test_cli_register_errors.py | 11 +- 6 files changed, 161 insertions(+), 230 deletions(-) create mode 100644 src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py diff --git a/src/ria_toolkit_oss/agent/cli.py b/src/ria_toolkit_oss/agent/cli.py index e7626c2..e410be1 100644 --- a/src/ria_toolkit_oss/agent/cli.py +++ b/src/ria_toolkit_oss/agent/cli.py @@ -56,18 +56,9 @@ _REGISTER_TIMEOUT_S = 15 REGISTRATION_REASON_MESSAGES = { - "invalid_key": ( - "Registration key not recognized. Generate a fresh key from " - "Settings → RIA Agents on the hub." - ), - "expired": ( - "This registration key has expired. Generate a new one from " - "Settings → RIA Agents on the hub." - ), - "revoked": ( - "This registration key was revoked. Generate a new one from " - "Settings → RIA Agents on the hub." - ), + "invalid_key": ("Registration key not recognized. Generate a fresh key from " "Settings → RIA Agents on the hub."), + "expired": ("This registration key has expired. Generate a new one from " "Settings → RIA Agents on the hub."), + "revoked": ("This registration key was revoked. Generate a new one from " "Settings → RIA Agents on the hub."), "already_consumed": ( "This single-use registration key has already been used. " "Generate a new one, or mint a reusable key instead." diff --git a/src/ria_toolkit_oss_cli/cli.py b/src/ria_toolkit_oss_cli/cli.py index b24e136..a1d92a5 100644 --- a/src/ria_toolkit_oss_cli/cli.py +++ b/src/ria_toolkit_oss_cli/cli.py @@ -3,6 +3,7 @@ This module contains the main group for the ria toolkit oss CLI. """ import subprocess +import sys import warnings warnings.filterwarnings( @@ -49,7 +50,7 @@ def cli(ctx, verbose): err=True, ) if ctx.invoked_subcommand is None: - if lfs_missing: + if lfs_missing and sys.stdin.isatty(): click.pause(info="\nPress Enter to continue...", err=True) click.echo(ctx.get_help()) diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py new file mode 100644 index 0000000..5b62635 --- /dev/null +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/_hub_auth.py @@ -0,0 +1,97 @@ +"""Shared authentication and security helpers for RIA Hub API calls.""" + +import base64 +import subprocess +import urllib.error +import urllib.parse +import urllib.request + +import click + +DEFAULT_HUB = "https://riahub.ai" + + +class _NoRedirectHandler(urllib.request.HTTPRedirectHandler): + """Block redirects on authenticated requests to prevent credential exfiltration. + + urllib re-sends the Authorization header on same-host redirects by default. + A malicious server could redirect a POST to a different host to harvest + credentials. We refuse all redirects — API clients should not encounter them + in normal operation. + """ + + def redirect_request(self, req, fp, code, msg, headers, newurl): + raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials") + + +def hub_opener() -> urllib.request.OpenerDirector: + """Return a urllib opener that blocks redirects.""" + return urllib.request.build_opener(_NoRedirectHandler) + + +def warn_if_insecure(hub: str) -> None: + """Warn when credentials would be sent over plain HTTP to a non-localhost host.""" + parsed = urllib.parse.urlparse(hub) + if parsed.scheme == "http": + host = parsed.hostname or "" + if host not in ("localhost", "127.0.0.1", "::1"): + click.echo( + f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.", + err=True, + ) + + +def basic_auth(username: str, password: str) -> str: + return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode() + + +def get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]: + """Ask git credential fill for stored creds. Returns (username, password) or (None, None).""" + parsed = urllib.parse.urlparse(hub_url) + payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n" + try: + result = subprocess.run( + ["git", "credential", "fill"], + input=payload, + capture_output=True, + text=True, + timeout=5, + ) + creds = {} + for line in result.stdout.splitlines(): + # Partition on the FIRST '=' only so passwords containing '=' are preserved. + k, sep, v = line.partition("=") + if sep: + creds[k.strip()] = v # keep value verbatim + return creds.get("username"), creds.get("password") + except Exception: + return None, None + + +def store_credentials(hub_url: str, username: str, password: str) -> None: + """Cache credentials via git credential approve (uses the system keychain/store).""" + parsed = urllib.parse.urlparse(hub_url) + payload = ( + f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n" + ) + try: + subprocess.run( + ["git", "credential", "approve"], + input=payload, + capture_output=True, + text=True, + timeout=5, + ) + except Exception: + pass # non-fatal — next push just prompts again + + +def resolve_credentials(hub: str) -> tuple[str, str]: + """Return (username, password), prompting interactively if not cached.""" + username, password = get_stored_credentials(hub) + if username and password: + return username, password + click.echo(f"No stored credentials found for {hub}.") + username = click.prompt("RIA Hub username") + password = click.prompt("Password / personal access token", hide_input=True) + return username, password diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/setup_repo.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/setup_repo.py index 34d7496..b826ec3 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/setup_repo.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/setup_repo.py @@ -1,6 +1,5 @@ """ria setup_repo — create and configure a RIA Hub Project repo.""" -import base64 import json import os import re @@ -12,6 +11,15 @@ import urllib.request import click +from ._hub_auth import ( + DEFAULT_HUB, + _NoRedirectHandler, + basic_auth, + resolve_credentials, + store_credentials, + warn_if_insecure, +) + RIA_LFS_RULES = [ ("*.pt", "filter=lfs diff=lfs merge=lfs -text"), ("*.pth", "filter=lfs diff=lfs merge=lfs -text"), @@ -27,89 +35,15 @@ RIA_LFS_RULES = [ ("*.pkl", "filter=lfs diff=lfs merge=lfs -text"), ] -DEFAULT_HUB = "https://riahub.ai" - # Repo names must be safe directory names and valid git remote path components. _SAFE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,100}$") -# --------------------------------------------------------------------------- -# Credential helpers — all credential I/O goes through git's own store -# --------------------------------------------------------------------------- - - -def _get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]: - """Ask git credential fill for stored creds. Returns (username, password) or (None, None).""" - parsed = urllib.parse.urlparse(hub_url) - payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n" - try: - result = subprocess.run( - ["git", "credential", "fill"], - input=payload, - capture_output=True, - text=True, - timeout=5, - ) - creds = {} - for line in result.stdout.splitlines(): - # partition on the FIRST '=' only so passwords containing '=' are preserved. - # Only strip whitespace from the key, not the value. - k, sep, v = line.partition("=") - if sep: - creds[k.strip()] = v # keep password value verbatim - return creds.get("username"), creds.get("password") - except Exception: - return None, None - - -def _store_credentials(hub_url: str, username: str, password: str) -> None: - """Cache credentials via git credential approve (uses the system keychain/store).""" - parsed = urllib.parse.urlparse(hub_url) - payload = ( - f"protocol={parsed.scheme}\n" f"host={parsed.netloc}\n" f"username={username}\n" f"password={password}\n\n" - ) - try: - subprocess.run( - ["git", "credential", "approve"], - input=payload, - capture_output=True, - text=True, - timeout=5, - ) - except Exception: - pass # non-fatal — next push just prompts again - - -def _resolve_credentials(hub: str) -> tuple[str, str]: - """Return (username, password), prompting interactively if not cached.""" - username, password = _get_stored_credentials(hub) - if username and password: - return username, password - - click.echo(f"No stored credentials found for {hub}.") - username = click.prompt("RIA Hub username") - password = click.prompt("Password / personal access token", hide_input=True) - return username, password - - # --------------------------------------------------------------------------- # API helpers # --------------------------------------------------------------------------- -class _NoRedirectHandler(urllib.request.HTTPRedirectHandler): - """Block redirects on API requests to prevent credential exfiltration. - - urllib follows redirects by default and re-sends the Authorization header - on same-host redirects. A malicious server could redirect a POST to a - different host to harvest credentials. We refuse all redirects — API - clients should not encounter them in normal operation. - """ - - def redirect_request(self, req, fp, code, msg, headers, newurl): - raise urllib.error.URLError(f"Unexpected redirect ({code}) to {newurl} — aborting to protect credentials") - - def _api_request( hub: str, path: str, @@ -129,10 +63,7 @@ def _api_request( data = json.dumps(body).encode() if body is not None else None req = urllib.request.Request(url, data=data, method=method) req.add_header("Content-Type", "application/json") - # Credentials are base64-encoded (not encrypted). Callers must ensure - # they only call this over HTTPS or localhost — enforced in _warn_if_insecure. - cred = base64.b64encode(f"{username}:{password}".encode()).decode() - req.add_header("Authorization", f"Basic {cred}") + req.add_header("Authorization", basic_auth(username, password)) opener = urllib.request.build_opener(_NoRedirectHandler) try: @@ -148,18 +79,6 @@ def _api_request( return {"message": str(e.reason)}, 0 -def _warn_if_insecure(hub: str) -> None: - """Warn when sending credentials over plain HTTP to a non-localhost host.""" - parsed = urllib.parse.urlparse(hub) - if parsed.scheme == "http": - host = parsed.hostname or "" - if host not in ("localhost", "127.0.0.1", "::1"): - click.echo( - f"Warning: sending credentials over plain HTTP to {host}. " "Use HTTPS in production.", - err=True, - ) - - def _get_authenticated_username(hub: str, username: str, password: str) -> str | None: """Return the login name of the authenticated user from GET /api/v1/user. @@ -349,7 +268,7 @@ def _configure_remote( click.echo( f"Skipped remote setup. Add it manually:\n" f" git -C {repo_path} remote add origin " - f"{hub.rstrip('/')}//{repo_name}.git" + f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git" ) return remote_url = f"{hub.rstrip('/')}/{resolved_owner}/{repo_name}.git" @@ -422,9 +341,9 @@ def setup_repo( sys.exit(1) if not no_remote: - _warn_if_insecure(hub) + warn_if_insecure(hub) - username, password = (None, None) if no_remote else _resolve_credentials(hub) + username, password = (None, None) if no_remote else resolve_credentials(hub) resolved_owner = _resolve_owner(hub, username, password, owner) # newly_created=True means the server ran auto_init+is_ria and seeded @@ -436,7 +355,7 @@ def setup_repo( click.echo(f"Repository '{resolved_owner}/{repo_name}' already exists on RIA Hub.") else: newly_created = _create_repo_on_hub(hub, repo_name, username, password, private) - _store_credentials(hub, username, password) + store_credentials(hub, username, password) if not os.path.exists(repo_path): os.makedirs(repo_path) diff --git a/src/ria_toolkit_oss_cli/ria_toolkit_oss/upload.py b/src/ria_toolkit_oss_cli/ria_toolkit_oss/upload.py index 40ad07d..10f04c7 100644 --- a/src/ria_toolkit_oss_cli/ria_toolkit_oss/upload.py +++ b/src/ria_toolkit_oss_cli/ria_toolkit_oss/upload.py @@ -27,6 +27,14 @@ import urllib.request import click +from ._hub_auth import ( + DEFAULT_HUB, + basic_auth, + hub_opener, + resolve_credentials, + warn_if_insecure, +) + # Read buffer for hashing and streaming — 8 MB keeps memory use flat # for arbitrarily large files. _CHUNK = 8 * 1024 * 1024 @@ -34,48 +42,6 @@ _CHUNK = 8 * 1024 * 1024 LFS_MEDIA_TYPE = "application/vnd.git-lfs+json" -# --------------------------------------------------------------------------- -# Credential helpers (reused from setup_repo pattern) -# --------------------------------------------------------------------------- - - -def _get_stored_credentials(hub_url: str) -> tuple[str | None, str | None]: - import subprocess - - parsed = urllib.parse.urlparse(hub_url) - payload = f"protocol={parsed.scheme}\nhost={parsed.netloc}\n\n" - try: - result = subprocess.run( - ["git", "credential", "fill"], - input=payload, - capture_output=True, - text=True, - timeout=5, - ) - creds = {} - for line in result.stdout.splitlines(): - k, sep, v = line.partition("=") - if sep: - creds[k.strip()] = v - return creds.get("username"), creds.get("password") - except Exception: - return None, None - - -def _resolve_credentials(hub: str) -> tuple[str, str]: - username, password = _get_stored_credentials(hub) - if username and password: - return username, password - click.echo(f"No stored credentials found for {hub}.") - username = click.prompt("RIA Hub username") - password = click.prompt("Password / personal access token", hide_input=True) - return username, password - - -def _basic_auth(username: str, password: str) -> str: - return "Basic " + base64.b64encode(f"{username}:{password}".encode()).decode() - - # --------------------------------------------------------------------------- # File helpers # --------------------------------------------------------------------------- @@ -134,64 +100,17 @@ def _lfs_batch( req = urllib.request.Request(url, data=body, method="POST") req.add_header("Content-Type", LFS_MEDIA_TYPE) req.add_header("Accept", LFS_MEDIA_TYPE) - req.add_header("Authorization", _basic_auth(username, password)) + req.add_header("Authorization", basic_auth(username, password)) + opener = hub_opener() try: - with urllib.request.urlopen(req, timeout=30) as resp: + with opener.open(req, timeout=30) as resp: return json.loads(resp.read()) except urllib.error.HTTPError as e: body_text = e.read().decode(errors="replace") raise RuntimeError(f"LFS batch request failed (HTTP {e.code}): {body_text}") from e -# --------------------------------------------------------------------------- -# Streaming PUT upload -# --------------------------------------------------------------------------- - - -def _stream_upload(href: str, headers: dict, file_path: str, size: int) -> None: - """ - PUT file_path content to href, streaming in _CHUNK-sized pieces. - Uses http.client directly so Content-Length is set without buffering - the whole file in memory. Works for files of any size. - """ - parsed = urllib.parse.urlparse(href) - host = parsed.netloc - path = parsed.path - if parsed.query: - path += "?" + parsed.query - - if parsed.scheme == "https": - conn = http.client.HTTPSConnection(host, timeout=300) - else: - conn = http.client.HTTPConnection(host, timeout=300) - - all_headers = dict(headers or {}) - all_headers.setdefault("Content-Type", "application/octet-stream") - all_headers["Content-Length"] = str(size) - - try: - conn.connect() - conn.putrequest("PUT", path) - for k, v in all_headers.items(): - conn.putheader(k, v) - conn.endheaders() - - with open(file_path, "rb") as f: - while True: - chunk = f.read(_CHUNK) - if not chunk: - break - conn.send(chunk) - - resp = conn.getresponse() - resp.read() # drain - if resp.status not in (200, 201): - raise RuntimeError(f"LFS object upload failed: HTTP {resp.status}") - finally: - conn.close() - - # --------------------------------------------------------------------------- # Gitea contents API — create / update a file to record the LFS pointer # --------------------------------------------------------------------------- @@ -214,9 +133,9 @@ def _get_file_sha( f"?ref={urllib.parse.quote(branch)}" ) req = urllib.request.Request(url) - req.add_header("Authorization", _basic_auth(username, password)) + req.add_header("Authorization", basic_auth(username, password)) try: - with urllib.request.urlopen(req, timeout=15) as resp: + with hub_opener().open(req, timeout=15) as resp: return json.loads(resp.read()).get("sha") except urllib.error.HTTPError as e: if e.code == 404: @@ -255,10 +174,10 @@ def _commit_lfs_pointer( method = "PUT" if existing_sha else "POST" req = urllib.request.Request(url, data=json.dumps(body).encode(), method=method) req.add_header("Content-Type", "application/json") - req.add_header("Authorization", _basic_auth(username, password)) + req.add_header("Authorization", basic_auth(username, password)) try: - with urllib.request.urlopen(req, timeout=30) as resp: + with hub_opener().open(req, timeout=30) as resp: resp.read() except urllib.error.HTTPError as e: body_text = e.read().decode(errors="replace") @@ -375,36 +294,37 @@ def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int) all_headers.setdefault("Content-Type", "application/octet-stream") all_headers["Content-Length"] = str(size) - with click.progressbar( - length=size, - label=" ", - width=40, - show_eta=True, - show_percent=True, - fill_char="█", - empty_char="░", - ) as bar: + try: conn.connect() conn.putrequest("PUT", path_q) for k, v in all_headers.items(): conn.putheader(k, v) conn.endheaders() - with open(file_path, "rb") as f: - while True: - chunk = f.read(_CHUNK) - if not chunk: - break - conn.send(chunk) - bar.update(len(chunk)) + with click.progressbar( + length=size, + label=" ", + width=40, + show_eta=True, + show_percent=True, + fill_char="█", + empty_char="░", + ) as bar: + with open(file_path, "rb") as f: + while True: + chunk = f.read(_CHUNK) + if not chunk: + break + conn.send(chunk) + bar.update(len(chunk)) resp = conn.getresponse() resp.read() + if resp.status not in (200, 201): + raise RuntimeError(f"HTTP {resp.status}") + finally: conn.close() - if resp.status not in (200, 201): - raise RuntimeError(f"HTTP {resp.status}") - # --------------------------------------------------------------------------- # Command @@ -416,7 +336,7 @@ def _stream_upload_progress(href: str, headers: dict, file_path: str, size: int) @click.option( "--repo", required=True, metavar="OWNER/NAME", help="Target repository on RIA Hub (e.g. benchinnery/my-dataset)." ) -@click.option("--hub", default="https://riahub.ai", show_default=True, metavar="URL", help="RIA Hub base URL.") +@click.option("--hub", default=DEFAULT_HUB, show_default=True, metavar="URL", help="RIA Hub base URL.") @click.option("--branch", default="main", show_default=True, help="Branch to commit the files to.") @click.option( "--path", @@ -460,9 +380,9 @@ def upload( sys.exit(1) resolved.append(os.path.abspath(pattern)) - # Credentials hub = hub.rstrip("/") - username, password = _resolve_credentials(hub) + warn_if_insecure(hub) + username, password = resolve_credentials(hub) click.echo(f"Uploading {len(resolved)} file(s) to {owner}/{repo_name} on {hub}...") diff --git a/tests/agent/test_cli_register_errors.py b/tests/agent/test_cli_register_errors.py index 657dfb0..3448ffc 100644 --- a/tests/agent/test_cli_register_errors.py +++ b/tests/agent/test_cli_register_errors.py @@ -79,9 +79,9 @@ def test_user_agent_is_set_and_not_python_default(): """ ua = agent_cli._user_agent() assert ua, "User-Agent must not be empty" - assert not ua.lower().startswith("python-urllib"), ( - f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it" - ) + assert not ua.lower().startswith( + "python-urllib" + ), f"User-Agent must not be Python's default (got {ua!r}) — Cloudflare blocks it" assert ua.startswith("ria-agent/") @@ -96,7 +96,10 @@ def test_register_request_carries_explicit_user_agent(tmp_path): captured["api_key"] = req.get_header("X-api-key") captured["timeout"] = kwargs.get("timeout") raise urllib.error.HTTPError( - url=req.full_url, code=403, msg="", hdrs=None, # type: ignore[arg-type] + url=req.full_url, + code=403, + msg="", + hdrs=None, # type: ignore[arg-type] fp=BytesIO(_structured("invalid_key")), )