forked from qoherent/modrec-workflow
Added an app.yaml file, as well as a python script that creates a settings object that can be used in all scripts
This commit is contained in:
parent
92c8c4678e
commit
6c1164b466
62
helpers/app_settings.py
Normal file
62
helpers/app_settings.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
import yaml
|
||||
|
||||
@dataclass
|
||||
class GeneralConfig:
|
||||
run_mode: str
|
||||
|
||||
@dataclass
|
||||
class DataSetConfig:
|
||||
input_dir: str
|
||||
num_slices: int
|
||||
train_split: float
|
||||
seed: int
|
||||
val_split: float
|
||||
output_dir: str
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
batch_size: int
|
||||
num_epochs: int
|
||||
learning_rate: float
|
||||
checkpoint_path: str
|
||||
use_gpu: bool
|
||||
|
||||
@dataclass
|
||||
class InferenceConfig:
|
||||
model_path: str
|
||||
num_classes: int
|
||||
output_dir: str
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
build_dir: str
|
||||
|
||||
class AppSettings:
|
||||
"""Application settings, to be initialized from app.yaml configuration file."""
|
||||
|
||||
def __init__(self, config_file: str):
|
||||
# Load the YAML configuration file
|
||||
with open(config_file, "r") as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
# Parse the loaded YAML into dataclass objects
|
||||
self.general = GeneralConfig(**config_data["general"])
|
||||
self.dataset = DataSetConfig(**config_data["dataset"])
|
||||
self.training = TrainingConfig(**config_data["training"])
|
||||
self.inference = InferenceConfig(**config_data["inference"])
|
||||
self.app = AppConfig(**config_data["app"])
|
||||
|
||||
@lru_cache
|
||||
def get_app_settings() -> AppSettings:
|
||||
"""Return application configuration settings."""
|
||||
module_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
config_file = os.path.join(module_path, "conf", "app.yaml")
|
||||
return AppSettings(config_file=config_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
s = get_app_settings()
|
|
@ -1,6 +1,7 @@
|
|||
import os, h5py, numpy as np
|
||||
from utils.io import from_npy
|
||||
from split_dataset import split
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
meta_dtype = np.dtype(
|
||||
[
|
||||
|
@ -76,7 +77,7 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
|
|||
return output_path
|
||||
|
||||
|
||||
def split_recording(recording_list):
|
||||
def split_recording(recording_list, num_snippets):
|
||||
"""
|
||||
Splits a list of recordings into smaller chunks.
|
||||
|
||||
|
@ -90,9 +91,8 @@ def split_recording(recording_list):
|
|||
|
||||
for data, md in recording_list:
|
||||
C, N = data.shape
|
||||
L = N // 8
|
||||
rec_id = md["rec_id"]
|
||||
for i in range(8):
|
||||
L = N // num_snippets
|
||||
for i in range(num_snippets):
|
||||
start = i * L
|
||||
end = (i + 1) * L
|
||||
snippet = data[:, start:end]
|
||||
|
@ -103,7 +103,7 @@ def split_recording(recording_list):
|
|||
return snippet_list
|
||||
|
||||
|
||||
def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
|
||||
def generate_datasets(cfg):
|
||||
"""
|
||||
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
|
||||
|
||||
|
@ -116,18 +116,18 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
|
|||
dset (h5py.Dataset): The created dataset object
|
||||
"""
|
||||
|
||||
parent = os.path.dirname(output_path)
|
||||
parent = os.path.dirname(cfg.output_dir)
|
||||
if parent:
|
||||
os.makedirs(parent, exist_ok=True)
|
||||
|
||||
# we assume the recordings are in .npy format
|
||||
files = os.listdir(path_to_recordings)
|
||||
files = os.listdir(cfg.input_dir)
|
||||
if not files:
|
||||
raise ValueError("No files found in the specified directory.")
|
||||
|
||||
records = []
|
||||
for fname in files:
|
||||
rec = from_npy(os.path.join(path_to_recordings, fname))
|
||||
rec = from_npy(os.path.join(cfg.input_dir, fname))
|
||||
|
||||
data = rec.data
|
||||
|
||||
|
@ -136,18 +136,24 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
|
|||
records.append((data, md))
|
||||
|
||||
# split each recording into 8 snippets each
|
||||
records = split_recording(records)
|
||||
records = split_recording(records, cfg.num_slices)
|
||||
|
||||
train_records, val_records = split(records, train_frac=0.8, seed=42)
|
||||
train_records, val_records = split(records, cfg.train_split, cfg.seed)
|
||||
|
||||
train_path = os.path.join(output_path, "train.h5")
|
||||
val_path = os.path.join(output_path, "val.h5")
|
||||
train_path = os.path.join(cfg.output_dir, "train.h5")
|
||||
val_path = os.path.join(cfg.output_dir, "val.h5")
|
||||
|
||||
write_hdf5_file(train_records, train_path, "training_data")
|
||||
write_hdf5_file(val_records, val_path, "validation_data")
|
||||
|
||||
return train_path, val_path
|
||||
|
||||
def main():
|
||||
settings = get_app_settings()
|
||||
dataset_cfg = settings.dataset
|
||||
train_path, val_path = generate_datasets(dataset_cfg)
|
||||
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(generate_datasets("recordings", "data"))
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user