102 lines
2.9 KiB
Python
102 lines
2.9 KiB
Python
|
|
"""In-memory state for running campaigns and inference sessions."""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import threading
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Any, Optional
|
||
|
|
|
||
|
|
|
||
|
|
class CampaignCancelledError(Exception):
|
||
|
|
"""Raised by the progress callback when a cancel is requested."""
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class CampaignState:
|
||
|
|
campaign_id: str
|
||
|
|
status: str # "running" | "completed" | "failed" | "cancelled"
|
||
|
|
config_name: str
|
||
|
|
cancel_event: threading.Event
|
||
|
|
thread: threading.Thread
|
||
|
|
total_steps: int = 0
|
||
|
|
progress: int = 0
|
||
|
|
result: Optional[dict] = None
|
||
|
|
error: Optional[str] = None
|
||
|
|
started_at: float = field(default_factory=time.time)
|
||
|
|
ended_at: Optional[float] = None
|
||
|
|
|
||
|
|
def to_dict(self) -> dict:
|
||
|
|
return {
|
||
|
|
"campaign_id": self.campaign_id,
|
||
|
|
"status": self.status,
|
||
|
|
"config_name": self.config_name,
|
||
|
|
"progress": self.progress,
|
||
|
|
"total_steps": self.total_steps,
|
||
|
|
"result": self.result,
|
||
|
|
"error": self.error,
|
||
|
|
"started_at": self.started_at,
|
||
|
|
"ended_at": self.ended_at,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class InferenceState:
|
||
|
|
model_path: str
|
||
|
|
label_map: dict[str, int] # class_name -> class_index
|
||
|
|
index_to_label: dict[int, str] # reverse: class_index -> class_name
|
||
|
|
session: Any # onnxruntime.InferenceSession
|
||
|
|
stop_event: threading.Event = field(default_factory=threading.Event)
|
||
|
|
thread: Optional[threading.Thread] = None
|
||
|
|
running: bool = False
|
||
|
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
||
|
|
_latest: Optional[dict] = field(default=None, repr=False)
|
||
|
|
|
||
|
|
def set_latest(self, result: dict) -> None:
|
||
|
|
with self._lock:
|
||
|
|
self._latest = result
|
||
|
|
|
||
|
|
def get_latest(self) -> Optional[dict]:
|
||
|
|
with self._lock:
|
||
|
|
return self._latest
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Module-level stores
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
_campaigns: dict[str, CampaignState] = {}
|
||
|
|
_campaigns_lock = threading.Lock()
|
||
|
|
|
||
|
|
_inference: Optional[InferenceState] = None
|
||
|
|
_inference_lock = threading.Lock()
|
||
|
|
|
||
|
|
|
||
|
|
def get_campaign(campaign_id: str) -> Optional[CampaignState]:
|
||
|
|
with _campaigns_lock:
|
||
|
|
return _campaigns.get(campaign_id)
|
||
|
|
|
||
|
|
|
||
|
|
def set_campaign(state: CampaignState) -> None:
|
||
|
|
with _campaigns_lock:
|
||
|
|
_campaigns[state.campaign_id] = state
|
||
|
|
|
||
|
|
|
||
|
|
def update_campaign(campaign_id: str, **kwargs) -> None:
|
||
|
|
with _campaigns_lock:
|
||
|
|
state = _campaigns.get(campaign_id)
|
||
|
|
if state:
|
||
|
|
for k, v in kwargs.items():
|
||
|
|
setattr(state, k, v)
|
||
|
|
|
||
|
|
|
||
|
|
def get_inference() -> Optional[InferenceState]:
|
||
|
|
with _inference_lock:
|
||
|
|
return _inference
|
||
|
|
|
||
|
|
|
||
|
|
def set_inference(state: Optional[InferenceState]) -> None:
|
||
|
|
global _inference
|
||
|
|
with _inference_lock:
|
||
|
|
_inference = state
|