forked from qoherent/modrec-workflow
63 lines
1.6 KiB
Python
63 lines
1.6 KiB
Python
import os
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
|
|
import yaml
|
|
|
|
@dataclass
|
|
class GeneralConfig:
|
|
run_mode: str
|
|
|
|
@dataclass
|
|
class DataSetConfig:
|
|
input_dir: str
|
|
num_slices: int
|
|
train_split: float
|
|
seed: int
|
|
modulation_types: list
|
|
val_split: float
|
|
output_dir: str
|
|
|
|
@dataclass
|
|
class TrainingConfig:
|
|
batch_size: int
|
|
epochs: int
|
|
learning_rate: float
|
|
checkpoint_path: str
|
|
use_gpu: bool
|
|
|
|
@dataclass
|
|
class InferenceConfig:
|
|
model_path: str
|
|
num_classes: int
|
|
output_path: str
|
|
|
|
@dataclass
|
|
class AppConfig:
|
|
build_dir: str
|
|
|
|
class AppSettings:
|
|
"""Application settings, to be initialized from app.yaml configuration file."""
|
|
|
|
def __init__(self, config_file: str):
|
|
# Load the YAML configuration file
|
|
with open(config_file, "r") as f:
|
|
config_data = yaml.safe_load(f)
|
|
|
|
# Parse the loaded YAML into dataclass objects
|
|
self.general = GeneralConfig(**config_data["general"])
|
|
self.dataset = DataSetConfig(**config_data["dataset"])
|
|
self.training = TrainingConfig(**config_data["training"])
|
|
self.inference = InferenceConfig(**config_data["inference"])
|
|
self.app = AppConfig(**config_data["app"])
|
|
|
|
@lru_cache
|
|
def get_app_settings() -> AppSettings:
|
|
"""Return application configuration settings."""
|
|
module_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
config_file = os.path.join(module_path, "conf", "app.yaml")
|
|
return AppSettings(config_file=config_file)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
s = get_app_settings() |