├── requirements_dev.txt
├── tox.ini
├── pyproject.toml
├── .gitignore
├── requirements.txt
├── LICENSE
├── .github
└── workflows
│ └── format.yaml
├── download_dataset.py
├── README.md
└── spectrogram_segmentation.ipynb
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | black[jupyter]==24.4.2
2 | flake8==7.1.0
3 | flake8-nb==0.5.3
4 | isort==5.13.2
5 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 119
3 | extend-ignore = W503, E203, E701
4 | exclude = .git, .github, venv, .venv, env, .env, .idea, .vscode,
5 | max-complexity = 10
6 |
7 | [flake8_nb]
8 | max-line-length = 119
9 | extend-ignore = W503, E402, E203, E701
10 | exclude = .git, .github, venv, .venv, env, .env, .idea, .vscode, .ipynb_checkpoints
11 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 119
3 | target-version = ["py310"]
4 | exclude = '''
5 | /(
6 | \.git
7 | | \.github
8 | | venv
9 | | \.venv
10 | | env
11 | | \.env
12 | | \.idea
13 | | \.vscode
14 | | \.ipynb_checkpoints
15 | )/
16 | '''
17 |
18 | [tool.isort]
19 | profile = "black"
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Jupyter Notebook
7 | .ipynb_checkpoints
8 |
9 | # IPython
10 | profile_default/
11 | ipython_config.py
12 |
13 | # Virutal environments
14 | .env
15 | .venv
16 | env/
17 | venv/
18 |
19 | # PyCharm specific settings
20 | .idea/
21 |
22 | # Visual Studio Code specific file
23 | .vscode/
24 |
25 | # Lightning
26 | /lightning_logs/
27 |
28 | # Project dataset
29 | /spectrum_sensing_dataset.hdf5
30 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu121 # CUDA 12.1 support for torch and torchvision.
2 | # --extra-index-url https://download.pytorch.org/whl/cu118 # CUDA 11.8 support for torch and torchvision.
3 |
4 | h5py==3.11.0
5 | jupyter==1.0.0 # Jupyter system, including the notebook and the IPython kernel, all in one go.
6 | lightning==2.3.3
7 | matplotlib==3.9.1
8 | numpy==1.26.4
9 | pandas==2.2.2
10 | pillow==10.4.0
11 | requests==2.32.3
12 | scikit-learn==1.5.1
13 | scipy==1.14.0
14 | tabulate==0.9.0
15 | torch==2.3.1
16 | torchmetrics==1.4.0
17 | torchvision==0.18.1
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Qoherent
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/format.yaml:
--------------------------------------------------------------------------------
1 | name: Lint and Format Checks
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | workflow_dispatch:
10 |
11 | jobs:
12 | lint-and-format-check:
13 | runs-on: ubuntu-latest
14 | name: Lint and Format Checks
15 | steps:
16 | - name: Check out
17 | uses: actions/checkout@v4
18 |
19 | - name: Set up Python environment
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: "3.10"
23 |
24 | - name: Back Formatting Check
25 | uses: psf/black@stable
26 | with:
27 | options: "--check --verbose"
28 | src: "./"
29 | jupyter: true
30 |
31 | - name: Setup Flake8 Annotations
32 | uses: rbialon/flake8-annotations@v1.1
33 |
34 | - name: Lint with Flake8
35 | uses: py-actions/flake8@v2
36 |
37 | - name: Install Flake8 for Jupyter Notebooks
38 | run: |
39 | pip install --upgrade pip
40 | pip install flake8-nb
41 |
42 | - name: Run Flake8 for Jupyter Notebooks
43 | run: |
44 | flake8-nb .
45 |
46 | - name: Ensure Clean Notebooks
47 | uses: ResearchSoftwareActions/EnsureCleanNotebooksAction@1.1
48 |
--------------------------------------------------------------------------------
/download_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Download MathWorks' Spectrum Sensing dataset, if it isn't already downloaded.
3 | """
4 |
5 | import hashlib
6 | import os
7 |
8 | import requests
9 | from torch.utils.model_zoo import tqdm
10 |
11 |
12 | def sha256(target: str) -> str:
13 | """Calculates the SHA256 hash of the target resource.
14 |
15 | :param target: The full path, including the filename, of the resource to hash.
16 | :type target: str
17 |
18 | :return: The SHA256 hash of the target resource.
19 | """
20 | sha256_hash = hashlib.sha256()
21 |
22 | with open(target, "rb") as file:
23 | for chunk in iter(lambda: file.read(4096), b""):
24 | sha256_hash.update(chunk)
25 |
26 | return sha256_hash.hexdigest()
27 |
28 |
29 | mirror = "https://storage.googleapis.com/qoherent_external_drive/general_dataset_library/"
30 | resource = "spectrum_sensing_dataset_v1.0.hdf5"
31 | file_url = "{}{}".format(mirror, resource)
32 | sha256_checksum = "8a93aa14145ea1a35cbc191defbbcf90c49ecdb89e6e93f3e55357f182d184c6"
33 |
34 | target = os.path.join(os.path.dirname(os.path.abspath(__file__)), "spectrum_sensing_dataset.hdf5")
35 |
36 | # Check if the dataset source file already exists.
37 | if os.path.exists(target):
38 | print(f"{target} already exists. Aborting download.")
39 | exit(1)
40 |
41 | # Download the dataset source file into the target directory.
42 | print(f"Downloading {format(file_url)}")
43 | n_bytes = int(requests.head(file_url).headers.get("Content-Length", 0))
44 |
45 | with (
46 | requests.get(file_url, stream=True, timeout=3) as r,
47 | open(target, "wb") as out_file,
48 | tqdm(desc="Downloading MathWorks' Spectrum Sensing Dataset", total=n_bytes, unit="B", unit_scale=True) as pbar,
49 | ):
50 | for chunk in r.iter_content(chunk_size=1024):
51 | if chunk:
52 | out_file.write(chunk)
53 | pbar.update(len(chunk))
54 |
55 | if sha256_checksum != sha256(target=target):
56 | raise RuntimeError(
57 | f"Checksum of {target} does not match expected.\n"
58 | f"The download may be corrupted, please remove the corrupted resource and try again."
59 | )
60 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spectrogram Segmentation
2 |
3 | The successful application of [semantic segmentation](https://www.ibm.com/topics/semantic-segmentation) to radiofrequency (RF) spectrograms has significant
4 | implications for [spectrum sensing](https://iopscience.iop.org/article/10.1088/1742-6596/2261/1/012016#:~:text=In%20cognitive%20radio%2C%20spectrum%20sensing,user%20can%20use%20the%20spectrum.), and serves as a foundational example showcasing the near-term feasibility of
5 | [intelligent radio](https://www.qoherent.ai/intelligentradio/) technology.
6 |
7 | In this example, we use [PyTorch](https://pytorch.org/) and [Lightning](https://lightning.ai/docs/pytorch/stable/) to train a segmentation model to identify and
8 | differentiate between 5G NR and 4G LTE signals within wideband spectrograms.
9 |
10 | Qoherent's mission to drive the creation of intelligent radio technology requires a combination of open-source and
11 | proprietary tools. This example, which leverages open-source tools and machine learning frameworks to train on
12 | synthetic radio data generated using MATLAB's powerful 5G and LTE toolboxes, showcases our commitment to
13 | interoperability and our tool-agnostic approach to innovation.
14 |
15 | Classification results are comparable to those achieved by MathWorks' custom network, albeit with more learnables.
16 | For more information, please refer to the following article by MathWorks:
17 | [Spectrum Sensing with Deep Learning to Identify 5G and LTE Signals](https://www.mathworks.com/help/comm/ug/spectrum-sensing-with-deep-learning-to-identify-5g-and-lte-signals.html).
18 |
19 | If you found this example interesting or helpful, don't forget to give it a star! ⭐
20 |
21 |
22 | ## 🚀 Getting Started
23 |
24 | This example is provided as a Jupyter Notebook. You have the option to either run this example locally or in Google
25 | Colab.
26 |
27 | To run this example locally, you'll need to download the project and dataset and set up a Python
28 | virtual environment. If this seems daunting, we recommend running this example on Google Colab (Coming soon!).
29 |
30 | ### Running this example locally
31 |
32 | Please note that running this example locally will require approximately 6.1 GB of free space. Please ensure you
33 | have sufficient space available prior to proceeding.
34 |
35 | 1. Ensure that [Git](https://git-scm.com/downloads) and [Python](https://www.python.org/downloads/) are installed on the computer where you plan to run this example.
36 | Additionally, if you'd like to accelerate model training with a GPU, you'll require [CUDA](https://docs.nvidia.com/cuda/cuda-quick-start-guide/index.html).
37 |
38 |
39 | 2. Clone this repository to your local computer:
40 | ```commandline
41 | git clone https://github.com/qoherent/spectrogram-segmentation.git
42 | ```
43 |
44 |
45 | 3. Create and activate a Python [virtual environment](https://docs.python.org/3/library/venv.html). This is a best practice for isolating project dependencies.
46 |
47 |
48 | Windows
49 |
50 | Use the following command to create a new directory named `venv` within the project directory:
51 | ```commandline
52 | python -m venv venv
53 | ```
54 |
55 | Then, activate the virtual environment with:
56 | ```commandline
57 | venv\Scripts\activate
58 | ```
59 |
60 |
61 |
62 |
63 | Linux/Mac
64 |
65 | Use the following command to create a new directory named `venv` within the project directory:
66 | ```commandline
67 | python3 -m venv venv
68 | ```
69 |
70 | Then, activate the virtual environment with:
71 | ```commandline
72 | source venv/bin/activate
73 | ```
74 |
75 |
76 |
77 | Activating the virtual environment should modify the command prompt to show `(venv)` at the beginning, indicating
78 | that the virtual environment is active.
79 |
80 |
81 | 4. Install project dependencies from the provided `requirements.txt` file:
82 | ```commandline
83 | pip install -r requirements.txt
84 | ```
85 |
86 |
87 | 5. Download the spectrum sensing dataset.
88 |
89 |
90 | Windows
91 |
92 | ```commandline
93 | python download_dataset.py
94 | ```
95 |
96 |
97 |
98 |
99 | Linux/Mac
100 |
101 | ```commandline
102 | python3 download_dataset.py
103 | ```
104 |
105 |
106 |
107 | This will download the `spectrum_sensing_dataset.hdf5` source file to the project's root directory.
108 |
109 |
110 | 6. Register the environment kernel with Jupyter:
111 | ```commandline
112 | ipython kernel install --user --name=venv --display-name "Spectrogram Segmentation (venv)"
113 | ```
114 |
115 |
116 | 7. Open the notebook, `spectrogram_segmentation.ipynb`, specifying to use the new kernel:
117 | ```commandline
118 | jupyter notebook spectrogram_segmentation.ipynb --MultiKernelManager.default_kernel_name=venv
119 | ```
120 |
121 |
122 | 8. Give yourself a pat on the back - you're all set up and ready to explore the example! For more information on
123 | navigating the Jupyter Notebook interface and executing code cells, please check out this tutorial by the Codecademy
124 | Team: [How To Use Jupyter Notebooks](https://www.codecademy.com/article/how-to-use-jupyter-notebooks).
125 |
126 | Depending on your system specifications and the availability of a CUDA, running this example locally may take
127 | several minutes. If a cell is taking too long to execute, you can interrupt its execution by clicking the "Kernel"
128 | menu and selecting "Interrupt Kernel" or by pressing `Ctrl + C` in the terminal where Jupyter Notebook is running.
129 |
130 |
131 | 9. After you finish exploring, consider removing the dataset from your system and deleting the virtual environment to
132 | free up space. Remember to deactivate the virtual environment using the deactivate command before deleting the folder.
133 |
134 |
135 | ## 🤝 Contribution
136 |
137 | We welcome contributions from the community! Whether it's an enhancement, bug fix, or improved explanation,
138 | your input is valuable. For significant changes, or if you'd like to prepare a separate tutorial, kindly
139 | [contact us](mailto:info@qoherent.ai) beforehand.
140 |
141 | If you encounter any issues or to report a security vulnerability, please submit a bug report to the GitHub Issues
142 | page [here](https://github.com/qoherent/spectrogram-segmentation/issues).
143 |
144 | Has this example inspired a project or research initiative related to intelligent radio? Please [get in touch](mailto:info@qoherent.ai);
145 | we'd love to collaborate with you! 📡🚀
146 |
147 | Finally, be sure to check out our open-source project: [RIA Core](https://github.com/qoherent/ria) (Coming soon!).
148 |
149 |
150 | ## 🖊️ Authorship
151 |
152 | This work is a product of the collaborative efforts of the Qoherent team. Of special mention are [Wan](https://github.com/wan-sdr),
153 | [Madrigal](https://github.com/MadrigalDW), [Dimitrios](https://github.com/DimitriosK), and [Michael](https://github.com/mrl280).
154 |
155 |
156 | ## 🙏 Attribution
157 |
158 | The dataset used in this example was prepared by MathWorks using their 5G and LTE toolboxes and is publicly available
159 | [here](https://www.mathworks.com/supportfiles/spc/SpectrumSensing/SpectrumSenseTrainingDataNetwork.tar.gz). For more information on how this dataset was generated or to generate further spectrum data, please refer
160 | to MathWork's article on spectrum sensing. For more information about Qoherent's use of MATLAB to accelerate
161 | intelligent radio research, check out our [customer story](https://www.mathworks.com/company/user_stories/qoherent-uses-matlab-to-accelerate-research-on-next-generation-ai-for-wireless.html).
162 |
163 | The DeepLabv3 models used in this example were initially proposed by Chen _et al._ and are further discussed
164 | in their 2017 paper titled '[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)'. The MobileNetV3
165 | backbone used in this example was developed by Howard _et al._ and is further discussed in their 2019 paper titled
166 | '[Searching for MobileNetV3](https://arxiv.org/abs/1905.02244)'. Models were accessed through [`torchvision`](https://pytorch.org/vision/stable/models/deeplabv3.html).
167 |
168 | A special thanks to the PyTorch and Lightning teams for providing the foundational machine learning frameworks used in
169 | this example.
170 |
--------------------------------------------------------------------------------
/spectrogram_segmentation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Spectrogram Segmentation\n",
8 | "\n",
9 | "In this example, we use [PyTorch](https://pytorch.org/) and [Lightning](https://lightning.ai/docs/pytorch/stable/) to train a deep learning model to identify and differentiate between 5G NR and 4G LTE signals within wideband spectrograms."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Outline\n",
17 | "\n",
18 | "**[Background](#Background):** Delve into the problem background and learn more about the machine learning frameworks, tools, and datasets used in this example.\n",
19 | "\n",
20 | "**[Set-up](#Set-Up):** Install the libraries necessary to run the code in this notebook.\n",
21 | "\n",
22 | "**[Data Preprocessing](#Data-Preprocessing):** Load and analyze the Spectrum Sensing dataset.\n",
23 | "\n",
24 | "**[Model Training](#Model-Training):** Configure and train a DeepLabV3 model with a MobileNetV3 backbone.\n",
25 | "\n",
26 | "**[Model Validation](#Model-Validation):** Assess the performance of the model using a suite of common machine learning metrics.\n",
27 | "\n",
28 | "**[Challenge Data](#Challenge-Data):** Challenge the model on combined frames containing both LTE and NR signal.\n",
29 | "\n",
30 | "**[Conclusions & Next Steps](#Conclusions-&-Next-Steps):** Interpret the results, summarize key learnings, and identify steps for expanding upon this example."
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "## Background\n",
38 | "\n",
39 | "5G NR (New Radio) and 4G LTE (Long-Term Evolution) are both cellular network technologies, but they represent \n",
40 | "different generations of mobile network standards. The ability to identify and distinguish between the two holds significant \n",
41 | "applications in [spectrum sensing](https://iopscience.iop.org/article/10.1088/1742-6596/2261/1/012016), and serves as a foundational example showcasing the near-term feasibility of \n",
42 | "[intelligent radio](https://www.qoherent.ai/intelligentradio/) technology.\n",
43 | "\n",
44 | "A spectrogram, which depicts the frequency spectrum of a signal over time, is essentially just an image. Therefore, we can\n",
45 | "apply state-of-the-art [semantic segmentation](https://www.ibm.com/topics/semantic-segmentation) techniques from \n",
46 | "the field of computer vision to the problem of spectrogram analysis. Our task is to assign one of the \n",
47 | "following labels to each pixel in the spectrogram: 'LTE', 'NR', or 'Noise'. ('Noise' refers to the absence of signal, representing \n",
48 | "a vacant or empty spectrum, also known as whitespace.)\n",
49 | "\n",
50 | "The machine learning model utilized in this example is a DeepLabV3 model with a MobileNetV3 (large) backbone. The DeepLabV3 framework was originally introduced by Chen _et al._ in their 2017 paper titled '[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587) and the MobileNetV3 backbone was developed by Howard _et al._ and is further discussed in their 2019 paper titled '[Searching for MobileNetV3](https://arxiv.org/abs/1905.02244)'. For an accessible introduction to the DeepLabV3 framework, please check out Isaac Berrios' article: [DeepLabv3: Building Blocks for Robust Segmentation Models](https://medium.com/@itberrios6/deeplabv3-c0c8c93d25a4).\n",
51 | "\n",
52 | "The dataset used in this example is the Spectrum Sensing dataset, provided by MathWorks. This dataset contains 900 LTE frames, 900 NR frames, and 900 combined frames with both LTE and NR signal. In this example, we train exclusively on the individual LTE and NR examples, excluding the combined frames from the training process."
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "metadata": {},
58 | "source": [
59 | "# Set-Up\n",
60 | "\n",
61 | "In this section, we will install the dependencies required to run the code in this notebook. These dependencies include libraries and packages for tasks such as data manipulation, visualization, and machine learning."
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "%matplotlib inline"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "import os\n",
80 | "import statistics\n",
81 | "from typing import Any, Optional\n",
82 | "\n",
83 | "import h5py\n",
84 | "import lightning as L\n",
85 | "import matplotlib.pyplot as plt\n",
86 | "import numpy as np\n",
87 | "import pandas as pd\n",
88 | "import torch\n",
89 | "from matplotlib.colors import ListedColormap\n",
90 | "from PIL import Image\n",
91 | "from sklearn.metrics import ConfusionMatrixDisplay\n",
92 | "from tabulate import tabulate\n",
93 | "from torch import Tensor, nn\n",
94 | "from torch.utils.data import DataLoader\n",
95 | "from torchmetrics.classification import (\n",
96 | " MulticlassAccuracy,\n",
97 | " MulticlassConfusionMatrix,\n",
98 | " MulticlassF1Score,\n",
99 | " MulticlassJaccardIndex,\n",
100 | " MulticlassPrecision,\n",
101 | " MulticlassRecall,\n",
102 | ")\n",
103 | "from torchvision.datasets import VisionDataset\n",
104 | "from torchvision.models.segmentation import ( # noqa: F401\n",
105 | " deeplabv3_mobilenet_v3_large,\n",
106 | " deeplabv3_resnet50,\n",
107 | " deeplabv3_resnet101,\n",
108 | ")\n",
109 | "from torchvision.transforms.v2 import (\n",
110 | " Compose,\n",
111 | " Normalize,\n",
112 | " PILToTensor,\n",
113 | " ToDtype,\n",
114 | " ToPILImage,\n",
115 | ")"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "title_font_size, label_font_size = 14, 12"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {},
130 | "source": [
131 | "# Data Preprocessing\n",
132 | "\n",
133 | "In semantic segmentation, the input data typically consists of images (in this case, spectrograms), while the output data consists of pixel-wise labels (masks) where each pixel is assigned a category label (in this case, either 'LTE', 'NR', or 'Noise'). \n",
134 | "\n",
135 | "We will use [supervised learning](https://www.ibm.com/topics/supervised-learning) techniques to train our model. Therefore, we need both spectrograms and the corresponding target masks for training. \n",
136 | "\n",
137 | "The dataset used in this example is the MathWorks' Spectrum Sensing dataset. This dataset has been converted into the file format used by [Radio Intelligence Apps](https://qoherent.ai/radiointelligenceapps-project/), which contains both the spectrograms and masks within a single high-performance [HDF5](https://www.hdfgroup.org/solutions/hdf5/) file. The file structure is as follows:"
138 | ]
139 | },
140 | {
141 | "cell_type": "raw",
142 | "metadata": {},
143 | "source": [
144 | "root/\n",
145 | "├── data (Dataset)\n",
146 | "│ ├── [MathWorks dataset licence (Attribute)]\n",
147 | "│ └── [spectrogram images, to use as input to the model]\n",
148 | "│\n",
149 | "├── masks (Dataset)\n",
150 | "│ └── [target masks, to use for training]\n",
151 | "│\n",
152 | "└── metadata (Group)\n",
153 | " ├── metadata (Dataset)\n",
154 | " │ ├── signal_type (Column)\n",
155 | " │ └── ...\n",
156 | " │\n",
157 | " └── about (Dataset)\n",
158 | " ├── author\n",
159 | " ├── name\n",
160 | " └── ..."
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "metadata": {},
166 | "source": [
167 | "The only metadata relevant to this example is `signal_type`, which can be one of the following:\n",
168 | "- `LTE`: This frame contains only LTE signal.\n",
169 | "- `NR`: This frame contains only NR signal.\n",
170 | "- `LTE_NR`: This frame contains both LTE and NR signal."
171 | ]
172 | },
173 | {
174 | "cell_type": "markdown",
175 | "metadata": {},
176 | "source": [
177 | "Because spectrogram segmentation is a computer vision task, let's extend the [VisionDataset](https://pytorch.org/vision/main/generated/torchvision.datasets.VisionDataset.html) class."
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": [
186 | "class SpectrumSensing(VisionDataset):\n",
187 | "\n",
188 | " def __init__(\n",
189 | " self,\n",
190 | " source: str,\n",
191 | " transform: Optional[callable] = None,\n",
192 | " target_transform: Optional[callable] = None,\n",
193 | " combined: Optional[bool] = True,\n",
194 | " ):\n",
195 | " \"\"\"Initialize a SpectrumSensing object from the dataset source file.\"\"\"\n",
196 | " super().__init__()\n",
197 | "\n",
198 | " self.source = source\n",
199 | " self.transform, self.target_transform = transform, target_transform\n",
200 | "\n",
201 | " # Build an index map, mapping the indices of valid data entries to their corresponding indices in the\n",
202 | " # original dataset. This mapping enables the efficient retrieval of a subset of frames in the case\n",
203 | " # where combined == False.\n",
204 | " with h5py.File(self.source, \"r\") as f:\n",
205 | " metadata = f[\"metadata/metadata\"]\n",
206 | "\n",
207 | " if combined:\n",
208 | " # Use all frames.\n",
209 | " valid_indices = range(0, len(metadata))\n",
210 | "\n",
211 | " else:\n",
212 | " # Use frames where the signal type is either 'LTE' or 'NR', but not 'LTE_NR'.\n",
213 | " df = pd.DataFrame(metadata[:])\n",
214 | " signal_types = df[\"signal_type\"]\n",
215 | " valid_indices = [index for index, value in enumerate(signal_types.isin([b\"LTE\", b\"NR\"])) if value]\n",
216 | "\n",
217 | " self.index_map = {idx: valid_indices[idx] for idx in range(len(valid_indices))}\n",
218 | "\n",
219 | " def __len__(self) -> int:\n",
220 | " return len(self.index_map)\n",
221 | "\n",
222 | " def __getitem__(self, idx: int) -> tuple[Image, Image]:\n",
223 | " \"\"\"Return the image-mask pair at idx.\"\"\"\n",
224 | " if idx >= len(self):\n",
225 | " raise IndexError\n",
226 | "\n",
227 | " with h5py.File(self.source, \"r\") as f:\n",
228 | " images, masks = f[\"data\"], f[\"masks\"]\n",
229 | " image_arr, mask_arr = images[self.index_map[idx]], masks[self.index_map[idx]]\n",
230 | " image, mask = Image.fromarray(image_arr), Image.fromarray(mask_arr)\n",
231 | "\n",
232 | " if self.transform is not None:\n",
233 | " image = self.transform(image)\n",
234 | "\n",
235 | " if self.target_transform is not None:\n",
236 | " mask = self.target_transform(mask)\n",
237 | "\n",
238 | " return image, mask"
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": [
245 | "Notice the `SpectrumSensing` class accepts two optional functions/transforms: `transform`, which is applied to the spectrogram, \n",
246 | "and `target_transform`, which is applied to the mask. Additionally, the class includes an optional parameter `combined`. When `combined == False`, only individual frames with either LTE or NR are included in the dataset. When `combined == True`, all frames are included, including combined frames containing both LTE and NR signals. \n",
247 | "\n",
248 | "Both the spectrograms and masks are 256 x 256 pixel images. However, the spectrograms are three channeled, while the masks are single-channeled. This is because the spectrograms are full RGB images, whereas the masks are ternary-valued images, where each pixel takes one of three discrete values:\n",
249 | "- `0`: Representing noise.\n",
250 | "- `1`: Representing NR signal.\n",
251 | "- `2`: Representing LTE signal.\n",
252 | "\n",
253 | "To prepare our spectrograms for training, we will convert them from PIL Images to Tensor objects. The images have to be loaded into the range `[0, 1]` and then normalized using a mean of `[0.485, 0.456, 0.406]` and a standard deviation of `[0.229, 0.224, 0.225]`, as required by our model. To prepare our masks for training, we will convert them to Tensor objects and remove the extraneous channel dimension."
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "project_root = os.getcwd()\n",
263 | "source = os.path.join(project_root, \"spectrum_sensing_dataset.hdf5\")\n",
264 | "\n",
265 | "mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]\n",
266 | "\n",
267 | "\n",
268 | "class Squeeze(torch.nn.Module):\n",
269 | " def forward(self, target: Tensor):\n",
270 | " return torch.squeeze(target)\n",
271 | "\n",
272 | "\n",
273 | "transform = Compose(\n",
274 | " [\n",
275 | " PILToTensor(),\n",
276 | " ToDtype(torch.float, scale=True),\n",
277 | " Normalize(mean=mean, std=std),\n",
278 | " ]\n",
279 | ")\n",
280 | "\n",
281 | "target_transform = Compose([PILToTensor(), Squeeze(), ToDtype(torch.long)])"
282 | ]
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {},
287 | "source": [
288 | "Now, let's initialize the dataset, and take a closer look at a random training example and its corresponding mask. Due to our transforms, we expect that the image and mask will be returned as Tensor objects."
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": [
297 | "dataset = SpectrumSensing(source=source, transform=transform, target_transform=target_transform, combined=False)"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {
304 | "scrolled": true
305 | },
306 | "outputs": [],
307 | "source": [
308 | "random_index = np.random.randint(len(dataset))\n",
309 | "training_example, corresponding_mask = dataset[random_index]\n",
310 | "\n",
311 | "print(f\"The full dataset has {len(dataset)} examples. Loading example at index {random_index}.\")\n",
312 | "print(f\"Spectrogram: {type(training_example)}, {training_example.dtype}, {training_example.size()}\")\n",
313 | "print(f\"Mask: {type(corresponding_mask)}, {corresponding_mask.dtype}, {corresponding_mask.size()}\")"
314 | ]
315 | },
316 | {
317 | "cell_type": "markdown",
318 | "metadata": {},
319 | "source": [
320 | "The dataset should contain 1,800 samples: 900 NR frames and 900 LTE frames. \n",
321 | "\n",
322 | "To gain further insight, let's write some transforms to undo the previous normalization and prepare this image-mask pair for viewing. And, let's build a custom colormap for the masks, with noise as cyan, NR signal as blue, and LTE signal as purple."
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": null,
328 | "metadata": {},
329 | "outputs": [],
330 | "source": [
331 | "inv_transform = Compose(\n",
332 | " [\n",
333 | " Normalize(mean=[0.0, 0.0, 0.0], std=[1 / x for x in std]),\n",
334 | " Normalize(mean=[-x for x in mean], std=[1.0, 1.0, 1.0]),\n",
335 | " ToPILImage(),\n",
336 | " ]\n",
337 | ")\n",
338 | "\n",
339 | "inv_target_transform = Compose([ToDtype(dtype=torch.uint8), ToPILImage()])\n",
340 | "\n",
341 | "training_example = inv_transform(training_example)\n",
342 | "corresponding_mask = inv_target_transform(corresponding_mask)\n",
343 | "\n",
344 | "values, labels, colors = [0, 1, 2], [\"Noise\", \"NR\", \"LTE\"], [\"cyan\", \"blue\", \"purple\"]\n",
345 | "mask_cmap = ListedColormap(colors)\n",
346 | "\n",
347 | "print(f\"Spectrogram: {training_example}\")\n",
348 | "print(f\"Mask: {corresponding_mask}\")"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": null,
354 | "metadata": {},
355 | "outputs": [],
356 | "source": [
357 | "fig, (ax1, ax2) = plt.subplots(figsize=[8, 3.5], nrows=1, ncols=2)\n",
358 | "plt.subplots_adjust(wspace=0.5)\n",
359 | "ax1.set_title(\"\\nRandom Spectrogram\", fontsize=title_font_size)\n",
360 | "ax2.set_title(\"Corresponding Mask\", fontsize=title_font_size)\n",
361 | "ax1.set_ylabel(\"Time [arb. units]\", fontsize=label_font_size)\n",
362 | "ax2.set_ylabel(\"Time [arb. units]\", fontsize=label_font_size)\n",
363 | "ax1.set_xlabel(\"Freq. [arb. units]\", fontsize=label_font_size)\n",
364 | "ax2.set_xlabel(\"Freq. [arb. units]\", fontsize=label_font_size)\n",
365 | "\n",
366 | "spect = ax1.imshow(training_example, vmin=0, vmax=255)\n",
367 | "fig.colorbar(spect, ax=ax1, fraction=0.045, ticks=[0, 255])\n",
368 | "\n",
369 | "mask = ax2.imshow(corresponding_mask, cmap=mask_cmap, vmin=0, vmax=2)\n",
370 | "mask_cbar = fig.colorbar(mask, ax=ax2, cmap=mask_cmap, fraction=0.045, ticks=[0.33, 1, 1.67])\n",
371 | "mask_cbar.ax.set_yticklabels(labels)"
372 | ]
373 | },
374 | {
375 | "cell_type": "markdown",
376 | "metadata": {},
377 | "source": [
378 | "**Note:** You can view different examples from the dataset by rerunning the previous few code cells."
379 | ]
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "metadata": {},
384 | "source": [
385 | "Let's analyze the relative frequencies of the different class labels. This step is critical for identifying imbalance in our dataset. Please note that the following code block might take a few seconds to run."
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": null,
391 | "metadata": {},
392 | "outputs": [],
393 | "source": [
394 | "class_counts = {label: 0 for label in labels}\n",
395 | "\n",
396 | "for _, mask in dataset:\n",
397 | " arr = np.asarray(mask)\n",
398 | " for i, label in enumerate(labels):\n",
399 | " class_counts[label] += np.sum(arr == values[i])\n",
400 | "\n",
401 | "normalized_counts = np.array(list(class_counts.values())) / sum(list(class_counts.values()))\n",
402 | "\n",
403 | "plt.bar(class_counts.keys(), normalized_counts, tick_label=labels, color=colors)\n",
404 | "plt.title(\"Distribution of Pixel Counts by Class\", fontsize=title_font_size)\n",
405 | "plt.xlabel(\"Class\", fontsize=label_font_size)\n",
406 | "plt.ylabel(\"Counts (Normalized)\", fontsize=label_font_size)"
407 | ]
408 | },
409 | {
410 | "cell_type": "markdown",
411 | "metadata": {},
412 | "source": [
413 | "It looks like our dataset is mostly noise! A classification dataset like this—with skewed class proportions—is called imbalanced.\n",
414 | "\n",
415 | "An imbalanced dataset can result in biased and poorly performing models. Models trained on imbalanced data tends to focus more on the majority classes and may not learn enough about the minority classes. To ensure the development of a fair, accurate, and robust model, we will need to address this class imbalance. \n",
416 | "\n",
417 | "But first, let's split the dataset into separate training and validation sets. The training dataset is the portion of the dataset that will be used to train the model, while the validation dataset will be held in reserve and used to evaluate the performance of the trained model. Let's start with a simple 80/20 split, where 80% of the dataset is used for training and 20% for validation."
418 | ]
419 | },
420 | {
421 | "cell_type": "code",
422 | "execution_count": null,
423 | "metadata": {},
424 | "outputs": [],
425 | "source": [
426 | "train_split = 0.80\n",
427 | "n_train_examples = int(len(dataset) * train_split)\n",
428 | "n_val_examples = len(dataset) - n_train_examples\n",
429 | "\n",
430 | "train_set, val_set = torch.utils.data.random_split(\n",
431 | " dataset=dataset, lengths=[n_train_examples, n_val_examples], generator=torch.Generator().manual_seed(42)\n",
432 | ")\n",
433 | "\n",
434 | "print(f\"The training split contains {len(train_set)} examples.\")\n",
435 | "print(f\"The validation split contains {len(val_set)} examples.\")"
436 | ]
437 | },
438 | {
439 | "cell_type": "markdown",
440 | "metadata": {},
441 | "source": [
442 | "In PyTorch, [DataLoaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#preparing-your-data-for-training-with-dataloaders) efficiently load and batch data and offer numerous other features to streamline data preprocessing, management, and integration within the training loop. Let's create data loaders for both the training and validation datasets.\n",
443 | "\n",
444 | "The `DataLoader` class allows us to pass a `batch_size` argument, which controls the number of samples used in each pass through the network. Using a small number of training examples each pass is called mini-batching, and can improve efficiency, stabilize training dynamics, and enable scalable training on large datasets. Choosing an appropriate mini-batch size depends on several factors, including the available memory on your hardware, training efficiency constraints, and generalization requirements. However, as with everything in machine learning, we ultimately rely on empirical testing to determine the optimal configuration that maximizes model performance for each specific task and dataset. In this example, we'll use mini-batches containing 4 samples each, which will easily fit on any CPU/GPU without issue and provide reasonable generalization performance."
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "execution_count": null,
450 | "metadata": {},
451 | "outputs": [],
452 | "source": [
453 | "batch_size = 4\n",
454 | "\n",
455 | "train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
456 | "val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)\n",
457 | "\n",
458 | "spects, masks = next(iter(train_loader))\n",
459 | "\n",
460 | "print(f\"Batch of spectrograms: {type(spects)}, {spects.dtype}, {spects.size()}\")\n",
461 | "print(f\"Batch of masks: {type(masks)}, {masks.dtype}, {masks.size()}\")"
462 | ]
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "metadata": {},
467 | "source": [
468 | "Let's examine a batch of spectrograms along with their corresponding masks. Please note that the following plotting code is optimized for small batch sizes and may not render as nicely with larger batch sizes."
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": null,
474 | "metadata": {},
475 | "outputs": [],
476 | "source": [
477 | "def plot_spects(spects: list[Image.Image]) -> None:\n",
478 | " fig, axes = plt.subplots(figsize=[batch_size * 2, 3], nrows=1, ncols=batch_size, sharey=True)\n",
479 | " fig.text(0.5, 0.75, \"Spectrograms\", fontsize=title_font_size, ha=\"center\")\n",
480 | " axes[0].set_ylabel(\"Time [arb. units]\", fontsize=label_font_size)\n",
481 | " fig.text(0.5, 0.12, \"Freq. [arb. units]\", fontsize=label_font_size, ha=\"center\")\n",
482 | "\n",
483 | " for i, ax in enumerate(axes):\n",
484 | " im = ax.imshow(spects[i], vmin=0, vmax=255)\n",
485 | "\n",
486 | " fig.subplots_adjust(right=0.90)\n",
487 | " cbar_ax = fig.add_axes(rect=[0.93, 0.24, 0.02, 0.5])\n",
488 | " fig.colorbar(im, cax=cbar_ax, ticks=[0, 255])\n",
489 | "\n",
490 | "\n",
491 | "def plot_masks(masks: list[Image.Image], prediction: bool = False) -> None:\n",
492 | " fig, axes = plt.subplots(figsize=[batch_size * 2, 3], nrows=1, ncols=batch_size, sharey=True)\n",
493 | " if prediction:\n",
494 | " fig.text(0.5, 0.75, \"Model Predictions\", fontsize=title_font_size, ha=\"center\")\n",
495 | " else:\n",
496 | " fig.text(0.5, 0.75, \"Masks\", fontsize=title_font_size, ha=\"center\")\n",
497 | " axes[0].set_ylabel(\"Time [arb. units]\", fontsize=label_font_size)\n",
498 | " fig.text(0.5, 0.12, \"Freq. [arb. units]\", fontsize=label_font_size, ha=\"center\")\n",
499 | "\n",
500 | " for i, ax in enumerate(axes):\n",
501 | " im = ax.imshow(masks[i], vmin=0, vmax=2, cmap=mask_cmap)\n",
502 | "\n",
503 | " fig.subplots_adjust(right=0.90)\n",
504 | " cbar_ax = fig.add_axes(rect=[0.93, 0.24, 0.02, 0.5])\n",
505 | " cbar = fig.colorbar(im, cax=cbar_ax, ticks=[0.33, 1, 1.66])\n",
506 | " cbar.ax.set_yticklabels(labels)"
507 | ]
508 | },
509 | {
510 | "cell_type": "code",
511 | "execution_count": null,
512 | "metadata": {},
513 | "outputs": [],
514 | "source": [
515 | "plot_spects(spects=[inv_transform(i) for i in spects])\n",
516 | "plot_masks(masks=[inv_target_transform(i) for i in masks])"
517 | ]
518 | },
519 | {
520 | "cell_type": "markdown",
521 | "metadata": {},
522 | "source": [
523 | "**Note:** You can view different batches from the dataset by rerunning the previous few code cells."
524 | ]
525 | },
526 | {
527 | "cell_type": "markdown",
528 | "metadata": {},
529 | "source": [
530 | "# Model Training\n",
531 | "\n",
532 | "In this example, we'll use a DeepLabV3 model with a MobileNetV3 (large) backbone. This model is designed to be lightweight and efficient, making it ideal for edge computing devices and quick proof-of-concept demonstrations."
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": null,
538 | "metadata": {},
539 | "outputs": [],
540 | "source": [
541 | "n_classes = 3 # We are dealing with three classes: Noise, NR, and LTE.\n",
542 | "model = deeplabv3_mobilenet_v3_large(num_classes=n_classes)"
543 | ]
544 | },
545 | {
546 | "cell_type": "markdown",
547 | "metadata": {},
548 | "source": [
549 | "Next, we need a loss function. A loss function, also known as a cost or objective function, measures how well a machine learning \n",
550 | "model's predictions match the actual target values. This quantifies the error between predicted outputs and ground truth labels, providing\n",
551 | "feedback that guides the model's training process. For classification problems, we commonly use the [Cross-Entropy Loss](https://machinelearningmastery.com/cross-entropy-for-machine-learning/), especially for \n",
552 | "multi-class classification problems. Let's use the [`CrossEntropyLoss`](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) class from PyTorch, which allows us to assign different weights to individual classes during the computation of the loss. \n",
553 | "\n",
554 | "We'll use weights inversely proportional to the relative pixel count for each class. That way, we assign lower weights to overrepresented classes, like noise, and larger weights to underrepresented classes, like NR and LTE. This reduces the impact of noise and allows the model to prioritize learning from NR and LTE signal. Class weighting is not the only way to address data imblance, but it is one of the more straightforward methods."
555 | ]
556 | },
557 | {
558 | "cell_type": "code",
559 | "execution_count": null,
560 | "metadata": {},
561 | "outputs": [],
562 | "source": [
563 | "median_count = statistics.median(list(class_counts.values()))\n",
564 | "weight = [median_count / class_counts[k] for k in class_counts.keys()]\n",
565 | "loss_function = nn.CrossEntropyLoss(weight=torch.tensor(weight, dtype=torch.float))\n",
566 | "\n",
567 | "print(\"Class weights: \", {k: round(weight[i], 2) for i, k in enumerate(class_counts.keys())})"
568 | ]
569 | },
570 | {
571 | "cell_type": "markdown",
572 | "metadata": {},
573 | "source": [
574 | "In this example, we will train out model using stochastic gradient descent (SGD). SGD is a variant of the standard [gradient descent](https://builtin.com/data-science/gradient-descent) optimizer where the loss function is computed on mini-batches of data rather than the entire dataset. This helps improve computational efficiency and scalability, particularly for large datasets, by updating model parameters based on the gradients computed on our mini-batches.\n",
575 | "\n",
576 | "We'll define the training and validation process of our segmentation model in a [`LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#lightningmodule). "
577 | ]
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": null,
582 | "metadata": {},
583 | "outputs": [],
584 | "source": [
585 | "class SegmentationSGD(L.LightningModule):\n",
586 | "\n",
587 | " def __init__(\n",
588 | " self,\n",
589 | " model: nn.Module,\n",
590 | " loss_function: nn.Module,\n",
591 | " n_classes: int,\n",
592 | " learning_rate: float,\n",
593 | " momentum: float,\n",
594 | " weight_decay: float,\n",
595 | " step_size: int,\n",
596 | " gamma: float,\n",
597 | " ):\n",
598 | " \"\"\"Initializes the SegmentationSGD module.\"\"\"\n",
599 | " super().__init__()\n",
600 | " self.model = model\n",
601 | " self.loss_function = loss_function\n",
602 | " self.n_classes = n_classes\n",
603 | "\n",
604 | " self.learning_rate = learning_rate\n",
605 | " self.momentum = momentum\n",
606 | " self.weight_decay = weight_decay\n",
607 | "\n",
608 | " self.step_size = step_size\n",
609 | " self.gamma = gamma\n",
610 | "\n",
611 | " self.train_accuracy = MulticlassAccuracy(num_classes=self.n_classes)\n",
612 | " self.val_accuracy = MulticlassAccuracy(num_classes=self.n_classes)\n",
613 | "\n",
614 | " def forward(self, x: Tensor) -> Tensor:\n",
615 | " \"\"\"Defines a forward pass through the model.\"\"\"\n",
616 | " return self.model(x)\n",
617 | "\n",
618 | " def training_step(self, batch: Tensor, batch_idx: int) -> Tensor:\n",
619 | " \"\"\"Defines a single training step.\"\"\"\n",
620 | " image, target = batch\n",
621 | " preds = self(image)[\"out\"]\n",
622 | " loss = self.loss_function(preds, target)\n",
623 | " self.train_accuracy(preds, target)\n",
624 | " self.log(name=\"train_accuracy\", value=self.train_accuracy, prog_bar=True)\n",
625 | " self.log(name=\"train_loss\", value=loss, on_epoch=True, on_step=False, prog_bar=True)\n",
626 | " return loss\n",
627 | "\n",
628 | " def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor:\n",
629 | " \"\"\"Defines a single validation step.\"\"\"\n",
630 | " image, target = batch\n",
631 | " preds = self(image)[\"out\"]\n",
632 | " loss = self.loss_function(preds, target)\n",
633 | " self.val_accuracy(preds, target)\n",
634 | " self.log(name=\"val_accuracy\", value=self.val_accuracy, prog_bar=True)\n",
635 | " self.log(name=\"val_loss\", value=loss, on_epoch=True, on_step=False, prog_bar=True)\n",
636 | " return loss\n",
637 | "\n",
638 | " def configure_optimizers(self) -> dict[str, Any]:\n",
639 | " \"\"\"Configure the optimizer and learning rate scheduler.\"\"\"\n",
640 | " optimizer = torch.optim.SGD(\n",
641 | " self.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay\n",
642 | " )\n",
643 | " lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)\n",
644 | "\n",
645 | " return {\"optimizer\": optimizer, \"lr_scheduler\": lr_scheduler}"
646 | ]
647 | },
648 | {
649 | "cell_type": "markdown",
650 | "metadata": {},
651 | "source": [
652 | "Our `SegmentationModelSGD` is initialized with several configuration settings that influence the behavior and performance of the machine learning algorithm or model. These parameters are called hyperparameters, and unlike model parameters, which are learned from the data during training, hyperparameters are set prior to training and influence the learning process.\n",
653 | "\n",
654 | "The following hyperparameters are used to configure the optimizer:\n",
655 | "- **Momentum:** A parameter that accelerates SGD in the relevant direction and dampens oscillations.\n",
656 | "- **Learning Rate:** The rate at which the model parameters are updated during optimization.\n",
657 | "- **Weight Decay:** A regularization term added to the loss function to penalize large weights in the model to prevent overfitting.\n",
658 | "\n",
659 | "By gradually reducing the learning rate over epochs, the scheduler can help improve the convergence and stability of the optimization process.\n",
660 | "We need to provide the following two parameters, which will be used by the learning rate scheduler to dynamically adjust the learning rate during training:\n",
661 | "- **Step Size:** The number of epochs after which the learning rate is reduced.\n",
662 | "- **Gamma:** The factor by which the learning rate is reduced after every step-size epochs.\n",
663 | "\n",
664 | "Adjusting these hyperparameters can significantly impact the training process and the final performance of the model, for better or for worse!"
665 | ]
666 | },
667 | {
668 | "cell_type": "code",
669 | "execution_count": null,
670 | "metadata": {},
671 | "outputs": [],
672 | "source": [
673 | "segmentation_module = SegmentationSGD(\n",
674 | " model=model,\n",
675 | " loss_function=loss_function,\n",
676 | " n_classes=n_classes,\n",
677 | " learning_rate=0.02, # Represents the initial learning rate.\n",
678 | " momentum=0.9,\n",
679 | " weight_decay=1.0e-04,\n",
680 | " step_size=10,\n",
681 | " gamma=0.1,\n",
682 | ")"
683 | ]
684 | },
685 | {
686 | "cell_type": "markdown",
687 | "metadata": {},
688 | "source": [
689 | "Now that we have our model, weighted loss function, and Lightning Module, we are prepared to train our model. If available, we will leverage GPU acceleration. Otherwise, the training process will default to using the CPU. Please be patient; model training time may vary from a few minutes to over an hour depending on the current hardware configuration.\n",
690 | "\n",
691 | "The number of epochs determines how many times the entire dataset will be used to train the model. For this specific model and dataset, 10 epochs should be more than sufficient. However, if you are training on the CPU, you might want to consider reducing the number of training epochs to 4 to save on training time."
692 | ]
693 | },
694 | {
695 | "cell_type": "code",
696 | "execution_count": null,
697 | "metadata": {},
698 | "outputs": [],
699 | "source": [
700 | "n_epochs = 10\n",
701 | "# n_epochs = 4 # Suggested for CPU training.\n",
702 | "\n",
703 | "if torch.cuda.is_available():\n",
704 | " print(\"Training model on GPU.\")\n",
705 | " trainer = L.Trainer(accelerator=\"gpu\", max_epochs=n_epochs, logger=True)\n",
706 | " device = \"cuda\"\n",
707 | "else:\n",
708 | " print(\"Training model on CPU.\")\n",
709 | " trainer = L.Trainer(max_epochs=n_epochs, logger=True)\n",
710 | " device = \"cpu\"\n",
711 | "\n",
712 | "trainer.fit(model=segmentation_module, train_dataloaders=train_loader, val_dataloaders=val_loader)"
713 | ]
714 | },
715 | {
716 | "cell_type": "markdown",
717 | "metadata": {},
718 | "source": [
719 | "If you are running this example locally, you can refer to the `metrics.csv` file located in the `lightning_logs` directory for more information regarding training and validation loss and accuracy across training epochs. Please remember to close the `metrics.csv` file before proceeding."
720 | ]
721 | },
722 | {
723 | "cell_type": "markdown",
724 | "metadata": {},
725 | "source": [
726 | "# Model Validation\n",
727 | "\n",
728 | "Having trained our model, the next step is to evaluate its capabilities. To accomplish this, we'll use a suite of standard machine learning metrics. But first, let's take a look at a random batch of predictions.\n",
729 | "\n",
730 | "Because the model returns the probabilities corresponding to the predictions of each class. We need to use `argmax()` to obtain the class with the highest prediction probability. The result is a singe-channel image for each example in the batch, which can be compared directly to the corresponding target mask."
731 | ]
732 | },
733 | {
734 | "cell_type": "code",
735 | "execution_count": null,
736 | "metadata": {},
737 | "outputs": [],
738 | "source": [
739 | "model.eval()\n",
740 | "model.to(device)\n",
741 | "\n",
742 | "spects, masks = next(iter(train_loader))\n",
743 | "spects = spects.to(device)\n",
744 | "\n",
745 | "with torch.inference_mode():\n",
746 | " preds = (model(spects)[\"out\"]).argmax(1)\n",
747 | "\n",
748 | "print(\"Predictions:\", preds.size())"
749 | ]
750 | },
751 | {
752 | "cell_type": "code",
753 | "execution_count": null,
754 | "metadata": {},
755 | "outputs": [],
756 | "source": [
757 | "plot_spects(spects=[inv_transform(i.cpu()) for i in spects])\n",
758 | "plot_masks(masks=[inv_target_transform(i) for i in masks])\n",
759 | "plot_masks(masks=preds.cpu(), prediction=True)"
760 | ]
761 | },
762 | {
763 | "cell_type": "markdown",
764 | "metadata": {},
765 | "source": [
766 | "Looks pretty good! To get a more objective sense, let's turn to the metrics. We'll start with model accuracy, calculated as the ratio of correctly predicted pixels to the total number of pixels.\n",
767 | "\n",
768 | "**Note:** You can view different predictions by rerunning the previous few code cells."
769 | ]
770 | },
771 | {
772 | "cell_type": "code",
773 | "execution_count": null,
774 | "metadata": {},
775 | "outputs": [],
776 | "source": [
777 | "scores = trainer.validate(model=segmentation_module, dataloaders=val_loader)"
778 | ]
779 | },
780 | {
781 | "cell_type": "markdown",
782 | "metadata": {},
783 | "source": [
784 | "The accuracy can give us a quick sense of the model's overall performance. However, accuracy alone doesn't tell the whole story. In fact, because of the imbalance in our dataset, a reasonably high accuracy could be achieved by always predicting noise. The simple accuracy provided above is an unweighted mean of the accuracies over each class.\n",
785 | "\n",
786 | "To gain a better understanding of our model's ability to predict specific classes, let's take a look at the confusion matrix, which provides a more comprehensive overview of model capability. The diagonal elements represent the correct predictions and off-diagonal elements indicate prediction errors."
787 | ]
788 | },
789 | {
790 | "cell_type": "code",
791 | "execution_count": null,
792 | "metadata": {},
793 | "outputs": [],
794 | "source": [
795 | "model.to(device)\n",
796 | "confusion_matrix = MulticlassConfusionMatrix(num_classes=n_classes, normalize=\"true\").to(device)\n",
797 | "\n",
798 | "with torch.inference_mode():\n",
799 | " for spect, mask in val_loader:\n",
800 | " spect, mask = spect.to(device), mask.to(device)\n",
801 | " pred = (model(spect)[\"out\"]).argmax(dim=1)\n",
802 | " confusion_matrix.update(pred, mask)\n",
803 | "\n",
804 | "confusion_matrix = np.round(confusion_matrix.compute().cpu().numpy(), decimals=2)\n",
805 | "ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=labels).plot()"
806 | ]
807 | },
808 | {
809 | "cell_type": "markdown",
810 | "metadata": {},
811 | "source": [
812 | "Finally, let's generate a more comprehensive report, complete with the following metrics:\n",
813 | "\n",
814 | "- **Recall:** The recall (sensitivity) measures the ability of the model to identify the relevant pixels. A higher recall indicates that the model is better at identifying signal.\n",
815 | "\n",
816 | "- **Precision:** The precision assesses the accuracy of positive predictions. A higher precision indicates that when the model predicts signal, it is more likely to be correct.\n",
817 | "\n",
818 | "- **F1 Score:** The F1 score combines both recall and precision into a single value, providing a more balanced measure of the model's performance. A higher F1 indicates a model with both good precision and recall (fewer false positives and false negatives overall).\n",
819 | "\n",
820 | "- **Intersection over Union (IoU):** The IoU quantifies the overlap between the predicted bounding box or segmented region and the ground truth. A higher IoU value indicates a better alignment between the predicted and actual regions, reflecting a more accurate model."
821 | ]
822 | },
823 | {
824 | "cell_type": "code",
825 | "execution_count": null,
826 | "metadata": {},
827 | "outputs": [],
828 | "source": [
829 | "def print_metrics_report(dataloader: DataLoader) -> None:\n",
830 | " \"\"\"Compute accuracy, recall, precision, F1 score, and IoU (Intersection over Union),\n",
831 | " and print a report containing these metrics.\"\"\"\n",
832 | " metrics = [\n",
833 | " MulticlassAccuracy(num_classes=n_classes, average=None),\n",
834 | " MulticlassRecall(num_classes=n_classes, average=None),\n",
835 | " MulticlassPrecision(num_classes=n_classes, average=None),\n",
836 | " MulticlassF1Score(num_classes=n_classes, average=None),\n",
837 | " # The IoU is commonly referred to as Jaccard's Index\n",
838 | " MulticlassJaccardIndex(num_classes=n_classes, average=None),\n",
839 | " ]\n",
840 | " metrics = [m.to(device) for m in metrics]\n",
841 | "\n",
842 | " with torch.inference_mode():\n",
843 | " for spect, mask in dataloader:\n",
844 | " spect, mask = spect.to(device), mask.to(device)\n",
845 | " pred = (model(spect)[\"out\"]).argmax(dim=1)\n",
846 | " for m in metrics:\n",
847 | " m.update(pred, mask)\n",
848 | "\n",
849 | " metrics = [m.compute().cpu().numpy() for m in metrics]\n",
850 | " metrics = [np.append(m, statistics.mean(m)) for m in metrics]\n",
851 | "\n",
852 | " df = pd.DataFrame(\n",
853 | " {\n",
854 | " \"Class\": np.append(np.asarray(labels), \"Mean\"),\n",
855 | " \"Accuracy\": metrics[0],\n",
856 | " \"Recall\": metrics[1],\n",
857 | " \"Precision\": metrics[2],\n",
858 | " \"F1 Score\": metrics[3],\n",
859 | " \"IoU\": metrics[4],\n",
860 | " }\n",
861 | " )\n",
862 | " print(\n",
863 | " tabulate(\n",
864 | " df, headers=\"keys\", tablefmt=\"grid\", showindex=False, numalign=\"center\", stralign=\"center\", floatfmt=\".2f\"\n",
865 | " )\n",
866 | " )"
867 | ]
868 | },
869 | {
870 | "cell_type": "code",
871 | "execution_count": null,
872 | "metadata": {},
873 | "outputs": [],
874 | "source": [
875 | "print_metrics_report(dataloader=val_loader)"
876 | ]
877 | },
878 | {
879 | "cell_type": "markdown",
880 | "metadata": {},
881 | "source": [
882 | "# Challenge Data\n",
883 | "\n",
884 | "In machine learning, out-of-distribution data refers to examples that deviate from those used during training. For example, recall the Spectrogram Sensing dataset comprises 900 combined frames featuring both LTE and NR signals. As we excluded the combined frames from the training process, they represent out-of-distribution data. To get a quick sense of how our model performs on these combined frames, let's take a look at a random batch of predictions."
885 | ]
886 | },
887 | {
888 | "cell_type": "code",
889 | "execution_count": null,
890 | "metadata": {},
891 | "outputs": [],
892 | "source": [
893 | "challenge_dataset = SpectrumSensing(\n",
894 | " source=source, transform=transform, target_transform=target_transform, combined=True\n",
895 | ")\n",
896 | "\n",
897 | "challenge_loader = DataLoader(challenge_dataset, batch_size=batch_size, shuffle=True)"
898 | ]
899 | },
900 | {
901 | "cell_type": "code",
902 | "execution_count": null,
903 | "metadata": {},
904 | "outputs": [],
905 | "source": [
906 | "spects, masks = next(iter(challenge_loader))\n",
907 | "spects = spects.to(device)\n",
908 | "\n",
909 | "with torch.inference_mode():\n",
910 | " preds = (model(spects)[\"out\"]).argmax(1)\n",
911 | "\n",
912 | "print(\"Predictions:\", preds.size())"
913 | ]
914 | },
915 | {
916 | "cell_type": "code",
917 | "execution_count": null,
918 | "metadata": {},
919 | "outputs": [],
920 | "source": [
921 | "plot_spects(spects=[inv_transform(i.cpu()) for i in spects])\n",
922 | "plot_masks(masks=[inv_target_transform(i) for i in masks])\n",
923 | "plot_masks(masks=preds.cpu(), prediction=True)"
924 | ]
925 | },
926 | {
927 | "cell_type": "markdown",
928 | "metadata": {},
929 | "source": [
930 | "**Note:** You can view different examples by rerunning the previous few code cells.\n",
931 | "\n",
932 | "Now, let's evaluate the same metrics as we did above in the [Model Validation](#Model-Validation) section, but now for the challenge dataset. Given the model's lack of exposure to these combined frames during the training process, we anticipate the model's capabilities to be somewhat diminished. Yet, we still anticipate reasonable results."
933 | ]
934 | },
935 | {
936 | "cell_type": "code",
937 | "execution_count": null,
938 | "metadata": {},
939 | "outputs": [],
940 | "source": [
941 | "print_metrics_report(dataloader=challenge_loader)"
942 | ]
943 | },
944 | {
945 | "cell_type": "markdown",
946 | "metadata": {},
947 | "source": [
948 | "# Conclusions & Next Steps\n",
949 | "\n",
950 | "In this example, we used PyTorch and Lightning to train a DeepLabV3 model to identify and differentiate between 5G NR and 4G LTE signals within wideband spectrograms, showcasing one of the ways we leverage machine learning to identify things in the wireless spectrum. This involved data analysis and preprocessing, choosing a loss function and optimizer, model training, model performance validation, and finally testing the model on out-of-distribution frames containing both NR and LTE signal.\n",
951 | "\n",
952 | "The capability to differentiate and recognize various signals finds direct applications in spectrum sensing, which is fundamental to autonomous spectrum management, and brings us one step closer to more holistic intelligent radio solutions! 📡🚀"
953 | ]
954 | },
955 | {
956 | "cell_type": "markdown",
957 | "metadata": {},
958 | "source": [
959 | "We hope this example was informative. Here are some next steps you can take to further explore and expand upon what you've learned:\n",
960 | "\n",
961 | "- **Experiment with the Hyperparameters:** Adjust the values of hyperparameters such as the number of training epochs, batch size, and learning rate, and observe how these configurations influence model training and performance. After gaining insights through manual hyperparameter tuning, explore automated approaches using tools like [Ray Tune](https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html) or [Optuna](https://optuna.org/).\n",
962 | "\n",
963 | "- **Experiment with DeepLabV3's ResNet Models:** DeepLabV3 also provides models with ResNet-50 and ResNet-101 backbones. These ResNet models are deeper and more complex, and generally offer better model performance than MobileNetV3, which is designed to be lightweight and efficient. Because all DeepLabV3 models implement the same interface, no code changes are required. However, some hyperparameter tuning and/or a larger dataset may be required to train these models effectively. These models have already been imported into this notebook for your convenience.\n",
964 | "\n",
965 | "- **Explore Alternative Solutions to Class Imbalance:** In this example, we addressed class imbalance in our dataset using a weighted cross-entropy loss function. Research and implement alternative strategies or loss functions designed to address imbalance in image datasets.\n",
966 | "\n",
967 | "- **Integrate Combined Frames:** In this example, we trained exclusively on the individual NR and LTE frames. Try integrating the combined frames that contain both the NR and LTE signals into the training process, and evaluate the effect on model performance.\n",
968 | "\n",
969 | "- **Test your Model on Captured Radio Data:** If you have radio hardware available, consider testing your model on real recordings of live radio data. Check out this article from MathWorks for more information on how to capture NR and LTE signals: [Capture and Label NR and LTE Signals for AI Training](https://www.mathworks.com/help/wireless-testbench/ug/capture-and-label-nr-and-lte-signals-for-ai-training.html).\n",
970 | "\n",
971 | "- **Explore RIA Core on GitHub (Coming soon!):** At Qoherent, we're building [Radio Intelligence Applications](https://qoherent.ai/radiointelligenceapps-project/) (RIA) to drive the creation of intelligent radios. Check out [RIA Core](https://github.com/qoherent/ria)—the free and open-source foundation of RIA—and consider contributing to the project. ⭐"
972 | ]
973 | }
974 | ],
975 | "metadata": {
976 | "kernelspec": {
977 | "display_name": "Python 3 (ipykernel)",
978 | "language": "python",
979 | "name": "python3"
980 | },
981 | "language_info": {
982 | "codemirror_mode": {
983 | "name": "ipython",
984 | "version": 3
985 | },
986 | "file_extension": ".py",
987 | "mimetype": "text/x-python",
988 | "name": "python",
989 | "nbconvert_exporter": "python",
990 | "pygments_lexer": "ipython3",
991 | "version": "3.10.13"
992 | }
993 | },
994 | "nbformat": 4,
995 | "nbformat_minor": 4
996 | }
997 |
--------------------------------------------------------------------------------