forked from qoherent/modrec-workflow
Added an app.yaml file, as well as a python script that creates a settings object that can be used in all scripts
This commit is contained in:
parent
92c8c4678e
commit
6c1164b466
62
helpers/app_settings.py
Normal file
62
helpers/app_settings.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeneralConfig:
|
||||||
|
run_mode: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataSetConfig:
|
||||||
|
input_dir: str
|
||||||
|
num_slices: int
|
||||||
|
train_split: float
|
||||||
|
seed: int
|
||||||
|
val_split: float
|
||||||
|
output_dir: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
batch_size: int
|
||||||
|
num_epochs: int
|
||||||
|
learning_rate: float
|
||||||
|
checkpoint_path: str
|
||||||
|
use_gpu: bool
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceConfig:
|
||||||
|
model_path: str
|
||||||
|
num_classes: int
|
||||||
|
output_dir: str
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AppConfig:
|
||||||
|
build_dir: str
|
||||||
|
|
||||||
|
class AppSettings:
|
||||||
|
"""Application settings, to be initialized from app.yaml configuration file."""
|
||||||
|
|
||||||
|
def __init__(self, config_file: str):
|
||||||
|
# Load the YAML configuration file
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config_data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Parse the loaded YAML into dataclass objects
|
||||||
|
self.general = GeneralConfig(**config_data["general"])
|
||||||
|
self.dataset = DataSetConfig(**config_data["dataset"])
|
||||||
|
self.training = TrainingConfig(**config_data["training"])
|
||||||
|
self.inference = InferenceConfig(**config_data["inference"])
|
||||||
|
self.app = AppConfig(**config_data["app"])
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_app_settings() -> AppSettings:
|
||||||
|
"""Return application configuration settings."""
|
||||||
|
module_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
config_file = os.path.join(module_path, "conf", "app.yaml")
|
||||||
|
return AppSettings(config_file=config_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
s = get_app_settings()
|
|
@ -1,6 +1,7 @@
|
||||||
import os, h5py, numpy as np
|
import os, h5py, numpy as np
|
||||||
from utils.io import from_npy
|
from utils.io import from_npy
|
||||||
from split_dataset import split
|
from split_dataset import split
|
||||||
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
meta_dtype = np.dtype(
|
meta_dtype = np.dtype(
|
||||||
[
|
[
|
||||||
|
@ -76,7 +77,7 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
def split_recording(recording_list):
|
def split_recording(recording_list, num_snippets):
|
||||||
"""
|
"""
|
||||||
Splits a list of recordings into smaller chunks.
|
Splits a list of recordings into smaller chunks.
|
||||||
|
|
||||||
|
@ -90,9 +91,8 @@ def split_recording(recording_list):
|
||||||
|
|
||||||
for data, md in recording_list:
|
for data, md in recording_list:
|
||||||
C, N = data.shape
|
C, N = data.shape
|
||||||
L = N // 8
|
L = N // num_snippets
|
||||||
rec_id = md["rec_id"]
|
for i in range(num_snippets):
|
||||||
for i in range(8):
|
|
||||||
start = i * L
|
start = i * L
|
||||||
end = (i + 1) * L
|
end = (i + 1) * L
|
||||||
snippet = data[:, start:end]
|
snippet = data[:, start:end]
|
||||||
|
@ -103,7 +103,7 @@ def split_recording(recording_list):
|
||||||
return snippet_list
|
return snippet_list
|
||||||
|
|
||||||
|
|
||||||
def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
|
def generate_datasets(cfg):
|
||||||
"""
|
"""
|
||||||
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
|
Generates a dataset from a folder of .npy files and saves it to an HDF5 file
|
||||||
|
|
||||||
|
@ -116,18 +116,18 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
|
||||||
dset (h5py.Dataset): The created dataset object
|
dset (h5py.Dataset): The created dataset object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parent = os.path.dirname(output_path)
|
parent = os.path.dirname(cfg.output_dir)
|
||||||
if parent:
|
if parent:
|
||||||
os.makedirs(parent, exist_ok=True)
|
os.makedirs(parent, exist_ok=True)
|
||||||
|
|
||||||
# we assume the recordings are in .npy format
|
# we assume the recordings are in .npy format
|
||||||
files = os.listdir(path_to_recordings)
|
files = os.listdir(cfg.input_dir)
|
||||||
if not files:
|
if not files:
|
||||||
raise ValueError("No files found in the specified directory.")
|
raise ValueError("No files found in the specified directory.")
|
||||||
|
|
||||||
records = []
|
records = []
|
||||||
for fname in files:
|
for fname in files:
|
||||||
rec = from_npy(os.path.join(path_to_recordings, fname))
|
rec = from_npy(os.path.join(cfg.input_dir, fname))
|
||||||
|
|
||||||
data = rec.data
|
data = rec.data
|
||||||
|
|
||||||
|
@ -136,18 +136,24 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
|
||||||
records.append((data, md))
|
records.append((data, md))
|
||||||
|
|
||||||
# split each recording into 8 snippets each
|
# split each recording into 8 snippets each
|
||||||
records = split_recording(records)
|
records = split_recording(records, cfg.num_slices)
|
||||||
|
|
||||||
train_records, val_records = split(records, train_frac=0.8, seed=42)
|
train_records, val_records = split(records, cfg.train_split, cfg.seed)
|
||||||
|
|
||||||
train_path = os.path.join(output_path, "train.h5")
|
train_path = os.path.join(cfg.output_dir, "train.h5")
|
||||||
val_path = os.path.join(output_path, "val.h5")
|
val_path = os.path.join(cfg.output_dir, "val.h5")
|
||||||
|
|
||||||
write_hdf5_file(train_records, train_path, "training_data")
|
write_hdf5_file(train_records, train_path, "training_data")
|
||||||
write_hdf5_file(val_records, val_path, "validation_data")
|
write_hdf5_file(val_records, val_path, "validation_data")
|
||||||
|
|
||||||
return train_path, val_path
|
return train_path, val_path
|
||||||
|
|
||||||
|
def main():
|
||||||
|
settings = get_app_settings()
|
||||||
|
dataset_cfg = settings.dataset
|
||||||
|
train_path, val_path = generate_datasets(dataset_cfg)
|
||||||
|
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(generate_datasets("recordings", "data"))
|
main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user