forked from qoherent/modrec-workflow
formatted all files
This commit is contained in:
parent
3557f854e8
commit
ff3d45653d
|
@ -3,4 +3,4 @@ import os
|
||||||
CHECKPOINTS_DIR = os.path.dirname(os.path.abspath(__file__))
|
CHECKPOINTS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(CHECKPOINTS_DIR)
|
print(CHECKPOINTS_DIR)
|
||||||
|
|
|
@ -8,7 +8,6 @@ from onnx_files import ONNX_DIR
|
||||||
from helpers.app_settings import get_app_settings
|
from helpers.app_settings import get_app_settings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def convert_to_onnx(ckpt_path, fp16=False):
|
def convert_to_onnx(ckpt_path, fp16=False):
|
||||||
"""
|
"""
|
||||||
Convert a PyTorch model to ONNX format.
|
Convert a PyTorch model to ONNX format.
|
||||||
|
@ -19,16 +18,15 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
||||||
output_path (str): The path to save the converted ONNX model.
|
output_path (str): The path to save the converted ONNX model.
|
||||||
"""
|
"""
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
|
|
||||||
inference_cfg = settings.inference
|
inference_cfg = settings.inference
|
||||||
dataset_cfg = settings.dataset
|
dataset_cfg = settings.dataset
|
||||||
|
|
||||||
|
|
||||||
in_channels = 2
|
in_channels = 2
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
slice_length = int(1024/dataset_cfg.num_slices)
|
slice_length = int(1024 / dataset_cfg.num_slices)
|
||||||
num_classes = inference_cfg.num_classes
|
num_classes = inference_cfg.num_classes
|
||||||
|
|
||||||
model = RFClassifier(
|
model = RFClassifier(
|
||||||
model=mobilenetv3(
|
model=mobilenetv3(
|
||||||
model_size="mobilenetv3_small_050",
|
model_size="mobilenetv3_small_050",
|
||||||
|
@ -36,18 +34,17 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
||||||
in_chans=in_channels,
|
in_chans=in_channels,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
ckpt_path, weights_only = True, map_location=torch.device("cpu")
|
ckpt_path, weights_only=True, map_location=torch.device("cpu")
|
||||||
)
|
)
|
||||||
model.load_state_dict(checkpoint["state_dict"])
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
|
|
||||||
if fp16:
|
if fp16:
|
||||||
model.half()
|
model.half()
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|
||||||
# Generate random sample data
|
# Generate random sample data
|
||||||
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
base, ext = os.path.splitext(os.path.basename(ckpt_path))
|
||||||
if fp16:
|
if fp16:
|
||||||
|
@ -58,8 +55,7 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
||||||
else:
|
else:
|
||||||
output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx")
|
output_path = os.path.join(ONNX_DIR, f"{base}_fp32.onnx")
|
||||||
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
sample_input = torch.rand(batch_size, in_channels, slice_length)
|
||||||
|
|
||||||
|
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
model=model,
|
model=model,
|
||||||
args=sample_input,
|
args=sample_input,
|
||||||
|
@ -71,14 +67,17 @@ def convert_to_onnx(ckpt_path, fp16=False):
|
||||||
output_names=["output"],
|
output_names=["output"],
|
||||||
dynamo=False, # Requires onnxscript
|
dynamo=False, # Requires onnxscript
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from checkpoint_files import CHECKPOINTS_DIR
|
from checkpoint_files import CHECKPOINTS_DIR
|
||||||
|
|
||||||
model_checkpoint = "interference_recognition_model.ckpt"
|
model_checkpoint = "interference_recognition_model.ckpt"
|
||||||
|
|
||||||
print("Converting to ONNX...")
|
print("Converting to ONNX...")
|
||||||
|
|
||||||
convert_to_onnx(ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False)
|
convert_to_onnx(
|
||||||
|
ckpt_path=os.path.join(CHECKPOINTS_DIR, model_checkpoint), fp16=False
|
||||||
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint))
|
)
|
||||||
|
|
||||||
|
print("Conversion complete stored at: ", os.path.join(ONNX_DIR, model_checkpoint))
|
||||||
|
|
|
@ -23,7 +23,6 @@ info_dtype = np.dtype(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def write_hdf5_file(records, output_path, dataset_name="data"):
|
def write_hdf5_file(records, output_path, dataset_name="data"):
|
||||||
"""
|
"""
|
||||||
Writes a list of records to an HDF5 file.
|
Writes a list of records to an HDF5 file.
|
||||||
|
@ -77,11 +76,12 @@ def write_hdf5_file(records, output_path, dataset_name="data"):
|
||||||
|
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
def complex_to_channel(data):
|
def complex_to_channel(data):
|
||||||
"""
|
"""
|
||||||
Convert complex-valued IQ data of shape (1, N) to 2-channel real array of shape (2, N).
|
Convert complex-valued IQ data of shape (1, N) to 2-channel real array of shape (2, N).
|
||||||
"""
|
"""
|
||||||
assert np.iscomplexobj(data) #check if the data is in the form a+bi
|
assert np.iscomplexobj(data) # check if the data is in the form a+bi
|
||||||
real = np.real(data[0]) # (N,)
|
real = np.real(data[0]) # (N,)
|
||||||
imag = np.imag(data[0]) # (N,)
|
imag = np.imag(data[0]) # (N,)
|
||||||
stacked = np.stack([real, imag], axis=0) # shape (2, N)
|
stacked = np.stack([real, imag], axis=0) # shape (2, N)
|
||||||
|
@ -114,21 +114,17 @@ def generate_datasets(cfg):
|
||||||
for fname in files:
|
for fname in files:
|
||||||
rec = from_npy(os.path.join(cfg.input_dir, fname))
|
rec = from_npy(os.path.join(cfg.input_dir, fname))
|
||||||
|
|
||||||
data = rec.data #here data is a numpy array with the shape (1, N)
|
data = rec.data # here data is a numpy array with the shape (1, N)
|
||||||
|
|
||||||
data = complex_to_channel(data) # convert to 2-channel real array
|
data = complex_to_channel(data) # convert to 2-channel real array
|
||||||
|
|
||||||
|
|
||||||
md = rec.metadata # pull metadata from the recording
|
md = rec.metadata # pull metadata from the recording
|
||||||
md.setdefault("recid", len(records))
|
md.setdefault("recid", len(records))
|
||||||
records.append((data, md))
|
records.append((data, md))
|
||||||
|
|
||||||
# split each recording into <num_slices> snippets each
|
# split each recording into <num_slices> snippets each
|
||||||
|
|
||||||
|
|
||||||
records = split_recording(records, cfg.num_slices)
|
records = split_recording(records, cfg.num_slices)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
train_records, val_records = split(records, cfg.train_split, cfg.seed)
|
train_records, val_records = split(records, cfg.train_split, cfg.seed)
|
||||||
|
|
||||||
|
@ -147,6 +143,6 @@ def main():
|
||||||
train_path, val_path = generate_datasets(dataset_cfg)
|
train_path, val_path = generate_datasets(dataset_cfg)
|
||||||
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
|
print(f"✅ Train: {train_path}\n✅ Val: {val_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -2,7 +2,7 @@ import random
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
|
def split(dataset, train_frac=0.8, seed=42, label_key="modulation"):
|
||||||
"""
|
"""
|
||||||
Splits a dataset into smaller datasets based on the specified lengths.
|
Splits a dataset into smaller datasets based on the specified lengths.
|
||||||
|
|
||||||
|
@ -16,31 +16,29 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
|
||||||
rec_buckets = defaultdict(list)
|
rec_buckets = defaultdict(list)
|
||||||
for data, md in dataset:
|
for data, md in dataset:
|
||||||
rec_buckets[md["recid"]].append((data, md))
|
rec_buckets[md["recid"]].append((data, md))
|
||||||
|
|
||||||
|
rec_labels = {} # store labels for each recording
|
||||||
rec_labels = {} #store labels for each recording
|
|
||||||
for rec_id, group in rec_buckets.items():
|
for rec_id, group in rec_buckets.items():
|
||||||
label = group[0][1][label_key]
|
label = group[0][1][label_key]
|
||||||
if isinstance(label, bytes): #if the label is a byte string
|
if isinstance(label, bytes): # if the label is a byte string
|
||||||
label = label.decode("utf-8")
|
label = label.decode("utf-8")
|
||||||
rec_labels[rec_id] = label
|
rec_labels[rec_id] = label
|
||||||
|
|
||||||
label_rec_ids = defaultdict(list) #group rec_ids by label
|
label_rec_ids = defaultdict(list) # group rec_ids by label
|
||||||
for rec_id, label in rec_labels.items():
|
for rec_id, label in rec_labels.items():
|
||||||
label_rec_ids[label].append(rec_id)
|
label_rec_ids[label].append(rec_id)
|
||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
train_recs, val_recs = set(), set()
|
train_recs, val_recs = set(), set()
|
||||||
|
|
||||||
for label, rec_ids in label_rec_ids.items():
|
for label, rec_ids in label_rec_ids.items():
|
||||||
random.shuffle(rec_ids)
|
random.shuffle(rec_ids)
|
||||||
split_idx = int(len(rec_ids) * train_frac)
|
split_idx = int(len(rec_ids) * train_frac)
|
||||||
train_recs.update(rec_ids[:split_idx]) #pulls train_frac or rec_ids per label, guarantees all modulations are represented
|
train_recs.update(
|
||||||
|
rec_ids[:split_idx]
|
||||||
|
) # pulls train_frac or rec_ids per label, guarantees all modulations are represented
|
||||||
val_recs.update(rec_ids[split_idx:])
|
val_recs.update(rec_ids[split_idx:])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# add the assigned recordings to the train and val datasets
|
# add the assigned recordings to the train and val datasets
|
||||||
train_dataset, val_dataset = [], []
|
train_dataset, val_dataset = [], []
|
||||||
for rec_id, group in rec_buckets.items():
|
for rec_id, group in rec_buckets.items():
|
||||||
|
@ -48,10 +46,10 @@ def split(dataset, train_frac=0.8, seed=42, label_key = "modulation"):
|
||||||
train_dataset.extend(group)
|
train_dataset.extend(group)
|
||||||
elif rec_id in val_recs:
|
elif rec_id in val_recs:
|
||||||
val_dataset.extend(group)
|
val_dataset.extend(group)
|
||||||
|
|
||||||
|
|
||||||
return train_dataset, val_dataset
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
|
||||||
def split_recording(recording_list, num_snippets):
|
def split_recording(recording_list, num_snippets):
|
||||||
"""
|
"""
|
||||||
Splits a list of recordings into smaller chunks.
|
Splits a list of recordings into smaller chunks.
|
||||||
|
@ -71,7 +69,7 @@ def split_recording(recording_list, num_snippets):
|
||||||
start = i * L
|
start = i * L
|
||||||
end = (i + 1) * L
|
end = (i + 1) * L
|
||||||
snippet = data[:, start:end]
|
snippet = data[:, start:end]
|
||||||
|
|
||||||
# copy the metadata, adding a snippet index
|
# copy the metadata, adding a snippet index
|
||||||
snippet_md = md.copy()
|
snippet_md = md.copy()
|
||||||
snippet_md["snippet_idx"] = i
|
snippet_md["snippet_idx"] = i
|
||||||
|
|
|
@ -197,4 +197,4 @@ class RFClassifier(L.LightningModule):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
|
|
|
@ -12,7 +12,7 @@ from helpers.app_settings import get_app_settings
|
||||||
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
|
||||||
if project_root not in sys.path:
|
if project_root not in sys.path:
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -24,7 +24,6 @@ from modulation_dataset import ModulationH5Dataset
|
||||||
import mobilenetv3
|
import mobilenetv3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
settings = get_app_settings()
|
settings = get_app_settings()
|
||||||
dataset = settings.dataset.modulation_types
|
dataset = settings.dataset.modulation_types
|
||||||
|
@ -33,16 +32,20 @@ def train_model():
|
||||||
batch_size = 128
|
batch_size = 128
|
||||||
epochs = 1
|
epochs = 1
|
||||||
|
|
||||||
checkpoint_filename = f'/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model'
|
checkpoint_filename = f"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/results/interference_recognition_model"
|
||||||
|
|
||||||
train_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5'
|
train_data = (
|
||||||
val_data = '/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5'
|
"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/train.h5"
|
||||||
|
)
|
||||||
|
val_data = (
|
||||||
|
"/Users/liyuxiao/Documents/CS/qoherent/modrec-workflow/data/dataset/val.h5"
|
||||||
|
)
|
||||||
|
|
||||||
dataset_name = 'Modulation Inference - Initial Model'
|
dataset_name = "Modulation Inference - Initial Model"
|
||||||
metadata_names = 'Modulation'
|
metadata_names = "Modulation"
|
||||||
label = 'modulation'
|
label = "modulation"
|
||||||
|
|
||||||
torch.set_float32_matmul_precision('high')
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
|
ds_train = ModulationH5Dataset(train_data, label, data_key="training_data")
|
||||||
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
|
ds_val = ModulationH5Dataset(val_data, label, data_key="validation_data")
|
||||||
|
@ -69,17 +72,19 @@ def train_model():
|
||||||
num_classes = len(ds_train.label_encoder.classes_)
|
num_classes = len(ds_train.label_encoder.classes_)
|
||||||
|
|
||||||
hparams = {
|
hparams = {
|
||||||
'drop_path_rate': 0.2,
|
"drop_path_rate": 0.2,
|
||||||
'drop_rate': 0.5,
|
"drop_rate": 0.5,
|
||||||
'learning_rate': 3e-4,
|
"learning_rate": 3e-4,
|
||||||
'wd': 0.2
|
"wd": 0.2,
|
||||||
}
|
}
|
||||||
|
|
||||||
class RFClassifier(L.LightningModule):
|
class RFClassifier(L.LightningModule):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=num_classes)
|
self.accuracy = torchmetrics.Accuracy(
|
||||||
|
task="multiclass", num_classes=num_classes
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
|
@ -87,26 +92,23 @@ def train_model():
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
lr=hparams['learning_rate'],
|
lr=hparams["learning_rate"],
|
||||||
weight_decay=hparams['wd'],
|
weight_decay=hparams["wd"],
|
||||||
)
|
)
|
||||||
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||||
optimizer,
|
optimizer,
|
||||||
T_0=len(train_loader),
|
T_0=len(train_loader),
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
'optimizer': optimizer,
|
"optimizer": optimizer,
|
||||||
'lr_scheduler': {
|
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
|
||||||
'scheduler': lr_scheduler,
|
|
||||||
'interval': 'step'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
loss = F.cross_entropy(y_hat, y)
|
loss = F.cross_entropy(y_hat, y)
|
||||||
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
|
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
|
@ -114,15 +116,15 @@ def train_model():
|
||||||
y_hat = self(x)
|
y_hat = self(x)
|
||||||
loss = F.cross_entropy(y_hat, y)
|
loss = F.cross_entropy(y_hat, y)
|
||||||
self.accuracy(y_hat, y)
|
self.accuracy(y_hat, y)
|
||||||
self.log('val_loss', loss, prog_bar=True)
|
self.log("val_loss", loss, prog_bar=True)
|
||||||
self.log('val_acc', self.accuracy, prog_bar=True)
|
self.log("val_acc", self.accuracy, prog_bar=True)
|
||||||
|
|
||||||
model = RFClassifier(
|
model = RFClassifier(
|
||||||
mobilenetv3.mobilenetv3(
|
mobilenetv3.mobilenetv3(
|
||||||
model_size='mobilenetv3_small_050',
|
model_size="mobilenetv3_small_050",
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
drop_rate=hparams['drop_rate'],
|
drop_rate=hparams["drop_rate"],
|
||||||
drop_path_rate=hparams['drop_path_rate']
|
drop_path_rate=hparams["drop_path_rate"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -130,24 +132,24 @@ def train_model():
|
||||||
filename=checkpoint_filename,
|
filename=checkpoint_filename,
|
||||||
save_top_k=True,
|
save_top_k=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
monitor='val_acc',
|
monitor="val_acc",
|
||||||
mode='max',
|
mode="max",
|
||||||
enable_version_counter=False,
|
enable_version_counter=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = L.Trainer(
|
trainer = L.Trainer(
|
||||||
max_epochs=epochs,
|
max_epochs=epochs,
|
||||||
callbacks=[checkpoint_callback],
|
callbacks=[checkpoint_callback],
|
||||||
accelerator='gpu',
|
accelerator="gpu",
|
||||||
devices=1,
|
devices=1,
|
||||||
benchmark=True,
|
benchmark=True,
|
||||||
precision='bf16-mixed',
|
precision="bf16-mixed",
|
||||||
logger=False
|
logger=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if train_flag:
|
if train_flag:
|
||||||
trainer.fit(model, train_loader, val_loader)
|
trainer.fit(model, train_loader, val_loader)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
train_model()
|
train_model()
|
||||||
|
|
|
@ -4,10 +4,12 @@ from functools import lru_cache
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GeneralConfig:
|
class GeneralConfig:
|
||||||
run_mode: str
|
run_mode: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataSetConfig:
|
class DataSetConfig:
|
||||||
input_dir: str
|
input_dir: str
|
||||||
|
@ -17,7 +19,8 @@ class DataSetConfig:
|
||||||
modulation_types: list
|
modulation_types: list
|
||||||
val_split: float
|
val_split: float
|
||||||
output_dir: str
|
output_dir: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingConfig:
|
class TrainingConfig:
|
||||||
batch_size: int
|
batch_size: int
|
||||||
|
@ -26,16 +29,19 @@ class TrainingConfig:
|
||||||
checkpoint_path: str
|
checkpoint_path: str
|
||||||
use_gpu: bool
|
use_gpu: bool
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InferenceConfig:
|
class InferenceConfig:
|
||||||
model_path: str
|
model_path: str
|
||||||
num_classes: int
|
num_classes: int
|
||||||
output_path: str
|
output_path: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AppConfig:
|
class AppConfig:
|
||||||
build_dir: str
|
build_dir: str
|
||||||
|
|
||||||
|
|
||||||
class AppSettings:
|
class AppSettings:
|
||||||
"""Application settings, to be initialized from app.yaml configuration file."""
|
"""Application settings, to be initialized from app.yaml configuration file."""
|
||||||
|
|
||||||
|
@ -51,6 +57,7 @@ class AppSettings:
|
||||||
self.inference = InferenceConfig(**config_data["inference"])
|
self.inference = InferenceConfig(**config_data["inference"])
|
||||||
self.app = AppConfig(**config_data["app"])
|
self.app = AppConfig(**config_data["app"])
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_app_settings() -> AppSettings:
|
def get_app_settings() -> AppSettings:
|
||||||
"""Return application configuration settings."""
|
"""Return application configuration settings."""
|
||||||
|
@ -60,4 +67,4 @@ def get_app_settings() -> AppSettings:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
s = get_app_settings()
|
s = get_app_settings()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user