clean up code, adjusted default parameters
This commit is contained in:
parent
a298384f7e
commit
04e5f4db8a
|
@ -5,9 +5,6 @@ dataset:
|
||||||
# Number of samples per recording
|
# Number of samples per recording
|
||||||
recording_length: 1024
|
recording_length: 1024
|
||||||
|
|
||||||
# Set this to scale the number of generated recordings
|
|
||||||
mult_factor: 5
|
|
||||||
|
|
||||||
# List of signal modulation schemes to include in the dataset
|
# List of signal modulation schemes to include in the dataset
|
||||||
modulation_types:
|
modulation_types:
|
||||||
- bpsk
|
- bpsk
|
||||||
|
@ -27,7 +24,7 @@ dataset:
|
||||||
snr_step: 3
|
snr_step: 3
|
||||||
|
|
||||||
# Number of iterations (signal recordings) per modulation and SNR combination
|
# Number of iterations (signal recordings) per modulation and SNR combination
|
||||||
num_iterations: 3
|
num_iterations: 100
|
||||||
|
|
||||||
# Modulation scheme settings; keys must match the `modulation_types` list above
|
# Modulation scheme settings; keys must match the `modulation_types` list above
|
||||||
# Each entry includes:
|
# Each entry includes:
|
||||||
|
|
|
@ -9,7 +9,6 @@ import yaml
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataSetConfig:
|
class DataSetConfig:
|
||||||
num_slices: int
|
num_slices: int
|
||||||
mult_factor: int
|
|
||||||
train_split: float
|
train_split: float
|
||||||
seed: int
|
seed: int
|
||||||
modulation_types: list
|
modulation_types: list
|
||||||
|
@ -42,7 +41,11 @@ class AppConfig:
|
||||||
|
|
||||||
|
|
||||||
class AppSettings:
|
class AppSettings:
|
||||||
"""Application settings, to be initialized from app.yaml configuration file."""
|
"""
|
||||||
|
Application settings,
|
||||||
|
to be initialized from
|
||||||
|
app.yaml configuration file.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config_file: str):
|
def __init__(self, config_file: str):
|
||||||
# Load the YAML configuration file
|
# Load the YAML configuration file
|
||||||
|
|
|
@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
|
||||||
|
|
||||||
for modulation in settings.modulation_types:
|
for modulation in settings.modulation_types:
|
||||||
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
||||||
for _ in range(settings.mult_factor):
|
for _ in range(settings.num_iterations):
|
||||||
recording_length = settings.recording_length
|
recording_length = settings.recording_length
|
||||||
beta = (
|
beta = (
|
||||||
settings.beta
|
settings.beta
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import numpy as np
|
|
||||||
import timm
|
import timm
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
|
@ -2,15 +2,15 @@ import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from sklearn.metrics import classification_report
|
|
||||||
|
|
||||||
os.environ["NNPACK"] = "0"
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from mobilenetv3 import RFClassifier, mobilenetv3
|
from mobilenetv3 import RFClassifier, mobilenetv3
|
||||||
from modulation_dataset import ModulationH5Dataset
|
from modulation_dataset import ModulationH5Dataset
|
||||||
|
from sklearn.metrics import classification_report
|
||||||
|
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
|
os.environ["NNPACK"] = "0"
|
||||||
|
|
||||||
|
|
||||||
def load_validation_data():
|
def load_validation_data():
|
||||||
val_dataset = ModulationH5Dataset(
|
val_dataset = ModulationH5Dataset(
|
||||||
|
|
|
@ -1,23 +1,22 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
os.environ["NNPACK"] = "0"
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import mobilenetv3
|
import mobilenetv3
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
|
||||||
from modulation_dataset import ModulationH5Dataset
|
from modulation_dataset import ModulationH5Dataset
|
||||||
|
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
|
os.environ["NNPACK"] = "0"
|
||||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||||
if project_root not in sys.path:
|
if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
|
||||||
|
|
||||||
|
|
||||||
class CustomProgressBar(TQDMProgressBar):
|
class CustomProgressBar(TQDMProgressBar):
|
||||||
|
@ -59,8 +58,6 @@ def train_model():
|
||||||
print("X shape:", x.shape)
|
print("X shape:", x.shape)
|
||||||
print("Y values:", y[:10])
|
print("Y values:", y[:10])
|
||||||
break
|
break
|
||||||
|
|
||||||
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
|
|
||||||
num_classes = len(ds_train.label_encoder.classes_)
|
num_classes = len(ds_train.label_encoder.classes_)
|
||||||
|
|
||||||
hparams = {
|
hparams = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user