Added an app.yaml file, as well as a python script that creates a settings object that can be used in all scripts

This commit is contained in:
liyuxiao2 2025-05-20 10:26:03 -04:00
parent 92c8c4678e
commit 6c1164b466
2 changed files with 81 additions and 13 deletions

62
helpers/app_settings.py Normal file
View File

@ -0,0 +1,62 @@
import os
from dataclasses import dataclass
from functools import lru_cache
import yaml
@dataclass
class GeneralConfig:
run_mode: str
@dataclass
class DataSetConfig:
input_dir: str
num_slices: int
train_split: float
seed: int
val_split: float
output_dir: str
@dataclass
class TrainingConfig:
batch_size: int
num_epochs: int
learning_rate: float
checkpoint_path: str
use_gpu: bool
@dataclass
class InferenceConfig:
model_path: str
num_classes: int
output_dir: str
@dataclass
class AppConfig:
build_dir: str
class AppSettings:
"""Application settings, to be initialized from app.yaml configuration file."""
def __init__(self, config_file: str):
# Load the YAML configuration file
with open(config_file, "r") as f:
config_data = yaml.safe_load(f)
# Parse the loaded YAML into dataclass objects
self.general = GeneralConfig(**config_data["general"])
self.dataset = DataSetConfig(**config_data["dataset"])
self.training = TrainingConfig(**config_data["training"])
self.inference = InferenceConfig(**config_data["inference"])
self.app = AppConfig(**config_data["app"])
@lru_cache
def get_app_settings() -> AppSettings:
"""Return application configuration settings."""
module_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
config_file = os.path.join(module_path, "conf", "app.yaml")
return AppSettings(config_file=config_file)
if __name__ == "__main__":
s = get_app_settings()

View File

@ -1,6 +1,7 @@
import os, h5py, numpy as np import os, h5py, numpy as np
from utils.io import from_npy from utils.io import from_npy
from split_dataset import split from split_dataset import split
from helpers.app_settings import get_app_settings
meta_dtype = np.dtype( meta_dtype = np.dtype(
[ [
@ -76,7 +77,7 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
return output_path return output_path
def split_recording(recording_list): def split_recording(recording_list, num_snippets):
""" """
Splits a list of recordings into smaller chunks. Splits a list of recordings into smaller chunks.
@ -90,9 +91,8 @@ def split_recording(recording_list):
for data, md in recording_list: for data, md in recording_list:
C, N = data.shape C, N = data.shape
L = N // 8 L = N // num_snippets
rec_id = md["rec_id"] for i in range(num_snippets):
for i in range(8):
start = i * L start = i * L
end = (i + 1) * L end = (i + 1) * L
snippet = data[:, start:end] snippet = data[:, start:end]
@ -103,7 +103,7 @@ def split_recording(recording_list):
return snippet_list return snippet_list
def generate_datasets(path_to_recordings, output_path, dataset_name="data"): def generate_datasets(cfg):
""" """
Generates a dataset from a folder of .npy files and saves it to an HDF5 file Generates a dataset from a folder of .npy files and saves it to an HDF5 file
@ -116,18 +116,18 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
dset (h5py.Dataset): The created dataset object dset (h5py.Dataset): The created dataset object
""" """
parent = os.path.dirname(output_path) parent = os.path.dirname(cfg.output_dir)
if parent: if parent:
os.makedirs(parent, exist_ok=True) os.makedirs(parent, exist_ok=True)
# we assume the recordings are in .npy format # we assume the recordings are in .npy format
files = os.listdir(path_to_recordings) files = os.listdir(cfg.input_dir)
if not files: if not files:
raise ValueError("No files found in the specified directory.") raise ValueError("No files found in the specified directory.")
records = [] records = []
for fname in files: for fname in files:
rec = from_npy(os.path.join(path_to_recordings, fname)) rec = from_npy(os.path.join(cfg.input_dir, fname))
data = rec.data data = rec.data
@ -136,18 +136,24 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"):
records.append((data, md)) records.append((data, md))
# split each recording into 8 snippets each # split each recording into 8 snippets each
records = split_recording(records) records = split_recording(records, cfg.num_slices)
train_records, val_records = split(records, train_frac=0.8, seed=42) train_records, val_records = split(records, cfg.train_split, cfg.seed)
train_path = os.path.join(output_path, "train.h5") train_path = os.path.join(cfg.output_dir, "train.h5")
val_path = os.path.join(output_path, "val.h5") val_path = os.path.join(cfg.output_dir, "val.h5")
write_hdf5_file(train_records, train_path, "training_data") write_hdf5_file(train_records, train_path, "training_data")
write_hdf5_file(val_records, val_path, "validation_data") write_hdf5_file(val_records, val_path, "validation_data")
return train_path, val_path return train_path, val_path
def main():
settings = get_app_settings()
dataset_cfg = settings.dataset
train_path, val_path = generate_datasets(dataset_cfg)
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
if __name__ == "__main__": if __name__ == "__main__":
print(generate_datasets("recordings", "data")) main()