ria-toolkit-oss/src/ria_toolkit_oss/server/state.py

122 lines
3.5 KiB
Python
Raw Normal View History

2026-03-11 10:27:18 -04:00
"""In-memory state for running campaigns and inference sessions."""
from __future__ import annotations
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional
class CampaignCancelledError(Exception):
"""Raised by the progress callback when a cancel is requested."""
@dataclass
class CampaignState:
campaign_id: str
status: str # "running" | "completed" | "failed" | "cancelled"
config_name: str
cancel_event: threading.Event
thread: threading.Thread
total_steps: int = 0
progress: int = 0
result: Optional[dict] = None
error: Optional[str] = None
started_at: float = field(default_factory=time.time)
ended_at: Optional[float] = None
def to_dict(self) -> dict:
return {
"campaign_id": self.campaign_id,
"status": self.status,
"config_name": self.config_name,
"progress": self.progress,
"total_steps": self.total_steps,
"result": self.result,
"error": self.error,
"started_at": self.started_at,
"ended_at": self.ended_at,
}
@dataclass
class InferenceState:
model_path: str
2026-03-31 13:51:10 -04:00
label_map: dict[str, int] # class_name -> class_index
2026-03-11 10:27:18 -04:00
index_to_label: dict[int, str] # reverse: class_index -> class_name
2026-03-31 13:51:10 -04:00
session: Any # onnxruntime.InferenceSession
2026-03-11 10:27:18 -04:00
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None
2026-03-31 13:51:10 -04:00
sdr: Any = None # live SDR object while inference is running
2026-03-11 10:27:18 -04:00
running: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
_latest: Optional[dict] = field(default=None, repr=False)
2026-03-12 11:45:07 -04:00
_pending_sdr_config: Optional[dict] = field(default=None, repr=False)
2026-03-11 10:27:18 -04:00
def set_latest(self, result: dict) -> None:
with self._lock:
self._latest = result
def get_latest(self) -> Optional[dict]:
with self._lock:
return self._latest
2026-03-12 11:45:07 -04:00
def set_pending_config(self, config: dict) -> None:
with self._lock:
self._pending_sdr_config = config
def pop_pending_config(self) -> Optional[dict]:
with self._lock:
cfg = self._pending_sdr_config
self._pending_sdr_config = None
return cfg
2026-03-31 13:51:10 -04:00
def set_running(self, value: bool) -> None:
with self._lock:
self.running = value
def get_running(self) -> bool:
with self._lock:
return self.running
2026-03-11 10:27:18 -04:00
# ---------------------------------------------------------------------------
# Module-level stores
# ---------------------------------------------------------------------------
_campaigns: dict[str, CampaignState] = {}
_campaigns_lock = threading.Lock()
_inference: Optional[InferenceState] = None
_inference_lock = threading.Lock()
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
with _campaigns_lock:
return _campaigns.get(campaign_id)
def set_campaign(state: CampaignState) -> None:
with _campaigns_lock:
_campaigns[state.campaign_id] = state
def update_campaign(campaign_id: str, **kwargs) -> None:
with _campaigns_lock:
state = _campaigns.get(campaign_id)
if state:
for k, v in kwargs.items():
setattr(state, k, v)
def get_inference() -> Optional[InferenceState]:
with _inference_lock:
return _inference
def set_inference(state: Optional[InferenceState]) -> None:
global _inference
with _inference_lock:
_inference = state