ria-toolkit-oss/src/ria_toolkit_oss/server/state.py

102 lines
2.9 KiB
Python
Raw Normal View History

2026-03-11 10:27:18 -04:00
"""In-memory state for running campaigns and inference sessions."""
from __future__ import annotations
import threading
import time
from dataclasses import dataclass, field
from typing import Any, Optional
class CampaignCancelledError(Exception):
"""Raised by the progress callback when a cancel is requested."""
@dataclass
class CampaignState:
campaign_id: str
status: str # "running" | "completed" | "failed" | "cancelled"
config_name: str
cancel_event: threading.Event
thread: threading.Thread
total_steps: int = 0
progress: int = 0
result: Optional[dict] = None
error: Optional[str] = None
started_at: float = field(default_factory=time.time)
ended_at: Optional[float] = None
def to_dict(self) -> dict:
return {
"campaign_id": self.campaign_id,
"status": self.status,
"config_name": self.config_name,
"progress": self.progress,
"total_steps": self.total_steps,
"result": self.result,
"error": self.error,
"started_at": self.started_at,
"ended_at": self.ended_at,
}
@dataclass
class InferenceState:
model_path: str
label_map: dict[str, int] # class_name -> class_index
index_to_label: dict[int, str] # reverse: class_index -> class_name
session: Any # onnxruntime.InferenceSession
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None
running: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
_latest: Optional[dict] = field(default=None, repr=False)
def set_latest(self, result: dict) -> None:
with self._lock:
self._latest = result
def get_latest(self) -> Optional[dict]:
with self._lock:
return self._latest
# ---------------------------------------------------------------------------
# Module-level stores
# ---------------------------------------------------------------------------
_campaigns: dict[str, CampaignState] = {}
_campaigns_lock = threading.Lock()
_inference: Optional[InferenceState] = None
_inference_lock = threading.Lock()
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
with _campaigns_lock:
return _campaigns.get(campaign_id)
def set_campaign(state: CampaignState) -> None:
with _campaigns_lock:
_campaigns[state.campaign_id] = state
def update_campaign(campaign_id: str, **kwargs) -> None:
with _campaigns_lock:
state = _campaigns.get(campaign_id)
if state:
for k, v in kwargs.items():
setattr(state, k, v)
def get_inference() -> Optional[InferenceState]:
with _inference_lock:
return _inference
def set_inference(state: Optional[InferenceState]) -> None:
global _inference
with _inference_lock:
_inference = state