Compare commits
No commits in common. "53d0552fd468596fb1bae754c41ed3598480e931" and "77179d38f352ad7ed881e00c964050462eec2f0e" have entirely different histories.
53d0552fd4
...
77179d38f3
|
@ -1,4 +1,4 @@
|
|||
name: Modulation Recognition Demo
|
||||
name: RIA Hub Workflow Demo
|
||||
|
||||
on:
|
||||
push:
|
||||
|
@ -11,6 +11,9 @@ on:
|
|||
jobs:
|
||||
ria-demo:
|
||||
runs-on: ubuntu-latest-2080
|
||||
env:
|
||||
RIAGIT_USERNAME: ${{ secrets.USERNAME }}
|
||||
RIAGIT_TOKEN: ${{ secrets.TOKEN }}
|
||||
steps:
|
||||
- name: Print GPU information
|
||||
run: |
|
||||
|
@ -21,7 +24,7 @@ jobs:
|
|||
echo "⚠️ No NVIDIA GPU found"
|
||||
fi
|
||||
|
||||
- name: Checkout project code
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
|
@ -39,10 +42,13 @@ jobs:
|
|||
utils \
|
||||
-r requirements.txt
|
||||
|
||||
|
||||
|
||||
- name: 1. Generate Recordings
|
||||
run: |
|
||||
mkdir -p data/recordings
|
||||
PYTHONPATH=. python scripts/dataset_manager/data_gen.py --output-dir data/recordings
|
||||
PYTHONPATH=. python scripts/dataset_building/data_gen.py --output-dir data/recordings
|
||||
echo "recordings produced successfully"
|
||||
|
||||
- name: ⬆️ Upload recordings
|
||||
uses: actions/upload-artifact@v3
|
||||
|
@ -53,10 +59,11 @@ jobs:
|
|||
- name: 2. Build HDF5 Dataset
|
||||
run: |
|
||||
mkdir -p data/dataset
|
||||
PYTHONPATH=. python scripts/dataset_manager/produce_dataset.py
|
||||
PYTHONPATH=. python scripts/dataset_building/produce_dataset.py
|
||||
echo "datasets produced successfully"
|
||||
shell: bash
|
||||
|
||||
- name: ⬆️ Upload Dataset
|
||||
- name: 📤 Upload Dataset
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: dataset
|
||||
|
@ -68,30 +75,34 @@ jobs:
|
|||
PYTORCH_NO_NNPACK: 1
|
||||
run: |
|
||||
mkdir -p checkpoint_files
|
||||
PYTHONPATH=. python scripts/model_builder/train.py 2>/dev/null
|
||||
PYTHONPATH=. python scripts/training/train.py 2>/dev/null
|
||||
echo "training model"
|
||||
|
||||
- name: 4. Plot Model
|
||||
env:
|
||||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
run: |
|
||||
PYTHONPATH=. python scripts/model_builder/plot_data.py 2>/dev/null
|
||||
PYTHONPATH=. python scripts/training/plot_data.py 2>/dev/null
|
||||
|
||||
|
||||
- name: ⬆️ Upload Checkpoints
|
||||
- name: Upload Checkpoints
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: checkpoints
|
||||
path: checkpoint_files/*
|
||||
|
||||
- name: 5. Export model to ONNX graph
|
||||
|
||||
- name: 5. Convert to ONNX file
|
||||
env:
|
||||
NO_NNPACK: 1
|
||||
PYTORCH_NO_NNPACK: 1
|
||||
run: |
|
||||
mkdir -p onnx_files
|
||||
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/application_packager/convert_to_onnx.py 2>/dev/null
|
||||
MKL_DISABLE_FAST_MM=1 PYTHONPATH=. python scripts/onnx/convert_to_onnx.py 2>/dev/null
|
||||
echo "building inference app"
|
||||
|
||||
- name: ⬆️ Upload ONNX file
|
||||
- name: Upload ONNX file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: onnx-file
|
||||
|
@ -99,20 +110,21 @@ jobs:
|
|||
|
||||
- name: 6. Profile ONNX model
|
||||
run: |
|
||||
PYTHONPATH=. python scripts/application_packager/profile_onnx.py
|
||||
PYTHONPATH=. python scripts/onnx/profile_onnx.py
|
||||
|
||||
- name: ⬆️ Upload JSON trace
|
||||
- name: Upload JSON profiling data
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: profile-data
|
||||
path: '**/onnxruntime_profile_*.json'
|
||||
|
||||
- name: 7. Convert ONNX graph to an ORT file
|
||||
- name: 7. Convert to ORT file
|
||||
run: |
|
||||
PYTHONPATH=. python scripts/application_packager/convert_to_ort.py
|
||||
PYTHONPATH=. python scripts/ort/convert_to_ort.py
|
||||
|
||||
- name: ⬆️ Upload ORT file
|
||||
- name: Upload ORT file
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ort-file
|
||||
path: ort_files/inference_recognition_model.ort
|
||||
|
||||
|
|
165
README.md
165
README.md
|
@ -1,7 +1,8 @@
|
|||
# Modulation Recognition Demo
|
||||
|
||||
RIA Hub Workflows is an automation platform integrated into RIA Hub. This project provides an example machine learning
|
||||
workflow for signal modulation classification, offering a practical introduction to RIA Hub Workflows
|
||||
RIA Hub Workflows is an automation platform built into RIA Hub. This project contains an example machine learning
|
||||
workflow for the problem of signal modulation classification. It also serves as an excellent introduction to
|
||||
RIA Hub Workflows.
|
||||
|
||||
|
||||
## 📡 The machine learning development workflow
|
||||
|
@ -9,24 +10,24 @@ workflow for signal modulation classification, offering a practical introduction
|
|||
The development of intelligent radio solutions involves multiple steps:
|
||||
|
||||
1. First, we need to prepare a machine learning-ready dataset. This involves signal synthesis or capture, followed by
|
||||
dataset curation to extract and qualify examples. Finally, we need to perform any required data preprocessing—such as
|
||||
augmentation—and split the dataset into training and test sets.
|
||||
dataset curation to extract and qualify training examples. Finally, we need to perform any required data preprocessing
|
||||
—such as augmentation—and split the dataset into training and test sets.
|
||||
|
||||
|
||||
2. Secondly, we need to design and train a model. This is typically an iterative process, often accelerated using
|
||||
techniques such as Neural Architecture Search (NAS) and hyperparameter optimization (HPO), which help automate the
|
||||
discovery of an effective model structure and optimal hyperparameter settings.
|
||||
2. Secondly, we need to design and train a model. This is often an iterative process and can leverage techniques like
|
||||
Neural Architecture Search (NAS) and hyperparameter optimization to automate finding a suitable model structure and
|
||||
optimal hyperparameter configuration, respectively.
|
||||
|
||||
|
||||
3. Once a machine learning model has been trained and validated, the next step is to build an inference application.
|
||||
This step transforms the model from a research artifact into a practical tool capable of making predictions in
|
||||
real-world conditions. Building an inference application typically involves several steps including model
|
||||
real-world conditions. Building an inference application typically involves several substeps including model
|
||||
optimization, packaging and integration, and monitoring and logging.
|
||||
|
||||
This is a lot of work, and much of it involves tedious software development and repetitive tasks, like setting up and
|
||||
This is a lot of work, and much of it involves tedious software development and repetitive tasks like setting up and
|
||||
configuring infrastructure. What's more? There is a shortage of domain expertize in ML and MLOps for radio. That's
|
||||
where we come in. RIA Hub offers a no-code and low-code solution for automating the end-to-end development of
|
||||
intelligent radio systems.
|
||||
where we come in. RIA Hub offers a no- and low-code solution for the end-to-end development of intelligent radio
|
||||
systems, allowing for a sharper focus on innovation.
|
||||
|
||||
|
||||
## ▶️ RIA Hub Workflows
|
||||
|
@ -34,34 +35,25 @@ intelligent radio systems.
|
|||
One of the core principles of RIA Hub is Workflows, which allow users to run jobs in isolated Docker containers.
|
||||
|
||||
You can create workflows in one of two ways:
|
||||
- Writing YAML and placing it in the special `.riahub/workflows/` directory in your repository.
|
||||
|
||||
|
||||
- Writing YAML and placing it in the special `.riahub/workflows/` directory in your repository.
|
||||
- Using RIA Hub's built-in tools for Dataset Management, Model Building, and Application Development, which will
|
||||
automatically generate the YAML workflow file(s) for you.
|
||||
|
||||
Workflows can be configured to run automatically on push and pull request events. You can monitor and manage running
|
||||
workflows in the 'Workflows' tab in your repository.
|
||||
|
||||
Workflows require a _runner_, which retrieves job definitions from RIA Hub, executes them in isolated containers, and
|
||||
reports the results back to RIA Hub. The next section outlines the convenience and advantage of using Qoherent-hosted
|
||||
runners. The workflow configuration defines the specifications and settings of the available job containers.
|
||||
|
||||
The best part? RIA Hub Workflows are built on [Gitea Actions](https://docs.gitea.com/usage/actions/overview) (similar to [GitHub Actions](https://github.com/features/actions)), providing a
|
||||
familiar syntax and allowing you to leverage a wide range of third-party Actions.
|
||||
|
||||
|
||||
## ⚙️ Qoherent-hosted runners
|
||||
|
||||
Qoherent-hosted runners are workflow runners that Qoherent provides and manages to run your workflows and jobs in
|
||||
RIA Hub Workflows.
|
||||
Qoherent-hosted runners are job containers that Qoherent provides and manages to run your workflows and jobs in RIA Hub
|
||||
Workflows.
|
||||
|
||||
Why use Qoherent-hosted runners?
|
||||
- Start running workflows right away, without the need to set up your own infrastructure.
|
||||
- Qoherent maintains runners equipped with access to hardware and tools common for radio ML development, including
|
||||
Why use GitHub-hosted runners?
|
||||
- Easy to set up and start running workflows quickly, without the need to set up your own infrastructure.
|
||||
- Qoherent maintains runners equipped with access to common hardware and tools for radio ML development, including
|
||||
SDR testbeds and common embedded targets.
|
||||
|
||||
If you want to learn more about the runners we have available, [contact us](https://www.qoherent.ai/contact/) directly. We can also provide
|
||||
If you want to learn more about the runners we have available, please feel free to reach out. We can also provide
|
||||
custom runners equipped with specific radio hardware and RAN software upon request.
|
||||
|
||||
Want to register your own runner? No problem! Please refer to the RIA Hub documentation for more details.
|
||||
|
@ -69,18 +61,6 @@ Want to register your own runner? No problem! Please refer to the RIA Hub docume
|
|||
|
||||
## 🔍 Modulation Recognition
|
||||
|
||||
In radio, the modulation scheme refers to the method used to encode information onto a carrier signal. Common schemes
|
||||
such as BPSK, QPSK, and QAM vary the amplitude, phase, or frequency of the signal in structured ways to represent
|
||||
digital data. These schemes are fundamental to nearly all wireless communication systems, enabling efficient and
|
||||
reliable transmission over different channels and under various noise conditions.
|
||||
|
||||
Machine learning-based modulation classification helps identify which modulation scheme is being used, especially
|
||||
in scenarios where prior knowledge of the signal format is unavailable or unreliable. Traditional methods often rely
|
||||
on expert-designed features and rule-based algorithms, which can struggle in real-world environments with multipath,
|
||||
interference, or hardware impairments. In contrast, ML-based approaches can learn complex patterns directly from
|
||||
raw signal data, offering higher robustness and adaptability. This is particularly valuable in applications like
|
||||
cognitive radio, spectrum monitoring, electronic warfare, and autonomous communication systems, where accurate and
|
||||
fast modulation recognition is critical.
|
||||
|
||||
|
||||
## 🚀 Getting started
|
||||
|
@ -89,61 +69,44 @@ fast modulation recognition is critical.
|
|||
|
||||
|
||||
2. Enable Workflows (*Settings → Advanced Settings → Enable Repository Actions*).
|
||||
_TODO: Remove this point once default units have been updated to include actions in forks_
|
||||
|
||||
|
||||
3. Check for available runners. The runner management tab can found at the top of the 'Workflows' tab in your
|
||||
repository. If no runners are available, you'll need to register one before proceeding.
|
||||
3. Check for available runners. The runner management tab can found at the top of the 'Workflows' tab. If no runners
|
||||
are available, you'll need to register one before proceeding.
|
||||
|
||||
|
||||
4. Configure Git API credentials, if not suitable credentials are already set. This is required for accessing Utils
|
||||
in the job container. This requires three steps:
|
||||
|
||||
- Create a personal access token with the following permissions: `read:packages` (*User Settings → Applications → Manage Access Tokens*).
|
||||
|
||||
- Create a Workflow Variable `RIAHUB_USER` with your RIA Hub username (*Repo Settings → Actions → Variables Management*)
|
||||
|
||||
- Create a Workflow Secret `RIAHUB_TOKEN` with the token created above (*Repo Settings → Actions → Secrets Management*)
|
||||
|
||||
_TODO: Remove this point once the Utils wheel file has been added to this project._
|
||||
|
||||
|
||||
5. Clone down the project. For example:
|
||||
4. Clone down the project. For example:
|
||||
```commandline
|
||||
git clone https://git.riahub.ai/user/modrec-workflow.git
|
||||
cd modrec-workflow
|
||||
```
|
||||
|
||||
6. Set the workflow runner in `.riahub/workflows/workflow.yaml`. The runner is set on line 13:
|
||||
5. Set the workflow runner in `.riahub/workflows/workflow.yaml`. The runner is set on line 13:
|
||||
```yaml
|
||||
runs-on: ubuntu-latest-2080
|
||||
runs-on: ubuntu-latest
|
||||
```
|
||||
**Note:** We recommend running this demo on a GPU-enabled runner. If a GPU runner is not available, you can still run
|
||||
the workflow, but we suggest reducing the number of training epochs to keep runtime reasonable.
|
||||
|
||||
|
||||
7. (Optional) Configure the workflow. All parameters—including file paths, model architecture, and training
|
||||
settings—are set in `conf/app.yaml`. Want to jump right in? No problem, the default configuration is suitable.
|
||||
6. (Optional) Configure the workflow. All parameters—including file paths, model architecture, and training
|
||||
settings—are set in `conf/app.yaml`. Want to jump right in? The default configuration is suitable for getting started.
|
||||
|
||||
|
||||
8. Push changes. This will automatically trigger the workflow. You can monitor workflow progress under the 'Workflows'
|
||||
tab in the repository.
|
||||
7. Push changes. This will start the workflow automatically.
|
||||
|
||||
|
||||
9. Inspect the workflow output. You can expand and collapse individual steps to view terminal output. A check
|
||||
8. Inspect the workflow output. You can expand and collapse individual steps to view their terminal output. A check
|
||||
mark indicates that the step completed successfully.
|
||||
|
||||
|
||||
10. Inspect the workflow artifacts. Additional information on workflow artifacts can be found in the next section.
|
||||
9. Inspect the workflow artifacts. Additional information on workflow artifacts can be found in the next section.
|
||||
|
||||
|
||||
|
||||
## Workflow artifacts
|
||||
|
||||
This workflow generates several artifacts, including:
|
||||
|
||||
- `recordings`: Folder of synthetic signal recordings.
|
||||
|
||||
|
||||
The example generates several workflow artifacts, including:
|
||||
- `dataset`: The training and validation datasets: `train.h5` and `val.h5`, respectively.
|
||||
|
||||
|
||||
|
@ -158,22 +121,18 @@ stages of training.
|
|||
by [ONNX Runtime](https://onnxruntime.ai/) for more efficient loading and execution.)
|
||||
|
||||
|
||||
- `profile-data`: Model execution traces, in JSON format. See the section below for instructions on how to inspect the
|
||||
trace using Perfetto.
|
||||
- `profile-data`: Model execution traces, in JSON format.
|
||||
|
||||
|
||||
## 📊 Inspecting the model trace using Perfetto
|
||||
- `recordings`: Folder of synthesised signal recordings.
|
||||
|
||||
|
||||
[Perfetto](https://ui.perfetto.dev/) is an open-source trace visualization tool developed by Google. It provides a powerful web-based
|
||||
interface for inspecting model execution traces. Perfetto is especially useful for identifying bottlenecks.
|
||||
|
||||
To inspect model trace, navigate to Perfetto. Select *Navigation → Open trace file*, and choose the JSON trace file
|
||||
includes in the `profile-data` artifact.
|
||||
|
||||
|
||||
## 🤝 Contribution
|
||||
|
||||
We welcome contributions from the community! Whether it's an enhancement, bug fix, or new tutorial, your
|
||||
We welcome contributions from the community! Whether it's an enhancement, bug fix, or new how-to guide, your
|
||||
input is valuable. To get started, please [contact us](https://www.qoherent.ai/contact/) directly, we're looking forward to collaborating with
|
||||
you. 🚀
|
||||
|
||||
|
@ -199,3 +158,57 @@ This example is **free and open-source**, released under [AGPLv3](https://www.gn
|
|||
|
||||
Alternative licensing options are available. Alternative licensing options are available. Please [contact us](https://www.qoherent.ai/contact/)
|
||||
for further details.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
### Configure GitHub Secrets
|
||||
|
||||
Before running the pipeline, add the following repository secrets in GitHub (Settings → Secrets and variables → Actions):
|
||||
|
||||
- **RIAHUB_USER**: Your RIA Hub username.
|
||||
- **RIAHUB_TOKEN**: RIA Hub access token with `read:packages` scope (from your RIA Hub account **Settings → Access Tokens**).
|
||||
- **CLONER_TOKEN**: Personal access token for `stark_cloner_bot` with `read_repository` scope (from your on-prem Git server user settings).
|
||||
|
||||
Once secrets are configured, you can run the pipeline:
|
||||
|
||||
|
||||
3.
|
||||
|
||||
|
||||
## How to View the JSON Trace File
|
||||
|
||||
- Captures a full trace of model training and inference performance for profiling and debugging
|
||||
- Useful for identifying performance bottlenecks, optimizing resource usage, and tracking metadata
|
||||
-
|
||||
Access this [link](https://ui.perfetto.dev/)
|
||||
Click on Open Trace File -> Select your specific JSON trace file
|
||||
Explore detailed visualizations of performance metrics, timelines, and resource usage to diagnose bottlenecks and optimize your workflow.
|
||||
|
||||
|
||||
|
||||
## Submiting Issues
|
||||
Found a bug or have a feature request?
|
||||
Please submit an issue via the GitHub Issues page.
|
||||
When reporting bugs, include:
|
||||
Steps to reproduce
|
||||
- Error logs and screenshots (if applicable)
|
||||
- Your app.yaml configuration (if relevant)
|
||||
|
||||
|
||||
|
||||
## Developer Details
|
||||
Coding Guidelines:
|
||||
Follow PEP 8 for Python code style.
|
||||
Include type annotations for all public functions and methods.
|
||||
Write clear docstrings for modules, classes, and functions.
|
||||
Use descriptive commit messages and reference issue numbers when relevant.
|
||||
Contributing
|
||||
All contributions must be reviewed via pull requests.
|
||||
Run all tests and ensure code passes lint checks before submission.
|
|
@ -1,16 +1,20 @@
|
|||
general:
|
||||
# Run mode. Options are 'prod' or 'dev'.
|
||||
run_mode: prod
|
||||
|
||||
dataset:
|
||||
# Seed for the random number generator, used for signal generation
|
||||
seed: 42
|
||||
#number of slices you want to split each recording into
|
||||
num_slices: 8
|
||||
|
||||
# Number of samples per recording
|
||||
recording_length: 1024
|
||||
#training/val split between the 2 data sets
|
||||
train_split: 0.8
|
||||
val_split : 0.2
|
||||
|
||||
# List of signal modulation schemes to include in the dataset
|
||||
modulation_types:
|
||||
- bpsk
|
||||
- qpsk
|
||||
- qam16
|
||||
- qam64
|
||||
#used to initialize a random number generator.
|
||||
seed: 25
|
||||
|
||||
#multiple modulations to contain in the dataset
|
||||
modulation_types: [bpsk, qpsk, qam16, qam64]
|
||||
|
||||
# Rolloff factor for pulse shaping filter (0 < beta <= 1)
|
||||
beta: 0.3
|
||||
|
@ -19,18 +23,20 @@ dataset:
|
|||
sps: 4
|
||||
|
||||
# SNR sweep range: start, stop (exclusive), and step (in dB)
|
||||
snr_start: -6
|
||||
snr_stop: 13
|
||||
snr_step: 3
|
||||
snr_start: -6 # Start value of SNR sweep (in dB)
|
||||
snr_stop: 13 # Stop value (exclusive) of SNR sweep (in dB)
|
||||
snr_step: 3 # Step size for SNR sweep (in dB)
|
||||
|
||||
# Number of iterations (signal recordings) per modulation and SNR combination
|
||||
# Number of iterations (samples) per modulation and SNR combination
|
||||
num_iterations: 3
|
||||
|
||||
# Modulation scheme settings; keys must match the `modulation_types` list above
|
||||
# Each entry includes:
|
||||
# - num_bits_per_symbol: bits encoded per symbol (e.g., 1 for BPSK, 4 for 16-QAM)
|
||||
# - constellation_type: modulation category (e.g., "psk", "qam", "fsk", "ofdm")
|
||||
# TODO: Combine entries for 'modulation_types' and 'modulation_settings'
|
||||
# Number of samples per generated recording
|
||||
recording_length: 1024
|
||||
|
||||
# Settings for each modulation scheme
|
||||
# Keys must match entries in `modulation_types`
|
||||
# - `num_bits_per_symbol`: how many bits each symbol encodes (e.g., 1 for BPSK, 4 for 16-QAM)
|
||||
# - `constellation_type`: type of modulation (e.g., "psk", "qam", "fsk", "ofdm")
|
||||
modulation_settings:
|
||||
bpsk:
|
||||
num_bits_per_symbol: 1
|
||||
|
@ -45,25 +51,20 @@ dataset:
|
|||
num_bits_per_symbol: 6
|
||||
constellation_type: qam
|
||||
|
||||
# Number of slices to cut from each recording
|
||||
num_slices: 8
|
||||
|
||||
# Training and validation split ratios; must sum to 1
|
||||
train_split: 0.8
|
||||
val_split : 0.2
|
||||
|
||||
training:
|
||||
# Number of training examples processed together before the model updates its weights
|
||||
# Number of training samples processed together before the model updates its weights
|
||||
batch_size: 256
|
||||
|
||||
# Number of complete passes through the training dataset during training
|
||||
epochs: 5
|
||||
|
||||
# Learning rate: step size for weight updates after each batch
|
||||
# Recommended range for fine-tuning: 1e-6 to 1e-4
|
||||
# Learning rate: how much weights are updated after every batch
|
||||
# Suggested range for fine-tuning: 1e-6 to 1e-4
|
||||
learning_rate: 1e-4
|
||||
|
||||
# Enable GPU acceleration for training if available
|
||||
# Whether to use GPU acceleration for training (if available)
|
||||
use_gpu: true
|
||||
|
||||
# Dropout rate for individual neurons/layers (probability of dropping out a unit)
|
||||
|
@ -72,12 +73,13 @@ training:
|
|||
# Drop path rate: probability of dropping entire residual paths (stochastic depth)
|
||||
drop_path_rate: 0.2
|
||||
|
||||
# Weight decay (L2 regularization) coefficient to help prevent overfitting
|
||||
# Weight decay (L2 regularization) to help prevent overfitting
|
||||
wd: 0.01
|
||||
|
||||
app:
|
||||
# Optimization style for ORT conversion; options: 'Fixed', 'None'
|
||||
optimization_style: "Fixed"
|
||||
|
||||
# Target platform architecture; common options: 'amd64', 'arm64'
|
||||
target_platform: "amd64"
|
||||
app:
|
||||
# Optimization style for ORT conversion. Options: 'Fixed', 'None'
|
||||
optimization_style: Fixed
|
||||
|
||||
# Target platform architecture. Common options: 'amd64', 'arm64'
|
||||
target_platform: amd64
|
|
@ -1,9 +1,7 @@
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
from utils.data import Recording
|
||||
import numpy as np
|
||||
from utils.signal import block_generator
|
||||
|
||||
import argparse
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
settings = get_app_settings().dataset
|
|
@ -1,11 +1,7 @@
|
|||
import os
|
||||
import os, h5py, numpy as np
|
||||
from typing import List
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
from split_dataset import split, split_recording
|
||||
from utils.io import from_npy
|
||||
|
||||
from split_dataset import split, split_recording
|
||||
from helpers.app_settings import DataSetConfig, get_app_settings
|
||||
|
||||
meta_dtype = np.dtype(
|
||||
|
@ -50,6 +46,8 @@ def write_hdf5_file(records: List, output_path: str, dataset_name: str = "data")
|
|||
)
|
||||
|
||||
first_rec, _ = records[0] # records[0] is a tuple of (data, md)
|
||||
sample = first_rec
|
||||
shape, dtype = sample.shape, sample.dtype
|
||||
|
||||
with h5py.File(output_path, "w") as hf:
|
||||
data_arr = np.stack([rec[0] for rec in records])
|
|
@ -1,7 +1,6 @@
|
|||
import random
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from typing import List, Tuple, Dict
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -2,8 +2,8 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from scripts.training.mobilenetv3 import RFClassifier, mobilenetv3
|
||||
|
||||
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
|
||||
|
@ -12,8 +12,8 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
|||
Convert a PyTorch model to ONNX format.
|
||||
|
||||
Parameters:
|
||||
ckpt_path (str): The path to save the converted ONNX model.
|
||||
fp16 (bool): 16 float point precision
|
||||
output_path (str): The path to save the converted ONNX model.
|
||||
fp16 (bool): 16 float point percision
|
||||
"""
|
||||
settings = get_app_settings()
|
||||
|
||||
|
@ -68,6 +68,8 @@ def convert_to_onnx(ckpt_path: str, fp16: bool = False) -> None:
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
settings = get_app_settings()
|
||||
|
||||
model_checkpoint = "inference_recognition_model.ckpt"
|
||||
|
||||
print("Converting to ONNX...")
|
|
@ -1,9 +1,9 @@
|
|||
import json
|
||||
import onnxruntime as ort
|
||||
import numpy as np
|
||||
from helpers.app_settings import get_app_settings
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import json
|
||||
|
||||
|
||||
def profile_onnx_model(
|
||||
|
@ -84,5 +84,6 @@ def profile_onnx_model(
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
output_path = os.path.join("onnx_files", "inference_recognition_model.onnx")
|
||||
profile_onnx_model(output_path)
|
|
@ -1,5 +1,4 @@
|
|||
import subprocess
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
settings = get_app_settings()
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import lightning as L
|
||||
import numpy as np
|
||||
import timm
|
||||
import torch
|
||||
import timm
|
||||
from torch import nn
|
||||
import lightning as L
|
||||
|
||||
sizes = [
|
||||
"mobilenetv3_large_075",
|
||||
|
@ -24,9 +24,11 @@ class SqueezeExcite(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
se_ratio=0.25,
|
||||
reduced_base_chs=None,
|
||||
act_layer=nn.SiLU,
|
||||
gate_fn=torch.sigmoid,
|
||||
divisor=1,
|
||||
**_,
|
||||
):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
|
@ -75,6 +77,13 @@ class GBN(torch.nn.Module):
|
|||
self.act = act
|
||||
|
||||
def forward(self, x):
|
||||
# chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
|
||||
# res = [self.bn(x_) for x_ in chunks]
|
||||
# return self.drop(self.act(torch.cat(res, dim=0)))
|
||||
# x = self.bn(x)
|
||||
# x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# return x
|
||||
return self.drop(self.act(self.bn(x)))
|
||||
|
||||
|
|
@ -1,12 +1,10 @@
|
|||
import os
|
||||
import sys
|
||||
import sys, os
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../..")) # or ".." if needed
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import h5py
|
||||
from helpers.app_settings import get_app_settings
|
||||
|
||||
settings = get_app_settings()
|
|
@ -1,16 +1,15 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import numpy as np
|
||||
from sklearn.metrics import classification_report
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
from cm_plotter import plot_confusion_matrix
|
||||
from matplotlib import pyplot as plt
|
||||
from scripts.training.mobilenetv3 import RFClassifier, mobilenetv3
|
||||
from scripts.training.modulation_dataset import ModulationH5Dataset
|
||||
|
||||
from scripts.training.mobilenetv3 import mobilenetv3, RFClassifier
|
||||
from helpers.app_settings import get_app_settings
|
||||
from cm_plotter import plot_confusion_matrix
|
||||
from scripts.training.modulation_dataset import ModulationH5Dataset
|
||||
|
||||
|
||||
def load_validation_data():
|
||||
|
@ -142,4 +141,5 @@ def plot_confusion_matrix_with_counts(
|
|||
|
||||
if __name__ == "__main__":
|
||||
settings = get_app_settings()
|
||||
evaluate_checkpoint(os.path.join("checkpoint_files", "inference_recognition_model.ckpt"))
|
||||
ckpt_path = os.path.join("checkpoint_files", "inference_recognition_model.ckpt")
|
||||
evaluate_checkpoint(ckpt_path)
|
|
@ -1,16 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
import sys, os
|
||||
|
||||
os.environ["NNPACK"] = "0"
|
||||
import lightning as L
|
||||
import mobilenetv3
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchmetrics
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
|
||||
from helpers.app_settings import get_app_settings
|
||||
from modulation_dataset import ModulationH5Dataset
|
||||
import mobilenetv3
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_dir = os.path.abspath(os.path.join(script_dir, ".."))
|
Loading…
Reference in New Issue
Block a user