ria-toolkit-oss/src/ria_toolkit_oss/server/models.py
ben 07c72294f5
All checks were successful
Build Sphinx Docs Set / Build Docs (pull_request) Successful in 19s
Test with tox / Test with tox (3.12) (pull_request) Successful in 10m47s
Test with tox / Test with tox (3.11) (pull_request) Successful in 15m47s
Build Project / Build Project (3.12) (pull_request) Successful in 15m55s
Build Project / Build Project (3.11) (pull_request) Successful in 16m46s
Build Project / Build Project (3.10) (pull_request) Successful in 16m49s
Test with tox / Test with tox (3.10) (pull_request) Successful in 18m15s
removing orchestrator references
2026-04-22 10:10:25 -04:00

115 lines
2.9 KiB
Python

"""Pydantic request and response models for the RT-OSS HTTP server."""
from __future__ import annotations
from pathlib import Path
from pydantic import BaseModel, field_validator
# ---------------------------------------------------------------------------
# Conductor
# ---------------------------------------------------------------------------
class DeployRequest(BaseModel):
config: dict
class DeployResponse(BaseModel):
campaign_id: str
class CampaignStatusResponse(BaseModel):
campaign_id: str
status: str
config_name: str
progress: int
total_steps: int
started_at: float
ended_at: float | None = None
result: dict | None = None
error: str | None = None
class CancelResponse(BaseModel):
campaign_id: str
cancelled: bool
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
class SdrConfig(BaseModel):
device: str
center_freq: float
sample_rate: float
gain: float | str = "auto"
class LoadModelRequest(BaseModel):
model_path: str
label_map: dict[str, int] # class_name -> class_index
@field_validator("model_path")
@classmethod
def validate_model_path(cls, v: str) -> str:
p = Path(v)
if ".." in p.parts:
raise ValueError("model_path must not contain path traversal components")
if p.suffix.lower() != ".onnx":
raise ValueError("model_path must point to an .onnx file")
# Resolve to catch symlink-based traversal; return the resolved absolute path
# so callers always work with the real filesystem location.
resolved = p.resolve()
if resolved.suffix.lower() != ".onnx":
raise ValueError("Resolved model_path must point to an .onnx file")
return str(resolved)
class LoadModelResponse(BaseModel):
loaded: bool
model_path: str
num_classes: int
class StartInferenceRequest(BaseModel):
sdr_config: SdrConfig
class StartInferenceResponse(BaseModel):
running: bool
class StopInferenceResponse(BaseModel):
stopped: bool
class ConfigureRequest(BaseModel):
"""Partial SDR reconfiguration — only supplied fields are updated."""
center_freq: float | None = None
sample_rate: float | None = None
gain: float | str | None = None
class ConfigureResponse(BaseModel):
configured: bool
class InferenceStatusResponse(BaseModel):
"""Latest inference result as returned by GET /inference/status.
When ``idle`` is True the radio is scanning but no signal was detected.
``device_id`` is the raw prediction label from the model's label map.
The frontend is responsible for mapping device_id to a human name and
determining whether the device is authorized.
"""
timestamp: float
idle: bool = False
device_id: str | None = None # prediction label; None when idle
confidence: float = 0.0
snr_db: float = 0.0