Some checks failed
Build Sphinx Docs Set / Build Docs (pull_request) Failing after 1s
Build Project / Build Project (3.10) (pull_request) Successful in 57s
Build Project / Build Project (3.11) (pull_request) Successful in 1m7s
Build Project / Build Project (3.12) (pull_request) Successful in 56s
Test with tox / Test with tox (3.12) (pull_request) Failing after 5m13s
Test with tox / Test with tox (3.11) (pull_request) Failing after 5m48s
Test with tox / Test with tox (3.10) (pull_request) Failing after 8m46s
122 lines
3.5 KiB
Python
122 lines
3.5 KiB
Python
"""In-memory state for running campaigns and inference sessions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Optional
|
|
|
|
|
|
class CampaignCancelledError(Exception):
|
|
"""Raised by the progress callback when a cancel is requested."""
|
|
|
|
|
|
@dataclass
|
|
class CampaignState:
|
|
campaign_id: str
|
|
status: str # "running" | "completed" | "failed" | "cancelled"
|
|
config_name: str
|
|
cancel_event: threading.Event
|
|
thread: threading.Thread
|
|
total_steps: int = 0
|
|
progress: int = 0
|
|
result: Optional[dict] = None
|
|
error: Optional[str] = None
|
|
started_at: float = field(default_factory=time.time)
|
|
ended_at: Optional[float] = None
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"campaign_id": self.campaign_id,
|
|
"status": self.status,
|
|
"config_name": self.config_name,
|
|
"progress": self.progress,
|
|
"total_steps": self.total_steps,
|
|
"result": self.result,
|
|
"error": self.error,
|
|
"started_at": self.started_at,
|
|
"ended_at": self.ended_at,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class InferenceState:
|
|
model_path: str
|
|
label_map: dict[str, int] # class_name -> class_index
|
|
index_to_label: dict[int, str] # reverse: class_index -> class_name
|
|
session: Any # onnxruntime.InferenceSession
|
|
stop_event: threading.Event = field(default_factory=threading.Event)
|
|
thread: Optional[threading.Thread] = None
|
|
sdr: Any = None # live SDR object while inference is running
|
|
running: bool = False
|
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
|
_latest: Optional[dict] = field(default=None, repr=False)
|
|
_pending_sdr_config: Optional[dict] = field(default=None, repr=False)
|
|
|
|
def set_latest(self, result: dict) -> None:
|
|
with self._lock:
|
|
self._latest = result
|
|
|
|
def get_latest(self) -> Optional[dict]:
|
|
with self._lock:
|
|
return self._latest
|
|
|
|
def set_pending_config(self, config: dict) -> None:
|
|
with self._lock:
|
|
self._pending_sdr_config = config
|
|
|
|
def pop_pending_config(self) -> Optional[dict]:
|
|
with self._lock:
|
|
cfg = self._pending_sdr_config
|
|
self._pending_sdr_config = None
|
|
return cfg
|
|
|
|
def set_running(self, value: bool) -> None:
|
|
with self._lock:
|
|
self.running = value
|
|
|
|
def get_running(self) -> bool:
|
|
with self._lock:
|
|
return self.running
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level stores
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_campaigns: dict[str, CampaignState] = {}
|
|
_campaigns_lock = threading.Lock()
|
|
|
|
_inference: Optional[InferenceState] = None
|
|
_inference_lock = threading.Lock()
|
|
|
|
|
|
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
|
|
with _campaigns_lock:
|
|
return _campaigns.get(campaign_id)
|
|
|
|
|
|
def set_campaign(state: CampaignState) -> None:
|
|
with _campaigns_lock:
|
|
_campaigns[state.campaign_id] = state
|
|
|
|
|
|
def update_campaign(campaign_id: str, **kwargs) -> None:
|
|
with _campaigns_lock:
|
|
state = _campaigns.get(campaign_id)
|
|
if state:
|
|
for k, v in kwargs.items():
|
|
setattr(state, k, v)
|
|
|
|
|
|
def get_inference() -> Optional[InferenceState]:
|
|
with _inference_lock:
|
|
return _inference
|
|
|
|
|
|
def set_inference(state: Optional[InferenceState]) -> None:
|
|
global _inference
|
|
with _inference_lock:
|
|
_inference = state
|