liyu-dev #3
|
@ -5,9 +5,6 @@ dataset:
|
|||
# Number of samples per recording
|
||||
recording_length: 1024
|
||||
|
||||
# Set this to scale the number of generated recordings
|
||||
mult_factor: 5
|
||||
|
||||
# List of signal modulation schemes to include in the dataset
|
||||
modulation_types:
|
||||
- bpsk
|
||||
|
@ -27,7 +24,7 @@ dataset:
|
|||
snr_step: 3
|
||||
|
||||
# Number of iterations (signal recordings) per modulation and SNR combination
|
||||
num_iterations: 3
|
||||
num_iterations: 100
|
||||
|
||||
# Modulation scheme settings; keys must match the `modulation_types` list above
|
||||
# Each entry includes:
|
||||
|
|
|
@ -9,7 +9,6 @@ import yaml
|
|||
@dataclass
|
||||
class DataSetConfig:
|
||||
num_slices: int
|
||||
mult_factor: int
|
||||
train_split: float
|
||||
seed: int
|
||||
modulation_types: list
|
||||
|
@ -42,7 +41,11 @@ class AppConfig:
|
|||
|
||||
|
||||
class AppSettings:
|
||||
"""Application settings, to be initialized from app.yaml configuration file."""
|
||||
"""
|
||||
Application settings,
|
||||
to be initialized from
|
||||
app.yaml configuration file.
|
||||
"""
|
||||
|
||||
def __init__(self, config_file: str):
|
||||
# Load the YAML configuration file
|
||||
|
|
|
@ -29,7 +29,7 @@ def generate_modulated_signals(output_dir: str) -> None:
|
|||
|
||||
for modulation in settings.modulation_types:
|
||||
for snr in np.arange(settings.snr_start, settings.snr_stop, settings.snr_step):
|
||||
for _ in range(settings.mult_factor):
|
||||
for _ in range(settings.num_iterations):
|
||||
recording_length = settings.recording_length
|
||||
beta = (
|
||||
settings.beta
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import lightning as L
|
||||
import numpy as np
|
||||
import timm
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
|
@ -2,15 +2,15 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
from matplotlib import pyplot as plt
|
||||
from mobilenetv3 import RFClassifier, mobilenetv3
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
|
||||
|
||||
def load_validation_data():
|
||||
val_dataset = ModulationH5Dataset(
|
||||
|
|
|
@ -1,23 +1,22 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
import lightning as L
|
||||
import mobilenetv3
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
from lightning.pytorch.callbacks import TQDMProgressBar
|
||||
|
||||
|
||||
class CustomProgressBar(TQDMProgressBar):
|
||||
|
@ -59,8 +58,6 @@ def train_model():
|
|||
print("X shape:", x.shape)
|
||||
print("Y values:", y[:10])
|
||||
break
|
||||
|
||||
unique_labels = list(set([row[label].decode("utf-8") for row in ds_train.metadata]))
|
||||
num_classes = len(ds_train.label_encoder.classes_)
|
||||
|
||||
hparams = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user