From ba796961a36c78907767f322dab0dbe01e59f7a0 Mon Sep 17 00:00:00 2001 From: liyuxiao2 Date: Thu, 22 May 2025 14:11:18 -0400 Subject: [PATCH] reorganized file structure --- data/models/interference_recognition.ipynb | 3428 ----------------- data/{models => training}/cm_plotter.py | 30 +- data/{models => training}/mobilenetv3.py | 78 +- .../modulation_dataset.py | 30 +- data/training/train.py | 153 + 5 files changed, 225 insertions(+), 3494 deletions(-) delete mode 100644 data/models/interference_recognition.ipynb rename data/{models => training}/cm_plotter.py (73%) rename data/{models => training}/mobilenetv3.py (74%) rename data/{models => training}/modulation_dataset.py (81%) create mode 100644 data/training/train.py diff --git a/data/models/interference_recognition.ipynb b/data/models/interference_recognition.ipynb deleted file mode 100644 index d617ef7..0000000 --- a/data/models/interference_recognition.ipynb +++ /dev/null @@ -1,3428 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['bpsk', 'qpsk', 'qam16', 'qam64']\n" - ] - } - ], - "source": [ - "import sys, os\n", - "project_root = os.path.abspath(os.path.join(os.getcwd(), \"../..\"))\n", - "if project_root not in sys.path:\n", - " sys.path.insert(0, project_root)\n", - "import lightning as L\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torch.utils.data import Dataset\n", - "import torchmetrics\n", - "from sklearn.preprocessing import LabelEncoder\n", - "from sklearn.metrics import classification_report\n", - "import matplotlib.pyplot as plt\n", - "from sklearn.metrics import RocCurveDisplay\n", - "import h5py\n", - "from helpers.app_settings import get_app_settings\n", - "from modulation_dataset import ModulationH5Dataset\n", - "import torch.nn as nn\n", - "\n", - "\n", - "settings = get_app_settings()\n", - "dataset = settings.dataset.modulation_types\n", - "\n", - "print(dataset)\n", - "\n", - "\n", - "\n", - "import mobilenetv3\n", - "from cm_plotter import plot_confusion_matrix" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Training Settings" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "# Set this to true to train a new model from scratch, otherwise it will load an existing checkpoint\n", - "# Training will overwrite an existing checkpoint\n", - "train_flag = True\n", - "\n", - "batch_size = 128\n", - "epochs = 100\n", - "\n", - "checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'\n", - "\n", - "train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'\n", - "val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'\n", - "\n", - "\n", - "dataset_name = 'Modulation Inference - Initial Model'\n", - "metadata_names = 'Modulation'\n", - "label = 'modulation'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Precision Settings" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "torch.set_float32_matmul_precision('high')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "X shape: torch.Size([128, 2, 128])\n", - "Y values: tensor([1, 3, 1, 0, 0, 3, 2, 1, 2, 1])\n", - "Train Dataset: 160 examples, each of shape torch.Size([2, 128])\n", - "Labels found in metadata: modulation\n", - "Example label values (first 5): ['qam64', 'qam64', 'qam64', 'qam64', 'qam64']\n", - "Unique labels in training data: ['bpsk', 'qam64', 'qpsk', 'qam16']\n", - "\n", - "Label being used: modulation\n", - "Number of classes: 4\n", - "Class index mapping: {'bpsk': 0, 'qam16': 1, 'qam64': 2, 'qpsk': 3}\n" - ] - } - ], - "source": [ - "\n", - "ds_train = ModulationH5Dataset(train_data, label, data_key= \"training_data\")\n", - "ds_val = ModulationH5Dataset(val_data, label, data_key=\"validation_data\")\n", - "\n", - "train_loader = torch.utils.data.DataLoader(\n", - " dataset=ds_train,\n", - " batch_size=batch_size,\n", - " shuffle=True,\n", - " num_workers=8,\n", - " )\n", - "val_loader = torch.utils.data.DataLoader(\n", - " dataset=ds_val,\n", - " batch_size=2048,\n", - " shuffle=False,\n", - " num_workers=8,\n", - " )\n", - "\n", - "for x, y in train_loader:\n", - " print(\"X shape:\", x.shape)\n", - " print(\"Y values:\", y[:10])\n", - " break\n", - "\n", - "\n", - "print(f'Train Dataset: {len(ds_train)} examples, each of shape {ds_train[0][0].shape}')\n", - "print(f'Labels found in metadata: {label}')\n", - "print(f'Example label values (first 5): {[row[label].decode(\"utf-8\") for row in ds_train.metadata[:5]]}')\n", - "\n", - "unique_labels = list(set([row[label].decode(\"utf-8\") for row in ds_train.metadata]))\n", - "print(f'Unique labels in training data: {unique_labels}')\n", - "\n", - "\n", - "\n", - "num_classes = len(ds_train.label_encoder.classes_)\n", - "print(f'\\nLabel being used: {label}')\n", - "print(f'Number of classes: {num_classes}')\n", - "print(f'Class index mapping: {dict(zip(ds_train.label_encoder.classes_, range(num_classes)))}')" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Validation Dataset: 64 examples of length 128\n", - "Annotation fields: ('rec_id', 'snippet_idx', 'modulation', 'snr', 'beta', 'sps')\n", - "Labels:\n", - "rec_id\n", - "\trec_id: ['491c4575bc4b3605997a42031f886c2bcc1dfcb36f7cd5c9991958dd234fd150', '66766004b23abcb02301e70dd8bbb21f802bc46f7c4ae2c65f36d4b4afdacfd0', '6d85f3e640d771e2106eb3fcf2c277def16292edbd37de62d5bd8f6b41962e95', '940988edd5df81d1170c5794ec259629cddefa674049e871f5229a9a24cb94bc', 'a4a6ba686c081920f6f6c7218a300939fc75bdf39d24044c0e084f75d264e9e3', 'dd021f7cce295866dfd8e7cfd109577a0b5f8dbbe25c2624ca3bb2d8fafc18e5', 'e61d9bfb32bb2d4bfa0666b764e32a018356bf8517d820b731d911bd1b67e4f6', 'f024082e859fc2b044bc410eeb8285ebae297c25808e596b310d11ef40f9b70c']\n", - "snippet_idx\n", - "\tsnippet_idx: [0, 1, 2, 3, 4, 5, 6, 7]\n", - "modulation\n", - "\tmodulation: ['bpsk', 'qam16', 'qam64', 'qpsk']\n", - "snr\n", - "\tsnr: [-3, 0, 3, 6]\n", - "beta\n", - "\tbeta: [0.3]\n", - "sps\n", - "\tsps: [4]\n", - "\n", - "Label being used: modulation\n", - "\tqam64: 16\n", - "\tqam16: 16\n", - "\tbpsk: 16\n", - "\tqpsk: 16\n" - ] - } - ], - "source": [ - "print(f'Validation Dataset: {len(ds_val)} examples of length {ds_val[0][0].shape[1]}')\n", - "print(f'Annotation fields: {ds_val.metadata.dtype.names}')\n", - "print('Labels:')\n", - "\n", - "\n", - "for a in ds_val.metadata.dtype.names:\n", - " print(a)\n", - " # decode only if dtype is bytes\n", - " values = [row[a].decode('utf-8') if isinstance(row[a], bytes) else row[a] for row in ds_val.metadata]\n", - " unique_vals = sorted(set(values))\n", - " print(f'\\t{a}: {unique_vals}')\n", - "\n", - "print(f'\\nLabel being used: {label}')\n", - "\n", - "# Print value counts for the label field\n", - "from collections import Counter\n", - "label_values = [row[label].decode('utf-8') for row in ds_val.metadata]\n", - "value_counts = Counter(label_values)\n", - "for k, v in value_counts.items():\n", - " print(f'\\t{k}: {v}')\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "hparams = {\n", - " 'drop_path_rate': 0.2,\n", - " 'drop_rate': 0.5,\n", - " 'learning_rate': 3e-4,\n", - " 'wd': 0.2}\n", - "class RFClassifier(L.LightningModule):\n", - " def __init__(self, model):\n", - " super().__init__()\n", - "\n", - " self.model = model\n", - " self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)\n", - " \n", - "\n", - " def forward(self,x):\n", - " return self.model(x)\n", - " \n", - " def configure_optimizers(self):\n", - " optimizer = torch.optim.AdamW(self.parameters(),\n", - " lr = hparams['learning_rate'],\n", - " weight_decay = hparams['wd'],\n", - " )\n", - " lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(\n", - " optimizer, \n", - " T_0=len(train_loader), \n", - " )\n", - " return {\n", - " 'optimizer': optimizer,\n", - " 'lr_scheduler': {\n", - " 'scheduler': lr_scheduler,\n", - " 'interval': 'step'\n", - " }\n", - " }\n", - " # return optimizer\n", - " \n", - " def training_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " print(x.shape)\n", - " y_hat = self(x)\n", - " loss = F.cross_entropy(y_hat, y)\n", - " self.log('train_loss', loss,\n", - " on_epoch=True,\n", - " prog_bar=True,\n", - " )\n", - " return loss\n", - "\n", - " def validation_step(self, batch, batch_idx):\n", - " x, y = batch\n", - " \n", - " y_hat = self(x)\n", - " loss = F.cross_entropy(y_hat, y)\n", - " self.accuracy(y_hat, y)\n", - " self.log('val_loss', loss, prog_bar=True)\n", - " self.log('val_acc', self.accuracy, prog_bar=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Train" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using bfloat16 Automatic Mixed Precision (AMP)\n", - "GPU available: True (mps), used: True\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n" - ] - } - ], - "source": [ - "model = RFClassifier(mobilenetv3.mobilenetv3(\n", - " model_size='mobilenetv3_small_050', \n", - " num_classes=num_classes, \n", - " drop_rate = hparams['drop_rate'], \n", - " drop_path_rate = hparams['drop_path_rate']\n", - " )\n", - " )\n", - "\n", - "\n", - "\n", - "checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(\n", - " filename=checkpoint_filename,\n", - " save_top_k=True,\n", - " verbose=True,\n", - " monitor='val_acc',\n", - " mode='max',\n", - " enable_version_counter = False,\n", - ")\n", - "trainer = L.Trainer(\n", - " max_epochs=epochs,\n", - " callbacks = [checkpoint_callback,\n", - " ],\n", - " accelerator='gpu',\n", - " devices=1,\n", - " benchmark=True,\n", - " precision='bf16-mixed',\n", - " logger=False\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - " | Name | Type | Params | Mode \n", - "--------------------------------------------------------\n", - "0 | model | MobileNetV3 | 550 K | train\n", - "1 | accuracy | MulticlassAccuracy | 0 | train\n", - "--------------------------------------------------------\n", - "550 K Trainable params\n", - "0 Non-trainable params\n", - "550 K Total params\n", - "2.201 Total estimated model params size (MB)\n", - "259 Modules in train mode\n", - "0 Modules in eval mode\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7bf6aa152c7748b8a2cc4e0411949154", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00\n", - "Traceback (most recent call last):\n", - " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py\", line 1663, in __del__\n", - " self._shutdown_workers()\n", - " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py\", line 1627, in _shutdown_workers\n", - " w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)\n", - " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/multiprocessing/process.py\", line 149, in join\n", - " res = self._popen.wait(timeout)\n", - " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/multiprocessing/popen_fork.py\", line 40, in wait\n", - " if not wait([self.sentinel], timeout):\n", - " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/multiprocessing/connection.py\", line 931, in wait\n", - " ready = selector.select(timeout)\n", - " File \"/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/selectors.py\", line 416, in select\n", - " fd_event_list = self._selector.poll(timeout)\n", - "KeyboardInterrupt: \n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c8e35b4f64e84ab39f20b0a36dc0116d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: | | 0/? [00:00 4\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_loader\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:561\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 559\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 560\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshould_stop \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m--> 561\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 562\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 563\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:48\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 51\u001b[0m _call_teardown_hook(trainer)\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:599\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 592\u001b[0m download_model_from_registry(ckpt_path, \u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m 593\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 594\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 595\u001b[0m ckpt_path,\n\u001b[1;32m 596\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 597\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 598\u001b[0m )\n\u001b[0;32m--> 599\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 601\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 602\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1012\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1007\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 1009\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 1011\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m-> 1012\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1014\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 1015\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 1016\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 1017\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1056\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1056\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1058\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:216\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 215\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 217\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 218\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:455\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 454\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 455\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:151\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madvance(data_fetcher)\n\u001b[0;32m--> 151\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mon_advance_end\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:370\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.on_advance_end\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 366\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_should_accumulate():\n\u001b[1;32m 367\u001b[0m \u001b[38;5;66;03m# clear gradients to not leave any unused memory during validation\u001b[39;00m\n\u001b[1;32m 368\u001b[0m call\u001b[38;5;241m.\u001b[39m_call_lightning_module_hook(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_validation_model_zero_grad\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 370\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 371\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 372\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39m_logger_connector\u001b[38;5;241m.\u001b[39m_first_loop_iter \u001b[38;5;241m=\u001b[39m first_loop_iter\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:179\u001b[0m, in \u001b[0;36m_no_grad_context.._decorator\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 177\u001b[0m context_manager \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mno_grad\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context_manager():\n\u001b[0;32m--> 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mloop_run\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:138\u001b[0m, in \u001b[0;36m_EvaluationLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 137\u001b[0m dataloader_iter \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 138\u001b[0m batch, batch_idx, dataloader_idx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m previous_dataloader_idx \u001b[38;5;241m!=\u001b[39m dataloader_idx:\n\u001b[1;32m 140\u001b[0m \u001b[38;5;66;03m# the dataloader has changed, notify the logger connector\u001b[39;00m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_store_dataloader_outputs()\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/fetchers.py:134\u001b[0m, in \u001b[0;36m_PrefetchDataFetcher.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone \u001b[38;5;241m=\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatches\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 133\u001b[0m \u001b[38;5;66;03m# this will run only when no pre-fetching was done.\u001b[39;00m\n\u001b[0;32m--> 134\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__next__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 135\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# the iterator is empty\u001b[39;00m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/loops/fetchers.py:61\u001b[0m, in \u001b[0;36m_DataFetcher.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_start_profiler()\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 61\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/utilities/combined_loader.py:341\u001b[0m, in \u001b[0;36mCombinedLoader.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__next__\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _ITERATOR_RETURN:\n\u001b[1;32m 340\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_iterator \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 341\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_iterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 342\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_iterator, _Sequential):\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/lightning/pytorch/utilities/combined_loader.py:142\u001b[0m, in \u001b[0;36m_Sequential.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 142\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43miterators\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 144\u001b[0m \u001b[38;5;66;03m# try the next iterator\u001b[39;00m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_use_next_iterator()\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py:733\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 730\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 731\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 732\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 733\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 734\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 735\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 736\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable\n\u001b[1;32m 737\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 738\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called\n\u001b[1;32m 739\u001b[0m ):\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1491\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1488\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_data(data, worker_id)\n\u001b[1;32m 1490\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_shutdown \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m-> 1491\u001b[0m idx, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1492\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_tasks_outstanding \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1493\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable:\n\u001b[1;32m 1494\u001b[0m \u001b[38;5;66;03m# Check for _IterableDatasetStopIteration\u001b[39;00m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1453\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1449\u001b[0m \u001b[38;5;66;03m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1450\u001b[0m \u001b[38;5;66;03m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1451\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1452\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m-> 1453\u001b[0m success, data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_try_get_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1454\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m success:\n\u001b[1;32m 1455\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/utils/data/dataloader.py:1284\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_try_get_data\u001b[39m(\u001b[38;5;28mself\u001b[39m, timeout\u001b[38;5;241m=\u001b[39m_utils\u001b[38;5;241m.\u001b[39mMP_STATUS_CHECK_INTERVAL):\n\u001b[1;32m 1272\u001b[0m \u001b[38;5;66;03m# Tries to fetch data from `self._data_queue` once for a given timeout.\u001b[39;00m\n\u001b[1;32m 1273\u001b[0m \u001b[38;5;66;03m# This can also be used as inner loop of fetching without timeout, with\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1281\u001b[0m \u001b[38;5;66;03m# Returns a 2-tuple:\u001b[39;00m\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;66;03m# (bool: whether successfully get data, any: data if successful else None)\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1284\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_queue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1285\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[38;5;28;01mTrue\u001b[39;00m, data)\n\u001b[1;32m 1286\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 1287\u001b[0m \u001b[38;5;66;03m# At timeout and error, we manually check whether any worker has\u001b[39;00m\n\u001b[1;32m 1288\u001b[0m \u001b[38;5;66;03m# failed. Note that this is the only mechanism for Windows to detect\u001b[39;00m\n\u001b[1;32m 1289\u001b[0m \u001b[38;5;66;03m# worker failures.\u001b[39;00m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/multiprocessing/queues.py:122\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rlock\u001b[38;5;241m.\u001b[39mrelease()\n\u001b[1;32m 121\u001b[0m \u001b[38;5;66;03m# unserialize the data after having released the lock\u001b[39;00m\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_ForkingPickler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[43mres\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/multiprocessing/reductions.py:560\u001b[0m, in \u001b[0;36mrebuild_storage_filename\u001b[0;34m(cls, manager, handle, size, dtype)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m storage\u001b[38;5;241m.\u001b[39m_shared_decref()\n\u001b[1;32m 559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 560\u001b[0m storage \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mUntypedStorage\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_new_shared_filename_cpu\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmanager\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhandle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 561\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 562\u001b[0m byte_size \u001b[38;5;241m=\u001b[39m size \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39m_utils\u001b[38;5;241m.\u001b[39m_element_size(dtype)\n", - "\u001b[0;31mRuntimeError\u001b[0m: Connection refused" - ] - } - ], - "source": [ - "train_flag = True\n", - "if train_flag:\n", - " \n", - " trainer.fit(model, train_loader, val_loader)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "checkpoint = torch.load(checkpoint_filename+\".ckpt\", map_location=lambda storage, loc: storage)\n", - "model.load_state_dict(checkpoint['state_dict'])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Results" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/liyuxiao/micromamba/envs/qoherent-env/lib/python3.10/site-packages/torch/amp/autocast_mode.py:266: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", - "\n", - "model.eval()\n", - "y_raw_preds = []\n", - "y_preds = []\n", - "y_true = []\n", - "loss = 0\n", - "\n", - "model = model.to(device)\n", - "with torch.inference_mode():\n", - " with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", - " for sample in val_loader:\n", - " data, target = sample[0].to(device), sample[1].to(device)\n", - " means = torch.mean(data, axis=-1, keepdims=True)\n", - " stds = torch.std(data, axis=-1, keepdims=True)\n", - " data = (data - means) / stds\n", - " y_true += target.tolist()\n", - " output = model(data)\n", - " y_raw_preds.append(output) # logits\n", - " for pred in output.argmax(dim=1, keepdim=True):\n", - " y_preds.append(int(pred))\n", - " loss += F.cross_entropy(output, target)/output.shape[0]\n", - " y_raw_preds = torch.vstack(y_raw_preds).cpu().to(dtype=torch.float32) # convert list of outputs to tensor" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'le' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[30], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m acc \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39msum(np\u001b[38;5;241m.\u001b[39masarray(y_preds)\u001b[38;5;241m==\u001b[39mnp\u001b[38;5;241m.\u001b[39masarray(y_true))\u001b[38;5;241m/\u001b[39m\u001b[38;5;28mlen\u001b[39m(y_true)\n\u001b[1;32m 2\u001b[0m plot_confusion_matrix(\n\u001b[1;32m 3\u001b[0m y_true, \n\u001b[1;32m 4\u001b[0m y_preds, \n\u001b[0;32m----> 5\u001b[0m classes\u001b[38;5;241m=\u001b[39m\u001b[43mle\u001b[49m\u001b[38;5;241m.\u001b[39mclasses_,\n\u001b[1;32m 6\u001b[0m normalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 7\u001b[0m title\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdataset_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Validation Set\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mTotal Accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00macc\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m100\u001b[39m\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 8\u001b[0m text\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 9\u001b[0m rotate_x_text\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m90\u001b[39m,\n\u001b[1;32m 10\u001b[0m figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m12\u001b[39m,\u001b[38;5;241m6.75\u001b[39m),\n\u001b[1;32m 11\u001b[0m )\n", - "\u001b[0;31mNameError\u001b[0m: name 'le' is not defined" - ] - } - ], - "source": [ - "acc = np.sum(np.asarray(y_preds)==np.asarray(y_true))/len(y_true)\n", - "plot_confusion_matrix(\n", - " y_true, \n", - " y_preds, \n", - " classes=le.classes_,\n", - " normalize=True,\n", - " title=f'{dataset_name} Validation Set\\nTotal Accuracy: {acc*100:.2f}%',\n", - " text=True,\n", - " rotate_x_text=90,\n", - " figsize=(12,6.75),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(classification_report(y_true, y_preds, target_names=le.classes_))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "y_onehot_test = F.one_hot(torch.tensor(y_true), num_classes)\n", - "\n", - "fig, ax = plt.subplots(figsize=(6, 6))\n", - "colors = ['#e6194B', '#f58231', '#ffe119', '#3cb44b', '#42d4f4', '#4363d8', '#f032e6', \n", - " '#a9a9a9', '#800000', '#9A6324']\n", - "for class_id, color in zip(range(num_classes), colors):\n", - " RocCurveDisplay.from_predictions(\n", - " y_onehot_test[:, class_id],\n", - " y_raw_preds[:, class_id],\n", - " name=f\"{le.classes_[class_id]}\",\n", - " color=color,\n", - " ax=ax,\n", - " plot_chance_level=(class_id == 7),\n", - " )\n", - "\n", - "_ = ax.set(\n", - " xlabel=\"False Positive Rate\",\n", - " ylabel=\"True Positive Rate\",\n", - " title=f'{dataset_name} ROC Curves',\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ds_test = C2_H5_Dataset(test_data, label)\n", - "\n", - "test_loader = torch.utils.data.DataLoader(\n", - " dataset=ds_test,\n", - " batch_size=2048,\n", - " shuffle=False,\n", - " num_workers=8,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(f'Test Dataset, {len(ds_test)} examples of length {ds_test[0][0].shape[1]}')\n", - "print(f'Annotations: {list(ds_test.annotations.columns)}')\n", - "print(f'Labels:')\n", - "for a in ds_test.annotations.columns:\n", - " print(f'\\t{a}: {ds_test.annotations[a].unique()}')\n", - "\n", - "print(f'\\nLabel being used: {label}')\n", - "\n", - "print(f'\\n{ds_test.annotations[label].value_counts()}')\n", - "\n", - "ds_test.encode_labels(le)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", - "\n", - "model.eval()\n", - "y_raw_preds = []\n", - "y_preds = []\n", - "y_true = []\n", - "loss = 0\n", - "\n", - "model = model.to(device)\n", - "with torch.inference_mode():\n", - " with torch.autocast(device_type='cuda', dtype=torch.bfloat16):\n", - " for sample in test_loader:\n", - " data, target = sample[0].to(device), sample[1].to(device)\n", - "\n", - " y_true += target.tolist()\n", - " output = model(data)\n", - " y_raw_preds.append(output) # logits\n", - " for pred in output.argmax(dim=1, keepdim=True):\n", - " y_preds.append(int(pred))\n", - " # print(F.cross_entropy(output, target))\n", - " loss += F.cross_entropy(output, target)/output.shape[0]\n", - " y_raw_preds = torch.vstack(y_raw_preds).cpu().to(dtype=torch.float32) # convert list of outputs to tensor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "acc = np.sum(np.asarray(y_preds)==np.asarray(y_true))/len(y_true)\n", - "plot_confusion_matrix(\n", - " y_true, \n", - " y_preds, \n", - " classes=le.classes_,\n", - " normalize=True,\n", - " title=f'{dataset_name} Inference Test\\nTotal Accuracy: {acc*100:.2f}%',\n", - " text=True,\n", - " rotate_x_text=90,\n", - " figsize=(12,6.75),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(classification_report(y_true, y_preds, target_names=le.classes_))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "y_onehot_test = F.one_hot(torch.tensor(y_true), num_classes)\n", - "\n", - "fig, ax = plt.subplots(figsize=(6, 6))\n", - "colors = ['#332288', '#88CCEE', '#44AA99', '#117733', '#999933', '#DDCC77', '#CC6677', '#882255']\n", - "for class_id, color in zip(range(num_classes), colors):\n", - " RocCurveDisplay.from_predictions(\n", - " y_onehot_test[:, class_id],\n", - " y_raw_preds[:, class_id],\n", - " name=f\"{le.classes_[class_id]}\",\n", - " color=color,\n", - " ax=ax,\n", - " plot_chance_level=(class_id == 7),\n", - " )\n", - "\n", - "_ = ax.set(\n", - " xlabel=\"False Positive Rate\",\n", - " ylabel=\"True Positive Rate\",\n", - " title=f'{dataset_name} ROC Curves',\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "qoherent-env", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.16" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/data/models/cm_plotter.py b/data/training/cm_plotter.py similarity index 73% rename from data/models/cm_plotter.py rename to data/training/cm_plotter.py index b2b4437..429293c 100644 --- a/data/models/cm_plotter.py +++ b/data/training/cm_plotter.py @@ -3,6 +3,7 @@ from typing import Optional from matplotlib import pyplot as plt from sklearn.metrics import confusion_matrix + def plot_confusion_matrix( y_true: np.array, y_pred: np.array, @@ -11,11 +12,11 @@ def plot_confusion_matrix( title: Optional[str] = None, text: bool = True, rotate_x_text: int = 90, - figsize: tuple = (16,9), + figsize: tuple = (16, 9), cmap: plt.cm = plt.cm.Blues, ): """Function to help plot confusion matrices - + https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html """ if not title: @@ -32,13 +33,13 @@ def plot_confusion_matrix( fig, ax = plt.subplots() im = ax.imshow(cm, interpolation="none", cmap=cmap) ax.figure.colorbar(im, ax=ax) - ax.set( - xticks=np.arange(cm.shape[1]), - yticks=np.arange(cm.shape[0]), - xticklabels=classes, - yticklabels=classes, - title=title, - ylabel="True label", + ax.set( + xticks=np.arange(cm.shape[1]), + yticks=np.arange(cm.shape[0]), + xticklabels=classes, + yticklabels=classes, + title=title, + ylabel="True label", xlabel="Predicted label", ) ax.set_xticklabels(classes, rotation=rotate_x_text) @@ -50,9 +51,16 @@ def plot_confusion_matrix( for i in range(cm.shape[0]): for j in range(cm.shape[1]): if text: - ax.text(j, i, format(cm[i,j], fmt), ha="center", va="center", color="white" if cm[i,j] > thresh else "black") + ax.text( + j, + i, + format(cm[i, j], fmt), + ha="center", + va="center", + color="white" if cm[i, j] > thresh else "black", + ) if len(classes) == 2: plt.axis([-0.5, 1.5, 1.5, -0.5]) fig.tight_layout() - return ax \ No newline at end of file + return ax diff --git a/data/models/mobilenetv3.py b/data/training/mobilenetv3.py similarity index 74% rename from data/models/mobilenetv3.py rename to data/training/mobilenetv3.py index 549bc51..754296a 100644 --- a/data/models/mobilenetv3.py +++ b/data/training/mobilenetv3.py @@ -2,21 +2,23 @@ import numpy as np import torch import timm from torch import nn +import lightning as L sizes = [ - 'mobilenetv3_large_075', - 'mobilenetv3_large_100', - 'mobilenetv3_rw', - 'mobilenetv3_small_050', - 'mobilenetv3_small_075', - 'mobilenetv3_small_100', - 'tf_mobilenetv3_large_075', - 'tf_mobilenetv3_large_100', - 'tf_mobilenetv3_large_minimal_100', - 'tf_mobilenetv3_small_075', - 'tf_mobilenetv3_small_100', - 'tf_mobilenetv3_small_minimal_100' - ] + "mobilenetv3_large_075", + "mobilenetv3_large_100", + "mobilenetv3_rw", + "mobilenetv3_small_050", + "mobilenetv3_small_075", + "mobilenetv3_small_100", + "tf_mobilenetv3_large_075", + "tf_mobilenetv3_large_100", + "tf_mobilenetv3_large_minimal_100", + "tf_mobilenetv3_small_075", + "tf_mobilenetv3_small_100", + "tf_mobilenetv3_small_minimal_100", +] + class SqueezeExcite(nn.Module): def __init__( @@ -54,10 +56,11 @@ class FastGlobalAvgPool1d(nn.Module): in_size = x.size() return x.view((in_size[0], in_size[1], -1)).mean(dim=2) else: - return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1) + return ( + x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1) + ) - class GBN(torch.nn.Module): """ Ghost Batch Normalization @@ -87,7 +90,7 @@ class GBN(torch.nn.Module): def replace_bn(parent): for n, m in parent.named_children(): if type(m) is timm.layers.norm_act.BatchNormAct2d: - # if type(m) is nn.BatchNorm2d: + # if type(m) is nn.BatchNorm2d: # print(type(m)) setattr( parent, @@ -97,6 +100,7 @@ def replace_bn(parent): else: replace_bn(m) + def replace_se(parent): for n, m in parent.named_children(): if type(m) is timm.models._efficientnet_blocks.SqueezeExcite: @@ -111,6 +115,7 @@ def replace_se(parent): else: replace_se(m) + def replace_conv(parent, ds_rate): for n, m in parent.named_children(): if type(m) is nn.Conv2d: @@ -145,6 +150,7 @@ def replace_conv(parent, ds_rate): else: replace_conv(m, ds_rate) + def create_mobilenetv3(network, ds_rate=2, in_chans=2): replace_se(network) replace_bn(network) @@ -152,19 +158,20 @@ def create_mobilenetv3(network, ds_rate=2, in_chans=2): network.global_pool = FastGlobalAvgPool1d() network.conv_stem = nn.Conv1d( - in_channels=in_chans, - out_channels=network.conv_stem.out_channels, - kernel_size=network.conv_stem.kernel_size, - stride=network.conv_stem.stride, - padding=network.conv_stem.padding, - bias=network.conv_stem.kernel_size, - groups=network.conv_stem.groups, - ) + in_channels=in_chans, + out_channels=network.conv_stem.out_channels, + kernel_size=network.conv_stem.kernel_size, + stride=network.conv_stem.stride, + padding=network.conv_stem.padding, + bias=network.conv_stem.kernel_size, + groups=network.conv_stem.groups, + ) return network + def mobilenetv3( - model_size = 'mobilenetv3_small_050', + model_size="mobilenetv3_small_050", num_classes: int = 10, drop_rate: float = 0, drop_path_rate: float = 0, @@ -183,24 +190,11 @@ def mobilenetv3( ) return mdl -import torch.nn as nn -class Simple1DCNN(nn.Module): - def __init__(self, in_chans=2, num_classes=4): +class RFClassifier(L.LightningModule): + def __init__(self, model): super().__init__() - self.net = nn.Sequential( - nn.Conv1d(in_chans, 32, kernel_size=3, padding=1), - nn.ReLU(), - nn.MaxPool1d(2), - nn.Conv1d(32, 64, kernel_size=3, padding=1), - nn.ReLU(), - nn.AdaptiveAvgPool1d(1), - nn.Flatten(), - nn.Linear(64, num_classes) - ) + self.model = model def forward(self, x): - return self.net(x) # x shape: [B, 2, 128] - -def simple_cnn(in_chans=2, num_classes=4): - return Simple1DCNN(in_chans, num_classes) + return self.model(x) \ No newline at end of file diff --git a/data/models/modulation_dataset.py b/data/training/modulation_dataset.py similarity index 81% rename from data/models/modulation_dataset.py rename to data/training/modulation_dataset.py index 9c60764..5f65dc5 100644 --- a/data/models/modulation_dataset.py +++ b/data/training/modulation_dataset.py @@ -1,4 +1,5 @@ import sys, os + sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed import numpy as np import torch @@ -11,39 +12,43 @@ dataset = settings.dataset.modulation_types class ModulationH5Dataset(Dataset): - def __init__(self, hdf5_path, label_name, data_key="training_data", label_encoder=None, transform=None): - self.hdf5_path = hdf5_path - self.data_key = data_key + def __init__( + self, + hdf5_path, + label_name, + data_key="training_data", + label_encoder=None, + transform=None, + ): + self.hdf5_path = hdf5_path + self.data_key = data_key self.label_name = label_name self.label_encoder = label_encoder self.transform = transform - - with h5py.File(hdf5_path, 'r') as f: + with h5py.File(hdf5_path, "r") as f: self.length = f[data_key].shape[0] self.metadata = f["metadata"]["metadata"][:] - - + settings = get_app_settings() dataset_cfg = settings.dataset all_labels = dataset_cfg.modulation_types - if self.label_encoder is None: from sklearn.preprocessing import LabelEncoder + self.label_encoder = LabelEncoder() self.label_encoder.fit(all_labels) - + # Get per-sample labels from metadata raw_labels = [row["modulation"].decode("utf-8") for row in self.metadata] self.encoded_labels = self.label_encoder.transform(raw_labels) - def __len__(self): return self.length - + def __getitem__(self, idx): - with h5py.File(self.hdf5_path, 'r') as f: + with h5py.File(self.hdf5_path, "r") as f: x = f[self.data_key][idx] # shape (1, 128) or similar # Normalize @@ -54,4 +59,3 @@ class ModulationH5Dataset(Dataset): label = torch.tensor(self.encoded_labels[idx], dtype=torch.long) return x, label - diff --git a/data/training/train.py b/data/training/train.py new file mode 100644 index 0000000..3b63a81 --- /dev/null +++ b/data/training/train.py @@ -0,0 +1,153 @@ +import sys, os + +script_dir = os.path.dirname(os.path.abspath(__file__)) +data_dir = os.path.abspath(os.path.join(script_dir, "..")) +project_root = os.path.abspath(os.path.join(script_dir, "../..")) + +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from helpers.app_settings import get_app_settings + +project_root = os.path.abspath(os.path.join(os.getcwd(), "..")) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +import lightning as L +import torch +import torch.nn.functional as F +import torchmetrics + +from helpers.app_settings import get_app_settings +from modulation_dataset import ModulationH5Dataset + +import mobilenetv3 + + + +def train_model(): + settings = get_app_settings() + dataset = settings.dataset.modulation_types + + train_flag = True + batch_size = 128 + epochs = 1 + + checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model' + + train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5' + val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5' + + dataset_name = 'Modulation Inference - Initial Model' + metadata_names = 'Modulation' + label = 'modulation' + + torch.set_float32_matmul_precision('high') + + ds_train = ModulationH5Dataset(train_data, label, data_key="training_data") + ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data") + + train_loader = torch.utils.data.DataLoader( + dataset=ds_train, + batch_size=batch_size, + shuffle=True, + num_workers=8, + ) + val_loader = torch.utils.data.DataLoader( + dataset=ds_val, + batch_size=2048, + shuffle=False, + num_workers=8, + ) + + for x, y in train_loader: + print("X shape:", x.shape) + print("Y values:", y[:10]) + break + + unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata])) + num_classes = len(ds_train.label_encoder.classes_) + + hparams = { + 'drop_path_rate': 0.2, + 'drop_rate': 0.5, + 'learning_rate': 3e-4, + 'wd': 0.2 + } + + class RFClassifier(L.LightningModule): + def __init__(self, model): + super().__init__() + self.model = model + self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes) + + def forward(self, x): + return self.model(x) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=hparams['learning_rate'], + weight_decay=hparams['wd'], + ) + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=len(train_loader), + ) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': lr_scheduler, + 'interval': 'step' + } + } + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('train_loss', loss, on_epoch=True, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.accuracy(y_hat, y) + self.log('val_loss', loss, prog_bar=True) + self.log('val_acc', self.accuracy, prog_bar=True) + + model = RFClassifier( + mobilenetv3.mobilenetv3( + model_size='mobilenetv3_small_050', + num_classes=num_classes, + drop_rate=hparams['drop_rate'], + drop_path_rate=hparams['drop_path_rate'] + ) + ) + + checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint( + filename=checkpoint_filename, + save_top_k=True, + verbose=True, + monitor='val_acc', + mode='max', + enable_version_counter=False, + ) + + trainer = L.Trainer( + max_epochs=epochs, + callbacks=[checkpoint_callback], + accelerator='gpu', + devices=1, + benchmark=True, + precision='bf16-mixed', + logger=False + ) + + if train_flag: + trainer.fit(model, train_loader, val_loader) + + +if __name__ == '__main__': + train_model() \ No newline at end of file