diff --git a/.gitignore b/.gitignore index 2bbace6..cee43c1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,3 @@ # Byte-compiled / optimized / DLL files __pycache__/ - -data/ \ No newline at end of file diff --git a/data/dataset/modulation_dataset.py b/data/dataset/modulation_dataset.py new file mode 100644 index 0000000..9c60764 --- /dev/null +++ b/data/dataset/modulation_dataset.py @@ -0,0 +1,57 @@ +import sys, os +sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed +import numpy as np +import torch +from torch.utils.data import Dataset +import h5py +from helpers.app_settings import get_app_settings + +settings = get_app_settings() +dataset = settings.dataset.modulation_types + + +class ModulationH5Dataset(Dataset): + def __init__(self, hdf5_path, label_name, data_key="training_data", label_encoder=None, transform=None): + self.hdf5_path = hdf5_path + self.data_key = data_key + self.label_name = label_name + self.label_encoder = label_encoder + self.transform = transform + + + with h5py.File(hdf5_path, 'r') as f: + self.length = f[data_key].shape[0] + self.metadata = f["metadata"]["metadata"][:] + + + settings = get_app_settings() + dataset_cfg = settings.dataset + all_labels = dataset_cfg.modulation_types + + + if self.label_encoder is None: + from sklearn.preprocessing import LabelEncoder + self.label_encoder = LabelEncoder() + self.label_encoder.fit(all_labels) + + # Get per-sample labels from metadata + raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata] + self.encoded_labels = self.label_encoder.transform(raw_labels) + + + def __len__(self): + return self.length + + def __getitem__(self, idx): + with h5py.File(self.hdf5_path, 'r') as f: + x = f[self.data_key][idx] # shape (1, 128) or similar + + # Normalize + mean = np.mean(x, axis=-1, keepdims=True) + std = np.std(x, axis=-1, keepdims=True) + x = (x - mean) / (std + 1e-6) + x = torch.tensor(x, dtype=torch.float32) + + label = torch.tensor(self.encoded_labels[idx], dtype=torch.long) + return x, label + diff --git a/data/dataset/train.h5 b/data/dataset/train.h5 new file mode 100644 index 0000000..ef9a5f8 Binary files /dev/null and b/data/dataset/train.h5 differ diff --git a/data/dataset/val.h5 b/data/dataset/val.h5 new file mode 100644 index 0000000..b3bb304 Binary files /dev/null and b/data/dataset/val.h5 differ diff --git a/data/models/cm_plotter.py b/data/models/cm_plotter.py new file mode 100644 index 0000000..b2b4437 --- /dev/null +++ b/data/models/cm_plotter.py @@ -0,0 +1,58 @@ +import numpy as np +from typing import Optional +from matplotlib import pyplot as plt +from sklearn.metrics import confusion_matrix + +def plot_confusion_matrix( + y_true: np.array, + y_pred: np.array, + classes: list, + normalize: bool = True, + title: Optional[str] = None, + text: bool = True, + rotate_x_text: int = 90, + figsize: tuple = (16,9), + cmap: plt.cm = plt.cm.Blues, +): + """Function to help plot confusion matrices + + https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html + """ + if not title: + if normalize: + title = "Normalized confusion matrix" + else: + title = "Confusion matrix, without normalization" + + # Compute confusion matrix + cm = confusion_matrix(y_true, y_pred) + if normalize: + cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] + + fig, ax = plt.subplots() + im = ax.imshow(cm, interpolation="none", cmap=cmap) + ax.figure.colorbar(im, ax=ax) + ax.set( + xticks=np.arange(cm.shape[1]), + yticks=np.arange(cm.shape[0]), + xticklabels=classes, + yticklabels=classes, + title=title, + ylabel="True label", + xlabel="Predicted label", + ) + ax.set_xticklabels(classes, rotation=rotate_x_text) + ax.figure.set_size_inches(figsize) + + # Loop over data dimensions and create text annotations. + fmt = ".2f" if normalize else "d" + thresh = cm.max() / 2.0 + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + if text: + ax.text(j, i, format(cm[i,j], fmt), ha="center", va="center", color="white" if cm[i,j] > thresh else "black") + if len(classes) == 2: + plt.axis([-0.5, 1.5, 1.5, -0.5]) + fig.tight_layout() + + return ax \ No newline at end of file diff --git a/data/models/interference_recognition.ipynb b/data/models/interference_recognition.ipynb new file mode 100644 index 0000000..49946f9 --- /dev/null +++ b/data/models/interference_recognition.ipynb @@ -0,0 +1,3428 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['bpsk', 'qpsk', 'qam16', 'qam64']\n" + ] + } + ], + "source": [ + "import sys, os\n", + "project_root = os.path.abspath(os.path.join(os.getcwd(), \"../..\"))\n", + "if project_root not in sys.path:\n", + " sys.path.insert(0, project_root)\n", + "import lightning as L\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import Dataset\n", + "import torchmetrics\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.metrics import classification_report\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.metrics import RocCurveDisplay\n", + "import h5py\n", + "from helpers.app_settings import get_app_settings\n", + "from data.dataset.modulation_dataset import ModulationH5Dataset\n", + "import torch.nn as nn\n", + "\n", + "\n", + "settings = get_app_settings()\n", + "dataset = settings.dataset.modulation_types\n", + "\n", + "print(dataset)\n", + "\n", + "\n", + "\n", + "import mobilenetv3\n", + "from cm_plotter import plot_confusion_matrix" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Training Settings" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# Set this to true to train a new model from scratch, otherwise it will load an existing checkpoint\n", + "# Training will overwrite an existing checkpoint\n", + "train_flag = True\n", + "\n", + "batch_size = 128\n", + "epochs = 100\n", + "\n", + "checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'\n", + "\n", + "train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'\n", + "val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'\n", + "\n", + "\n", + "dataset_name = 'Modulation Inference - Initial Model'\n", + "metadata_names = 'Modulation'\n", + "label = 'modulation'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Precision Settings" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "torch.set_float32_matmul_precision('high')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "X shape: torch.Size([128, 2, 128])\n", + "Y values: tensor([1, 3, 1, 0, 0, 3, 2, 1, 2, 1])\n", + "Train Dataset: 160 examples, each of shape torch.Size([2, 128])\n", + "Labels found in metadata: modulation\n", + "Example label values (first 5): ['qam64', 'qam64', 'qam64', 'qam64', 'qam64']\n", + "Unique labels in training data: ['bpsk', 'qam64', 'qpsk', 'qam16']\n", + "\n", + "Label being used: modulation\n", + "Number of classes: 4\n", + "Class index mapping: {'bpsk': 0, 'qam16': 1, 'qam64': 2, 'qpsk': 3}\n" + ] + } + ], + "source": [ + "\n", + "ds_train = ModulationH5Dataset(train_data, label, data_key= \"training_data\")\n", + "ds_val = ModulationH5Dataset(val_data, label, data_key=\"validation_data\")\n", + "\n", + "train_loader = torch.utils.data.DataLoader(\n", + " dataset=ds_train,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " num_workers=8,\n", + " )\n", + "val_loader = torch.utils.data.DataLoader(\n", + " dataset=ds_val,\n", + " batch_size=2048,\n", + " shuffle=False,\n", + " num_workers=8,\n", + " )\n", + "\n", + "for x, y in train_loader:\n", + " print(\"X shape:\", x.shape)\n", + " print(\"Y values:\", y[:10])\n", + " break\n", + "\n", + "\n", + "print(f'Train Dataset: {len(ds_train)} examples, each of shape {ds_train[0][0].shape}')\n", + "print(f'Labels found in metadata: {label}')\n", + "print(f'Example label values (first 5): {[row[label].decode(\"utf-8\") for row in ds_train.metadata[:5]]}')\n", + "\n", + "unique_labels = list(set([row[label].decode(\"utf-8\") for row in ds_train.metadata]))\n", + "print(f'Unique labels in training data: {unique_labels}')\n", + "\n", + "\n", + "\n", + "num_classes = len(ds_train.label_encoder.classes_)\n", + "print(f'\\nLabel being used: {label}')\n", + "print(f'Number of classes: {num_classes}')\n", + "print(f'Class index mapping: {dict(zip(ds_train.label_encoder.classes_, range(num_classes)))}')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Dataset: 64 examples of length 128\n", + "Annotation fields: ('rec_id', 'snippet_idx', 'modulation', 'snr', 'beta', 'sps')\n", + "Labels:\n", + "rec_id\n", + "\trec_id: ['491c4575bc4b3605997a42031f886c2bcc1dfcb36f7cd5c9991958dd234fd150', '66766004b23abcb02301e70dd8bbb21f802bc46f7c4ae2c65f36d4b4afdacfd0', '6d85f3e640d771e2106eb3fcf2c277def16292edbd37de62d5bd8f6b41962e95', '940988edd5df81d1170c5794ec259629cddefa674049e871f5229a9a24cb94bc', 'a4a6ba686c081920f6f6c7218a300939fc75bdf39d24044c0e084f75d264e9e3', 'dd021f7cce295866dfd8e7cfd109577a0b5f8dbbe25c2624ca3bb2d8fafc18e5', 'e61d9bfb32bb2d4bfa0666b764e32a018356bf8517d820b731d911bd1b67e4f6', 'f024082e859fc2b044bc410eeb8285ebae297c25808e596b310d11ef40f9b70c']\n", + "snippet_idx\n", + "\tsnippet_idx: [0, 1, 2, 3, 4, 5, 6, 7]\n", + "modulation\n", + "\tmodulation: ['bpsk', 'qam16', 'qam64', 'qpsk']\n", + "snr\n", + "\tsnr: [-3, 0, 3, 6]\n", + "beta\n", + "\tbeta: [0.3]\n", + "sps\n", + "\tsps: [4]\n", + "\n", + "Label being used: modulation\n", + "\tqam64: 16\n", + "\tqam16: 16\n", + "\tbpsk: 16\n", + "\tqpsk: 16\n" + ] + } + ], + "source": [ + "print(f'Validation Dataset: {len(ds_val)} examples of length {ds_val[0][0].shape[1]}')\n", + "print(f'Annotation fields: {ds_val.metadata.dtype.names}')\n", + "print('Labels:')\n", + "\n", + "\n", + "for a in ds_val.metadata.dtype.names:\n", + " print(a)\n", + " # decode only if dtype is bytes\n", + " values = [row[a].decode('utf-8') if isinstance(row[a], bytes) else row[a] for row in ds_val.metadata]\n", + " unique_vals = sorted(set(values))\n", + " print(f'\\t{a}: {unique_vals}')\n", + "\n", + "print(f'\\nLabel being used: {label}')\n", + "\n", + "# Print value counts for the label field\n", + "from collections import Counter\n", + "label_values = [row[label].decode('utf-8') for row in ds_val.metadata]\n", + "value_counts = Counter(label_values)\n", + "for k, v in value_counts.items():\n", + " print(f'\\t{k}: {v}')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "hparams = {\n", + " 'drop_path_rate': 0.2,\n", + " 'drop_rate': 0.5,\n", + " 'learning_rate': 3e-4,\n", + " 'wd': 0.2}\n", + "class RFClassifier(L.LightningModule):\n", + " def __init__(self, model):\n", + " super().__init__()\n", + "\n", + " self.model = model\n", + " self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)\n", + " \n", + "\n", + " def forward(self,x):\n", + " return self.model(x)\n", + " \n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.AdamW(self.parameters(),\n", + " lr = hparams['learning_rate'],\n", + " weight_decay = hparams['wd'],\n", + " )\n", + " lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n", + " optimizer, \n", + " T_0=len(train_loader), \n", + " )\n", + " return {\n", + " 'optimizer': optimizer,\n", + " 'lr_scheduler': {\n", + " 'scheduler': lr_scheduler,\n", + " 'interval': 'step'\n", + " }\n", + " }\n", + " # return optimizer\n", + " \n", + " def training_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " print(x.shape)\n", + " y_hat = self(x)\n", + " loss = F.cross_entropy(y_hat, y)\n", + " self.log('train_loss', loss,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " )\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " \n", + " y_hat = self(x)\n", + " loss = F.cross_entropy(y_hat, y)\n", + " self.accuracy(y_hat, y)\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', self.accuracy, prog_bar=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Train" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using bfloat16 Automatic Mixed Precision (AMP)\n", + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], + "source": [ + "model = RFClassifier(mobilenetv3.mobilenetv3(\n", + " model_size='mobilenetv3_small_050', \n", + " num_classes=num_classes, \n", + " drop_rate = hparams['drop_rate'], \n", + " drop_path_rate = hparams['drop_path_rate']\n", + " )\n", + " )\n", + "\n", + "\n", + "\n", + "checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(\n", + " filename=checkpoint_filename,\n", + " save_top_k=True,\n", + " verbose=True,\n", + " monitor='val_acc',\n", + " mode='max',\n", + " enable_version_counter = False,\n", + ")\n", + "trainer = L.Trainer(\n", + " max_epochs=epochs,\n", + " callbacks = [checkpoint_callback,\n", + " ],\n", + " accelerator='gpu',\n", + " devices=1,\n", + " benchmark=True,\n", + " precision='bf16-mixed',\n", + " logger=False\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params | Mode \n", + "--------------------------------------------------------\n", + "0 | model | MobileNetV3 | 550 K | train\n", + "1 | accuracy | MulticlassAccuracy | 0 | train\n", + "--------------------------------------------------------\n", + "550 K Trainable params\n", + "0 Non-trainable params\n", + "550 K Total params\n", + "2.201 Total estimated model params size (MB)\n", + "259 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7bf6aa152c7748b8a2cc4e0411949154", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00\n", + "Traceback (most recent call last):\n", + " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py\", line 1663, in __del__\n", + " self._shutdown_workers()\n", + " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py\", line 1627, in _shutdown_workers\n", + " w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)\n", + " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/multiprocessing/process.py\", line 149, in join\n", + " res = self._popen.wait(timeout)\n", + " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/multiprocessing/popen_fork.py\", line 40, in wait\n", + " if not wait([self.sentinel], timeout):\n", + " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/multiprocessing/connection.py\", line 931, in wait\n", + " ready = selector.select(timeout)\n", + " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/selectors.py\", line 416, in select\n", + " fd_event_list = self._selector.poll(timeout)\n", + "KeyboardInterrupt: \n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c8e35b4f64e84ab39f20b0a36dc0116d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | | 0/? [00:00 4\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_loader\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:561\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 560\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshould_stop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m--> 561\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 562\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 563\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:48\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 51\u001b[0m _call_teardown_hook(trainer)\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:599\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 592\u001b[0m download_model_from_registry(ckpt_path, \u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m 593\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 594\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 595\u001b[0m ckpt_path,\n\u001b[1;32m 596\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 597\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 598\u001b[0m )\n\u001b[0;32m--> 599\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 601\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 602\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1012\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1007\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 1009\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 1011\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m-> 1012\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1014\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 1015\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 1016\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 1017\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1056\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1056\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1058\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:216\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 218\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:455\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 454\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 455\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:151\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madvance(data_fetcher)\n\u001b[0;32m--> 151\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_advance_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:370\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.on_advance_end\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_accumulate():\n\u001b[1;32m 367\u001b[0m \u001b[38;5;66;03m# clear gradients to not leave any unused memory during validation\u001b[39;00m\n\u001b[1;32m 368\u001b[0m call\u001b[38;5;241m.\u001b[39m_call_lightning_module_hook(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_validation_model_zero_grad\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 370\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 371\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 372\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39m_logger_connector\u001b[38;5;241m.\u001b[39m_first_loop_iter \u001b[38;5;241m=\u001b[39m first_loop_iter\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:179\u001b[0m, in \u001b[0;36m_no_grad_context.._decorator\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 177\u001b[0m context_manager \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mno_grad\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context_manager():\n\u001b[0;32m--> 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mloop_run\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:138\u001b[0m, in \u001b[0;36m_EvaluationLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 137\u001b[0m dataloader_iter \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 138\u001b[0m batch, batch_idx, dataloader_idx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m previous_dataloader_idx \u001b[38;5;241m!=\u001b[39m dataloader_idx:\n\u001b[1;32m 140\u001b[0m \u001b[38;5;66;03m# the dataloader has changed, notify the logger connector\u001b[39;00m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_store_dataloader_outputs()\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/fetchers.py:134\u001b[0m, in \u001b[0;36m_PrefetchDataFetcher.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone \u001b[38;5;241m=\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatches\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 133\u001b[0m \u001b[38;5;66;03m# this will run only when no pre-fetching was done.\u001b[39;00m\n\u001b[0;32m--> 134\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__next__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# the iterator is empty\u001b[39;00m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/fetchers.py:61\u001b[0m, in \u001b[0;36m_DataFetcher.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_start_profiler()\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 61\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/utilities/combined_loader.py:341\u001b[0m, in \u001b[0;36mCombinedLoader.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__next__\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _ITERATOR_RETURN:\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_iterator \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 341\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_iterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 342\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_iterator, _Sequential):\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/utilities/combined_loader.py:142\u001b[0m, in \u001b[0;36m_Sequential.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 142\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miterators\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;66;03m# try the next iterator\u001b[39;00m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_use_next_iterator()\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py:733\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 730\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 731\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 732\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 733\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 734\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 735\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 736\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable\n\u001b[1;32m 737\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 738\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called\n\u001b[1;32m 739\u001b[0m ):\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1491\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1488\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data, worker_id)\n\u001b[1;32m 1490\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1491\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1492\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1493\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m 1494\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1453\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1449\u001b[0m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1450\u001b[0m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1451\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1452\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1453\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1454\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 1455\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1284\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_try_get_data\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout\u001b[38;5;241m=\u001b[39m_utils\u001b[38;5;241m.\u001b[39mMP_STATUS_CHECK_INTERVAL):\n\u001b[1;32m 1272\u001b[0m \u001b[38;5;66;03m# Tries to fetch data from `self._data_queue` once for a given timeout.\u001b[39;00m\n\u001b[1;32m 1273\u001b[0m \u001b[38;5;66;03m# This can also be used as inner loop of fetching without timeout, with\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1281\u001b[0m \u001b[38;5;66;03m# Returns a 2-tuple:\u001b[39;00m\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;66;03m# (bool: whether successfully get data, any: data if successful else None)\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1284\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_queue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1285\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n\u001b[1;32m 1286\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1287\u001b[0m \u001b[38;5;66;03m# At timeout and error, we manually check whether any worker has\u001b[39;00m\n\u001b[1;32m 1288\u001b[0m \u001b[38;5;66;03m# failed. Note that this is the only mechanism for Windows to detect\u001b[39;00m\n\u001b[1;32m 1289\u001b[0m \u001b[38;5;66;03m# worker failures.\u001b[39;00m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/multiprocessing/queues.py:122\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rlock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# unserialize the data after having released the lock\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_ForkingPickler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[43mres\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/multiprocessing/reductions.py:560\u001b[0m, in \u001b[0;36mrebuild_storage_filename\u001b[0;34m(cls, manager, handle, size, dtype)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m storage\u001b[38;5;241m.\u001b[39m_shared_decref()\n\u001b[1;32m 559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 560\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mUntypedStorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_new_shared_filename_cpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmanager\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 561\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 562\u001b[0m byte_size \u001b[38;5;241m=\u001b[39m size \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39m_utils\u001b[38;5;241m.\u001b[39m_element_size(dtype)\n", + "\u001b[0;31mRuntimeError\u001b[0m: Connection refused" + ] + } + ], + "source": [ + "train_flag = True\n", + "if train_flag:\n", + " \n", + " trainer.fit(model, train_loader, val_loader)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "checkpoint = torch.load(checkpoint_filename+\".ckpt\", map_location=lambda storage, loc: storage)\n", + "model.load_state_dict(checkpoint['state_dict'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Results" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "model.eval()\n", + "y_raw_preds = []\n", + "y_preds = []\n", + "y_true = []\n", + "loss = 0\n", + "\n", + "model = model.to(device)\n", + "with torch.inference_mode():\n", + " with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", + " for sample in val_loader:\n", + " data, target = sample[0].to(device), sample[1].to(device)\n", + " means = torch.mean(data, axis=-1, keepdims=True)\n", + " stds = torch.std(data, axis=-1, keepdims=True)\n", + " data = (data - means) / stds\n", + " y_true += target.tolist()\n", + " output = model(data)\n", + " y_raw_preds.append(output) # logits\n", + " for pred in output.argmax(dim=1, keepdim=True):\n", + " y_preds.append(int(pred))\n", + " loss += F.cross_entropy(output, target)/output.shape[0]\n", + " y_raw_preds = torch.vstack(y_raw_preds).cpu().to(dtype=torch.float32) # convert list of outputs to tensor" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'le' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[30], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m acc \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msum(np\u001b[38;5;241m.\u001b[39masarray(y_preds)\u001b[38;5;241m==\u001b[39mnp\u001b[38;5;241m.\u001b[39masarray(y_true))\u001b[38;5;241m/\u001b[39m\u001b[38;5;28mlen\u001b[39m(y_true)\n\u001b[1;32m 2\u001b[0m plot_confusion_matrix(\n\u001b[1;32m 3\u001b[0m y_true, \n\u001b[1;32m 4\u001b[0m y_preds, \n\u001b[0;32m----> 5\u001b[0m classes\u001b[38;5;241m=\u001b[39m\u001b[43mle\u001b[49m\u001b[38;5;241m.\u001b[39mclasses_,\n\u001b[1;32m 6\u001b[0m normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 7\u001b[0m title\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdataset_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Validation Set\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mTotal Accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00macc\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m100\u001b[39m\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 8\u001b[0m text\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 9\u001b[0m rotate_x_text\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m90\u001b[39m,\n\u001b[1;32m 10\u001b[0m figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m12\u001b[39m,\u001b[38;5;241m6.75\u001b[39m),\n\u001b[1;32m 11\u001b[0m )\n", + "\u001b[0;31mNameError\u001b[0m: name 'le' is not defined" + ] + } + ], + "source": [ + "acc = np.sum(np.asarray(y_preds)==np.asarray(y_true))/len(y_true)\n", + "plot_confusion_matrix(\n", + " y_true, \n", + " y_preds, \n", + " classes=le.classes_,\n", + " normalize=True,\n", + " title=f'{dataset_name} Validation Set\\nTotal Accuracy: {acc*100:.2f}%',\n", + " text=True,\n", + " rotate_x_text=90,\n", + " figsize=(12,6.75),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(classification_report(y_true, y_preds, target_names=le.classes_))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_onehot_test = F.one_hot(torch.tensor(y_true), num_classes)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 6))\n", + "colors = ['#e6194B', '#f58231', '#ffe119', '#3cb44b', '#42d4f4', '#4363d8', '#f032e6', \n", + " '#a9a9a9', '#800000', '#9A6324']\n", + "for class_id, color in zip(range(num_classes), colors):\n", + " RocCurveDisplay.from_predictions(\n", + " y_onehot_test[:, class_id],\n", + " y_raw_preds[:, class_id],\n", + " name=f\"{le.classes_[class_id]}\",\n", + " color=color,\n", + " ax=ax,\n", + " plot_chance_level=(class_id == 7),\n", + " )\n", + "\n", + "_ = ax.set(\n", + " xlabel=\"False Positive Rate\",\n", + " ylabel=\"True Positive Rate\",\n", + " title=f'{dataset_name} ROC Curves',\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds_test = C2_H5_Dataset(test_data, label)\n", + "\n", + "test_loader = torch.utils.data.DataLoader(\n", + " dataset=ds_test,\n", + " batch_size=2048,\n", + " shuffle=False,\n", + " num_workers=8,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f'Test Dataset, {len(ds_test)} examples of length {ds_test[0][0].shape[1]}')\n", + "print(f'Annotations: {list(ds_test.annotations.columns)}')\n", + "print(f'Labels:')\n", + "for a in ds_test.annotations.columns:\n", + " print(f'\\t{a}: {ds_test.annotations[a].unique()}')\n", + "\n", + "print(f'\\nLabel being used: {label}')\n", + "\n", + "print(f'\\n{ds_test.annotations[label].value_counts()}')\n", + "\n", + "ds_test.encode_labels(le)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "model.eval()\n", + "y_raw_preds = []\n", + "y_preds = []\n", + "y_true = []\n", + "loss = 0\n", + "\n", + "model = model.to(device)\n", + "with torch.inference_mode():\n", + " with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", + " for sample in test_loader:\n", + " data, target = sample[0].to(device), sample[1].to(device)\n", + "\n", + " y_true += target.tolist()\n", + " output = model(data)\n", + " y_raw_preds.append(output) # logits\n", + " for pred in output.argmax(dim=1, keepdim=True):\n", + " y_preds.append(int(pred))\n", + " # print(F.cross_entropy(output, target))\n", + " loss += F.cross_entropy(output, target)/output.shape[0]\n", + " y_raw_preds = torch.vstack(y_raw_preds).cpu().to(dtype=torch.float32) # convert list of outputs to tensor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "acc = np.sum(np.asarray(y_preds)==np.asarray(y_true))/len(y_true)\n", + "plot_confusion_matrix(\n", + " y_true, \n", + " y_preds, \n", + " classes=le.classes_,\n", + " normalize=True,\n", + " title=f'{dataset_name} Inference Test\\nTotal Accuracy: {acc*100:.2f}%',\n", + " text=True,\n", + " rotate_x_text=90,\n", + " figsize=(12,6.75),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(classification_report(y_true, y_preds, target_names=le.classes_))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "y_onehot_test = F.one_hot(torch.tensor(y_true), num_classes)\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 6))\n", + "colors = ['#332288', '#88CCEE', '#44AA99', '#117733', '#999933', '#DDCC77', '#CC6677', '#882255']\n", + "for class_id, color in zip(range(num_classes), colors):\n", + " RocCurveDisplay.from_predictions(\n", + " y_onehot_test[:, class_id],\n", + " y_raw_preds[:, class_id],\n", + " name=f\"{le.classes_[class_id]}\",\n", + " color=color,\n", + " ax=ax,\n", + " plot_chance_level=(class_id == 7),\n", + " )\n", + "\n", + "_ = ax.set(\n", + " xlabel=\"False Positive Rate\",\n", + " ylabel=\"True Positive Rate\",\n", + " title=f'{dataset_name} ROC Curves',\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "qoherent-env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data/models/mobilenetv3.py b/data/models/mobilenetv3.py new file mode 100644 index 0000000..549bc51 --- /dev/null +++ b/data/models/mobilenetv3.py @@ -0,0 +1,206 @@ +import numpy as np +import torch +import timm +from torch import nn + +sizes = [ + 'mobilenetv3_large_075', + 'mobilenetv3_large_100', + 'mobilenetv3_rw', + 'mobilenetv3_small_050', + 'mobilenetv3_small_075', + 'mobilenetv3_small_100', + 'tf_mobilenetv3_large_075', + 'tf_mobilenetv3_large_100', + 'tf_mobilenetv3_large_minimal_100', + 'tf_mobilenetv3_small_075', + 'tf_mobilenetv3_small_100', + 'tf_mobilenetv3_small_minimal_100' + ] + +class SqueezeExcite(nn.Module): + def __init__( + self, + in_chs, + se_ratio=0.25, + reduced_base_chs=None, + act_layer=nn.SiLU, + gate_fn=torch.sigmoid, + divisor=1, + **_, + ): + super(SqueezeExcite, self).__init__() + reduced_chs = reduced_base_chs + self.conv_reduce = nn.Conv1d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv1d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn + + def forward(self, x): + x_se = x.mean((2,), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + return x * self.gate_fn(x_se) + + +class FastGlobalAvgPool1d(nn.Module): + def __init__(self, flatten=False): + super(FastGlobalAvgPool1d, self).__init__() + self.flatten = flatten + + def forward(self, x): + if self.flatten: + in_size = x.size() + return x.view((in_size[0], in_size[1], -1)).mean(dim=2) + else: + return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1) + + + +class GBN(torch.nn.Module): + """ + Ghost Batch Normalization + https://arxiv.org/abs/1705.08741 + """ + + def __init__(self, input_dim, drop, act, virtual_batch_size=32, momentum=0.1): + super(GBN, self).__init__() + + self.input_dim = input_dim + self.virtual_batch_size = virtual_batch_size + self.bn = nn.BatchNorm1d(self.input_dim, momentum=momentum) + self.drop = drop + self.act = act + + def forward(self, x): + # chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0) + # res = [self.bn(x_) for x_ in chunks] + # return self.drop(self.act(torch.cat(res, dim=0))) + # x = self.bn(x) + # x = self.act(x) + # x = self.drop(x) + # return x + return self.drop(self.act(self.bn(x))) + + +def replace_bn(parent): + for n, m in parent.named_children(): + if type(m) is timm.layers.norm_act.BatchNormAct2d: + # if type(m) is nn.BatchNorm2d: + # print(type(m)) + setattr( + parent, + n, + GBN(m.num_features, m.drop, m.act), + ) + else: + replace_bn(m) + +def replace_se(parent): + for n, m in parent.named_children(): + if type(m) is timm.models._efficientnet_blocks.SqueezeExcite: + setattr( + parent, + n, + SqueezeExcite( + m.conv_reduce.in_channels, + reduced_base_chs=m.conv_reduce.out_channels, + ), + ) + else: + replace_se(m) + +def replace_conv(parent, ds_rate): + for n, m in parent.named_children(): + if type(m) is nn.Conv2d: + if ds_rate == 2: + setattr( + parent, + n, + nn.Conv1d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size[0], + stride=m.stride[0], + padding=m.padding[0], + bias=m.kernel_size[0], + groups=m.groups, + ), + ) + else: + setattr( + parent, + n, + nn.Conv1d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size[0] if m.kernel_size[0] == 1 else 5, + stride=m.stride[0] if m.stride[0] == 1 else ds_rate, + padding=m.padding[0] if m.padding[0] == 0 else 2, + bias=m.kernel_size[0], + groups=m.groups, + ), + ) + else: + replace_conv(m, ds_rate) + +def create_mobilenetv3(network, ds_rate=2, in_chans=2): + replace_se(network) + replace_bn(network) + replace_conv(network, ds_rate) + network.global_pool = FastGlobalAvgPool1d() + + network.conv_stem = nn.Conv1d( + in_channels=in_chans, + out_channels=network.conv_stem.out_channels, + kernel_size=network.conv_stem.kernel_size, + stride=network.conv_stem.stride, + padding=network.conv_stem.padding, + bias=network.conv_stem.kernel_size, + groups=network.conv_stem.groups, + ) + + return network + +def mobilenetv3( + model_size = 'mobilenetv3_small_050', + num_classes: int = 10, + drop_rate: float = 0, + drop_path_rate: float = 0, + in_chans=2, +): + mdl = create_mobilenetv3( + timm.create_model( + model_size, + num_classes=num_classes, + in_chans=in_chans, + drop_path_rate=drop_path_rate, + drop_rate=drop_rate, + exportable=True, + ), + in_chans=in_chans, + ) + return mdl + +import torch.nn as nn + +class Simple1DCNN(nn.Module): + def __init__(self, in_chans=2, num_classes=4): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(in_chans, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool1d(2), + nn.Conv1d(32, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool1d(1), + nn.Flatten(), + nn.Linear(64, num_classes) + ) + + def forward(self, x): + return self.net(x) # x shape: [B, 2, 128] + +def simple_cnn(in_chans=2, num_classes=4): + return Simple1DCNN(in_chans, num_classes) diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_0264b4a.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_0264b4a.npy new file mode 100644 index 0000000..c099ad4 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_0264b4a.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b05da5de57afd3e9b70fe69483ba3dfec7c9e5c16509558964de97fb461cdf2 +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_0b3b80f.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_0b3b80f.npy new file mode 100644 index 0000000..9d962d5 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_0b3b80f.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34787243d616952741ae11444478f967e9d38851609f006cafdbc0d0ce38d7fe +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_1effc4c.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_1effc4c.npy new file mode 100644 index 0000000..deee079 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_1effc4c.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c419fe64b7cba90f5323154742aa73ce759abafda9893ddc90d57db832b6b8a5 +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_37a73db.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_37a73db.npy new file mode 100644 index 0000000..9c6df67 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_37a73db.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65a25e6375871ab3556900abc7b96c9240827ec6c51d5d2f2d864cf0786128f4 +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_3d557a9.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_3d557a9.npy new file mode 100644 index 0000000..ad80691 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_3d557a9.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f38ac9a0f7192325b3cf9c03cc311439bf2b020772f019587a1e338969ae0a1f +size 17107 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_442fcb9.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_442fcb9.npy new file mode 100644 index 0000000..09158a9 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_442fcb9.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f76cae7448b608f024baea6685978929458103bc83196aae2ddd1a47af69ea43 +size 17107 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_491c457.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_491c457.npy new file mode 100644 index 0000000..7af0b08 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_491c457.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f4f8c00a9937d7245d017b792df0390692a9dd201fcd156d24255a6f0450692 +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_4fff84f.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_4fff84f.npy new file mode 100644 index 0000000..32f69ec --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_4fff84f.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47d51d307292089b66018d00ff39ba1e3ccafadc87829f46d5ad45599c5c7d97 +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_6676600.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_6676600.npy new file mode 100644 index 0000000..eb3825b --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_6676600.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:108e4fff19e390b1b66da936752e8341b4cf36443fd511d4114b5082f8b20580 +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_6d35ff9.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_6d35ff9.npy new file mode 100644 index 0000000..736b9aa --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_6d35ff9.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c62ff017c8666fa06555fdd102f75409781d98679f010e11fca6105fbd8b7e85 +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_6d85f3e.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_6d85f3e.npy new file mode 100644 index 0000000..9f70c2d --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_6d85f3e.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2f07fecedf5718257bf7837c2f28d7a6f83d9bfa43473a6d0aa7baa69a1b7ac +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_85a8c83.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_85a8c83.npy new file mode 100644 index 0000000..3771ab4 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_85a8c83.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db2e9425c889ee30d3f7c8e15996d3ce201dfb413a3e864ec95ac1c1e15d56a4 +size 17107 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_940988e.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_940988e.npy new file mode 100644 index 0000000..3e280c6 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_940988e.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9be736990738b2fba0e70a59b920d6ebadfbf7e4d4ba7c01afb4712d4b7b8ad0 +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_9f88dc2.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_9f88dc2.npy new file mode 100644 index 0000000..9915d10 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_9f88dc2.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b684a66eb335dc79ec669b9e8a6b2a73babd1c2a6394dc00b27f62762be20e68 +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_a4a6ba6.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_a4a6ba6.npy new file mode 100644 index 0000000..06fa5e1 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_a4a6ba6.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e36174699020ff6f5571b43b323ab3daa59cf819e26424cb462d26577637c91 +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_a60964b.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_a60964b.npy new file mode 100644 index 0000000..b56a221 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_a60964b.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3dfd1c4ef497a0530235fff62a9b642ae356845eca9e5db25062ce7773e21f96 +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_ad350fe.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_ad350fe.npy new file mode 100644 index 0000000..e45671e --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_ad350fe.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ad8f2ec6dae3c1ecf46f5e18a782436ccee6cbd7154340b156a5790d786664d2 +size 17107 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_ae5224a.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_ae5224a.npy new file mode 100644 index 0000000..7f54193 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_ae5224a.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8fba3e786b31e41c624fa01228c81a430b32b562d6ae245840e72d25aa83d556 +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_b68f080.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_b68f080.npy new file mode 100644 index 0000000..60bfa88 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_b68f080.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb45074f11187e4e2461efe47536b9704f47bc2ccd59801193fbc5c1a1261c09 +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_c00477b.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_c00477b.npy new file mode 100644 index 0000000..3dbc92d --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_c00477b.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2ddc93a5d31248700f4a5152af49d951b80712132c35c6f26f6838b5c5613402 +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_cca57ca.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_cca57ca.npy new file mode 100644 index 0000000..d323569 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_cca57ca.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8e8aa5d85ca26b34e47b02022aae3006981f1d386a91a8b6e6afe54d75bacaf +size 17106 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_db8a5b4.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_db8a5b4.npy new file mode 100644 index 0000000..47f65fc --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_db8a5b4.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4cb0bcde7d2ef19ead6bf460198cb8c3b32a4324f0148bb95425ba1bfe96f63c +size 17106 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_dd021f7.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_dd021f7.npy new file mode 100644 index 0000000..7b0edad --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_dd021f7.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7d1a52d4634b04cc6dd0f50f2c6882f139e0165a0a29be1076cb784f0629a10f +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_e0cc41d.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_e0cc41d.npy new file mode 100644 index 0000000..ca89984 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_e0cc41d.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f54fe5742a134f08dd21f3fbe8c8f2c72905dcd03ce771550d8b785d9c27aa8d +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_e61d9bf.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_e61d9bf.npy new file mode 100644 index 0000000..d6cd072 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_e61d9bf.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4d21bdc3474e20bd951da6e61bebfbd091113594f1f4d5d02fa173ca813a6f06 +size 17106 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_f024082.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_f024082.npy new file mode 100644 index 0000000..4768b30 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_f024082.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cba2de88026b625c954c9c6b875ee764e60ebe76587f00ba5e6a7f7f5dcfa814 +size 17103 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_f2013fa.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_f2013fa.npy new file mode 100644 index 0000000..e71fe04 --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_f2013fa.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f278f5ed086f9edd365200f2498f24ef26aa8259e684b932cf804dc14c09b282 +size 17104 diff --git a/data/recordings/rec_0Hz_2025-05-15_09-45-10_f2ae593.npy b/data/recordings/rec_0Hz_2025-05-15_09-45-10_f2ae593.npy new file mode 100644 index 0000000..192ed4a --- /dev/null +++ b/data/recordings/rec_0Hz_2025-05-15_09-45-10_f2ae593.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e327873b5962e61ca11bd0221e79547f4f69f0f655a6e44450d80b7ac4ec9ff9 +size 17106 diff --git a/data/scripts/data_gen.py b/data/scripts/data_gen.py new file mode 100644 index 0000000..5ca8563 --- /dev/null +++ b/data/scripts/data_gen.py @@ -0,0 +1,69 @@ +from utils.data import Recording +import numpy as np +from utils.signal import block_generator + +mods = { + "bpsk": {"num_bits_per_symbol": 1, "constellation_type": "psk"}, + "qpsk": {"num_bits_per_symbol": 2, "constellation_type": "psk"}, + "qam16": {"num_bits_per_symbol": 4, "constellation_type": "qam"}, + "qam64": {"num_bits_per_symbol": 6, "constellation_type": "qam"}, +} + + +def generate_modulated_signals(): + for modulation in ["bpsk", "qpsk", "qam16", "qam64"]: + for snr in np.arange(-6, 13, 3): + + recording_length = 1024 + beta = 0.3 # the rolloff factor, can be changed to add variety + sps = 4 # samples per symbol, or the relative bandwidth of the digital signal. Can also be changed. + + # blocks don't directly take the string 'qpsk' so we use the dict 'mods' to get parameters + constellation_type = mods[modulation]["constellation_type"] + num_bits_per_symbol = mods[modulation]["num_bits_per_symbol"] + + # construct the digital modulation blocks with these parameters + # we have bit source -> mapper -> upsampling -> pulse shaping + + bit_source = block_generator.RandomBinarySource() + mapper = block_generator.Mapper( + constellation_type=constellation_type, + num_bits_per_symbol=num_bits_per_symbol, + ) + upsampler = block_generator.Upsampling(factor=sps) + pulse_shaping_filter = block_generator.RaisedCosineFilter( + upsampling_factor=sps, beta=beta + ) + + pulse_shaping_filter.connect_input([upsampler]) + upsampler.connect_input([mapper]) + mapper.connect_input([bit_source]) + + modulation_recording = pulse_shaping_filter.record( + num_samples=recording_length + ) + + # add noise by calculating the power of the modulation recording and generating AWGN from the snr parameter + signal_power = np.mean(np.abs(modulation_recording.data[0] ** 2)) + awgn_source = block_generator.AWGNSource( + variance=(signal_power / 2) * (10 ** (((-1 * snr) / 20))) + ) + noise = awgn_source.record(num_samples=recording_length) + samples_with_noise = modulation_recording.data + noise.data + output_recording = Recording(data=samples_with_noise) + + # add metadata for ML later + output_recording.add_to_metadata(key="modulation", value=modulation) + output_recording.add_to_metadata(key="snr", value=int(snr)) + output_recording.add_to_metadata(key="beta", value=beta) + output_recording.add_to_metadata(key="sps", value=sps) + + # view if you want + # output_recording.view() + + # save to file + output_recording.to_npy() # optionally add path and filename parameters + + +if __name__ == "__main__": + generate_modulated_signals() diff --git a/data/scripts/produce_dataset.py b/data/scripts/produce_dataset.py new file mode 100644 index 0000000..e64daef --- /dev/null +++ b/data/scripts/produce_dataset.py @@ -0,0 +1,152 @@ +import os, h5py, numpy as np +from utils.io import from_npy +from split_dataset import split, split_recording +from helpers.app_settings import get_app_settings + +meta_dtype = np.dtype( + [ + ("rec_id", "S256"), + ("snippet_idx", np.int32), + ("modulation", "S32"), + ("snr", np.int32), + ("beta", np.float32), + ("sps", np.int32), + ] +) + +info_dtype = np.dtype( + [ + ("num_records", np.int32), + ("dataset_name", "S64"), # up to 64‐byte UTF-8 strings + ("creator", "S64"), + ] +) + + + +def write_hdf5_file(records, output_path, dataset_name="data"): + """ + Writes a list of records to an HDF5 file. + Parameters: + records (list): List of records to be written to the file + output_path (str): Path to the output HDF5 file + dataset_name (str): Name of the dataset in the HDF5 file (default: "data") + Returns: + str: Path to the created HDF5 file + """ + meta_arr = np.empty(len(records), dtype=meta_dtype) + for i, (_, md) in enumerate(records): + meta_arr[i] = ( + md["rec_id"].encode("utf-8"), + md["snippet_idx"], + md["modulation"].encode("utf-8"), + int(md["snr"]), + float(md["beta"]), + int(md["sps"]), + ) + + first_rec, _ = records[0] # records[0] is a tuple of (data, md) + sample = first_rec + shape, dtype = sample.shape, sample.dtype + + with h5py.File(output_path, "w") as hf: + dset = hf.create_dataset( + dataset_name, shape=(len(records),) + shape, dtype=dtype, compression="gzip" + ) + + for idx, (snip, md) in enumerate(records): + dset[idx, ...] = snip + + mg = hf.create_group("metadata") + mg.create_dataset("metadata", data=meta_arr, compression="gzip") + + print(dset.shape, f"snippets created in {dataset_name}") + + info_arr = np.array( + [ + ( + len(records), + dataset_name.encode("utf-8"), + b"generate_dataset.py", # already bytes + ) + ], + dtype=info_dtype, + ) + + mg.create_dataset("dataset_info", data=info_arr) + + return output_path + +def complex_to_channel(data): + """ + Convert complex-valued IQ data of shape (1, N) to 2-channel real array of shape (2, N). + """ + assert np.iscomplexobj(data) #check if the data is in the form a+bi + real = np.real(data[0]) # (N,) + imag = np.imag(data[0]) # (N,) + stacked = np.stack([real, imag], axis=0) # shape (2, N) + return stacked.astype(np.float32) + + +def generate_datasets(cfg): + """ + Generates a dataset from a folder of .npy files and saves it to an HDF5 file + + Parameters: + path_to_recordings (str): Path to the folder containing .npy files + output_path (str): Path to the output HDF5 file + dataset_name (str): Name of the dataset in the HDF5 file (default: "data") + + Returns: + dset (h5py.Dataset): The created dataset object + """ + + parent = os.path.dirname(cfg.output_dir) + if not parent: + os.makedirs(cfg.output_dir, exist_ok=True) + + # we assume the recordings are in .npy format + files = os.listdir(cfg.input_dir) + if not files: + raise ValueError("No files found in the specified directory.") + + records = [] + for fname in files: + rec = from_npy(os.path.join(cfg.input_dir, fname)) + + data = rec.data #here data is a numpy array with the shape (1, N) + + data = complex_to_channel(data) # convert to 2-channel real array + + + md = rec.metadata # pull metadata from the recording + md.setdefault("recid", len(records)) + records.append((data, md)) + + # split each recording into snippets each + + + records = split_recording(records, cfg.num_slices) + + + + train_records, val_records = split(records, cfg.train_split, cfg.seed) + + train_path = os.path.join(cfg.output_dir, "train.h5") + val_path = os.path.join(cfg.output_dir, "val.h5") + + write_hdf5_file(train_records, train_path, "training_data") + write_hdf5_file(val_records, val_path, "validation_data") + + return train_path, val_path + + +def main(): + settings = get_app_settings() + dataset_cfg = settings.dataset + train_path, val_path = generate_datasets(dataset_cfg) + print(f"✅ Train: {train_path}\n✅ Val: {val_path}") + + +if __name__ == "__main__": + main() diff --git a/data/scripts/split_dataset.py b/data/scripts/split_dataset.py new file mode 100644 index 0000000..061d7d5 --- /dev/null +++ b/data/scripts/split_dataset.py @@ -0,0 +1,79 @@ +import random +from collections import defaultdict + + +def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"): + """ + Splits a dataset into smaller datasets based on the specified lengths. + + Parameters: + dataset (list): The dataset to be split. + lengths (list): A list of lengths for each split. + + Returns: + list: A list of split datasets. + """ + rec_buckets = defaultdict(list) + for data, md in dataset: + rec_buckets[md["recid"]].append((data, md)) + + + rec_labels = {} #store labels for each recording + for rec_id, group in rec_buckets.items(): + label = group[0][1][label_key] + if isinstance(label, bytes): #if the label is a byte string + label = label.decode("utf-8") + rec_labels[rec_id] = label + + label_rec_ids = defaultdict(list) #group rec_ids by label + for rec_id, label in rec_labels.items(): + label_rec_ids[label].append(rec_id) + + random.seed(seed) + train_recs, val_recs = set(), set() + + for label, rec_ids in label_rec_ids.items(): + random.shuffle(rec_ids) + split_idx = int(len(rec_ids) * train_frac) + train_recs.update(rec_ids[:split_idx]) #pulls train_frac or rec_ids per label, guarantees all modulations are represented + val_recs.update(rec_ids[split_idx:]) + + + + + # add the assigned recordings to the train and val datasets + train_dataset, val_dataset = [], [] + for rec_id, group in rec_buckets.items(): + if rec_id in train_recs: + train_dataset.extend(group) + elif rec_id in val_recs: + val_dataset.extend(group) + + + return train_dataset, val_dataset + +def split_recording(recording_list, num_snippets): + """ + Splits a list of recordings into smaller chunks. + + Parameters: + recording_list (list): List of recordings to be split + + Returns: yeah yeah + list: List of split recordings + """ + snippet_list = [] + + for data, md in recording_list: + C, N = data.shape + L = N // num_snippets + for i in range(num_snippets): + start = i * L + end = (i + 1) * L + snippet = data[:, start:end] + + # copy the metadata, adding a snippet index + snippet_md = md.copy() + snippet_md["snippet_idx"] = i + snippet_list.append((snippet, snippet_md)) + return snippet_list