├── .gitignore ├── README.md ├── example ├── exploratory_analysis.ipynb ├── report.pdf ├── training_script.ipynb └── utils │ ├── create_frames.py │ ├── download_dataset_from_s3.py │ └── unpack_dataset.sh ├── requirements.txt └── segmentation ├── __init__.py ├── datasets ├── __init__.py └── mobiactv2.py ├── engine ├── __init__.py └── engine.py ├── logger ├── __init__.py └── logger.py ├── metrics ├── __init__.py └── metrics.py ├── models ├── __init__.py └── models.py ├── transforms ├── __init__.py ├── functional.py └── transforms.py └── utils ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.ipynb_checkpoints 3 | *__pycache__ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ts-segment 2 | 3 | ts-segment is a Python library for creating semantic segmentation models for multivariate time series, primarily (but not exclusively) for motion sensor data. 4 | 5 | The library is based on the PyTorch deep learning framework and on the Ignite package which helps write compact and full-featured training loops with just a few lines of code, supporting metrics, early-stopping, model checkpointing, learning rate scheduling and more. 6 | 7 | ## Getting started 8 | The ts-segment library is compatible with Ignite and follows the same core concepts, so to get started with ts-segment all you need is to define an iterator class for your dataset under `segmentation/datasets/`, load a model, create a training engine with any standard PyTorch optimizer and loss function, and run the engine on your data. 9 | ``` 10 | from segmentation.datasets import YourDataset 11 | from segmentation.models import Model 12 | from segmentation.engine import create_trainer 13 | 14 | dataset = YourDataset() 15 | model = Model() 16 | trainer = create_trainer(model, optimizer, loss_fn) 17 | trainer.run(dataset) 18 | ``` 19 | 20 | ## Time series segmentation metrics 21 | The ts-segment library allows easy logging of the most common semantic segmentation metrics including samplewise accuracy, mean accuracy, mean IoU and frequency weighted IoU. To log metrics during training, create an evaluator engine with your metrics, create a logger object with the train and validation datasets and attach it to your training engine. 22 | 23 | ``` 24 | from ignite.engine import Events 25 | from segmentation.engine import create_trainer, create_evaluator 26 | from segmentation.metrics import ( 27 | SamplewiseAccuracy, 28 | MeanAccuracy, 29 | MeanIoU, 30 | FrequencyWeightedIoU, 31 | ) 32 | from segmentation.logger import Logger 33 | 34 | trainer = create_trainer(model, optimizer, loss_fn) 35 | evaluator = create_evaluator( 36 | model, 37 | metrics={ 38 | "loss": Loss(loss_fn), 39 | "samplewise_accuracy": SamplewiseAccuracy(), 40 | "mean_accuracy": MeanAccuracy(), 41 | "mean_iou": MeanIoU(), 42 | "frequency_weighted_iou": FrequencyWeightedIoU(), 43 | }, 44 | ) 45 | 46 | logger = Logger(evaluator, train_dataset, validation_dataset) 47 | trainer.add_event_handler(Events.EPOCH_COMPLETED, logger) 48 | 49 | trainer.run(dataset) 50 | ``` 51 | After training finishes, the logger object contains a dictionary of epoch-wise metrics. 52 | 53 | ## Example 54 | An example application is included in this repository under `example/` where ts-segment is used to segment time series of motion sensor data for activity recognition. Before running the training notebook, the data should first be downloaded and then prepared for modeling. 55 | ``` 56 | cd example/ 57 | bash utils/unpack_dataset.sh 58 | ``` 59 | This will run two auxiliary python scripts to download the [MobiActV2 dataset](https://bmi.teicrete.gr/en/the-mobifall-and-mobiact-datasets-2/) from S3 and to transform the raw data and place it under `data/`. 60 | -------------------------------------------------------------------------------- /example/report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtdo/ts-segment/5b9e233ab69947e544888e7fed56a5edde549900/example/report.pdf -------------------------------------------------------------------------------- /example/training_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import pickle\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "import torch\n", 14 | "import torch.nn as nn\n", 15 | "import torch.optim as optim\n", 16 | "from torch.optim.lr_scheduler import StepLR\n", 17 | "\n", 18 | "import matplotlib.pyplot as plt\n", 19 | "\n", 20 | "plt.style.use(\"ggplot\")" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "from ignite.engine import Events\n", 30 | "from ignite.metrics import Loss\n", 31 | "from ignite.contrib.handlers.param_scheduler import LRScheduler\n", 32 | "\n", 33 | "from segmentation.datasets import MobiActV2\n", 34 | "from segmentation.models import SensorFCN\n", 35 | "from segmentation.engine import create_trainer, create_evaluator\n", 36 | "from segmentation.metrics import (\n", 37 | " SamplewiseAccuracy,\n", 38 | " MeanAccuracy,\n", 39 | " MeanIoU,\n", 40 | " FrequencyWeightedIoU,\n", 41 | ")\n", 42 | "from segmentation.logger import Logger\n", 43 | "from segmentation import utils" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "tags": [ 51 | "parameters" 52 | ] 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "# Papermill parameter cell\n", 57 | "# Model params\n", 58 | "sensors = \"ago\"\n", 59 | "input_kernel_size = 5\n", 60 | "n_filters = 16\n", 61 | "smoothing_kernel_size = 0\n", 62 | "\n", 63 | "# Run params\n", 64 | "random_seed = 1234\n", 65 | "experiment = \"input_kernel_size\"\n", 66 | "gpu = \"cuda:3\"" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "sensor_channels = []\n", 76 | "if \"a\" in sensors:\n", 77 | " sensor_channels.extend([\"acc_x\", \"acc_y\", \"acc_z\"])\n", 78 | "if \"g\" in sensors:\n", 79 | " sensor_channels.extend([\"gyro_x\", \"gyro_y\", \"gyro_z\"])\n", 80 | "if \"o\" in sensors:\n", 81 | " sensor_channels.extend([\"azimuth\", \"pitch\", \"roll\"])" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "# Train/validation/test split\n", 91 | "np.random.seed(random_seed)\n", 92 | "users = np.arange(1, 68)\n", 93 | "np.random.shuffle(users)\n", 94 | "\n", 95 | "test_users = users[0:7]\n", 96 | "validation_users = users[7:14]\n", 97 | "train_users = users[14:]\n", 98 | "\n", 99 | "# Load datasets\n", 100 | "train_set = MobiActV2(\"data/MobiActV2/frames\", sensor_channels, train_users)\n", 101 | "validation_set = MobiActV2(\"data/MobiActV2/frames\", sensor_channels, validation_users)\n", 102 | "test_set = MobiActV2(\"data/MobiActV2/frames\", sensor_channels, test_users)\n", 103 | "n_classes = len(train_set.label_codes)\n", 104 | "\n", 105 | "# Define data loaders\n", 106 | "batch_size = 32\n", 107 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)\n", 108 | "validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size)\n", 109 | "test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "# Define device\n", 119 | "device = torch.device(gpu if torch.cuda.is_available() else \"cpu\")\n", 120 | "print(device)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# Initialize network\n", 130 | "model = SensorFCN(\n", 131 | " n_input_channels=len(sensor_channels),\n", 132 | " n_classes=int(n_classes),\n", 133 | " input_kernel_size=int(input_kernel_size),\n", 134 | " n_filters=int(n_filters),\n", 135 | " smoothing_kernel_size=int(smoothing_kernel_size),\n", 136 | ")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "# Load class weights\n", 146 | "class_weights = np.load(\"data/MobiActV2/class_weights.npy\")\n", 147 | "class_weights = torch.tensor(class_weights).to(device, dtype=torch.float)\n", 148 | "\n", 149 | "# Define loss function and optimizer\n", 150 | "loss_fn = nn.CrossEntropyLoss(weight=class_weights)\n", 151 | "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "# Create trainer and evaluator engines\n", 161 | "trainer = create_trainer(model, optimizer, loss_fn, device)\n", 162 | "evaluator = create_evaluator(\n", 163 | " model,\n", 164 | " device,\n", 165 | " metrics={\n", 166 | " \"loss\": Loss(loss_fn),\n", 167 | " \"samplewise_accuracy\": SamplewiseAccuracy(),\n", 168 | " \"mean_accuracy\": MeanAccuracy(),\n", 169 | " \"mean_iou\": MeanIoU(),\n", 170 | " \"frequency_weighted_iou\": FrequencyWeightedIoU(),\n", 171 | " },\n", 172 | ")\n", 173 | "\n", 174 | "# Attach LR scheduler\n", 175 | "step_scheduler = StepLR(optimizer, step_size=5, gamma=0.9)\n", 176 | "scheduler = LRScheduler(step_scheduler)\n", 177 | "trainer.add_event_handler(Events.EPOCH_COMPLETED, scheduler)\n", 178 | "\n", 179 | "# Attach handler for training logging\n", 180 | "logger = Logger(evaluator, train_loader, validation_loader)\n", 181 | "trainer.add_event_handler(Events.EPOCH_COMPLETED, logger)\n", 182 | "\n", 183 | "# Run trainer engine\n", 184 | "trainer.run(train_loader, max_epochs=50)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "## Training metrics" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "utils.plot_metrics(logger.metrics, \"training\")" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "## Validation metrics" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "utils.plot_metrics(logger.metrics, \"validation\")" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "## Training confusion matrix" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "y_true_tr, y_pred_tr = utils.predict_with_model(model, train_set, device)\n", 233 | "utils.plot_confusion_matrix(\n", 234 | " y_true_tr, y_pred_tr, list(train_set.label_codes.keys()), normalize=True\n", 235 | ")" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "## Validation confusion matrix" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "y_true_val, y_pred_val = utils.predict_with_model(model, validation_set, device)\n", 252 | "utils.plot_confusion_matrix(\n", 253 | " y_true_val, y_pred_val, list(validation_set.label_codes.keys()), normalize=True\n", 254 | ")" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "## Save data" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "if experiment == \"input_kernel_size\":\n", 271 | " name = str(input_kernel_size)\n", 272 | "if experiment == \"n_filters\":\n", 273 | " name = str(n_filters)\n", 274 | "if experiment == \"sensors\":\n", 275 | " name = sensors\n", 276 | "if experiment == \"smoothing_kernel_size\":\n", 277 | " name = f\"{str(smoothing_kernel_size)}\"\n", 278 | "\n", 279 | "output_dir = f\"output_asd/{random_seed}/{experiment}/{name}\"\n", 280 | "if not os.path.exists(output_dir):\n", 281 | " os.makedirs(output_dir)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "# Save trained model\n", 291 | "torch.save(model.state_dict(), os.path.join(output_dir, \"model.pt\"))\n", 292 | "\n", 293 | "# Save model log/history\n", 294 | "f = open(os.path.join(output_dir, \"hist.p\"), \"wb\")\n", 295 | "pickle.dump(logger.metrics, f)\n", 296 | "f.close()\n", 297 | "\n", 298 | "# Save number of parameters\n", 299 | "f = open(os.path.join(output_dir, \"n_params.p\"), \"wb\")\n", 300 | "pickle.dump(sum(p.numel() for p in model.parameters()), f)\n", 301 | "f.close()\n", 302 | "\n", 303 | "# Save training and validation ground truth and predictions\n", 304 | "np.save(os.path.join(output_dir, \"y_true_tr.npy\"), y_true_tr)\n", 305 | "np.save(os.path.join(output_dir, \"y_pred_tr.npy\"), y_pred_tr)\n", 306 | "np.save(os.path.join(output_dir, \"y_true_val.npy\"), y_true_val)\n", 307 | "np.save(os.path.join(output_dir, \"y_pred_val.npy\"), y_pred_val)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [] 316 | } 317 | ], 318 | "metadata": { 319 | "kernelspec": { 320 | "display_name": "Python 3", 321 | "language": "python", 322 | "name": "python3" 323 | }, 324 | "language_info": { 325 | "codemirror_mode": { 326 | "name": "ipython", 327 | "version": 3 328 | }, 329 | "file_extension": ".py", 330 | "mimetype": "text/x-python", 331 | "name": "python", 332 | "nbconvert_exporter": "python", 333 | "pygments_lexer": "ipython3", 334 | "version": "3.7.4" 335 | } 336 | }, 337 | "nbformat": 4, 338 | "nbformat_minor": 4 339 | } 340 | -------------------------------------------------------------------------------- /example/utils/create_frames.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from scipy.interpolate import interp1d 5 | from tqdm import tqdm 6 | 7 | SOURCE_DIR = 'data/MobiActV2/sessions/csv' 8 | DESTINATION_DIR = 'data/MobiActV2/frames' 9 | INTERPOLATION_FREQUENCY = 50 # Hz 10 | FRAME_LENGTH = 3 # seconds 11 | N_POINTS_FRAME = int(INTERPOLATION_FREQUENCY*FRAME_LENGTH) 12 | 13 | 14 | if not os.path.exists(DESTINATION_DIR): 15 | print(f"Creating destination dir: {DESTINATION_DIR}") 16 | os.mkdir(DESTINATION_DIR) 17 | 18 | print(f"Creating frames from dir: {SOURCE_DIR}") 19 | files = np.array(os.listdir(SOURCE_DIR)) 20 | files = [f for f in files if f.endswith('.csv')] 21 | for file in tqdm(files): 22 | filepath = os.path.join(SOURCE_DIR, file) 23 | df = pd.read_csv(filepath) 24 | 25 | # Simple data validity checks 26 | ##### Check if relative time values are valid 27 | if df.rel_time.min() != 0 and df.rel_time.max() > df.rel_time.min(): 28 | print(f"Found file with a non valid minimum 'rel_time' start: {file}") 29 | ##### Check for missing sensor channels 30 | if not set(['acc_x', 'acc_y', 'acc_z']).issubset(df.columns): 31 | print(f"Found file with incomplete accelerometer channels: {file}") 32 | print(f"Omitting...") 33 | continue 34 | if not set(['gyro_x', 'gyro_y', 'gyro_z']).issubset(df.columns): 35 | print(f"Found file with incomplete gyroscope channels: {file}") 36 | print(f"Omitting...") 37 | continue 38 | if not set(['azimuth', 'pitch', 'roll']).issubset(df.columns): 39 | print(f"Found file with incomplete orientation channels: {file}") 40 | print(f"Omitting...") 41 | continue 42 | 43 | # There is a bug in the data where the azimuth angle is sometimes negative (<<1% of the data) 44 | # This can be easily fixed by taking the modulo of 360 45 | df.azimuth = df.azimuth%360 46 | 47 | # Interpolate data to specified interpolation frequency 48 | ##### Labels and numerical data 49 | labels = df[['rel_time', 'label']] 50 | df = df.drop(['timestamp', 'label'], axis=1) 51 | 52 | ##### Interpolation of numerical values 53 | n_points = (df.rel_time.max() - df.rel_time.min())*INTERPOLATION_FREQUENCY 54 | interp_f = interp1d(df.index, df.values, kind='slinear', assume_sorted=True, axis=0) 55 | x_prime = np.linspace(df.index.min(), df.index.max(), n_points) 56 | y_prime = interp_f(x_prime).astype(np.float32) 57 | df = pd.DataFrame(y_prime, columns=df.columns) 58 | 59 | ##### Readdition of labels to df 60 | transitions = [] 61 | transitions.append([labels.rel_time[0], labels.label[0]]) 62 | for i in range(1, len(labels)): 63 | if labels.label[i] != labels.label[i-1]: 64 | transitions.append([labels.rel_time[i], labels.label[i]]) 65 | transitions = [[np.where(df.rel_time >= x[0])[0][0], x[1]] for x in transitions] 66 | 67 | df['label'] = np.nan 68 | for transition in transitions: 69 | df.label.iloc[transition[0]] = transition[1] 70 | df = df.fillna(method='ffill') 71 | 72 | # Split interpolated df to 'frames' of specified length 73 | for frame_idx, i in enumerate(range(0, len(df)-N_POINTS_FRAME, N_POINTS_FRAME)): 74 | sub_df = df.iloc[i:i+N_POINTS_FRAME] 75 | fn = f"{os.path.join(DESTINATION_DIR, file[:-4])}_{frame_idx}.csv" 76 | assert len(sub_df) == N_POINTS_FRAME 77 | sub_df.to_csv(fn) 78 | sub_df = df.iloc[(len(df)-N_POINTS_FRAME):len(df)] 79 | fn = f"{os.path.join(DESTINATION_DIR, file[:-4])}_end.csv" 80 | assert len(sub_df) == N_POINTS_FRAME, str(len(df)) 81 | sub_df.to_csv(fn) 82 | 83 | print("Finished creating frames.") 84 | print(f"Frames saved to dir: {DESTINATION_DIR}") 85 | -------------------------------------------------------------------------------- /example/utils/download_dataset_from_s3.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | from botocore import UNSIGNED 3 | from botocore.client import Config 4 | 5 | s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED)) 6 | s3.download_file('mobiactv2', 'MobiAct_Dataset_v2.0.rar', 'data/MobiActV2/MobiAct_Dataset_v2.0.rar') 7 | -------------------------------------------------------------------------------- /example/utils/unpack_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir data 4 | mkdir data/MobiActV2 5 | python3 utils/download_dataset_from_s3.py 6 | unrar e data/MobiActV2/MobiAct_Dataset_v2.0.rar data/MobiActV2 7 | mkdir data/MobiActV2/sessions 8 | mkdir data/MobiActV2/sessions/csv 9 | mkdir data/MobiActV2/sessions/txt 10 | mv data/MobiActV2/*.csv data/MobiActV2/sessions/csv 11 | mv data/MobiActV2/*.txt data/MobiActV2/sessions/txt 12 | mv data/MobiActV2/sessions/txt/Readme.txt data/MobiActV2/README.txt 13 | python3 utils/create_frames.py 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | attrs==19.3.0 3 | backcall==0.1.0 4 | bleach==3.3.0 5 | boto3==1.9.251 6 | botocore==1.12.251 7 | cycler==0.10.0 8 | decorator==4.4.0 9 | defusedxml==0.6.0 10 | docutils==0.15.2 11 | entrypoints==0.3 12 | importlib-metadata==0.23 13 | ipykernel==5.1.2 14 | ipython==7.16.3 15 | ipython-genutils==0.2.0 16 | ipywidgets==7.5.1 17 | jedi==0.15.1 18 | Jinja2==2.11.3 19 | jmespath==0.9.4 20 | jsonschema==3.1.1 21 | jupyter==1.0.0 22 | jupyter-client==5.3.4 23 | jupyter-console==6.0.0 24 | jupyter-core==4.6.0 25 | kiwisolver==1.1.0 26 | MarkupSafe==1.1.1 27 | matplotlib==3.1.1 28 | mistune==0.8.4 29 | more-itertools==7.2.0 30 | nbconvert==5.6.0 31 | nbformat==4.4.0 32 | notebook==6.4.10 33 | numpy==1.17.2 34 | pandas==0.25.1 35 | pandocfilters==1.4.2 36 | parso==0.5.1 37 | pexpect==4.7.0 38 | pickleshare==0.7.5 39 | prometheus-client==0.7.1 40 | prompt-toolkit==2.0.10 41 | ptyprocess==0.6.0 42 | Pygments==2.7.4 43 | pyparsing==2.4.2 44 | pyrsistent==0.15.4 45 | python-dateutil==2.8.0 46 | pytorch-ignite==0.2.1 47 | pytz==2019.3 48 | pyzmq==18.1.0 49 | qtconsole==4.5.5 50 | s3transfer==0.2.1 51 | scipy==1.3.1 52 | Send2Trash==1.5.0 53 | six==1.12.0 54 | terminado==0.8.2 55 | testpath==0.4.2 56 | torch==1.3.0 57 | tornado==6.0.3 58 | tqdm==4.36.1 59 | traitlets==4.3.3 60 | urllib3==1.26.5 61 | wcwidth==0.1.7 62 | webencodings==0.5.1 63 | widgetsnbextension==3.5.1 64 | zipp==0.6.0 65 | -------------------------------------------------------------------------------- /segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | from . import models -------------------------------------------------------------------------------- /segmentation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobiactv2 import * 2 | -------------------------------------------------------------------------------- /segmentation/datasets/mobiactv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | import torch 5 | from torch.utils import data 6 | 7 | 8 | ACTIVITY_DESCRIPTIONS = { 9 | "STD": "Standing", 10 | "WAL": "Walking", 11 | "JOG": "Jogging", 12 | "JUM": "Jumping", 13 | "STU": "Stairs up", 14 | "STN": "Stairs down", 15 | "SCH": "Stand to sit", 16 | "SIT": "Sitting", 17 | "CHU": "Sit to stand", 18 | "CSI": "Car step in", 19 | "CSO": "Car step out", 20 | } 21 | 22 | 23 | SCENARIO_DESCRIPTIONS = { 24 | "SLH": "Leaving home", 25 | "SBW": "Being at work", 26 | "SLW": "Leaving work", 27 | "SBE": "Exercising", 28 | "SRH": "Returning home", 29 | } 30 | 31 | 32 | LABEL_CODES = { 33 | "STD": 0, 34 | "WAL": 1, 35 | "JOG": 2, 36 | "JUM": 3, 37 | "STU": 4, 38 | "STN": 5, 39 | "SCH": 6, 40 | "SIT": 7, 41 | "CHU": 8, 42 | "CSI": 9, 43 | "CSO": 10, 44 | } 45 | 46 | 47 | class MobiActV2(data.Dataset): 48 | """ A pyTorch Dataset class for MobiActV2 frames. 49 | 50 | Args: 51 | root (string): The root directory where the dataset exists. 52 | sensors (string): The sensor channels to include. 53 | users (list): List of user IDs to include in the dataset instance. 54 | """ 55 | 56 | def __init__(self, root, sensors, users): 57 | self.root = root 58 | self.sensors = sensors 59 | self.users = users 60 | self.activity_descriptions = ACTIVITY_DESCRIPTIONS 61 | self.scenario_descriptions = SCENARIO_DESCRIPTIONS 62 | self.label_codes = LABEL_CODES 63 | self._cache = {} 64 | 65 | # Load files for given users and activities/scenarios of interest 66 | files = os.listdir(self.root) 67 | files = [f for f in files if int(f.split("_")[1]) in self.users] 68 | files = [ 69 | f 70 | for f in files 71 | if f.split("_")[0] in ACTIVITY_DESCRIPTIONS 72 | or f.split("_")[0] in SCENARIO_DESCRIPTIONS 73 | ] 74 | self.files = files 75 | 76 | def __len__(self): 77 | return len(self.files) 78 | 79 | def __getitem__(self, idx): 80 | # Check if frame is in cache 81 | if idx in self._cache: 82 | return self._cache[idx][0], self._cache[idx][1] 83 | 84 | # Load frame data 85 | file_path = os.path.join(self.root, self.files[idx]) 86 | df = pd.read_csv(file_path) 87 | 88 | # Normalize sensor data to Android sensor ranges 89 | # Acceleration: [-20, 20] (m/s**2) 90 | # Angular velocity: [-10, 10] (rad/s) 91 | # Rotation: 92 | # - Azimuth: [0, 360] (degrees) 93 | # - Pitch: [-180, 180] (degrees) 94 | # - Roll: [-90, 90] (degrees) 95 | df[["acc_x", "acc_y", "acc_z"]] = df[["acc_x", "acc_y", "acc_z"]].apply( 96 | lambda x: -1 + 2 * (x + 20) / 40 97 | ) 98 | df[["gyro_x", "gyro_y", "gyro_z"]] = df[["gyro_x", "gyro_y", "gyro_z"]].apply( 99 | lambda x: -1 + 2 * (x + 10) / 20 100 | ) 101 | df["azimuth"] = df["azimuth"].apply(lambda x: -1 + 2 * x / 360) 102 | df["pitch"] = df["pitch"].apply(lambda x: -1 + 2 * (x + 180) / 360) 103 | df["roll"] = df["roll"].apply(lambda x: -1 + 2 * (x + 90) / 180) 104 | 105 | # Slice X, y 106 | X = df[self.sensors] 107 | y = df["label"] 108 | ts = df["rel_time"].tolist() 109 | 110 | # Tensorize 111 | X = torch.tensor(X.values, dtype=torch.float64) 112 | file_id = torch.tensor([idx]) 113 | y = torch.tensor(list(map(lambda x: self.label_codes[x], y)), dtype=torch.int64) 114 | 115 | target = {} 116 | target["file_id"] = file_id 117 | target["filename"] = self.files[idx] 118 | target["y"] = y 119 | target["ts"] = ts 120 | 121 | # Add frame to cache 122 | if idx not in self._cache: 123 | self._cache[idx] = [X, target] 124 | 125 | return X, target 126 | -------------------------------------------------------------------------------- /segmentation/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import * -------------------------------------------------------------------------------- /segmentation/engine/engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ignite.engine.engine import Engine 3 | 4 | 5 | def create_trainer(model, optimizer, loss_fn, device): 6 | """ Creates an ignite Engine instance for model training. 7 | 8 | Args: 9 | model: PyTorch model instance to be trained. 10 | optimizer: PyTorch optimizer to be used for model training. 11 | loss_fn: PyTorch loss function used for the model training. 12 | device: A string representing the used device (cpu or gpu). 13 | """ 14 | model.to(device) 15 | 16 | def _update(engine, batch): 17 | model.train() 18 | optimizer.zero_grad() 19 | 20 | # Prepare data 21 | inputs, labels = batch 22 | inputs = inputs.permute(0, 2, 1).to(device, dtype=torch.float) 23 | labels = labels["y"].to(device) 24 | 25 | outputs = model(inputs) 26 | loss = loss_fn(outputs, labels) 27 | loss.backward() 28 | optimizer.step() 29 | return loss.item() 30 | 31 | return Engine(_update) 32 | 33 | 34 | def create_evaluator(model, device, metrics={}): 35 | """ Creates an ignite Engine instance for model evaluation. 36 | 37 | Args: 38 | model: PyTorch model instance to be evaluated. 39 | device: A string representing the used device (cpu or gpu). 40 | metrics: A dictionary of the evaluation metrics. 41 | """ 42 | model.to(device) 43 | 44 | def _inference(engine, batch): 45 | model.eval() 46 | with torch.no_grad(): 47 | # Prepare data 48 | inputs, labels = batch 49 | inputs = inputs.permute(0, 2, 1).to(device, dtype=torch.float) 50 | labels = labels["y"].to(device) 51 | 52 | outputs = model(inputs) 53 | return outputs, labels 54 | 55 | engine = Engine(_inference) 56 | for name, metric in metrics.items(): 57 | metric.attach(engine, name) 58 | 59 | return engine 60 | -------------------------------------------------------------------------------- /segmentation/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * -------------------------------------------------------------------------------- /segmentation/logger/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Logger(object): 5 | """ A Logger object is used to log and print metrics during training. 6 | 7 | Args: 8 | evaluator: Ignite Engine used for model evaluation. 9 | train_loader: PyTorch DataLoader object for the training set. 10 | validation_loader: PyTorch DataLoader object for the validation set. 11 | """ 12 | def __init__(self, evaluator, train_loader, validation_loader): 13 | self.evaluator = evaluator 14 | self.train_loader = train_loader 15 | self.validation_loader = validation_loader 16 | self.metrics = {"training": [], "validation": []} 17 | 18 | def __call__(self, engine): 19 | print(f"Epoch {engine.state.epoch} training results") 20 | print("-------------------------") 21 | 22 | # Training 23 | self.evaluator.run(self.train_loader) 24 | self.metrics["training"].append(self.evaluator.state.metrics) 25 | print( 26 | f"Loss: {round(self.metrics['training'][-1]['loss'], 3)},\ 27 | Samplewise accuracy: {round(self.metrics['training'][-1]['samplewise_accuracy'], 3)},\ 28 | Mean IoU: {round(self.metrics['training'][-1]['mean_iou'], 3)},\ 29 | Frequency weighted IoU: {round(self.metrics['training'][-1]['frequency_weighted_iou'], 3)}" 30 | ) 31 | print() 32 | 33 | # Validation 34 | self.evaluator.run(self.validation_loader) 35 | self.metrics["validation"].append(self.evaluator.state.metrics) 36 | -------------------------------------------------------------------------------- /segmentation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import * -------------------------------------------------------------------------------- /segmentation/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ignite.metrics.metric import Metric 3 | 4 | 5 | class SamplewiseAccuracy(Metric): 6 | """ Segmentation samplewise accuracy. This metric can be attached to 7 | an ignite evaluator engine and will return the samplewise accuracy 8 | for each epoch.""" 9 | 10 | def reset(self): 11 | """ Resets the number of correctly predicted and total samples 12 | at the start of each epoch. """ 13 | self._correct_samples = 0 14 | self._total_samples = 0 15 | 16 | def update(self, data): 17 | # Unpack data, assert shapes and get predictions 18 | outputs, labels = data 19 | assert outputs.shape[0] == labels.shape[0] 20 | outputs = outputs.argmax(1) 21 | 22 | # Update numbers of correctly predicted and total samples 23 | self._correct_samples += (outputs == labels).sum(dtype=torch.float) 24 | self._total_samples += torch.numel(outputs) 25 | 26 | def compute(self): 27 | return self._correct_samples / self._total_samples 28 | 29 | 30 | class MeanAccuracy(Metric): 31 | """ Segmentation mean class accuracy. This metric can be attached to 32 | an ignite evaluator engine and will return the mean class accuracy 33 | for each epoch.""" 34 | 35 | def reset(self): 36 | """ Resets the classwise number of correctly predicted and total samples 37 | at the start of each epoch. """ 38 | self._correct_class_samples = {} 39 | self._total_class_samples = {} 40 | 41 | def update(self, data): 42 | # Unpack data, assert shapes and get predictions 43 | outputs, labels = data 44 | assert outputs.shape[0] == labels.shape[0] 45 | outputs = outputs.argmax(1) 46 | 47 | # Update correctly predicted and total precitions for each class in batch 48 | for label in torch.unique(labels): 49 | if not label in self._total_class_samples: 50 | self._correct_class_samples[label] = 0 51 | self._total_class_samples[label] = 0 52 | 53 | # Samples belonging to current class 54 | class_samples = labels == label 55 | 56 | # Correctly predicted samples and total samples for current class in batch 57 | correct_samples = (outputs[class_samples] == label).sum(dtype=torch.float) 58 | total_samples = class_samples.sum(dtype=torch.float) 59 | self._correct_class_samples[label] += correct_samples 60 | self._total_class_samples[label] += total_samples 61 | 62 | def compute(self): 63 | accuracies = [] 64 | for label in self._total_class_samples: 65 | correct_samples = self._correct_class_samples[label] 66 | total_samples = self._total_class_samples[label] 67 | accuracies.append(correct_samples / total_samples) 68 | return torch.mean(torch.tensor(accuracies)) 69 | 70 | 71 | class MeanIoU(Metric): 72 | """ Segmentation mean class IoU. This metric can be attached to 73 | an ignite evaluator engine and will return the mean IoU for each epoch.""" 74 | 75 | def reset(self): 76 | """ Resets the classwise intersection and union at the start of each epoch.""" 77 | self._class_intersection = {} 78 | self._class_union = {} 79 | 80 | def update(self, data): 81 | # Unpack data, assert shapes and get predictions 82 | outputs, labels = data 83 | assert outputs.shape[0] == labels.shape[0] 84 | outputs = outputs.argmax(1) 85 | 86 | # Update intersection and union for each class in batch 87 | for label in torch.unique(labels): 88 | if not label in self._class_intersection: 89 | self._class_intersection[label] = 0 90 | self._class_union[label] = 0 91 | 92 | # Intersection and union of current class 93 | intersection = ( 94 | ((labels == label) & (outputs == label)).sum(dtype=torch.float).item() 95 | ) 96 | union = ( 97 | ((labels == label) | (outputs == label)).sum(dtype=torch.float).item() 98 | ) 99 | self._class_intersection[label] += intersection 100 | self._class_union[label] += union 101 | 102 | def compute(self): 103 | ious = [] 104 | for label in self._class_intersection: 105 | total_intersection = self._class_intersection[label] 106 | total_union = self._class_union[label] 107 | ious.append(total_intersection / total_union) 108 | return torch.mean(torch.tensor(ious)) 109 | 110 | 111 | class FrequencyWeightedIoU(Metric): 112 | """ Segmentation frequency weighted class IoU. This metric can be attached to 113 | an ignite evaluator engine and will return the frequency weighted IoU for each epoch.""" 114 | 115 | def reset(self): 116 | """ Resets the classwise intersection, union, class samples and total samples at the start of each epoch.""" 117 | self._class_intersection = {} 118 | self._class_union = {} 119 | self._class_samples = {} 120 | self._total_samples = 0 121 | 122 | def update(self, data): 123 | # Unpack data, assert shapes and get predictions 124 | outputs, labels = data 125 | assert outputs.shape[0] == labels.shape[0] 126 | outputs = outputs.argmax(1) 127 | 128 | # Update intersection, union, class and total samples 129 | for label in torch.unique(labels): 130 | if not label in self._class_intersection: 131 | self._class_intersection[label] = 0 132 | self._class_union[label] = 0 133 | self._class_samples[label] = 0 134 | 135 | # Samples belonging to current class 136 | class_samples = labels == label 137 | 138 | # Total samples, class samples, and intersection and union of current class 139 | self._total_samples += class_samples.sum(dtype=torch.float).item() 140 | self._class_samples[label] += class_samples.sum(dtype=torch.float).item() 141 | intersection = ( 142 | ((labels == label) & (outputs == label)).sum(dtype=torch.float).item() 143 | ) 144 | union = ( 145 | ((labels == label) | (outputs == label)).sum(dtype=torch.float).item() 146 | ) 147 | self._class_intersection[label] += intersection 148 | self._class_union[label] += union 149 | 150 | def compute(self): 151 | ious = [] 152 | for label in self._class_intersection: 153 | total_samples = self._total_samples 154 | class_samples = self._class_samples[label] 155 | class_intersection = self._class_intersection[label] 156 | class_union = self._class_union[label] 157 | ious.append(class_samples * class_intersection / class_union) 158 | return torch.tensor(ious).sum().item() / total_samples 159 | -------------------------------------------------------------------------------- /segmentation/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * -------------------------------------------------------------------------------- /segmentation/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SensorFCN(nn.Module): 7 | """ A fully convolutional network for segmentation of motion sensor data from 8 | smartphones or wearable sensors. 9 | 10 | Args: 11 | n_input_channels (int): The number of input channels for the network. For example, 12 | if the input is the data of a triaxial accelerometer the number of input channels is 3. 13 | n_classes (int): The number of target classes for the model predictions. 14 | input_kernel_size (int): The number of samples to be used in the receptive field. 15 | n_filters (int): The filter parameter of the SensorFCN network architecture. Larger filter 16 | parameter values result in larger and more powerful models. 17 | smoothing_kernel_size (int): The kernel size for the smoothing component of the SensorFCN 18 | architecture. Including a smoothing component generally results in better models, 19 | especially for the segmentation of sporadically occurring patterns. 20 | """ 21 | def __init__( 22 | self, 23 | n_input_channels, 24 | n_classes, 25 | input_kernel_size, 26 | n_filters=16, 27 | smoothing_kernel_size=0, 28 | ): 29 | super(SensorFCN, self).__init__() 30 | self.n_input_channels = n_input_channels 31 | self.n_classes = n_classes 32 | self.input_kernel_size = input_kernel_size 33 | self.n_filters = n_filters 34 | self.smoothing_kernel_size = smoothing_kernel_size 35 | 36 | # Encoding layers 37 | self.conv1 = nn.Conv1d(n_input_channels, n_filters, input_kernel_size) 38 | self.conv2 = nn.Conv1d(n_filters, 2 * n_filters, 5) 39 | self.conv3 = nn.Conv1d(2 * n_filters, 4 * n_filters, 3) 40 | 41 | # Decoding layer 42 | self.tconv1 = nn.ConvTranspose1d( 43 | 4 * n_filters, n_filters, 2 + 4 + input_kernel_size 44 | ) 45 | 46 | # Scoring layer 47 | self.score_conv = nn.Conv1d(n_filters, n_classes, 1, 1, 0) 48 | 49 | # Post scoring layer 50 | if smoothing_kernel_size > 0: 51 | self.smoothing_conv = nn.Conv1d( 52 | n_classes, 53 | n_classes, 54 | smoothing_kernel_size, 55 | 1, 56 | smoothing_kernel_size // 2, 57 | ) 58 | 59 | def forward(self, x): 60 | x = F.relu(self.conv1(x)) 61 | x = F.relu(self.conv2(x)) 62 | x = F.relu(self.conv3(x)) 63 | 64 | x = self.tconv1(x) 65 | x = self.score_conv(x) 66 | if self.smoothing_kernel_size > 0: 67 | x = self.smoothing_conv(x) 68 | return x 69 | -------------------------------------------------------------------------------- /segmentation/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * -------------------------------------------------------------------------------- /segmentation/transforms/functional.py: -------------------------------------------------------------------------------- 1 | def xflip(frame): 2 | """ Flip the given sensor frame along the x-axis. 3 | 4 | Args: 5 | frame (torch.Tensor): Sensor frame to be flipped. 6 | """ 7 | frame[1, :] = -frame[1, :] 8 | frame[2, :] = -frame[2, :] 9 | return frame 10 | 11 | 12 | def yflip(frame): 13 | """ Flip the given sensor frame along the y-axis. 14 | 15 | Args: 16 | frame (torch.Tensor): Sensor frame to be flipped. 17 | """ 18 | frame[0, :] = -frame[0, :] 19 | frame[2, :] = -frame[2, :] 20 | return frame 21 | 22 | 23 | def zflip(frame): 24 | """ Flip the given sensor frame along the z-axis. 25 | 26 | Args: 27 | frame (torch.Tensor): Sensor frame to be flipped. 28 | """ 29 | frame[0, :] = -frame[0, :] 30 | frame[1, :] = -frame[1, :] 31 | return frame 32 | -------------------------------------------------------------------------------- /segmentation/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from . import functional as F 2 | import random 3 | 4 | 5 | class Compose(object): 6 | """ Composes multiple transforms together. 7 | 8 | Args: 9 | transforms: Transformation objects to be composed together. 10 | """ 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, frame): 15 | for t in self.transforms: 16 | frame = t(frame) 17 | return frame 18 | 19 | def __repr__(self): 20 | format_string = self.__class__.__name__ + "(" 21 | for t in self.transforms: 22 | format_string += "\n" 23 | format_string += " {0}".format(t) 24 | format_string += "\n)" 25 | return format_string 26 | 27 | 28 | class RandomXFlip(object): 29 | """ Flip the given sensor frame along the x-axis randomly with a given probability. 30 | 31 | Args: 32 | p (float): probability of the frame being flipped. 33 | """ 34 | 35 | def __init__(self, p=0.5): 36 | self.p = p 37 | 38 | def __call__(self, frame): 39 | if random.random() < self.p: 40 | return F.xflip(frame) 41 | return frame 42 | 43 | def __repr__(self): 44 | return self.__class__.__name__ + "(p={})".format(self.p) 45 | 46 | 47 | class RandomYFlip(object): 48 | """ Flip the given sensor frame along the y-axis randomly with a given probability. 49 | 50 | Args: 51 | p (float): probability of the frame being flipped. 52 | """ 53 | 54 | def __init__(self, p=0.5): 55 | self.p = p 56 | 57 | def __call__(self, frame): 58 | if random.random() < self.p: 59 | return F.yflip(frame) 60 | return frame 61 | 62 | def __repr__(self): 63 | return self.__class__.__name__ + "(p={})".format(self.p) 64 | 65 | 66 | class RandomZFlip(object): 67 | """ Flip the given sensor frame along the z-axis randomly with a given probability. 68 | 69 | Args: 70 | p (float): probability of the frame being flipped. 71 | """ 72 | 73 | def __init__(self, p=0.5): 74 | self.p = p 75 | 76 | def __call__(self, frame): 77 | if random.random() < self.p: 78 | return F.zflip(frame) 79 | return frame 80 | 81 | def __repr__(self): 82 | return self.__class__.__name__ + "(p={})".format(self.p) 83 | -------------------------------------------------------------------------------- /segmentation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /segmentation/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from sklearn.metrics import confusion_matrix 5 | from sklearn.utils.multiclass import unique_labels 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def get_metric(metric, metrics, mode): 11 | return list(map(lambda x: x[metric], metrics[mode])) 12 | 13 | 14 | def plot_metrics(metrics, mode): 15 | 16 | loss = get_metric("loss", metrics, mode) 17 | plt.plot(loss) 18 | plt.ylabel("Loss") 19 | plt.xlabel("Epoch") 20 | plt.show() 21 | 22 | plt.figure(figsize=(12, 10)) 23 | samplewise_accuracy = get_metric("samplewise_accuracy", metrics, mode) 24 | plt.subplot(2, 2, 1) 25 | plt.plot(samplewise_accuracy) 26 | plt.ylabel("Samplewise accuracy") 27 | plt.xlabel("Epoch") 28 | 29 | mean_accuracy = get_metric("mean_accuracy", metrics, mode) 30 | plt.subplot(2, 2, 2) 31 | plt.plot(mean_accuracy) 32 | plt.ylabel("Mean accuracy") 33 | plt.xlabel("Epoch") 34 | 35 | mean_iou = get_metric("mean_iou", metrics, mode) 36 | plt.subplot(2, 2, 3) 37 | plt.plot(mean_iou) 38 | plt.ylabel("Mean IoU") 39 | plt.xlabel("Epoch") 40 | 41 | frequency_weighted_iou = get_metric("frequency_weighted_iou", metrics, mode) 42 | plt.subplot(2, 2, 4) 43 | plt.plot(frequency_weighted_iou) 44 | plt.ylabel("Frequency weighted IoU") 45 | plt.xlabel("Epoch") 46 | 47 | plt.show() 48 | 49 | 50 | def predict_with_model(model, dataset, device): 51 | preds = [] 52 | gts = [] 53 | for i in range(len(dataset)): 54 | data, label = dataset[i] 55 | preds.extend( 56 | model(data.unsqueeze(0).permute(0, 2, 1).to(device, dtype=torch.float)) 57 | .argmax(1) 58 | .squeeze() 59 | .detach() 60 | .cpu() 61 | .numpy() 62 | ) 63 | gts.extend(label["y"].numpy()) 64 | 65 | preds = np.array(preds) 66 | gts = np.array(gts) 67 | 68 | return gts, preds 69 | 70 | 71 | def plot_confusion_matrix( 72 | y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues 73 | ): 74 | """ 75 | This function prints and plots the confusion matrix. 76 | Normalization can be applied by setting `normalize=True`. 77 | """ 78 | if not title: 79 | if normalize: 80 | title = "Normalized confusion matrix" 81 | else: 82 | title = "Confusion matrix, without normalization" 83 | 84 | # Compute confusion matrix 85 | cm = confusion_matrix(y_true, y_pred) 86 | # Only use the labels that appear in the data 87 | # classes = classes[unique_labels(y_true, y_pred)] 88 | if normalize: 89 | cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] 90 | print("Normalized confusion matrix") 91 | else: 92 | print("Confusion matrix, without normalization") 93 | 94 | fig, ax = plt.subplots(figsize=(12, 12)) 95 | im = ax.imshow(cm, interpolation="nearest", cmap=cmap) 96 | ax.figure.colorbar(im, ax=ax) 97 | # We want to show all ticks... 98 | ax.set( 99 | xticks=np.arange(cm.shape[1]), 100 | yticks=np.arange(cm.shape[0]), 101 | # ... and label them with the respective list entries 102 | xticklabels=classes, 103 | yticklabels=classes, 104 | title=title, 105 | ylabel="True label", 106 | xlabel="Predicted label", 107 | ) 108 | 109 | # Rotate the tick labels and set their alignment. 110 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 111 | 112 | # Loop over data dimensions and create text annotations. 113 | fmt = ".2f" if normalize else "d" 114 | thresh = cm.max() / 2.0 115 | for i in range(cm.shape[0]): 116 | for j in range(cm.shape[1]): 117 | ax.text( 118 | j, 119 | i, 120 | format(cm[i, j], fmt), 121 | ha="center", 122 | va="center", 123 | color="white" if cm[i, j] > thresh else "black", 124 | ) 125 | plt.grid(False) 126 | fig.tight_layout() 127 | return ax 128 | --------------------------------------------------------------------------------