diff --git a/helpers/app_settings.py b/helpers/app_settings.py new file mode 100644 index 0000000..72fb2d4 --- /dev/null +++ b/helpers/app_settings.py @@ -0,0 +1,62 @@ +import os +from dataclasses import dataclass +from functools import lru_cache + +import yaml + +@dataclass +class GeneralConfig: + run_mode: str + +@dataclass +class DataSetConfig: + input_dir: str + num_slices: int + train_split: float + seed: int + val_split: float + output_dir: str + +@dataclass +class TrainingConfig: + batch_size: int + num_epochs: int + learning_rate: float + checkpoint_path: str + use_gpu: bool + +@dataclass +class InferenceConfig: + model_path: str + num_classes: int + output_dir: str + +@dataclass +class AppConfig: + build_dir: str + +class AppSettings: + """Application settings, to be initialized from app.yaml configuration file.""" + + def __init__(self, config_file: str): + # Load the YAML configuration file + with open(config_file, "r") as f: + config_data = yaml.safe_load(f) + + # Parse the loaded YAML into dataclass objects + self.general = GeneralConfig(**config_data["general"]) + self.dataset = DataSetConfig(**config_data["dataset"]) + self.training = TrainingConfig(**config_data["training"]) + self.inference = InferenceConfig(**config_data["inference"]) + self.app = AppConfig(**config_data["app"]) + +@lru_cache +def get_app_settings() -> AppSettings: + """Return application configuration settings.""" + module_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + config_file = os.path.join(module_path, "conf", "app.yaml") + return AppSettings(config_file=config_file) + + +if __name__ == "__main__": + s = get_app_settings() \ No newline at end of file diff --git a/scripts/produce_dataset.py b/scripts/produce_dataset.py index 22f2205..4f7d233 100644 --- a/scripts/produce_dataset.py +++ b/scripts/produce_dataset.py @@ -1,6 +1,7 @@ import os, h5py, numpy as np from utils.io import from_npy from split_dataset import split +from helpers.app_settings import get_app_settings meta_dtype = np.dtype( [ @@ -76,7 +77,7 @@ def write_hdf5_file(records, output_path, dataset_name="data"): return output_path -def split_recording(recording_list): +def split_recording(recording_list, num_snippets): """ Splits a list of recordings into smaller chunks. @@ -90,9 +91,8 @@ def split_recording(recording_list): for data, md in recording_list: C, N = data.shape - L = N // 8 - rec_id = md["rec_id"] - for i in range(8): + L = N // num_snippets + for i in range(num_snippets): start = i * L end = (i + 1) * L snippet = data[:, start:end] @@ -103,7 +103,7 @@ def split_recording(recording_list): return snippet_list -def generate_datasets(path_to_recordings, output_path, dataset_name="data"): +def generate_datasets(cfg): """ Generates a dataset from a folder of .npy files and saves it to an HDF5 file @@ -116,18 +116,18 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"): dset (h5py.Dataset): The created dataset object """ - parent = os.path.dirname(output_path) + parent = os.path.dirname(cfg.output_dir) if parent: os.makedirs(parent, exist_ok=True) # we assume the recordings are in .npy format - files = os.listdir(path_to_recordings) + files = os.listdir(cfg.input_dir) if not files: raise ValueError("No files found in the specified directory.") records = [] for fname in files: - rec = from_npy(os.path.join(path_to_recordings, fname)) + rec = from_npy(os.path.join(cfg.input_dir, fname)) data = rec.data @@ -136,18 +136,24 @@ def generate_datasets(path_to_recordings, output_path, dataset_name="data"): records.append((data, md)) # split each recording into 8 snippets each - records = split_recording(records) + records = split_recording(records, cfg.num_slices) - train_records, val_records = split(records, train_frac=0.8, seed=42) + train_records, val_records = split(records, cfg.train_split, cfg.seed) - train_path = os.path.join(output_path, "train.h5") - val_path = os.path.join(output_path, "val.h5") + train_path = os.path.join(cfg.output_dir, "train.h5") + val_path = os.path.join(cfg.output_dir, "val.h5") write_hdf5_file(train_records, train_path, "training_data") write_hdf5_file(val_records, val_path, "validation_data") return train_path, val_path +def main(): + settings = get_app_settings() + dataset_cfg = settings.dataset + train_path, val_path = generate_datasets(dataset_cfg) + print(f"āœ… Train: {train_path}\nāœ… Val: {val_path}") + if __name__ == "__main__": - print(generate_datasets("recordings", "data")) + main()