├── InceptionTime evaluation.ipynb ├── InceptionTime_full-version_lr-{5e-3,1e-3,2e-4},_bs-512_ks-[5,11,23]_100-epochs_state_dict.pt ├── README.md ├── data ├── sequenced_data_for_VAE_length-160_stride-10_pt1.npy ├── sequenced_data_for_VAE_length-160_stride-10_pt2.npy └── sequenced_data_for_VAE_length-160_stride-10_targets.npy └── inception.py /InceptionTime evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np \n", 10 | "import time\n", 11 | "\n", 12 | "import torch \n", 13 | "import torch.nn as nn\n", 14 | "import torch.nn.functional as F \n", 15 | "\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "from collections import OrderedDict\n", 18 | "\n", 19 | "from sklearn.model_selection import train_test_split\n", 20 | "from sklearn.metrics import confusion_matrix, accuracy_score, f1_score\n", 21 | "from sklearn.preprocessing import RobustScaler\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from inception import Inception, InceptionBlock" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 7, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "class Flatten(nn.Module):\n", 40 | "\tdef __init__(self, out_features):\n", 41 | "\t\tsuper(Flatten, self).__init__()\n", 42 | "\t\tself.output_dim = out_features\n", 43 | "\n", 44 | "\tdef forward(self, x):\n", 45 | "\t\treturn x.view(-1, self.output_dim)\n", 46 | " \n", 47 | "class Reshape(nn.Module):\n", 48 | "\tdef __init__(self, out_shape):\n", 49 | "\t\tsuper(Reshape, self).__init__()\n", 50 | "\t\tself.out_shape = out_shape\n", 51 | "\n", 52 | "\tdef forward(self, x):\n", 53 | "\t\treturn x.view(-1, *self.out_shape)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "X = np.vstack((np.load(\"data/sequenced_data_for_VAE_length-160_stride-10_pt1.npy\"),\n", 63 | " np.load(\"data/sequenced_data_for_VAE_length-160_stride-10_pt2.npy\")))\n", 64 | "y = np.load(\"data/sequenced_data_for_VAE_length-160_stride-10_targets.npy\")" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 5, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 6, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "scaler = RobustScaler()\n", 83 | "X_train = scaler.fit_transform(X_train)\n", 84 | "X_test = scaler.transform(X_test)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 8, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "InceptionTime = nn.Sequential(\n", 94 | " Reshape(out_shape=(1,160)),\n", 95 | " InceptionBlock(\n", 96 | " in_channels=1, \n", 97 | " n_filters=32, \n", 98 | " kernel_sizes=[5, 11, 23],\n", 99 | " bottleneck_channels=32,\n", 100 | " use_residual=True,\n", 101 | " activation=nn.ReLU()\n", 102 | " ),\n", 103 | " InceptionBlock(\n", 104 | " in_channels=32*4, \n", 105 | " n_filters=32, \n", 106 | " kernel_sizes=[5, 11, 23],\n", 107 | " bottleneck_channels=32,\n", 108 | " use_residual=True,\n", 109 | " activation=nn.ReLU()\n", 110 | " ),\n", 111 | " nn.AdaptiveAvgPool1d(output_size=1),\n", 112 | " Flatten(out_features=32*4*1),\n", 113 | " nn.Linear(in_features=4*32*1, out_features=4)\n", 114 | " )" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 9, 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "" 126 | ] 127 | }, 128 | "execution_count": 9, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | } 132 | ], 133 | "source": [ 134 | "InceptionTime.load_state_dict(torch.load(\"InceptionTime_full-version_lr-{5e-3,1e-3,2e-4},_bs-512_ks-[5,11,23]_100-epochs_state_dict.pt\"))" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 14, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "text/plain": [ 145 | "Sequential(\n", 146 | " (0): Reshape()\n", 147 | " (1): InceptionBlock(\n", 148 | " (activation): ReLU()\n", 149 | " (inception_1): Inception(\n", 150 | " (conv_from_bottleneck_1): Conv1d(1, 32, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", 151 | " (conv_from_bottleneck_2): Conv1d(1, 32, kernel_size=(11,), stride=(1,), padding=(5,), bias=False)\n", 152 | " (conv_from_bottleneck_3): Conv1d(1, 32, kernel_size=(23,), stride=(1,), padding=(11,), bias=False)\n", 153 | " (max_pool): MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n", 154 | " (conv_from_maxpool): Conv1d(1, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 155 | " (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 156 | " (activation): ReLU()\n", 157 | " )\n", 158 | " (inception_2): Inception(\n", 159 | " (bottleneck): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 160 | " (conv_from_bottleneck_1): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", 161 | " (conv_from_bottleneck_2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,), bias=False)\n", 162 | " (conv_from_bottleneck_3): Conv1d(32, 32, kernel_size=(23,), stride=(1,), padding=(11,), bias=False)\n", 163 | " (max_pool): MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n", 164 | " (conv_from_maxpool): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 165 | " (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 166 | " (activation): ReLU()\n", 167 | " )\n", 168 | " (inception_3): Inception(\n", 169 | " (bottleneck): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 170 | " (conv_from_bottleneck_1): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", 171 | " (conv_from_bottleneck_2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,), bias=False)\n", 172 | " (conv_from_bottleneck_3): Conv1d(32, 32, kernel_size=(23,), stride=(1,), padding=(11,), bias=False)\n", 173 | " (max_pool): MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n", 174 | " (conv_from_maxpool): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 175 | " (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 176 | " (activation): ReLU()\n", 177 | " )\n", 178 | " (residual): Sequential(\n", 179 | " (0): Conv1d(1, 128, kernel_size=(1,), stride=(1,))\n", 180 | " (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 181 | " )\n", 182 | " )\n", 183 | " (2): InceptionBlock(\n", 184 | " (activation): ReLU()\n", 185 | " (inception_1): Inception(\n", 186 | " (bottleneck): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 187 | " (conv_from_bottleneck_1): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", 188 | " (conv_from_bottleneck_2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,), bias=False)\n", 189 | " (conv_from_bottleneck_3): Conv1d(32, 32, kernel_size=(23,), stride=(1,), padding=(11,), bias=False)\n", 190 | " (max_pool): MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n", 191 | " (conv_from_maxpool): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 192 | " (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 193 | " (activation): ReLU()\n", 194 | " )\n", 195 | " (inception_2): Inception(\n", 196 | " (bottleneck): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 197 | " (conv_from_bottleneck_1): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", 198 | " (conv_from_bottleneck_2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,), bias=False)\n", 199 | " (conv_from_bottleneck_3): Conv1d(32, 32, kernel_size=(23,), stride=(1,), padding=(11,), bias=False)\n", 200 | " (max_pool): MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n", 201 | " (conv_from_maxpool): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 202 | " (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 203 | " (activation): ReLU()\n", 204 | " )\n", 205 | " (inception_3): Inception(\n", 206 | " (bottleneck): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 207 | " (conv_from_bottleneck_1): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", 208 | " (conv_from_bottleneck_2): Conv1d(32, 32, kernel_size=(11,), stride=(1,), padding=(5,), bias=False)\n", 209 | " (conv_from_bottleneck_3): Conv1d(32, 32, kernel_size=(23,), stride=(1,), padding=(11,), bias=False)\n", 210 | " (max_pool): MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)\n", 211 | " (conv_from_maxpool): Conv1d(128, 32, kernel_size=(1,), stride=(1,), bias=False)\n", 212 | " (batch_norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 213 | " (activation): ReLU()\n", 214 | " )\n", 215 | " (residual): Sequential(\n", 216 | " (0): Conv1d(128, 128, kernel_size=(1,), stride=(1,))\n", 217 | " (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 218 | " )\n", 219 | " )\n", 220 | " (3): AdaptiveAvgPool1d(output_size=1)\n", 221 | " (4): Flatten()\n", 222 | " (5): Linear(in_features=128, out_features=4, bias=True)\n", 223 | ")" 224 | ] 225 | }, 226 | "execution_count": 14, 227 | "metadata": {}, 228 | "output_type": "execute_result" 229 | } 230 | ], 231 | "source": [ 232 | "InceptionTime" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 10, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "tensor([1, 1, 1, ..., 3, 1, 0])" 244 | ] 245 | }, 246 | "execution_count": 10, 247 | "metadata": {}, 248 | "output_type": "execute_result" 249 | } 250 | ], 251 | "source": [ 252 | "InceptionTime.eval()\n", 253 | "with torch.no_grad():\n", 254 | " x_pred = np.argmax(InceptionTime(torch.tensor(X_test).float()).detach(), axis=1)\n", 255 | "x_pred" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 11, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "data": { 265 | "text/plain": [ 266 | "0.9125625555928921" 267 | ] 268 | }, 269 | "execution_count": 11, 270 | "metadata": {}, 271 | "output_type": "execute_result" 272 | } 273 | ], 274 | "source": [ 275 | "f1_score(y_true=y_test, y_pred=x_pred,average=\"macro\")" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 12, 281 | "metadata": {}, 282 | "outputs": [ 283 | { 284 | "data": { 285 | "text/plain": [ 286 | "0.9493307839388145" 287 | ] 288 | }, 289 | "execution_count": 12, 290 | "metadata": {}, 291 | "output_type": "execute_result" 292 | } 293 | ], 294 | "source": [ 295 | "accuracy_score(y_true=y_test, y_pred=x_pred)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 13, 301 | "metadata": { 302 | "scrolled": true 303 | }, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "text/plain": [ 308 | "array([[ 3805, 296, 27, 107],\n", 309 | " [ 91, 10307, 8, 39],\n", 310 | " [ 22, 8, 425, 24],\n", 311 | " [ 169, 37, 20, 1351]], dtype=int64)" 312 | ] 313 | }, 314 | "execution_count": 13, 315 | "metadata": {}, 316 | "output_type": "execute_result" 317 | } 318 | ], 319 | "source": [ 320 | "cf1 = confusion_matrix(y_true=y_test, y_pred=x_pred) # x_axis = predicted, y_axis = ground_truth\n", 321 | "cf1" 322 | ] 323 | } 324 | ], 325 | "metadata": { 326 | "kernelspec": { 327 | "display_name": "Pytorch", 328 | "language": "python", 329 | "name": "pytorch" 330 | }, 331 | "language_info": { 332 | "codemirror_mode": { 333 | "name": "ipython", 334 | "version": 3 335 | }, 336 | "file_extension": ".py", 337 | "mimetype": "text/x-python", 338 | "name": "python", 339 | "nbconvert_exporter": "python", 340 | "pygments_lexer": "ipython3", 341 | "version": "3.6.9" 342 | } 343 | }, 344 | "nbformat": 4, 345 | "nbformat_minor": 2 346 | } 347 | -------------------------------------------------------------------------------- /InceptionTime_full-version_lr-{5e-3,1e-3,2e-4},_bs-512_ks-[5,11,23]_100-epochs_state_dict.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheMrGhostman/InceptionTime-Pytorch/ea97517a5ebbc901284225387baedeca695cde26/InceptionTime_full-version_lr-{5e-3,1e-3,2e-4},_bs-512_ks-[5,11,23]_100-epochs_state_dict.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InceptionTime (in Pytorch) 2 | Unofficial Pytorch implementation of Inception layer for time series classification and its possible transposition for further use in Variational AutoEncoder. 3 | 4 | - Fawaz, H. I., Lucas, B., Forestier, G., Pelletier, C., Schmidt, D. F., Weber, J., ... & Petitjean, F. (2019). InceptionTime: Finding AlexNet for Time Series Classification. [arXiv preprint arXiv:1909.04939](https://arxiv.org/abs/1909.04939). 5 | - Official InceptionTime tensorflow implementation: https://github.com/hfawaz/InceptionTime 6 | 7 | 8 | -------------------------------------------------------------------------------- /data/sequenced_data_for_VAE_length-160_stride-10_pt1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheMrGhostman/InceptionTime-Pytorch/ea97517a5ebbc901284225387baedeca695cde26/data/sequenced_data_for_VAE_length-160_stride-10_pt1.npy -------------------------------------------------------------------------------- /data/sequenced_data_for_VAE_length-160_stride-10_pt2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheMrGhostman/InceptionTime-Pytorch/ea97517a5ebbc901284225387baedeca695cde26/data/sequenced_data_for_VAE_length-160_stride-10_pt2.npy -------------------------------------------------------------------------------- /data/sequenced_data_for_VAE_length-160_stride-10_targets.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TheMrGhostman/InceptionTime-Pytorch/ea97517a5ebbc901284225387baedeca695cde26/data/sequenced_data_for_VAE_length-160_stride-10_targets.npy -------------------------------------------------------------------------------- /inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def correct_sizes(sizes): 6 | corrected_sizes = [s if s % 2 != 0 else s - 1 for s in sizes] 7 | return corrected_sizes 8 | 9 | 10 | def pass_through(X): 11 | return X 12 | 13 | 14 | class Inception(nn.Module): 15 | def __init__(self, in_channels, n_filters, kernel_sizes=[9, 19, 39], bottleneck_channels=32, activation=nn.ReLU(), return_indices=False): 16 | """ 17 | : param in_channels Number of input channels (input features) 18 | : param n_filters Number of filters per convolution layer => out_channels = 4*n_filters 19 | : param kernel_sizes List of kernel sizes for each convolution. 20 | Each kernel size must be odd number that meets -> "kernel_size % 2 !=0". 21 | This is nessesery because of padding size. 22 | For correction of kernel_sizes use function "correct_sizes". 23 | : param bottleneck_channels Number of output channels in bottleneck. 24 | Bottleneck wont be used if nuber of in_channels is equal to 1. 25 | : param activation Activation function for output tensor (nn.ReLU()). 26 | : param return_indices Indices are needed only if we want to create decoder with InceptionTranspose with MaxUnpool1d. 27 | """ 28 | super(Inception, self).__init__() 29 | self.return_indices=return_indices 30 | if in_channels > 1: 31 | self.bottleneck = nn.Conv1d( 32 | in_channels=in_channels, 33 | out_channels=bottleneck_channels, 34 | kernel_size=1, 35 | stride=1, 36 | bias=False 37 | ) 38 | else: 39 | self.bottleneck = pass_through 40 | bottleneck_channels = 1 41 | 42 | self.conv_from_bottleneck_1 = nn.Conv1d( 43 | in_channels=bottleneck_channels, 44 | out_channels=n_filters, 45 | kernel_size=kernel_sizes[0], 46 | stride=1, 47 | padding=kernel_sizes[0]//2, 48 | bias=False 49 | ) 50 | self.conv_from_bottleneck_2 = nn.Conv1d( 51 | in_channels=bottleneck_channels, 52 | out_channels=n_filters, 53 | kernel_size=kernel_sizes[1], 54 | stride=1, 55 | padding=kernel_sizes[1]//2, 56 | bias=False 57 | ) 58 | self.conv_from_bottleneck_3 = nn.Conv1d( 59 | in_channels=bottleneck_channels, 60 | out_channels=n_filters, 61 | kernel_size=kernel_sizes[2], 62 | stride=1, 63 | padding=kernel_sizes[2]//2, 64 | bias=False 65 | ) 66 | self.max_pool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1, return_indices=return_indices) 67 | self.conv_from_maxpool = nn.Conv1d( 68 | in_channels=in_channels, 69 | out_channels=n_filters, 70 | kernel_size=1, 71 | stride=1, 72 | padding=0, 73 | bias=False 74 | ) 75 | self.batch_norm = nn.BatchNorm1d(num_features=4*n_filters) 76 | self.activation = activation 77 | 78 | def forward(self, X): 79 | # step 1 80 | Z_bottleneck = self.bottleneck(X) 81 | if self.return_indices: 82 | Z_maxpool, indices = self.max_pool(X) 83 | else: 84 | Z_maxpool = self.max_pool(X) 85 | # step 2 86 | Z1 = self.conv_from_bottleneck_1(Z_bottleneck) 87 | Z2 = self.conv_from_bottleneck_2(Z_bottleneck) 88 | Z3 = self.conv_from_bottleneck_3(Z_bottleneck) 89 | Z4 = self.conv_from_maxpool(Z_maxpool) 90 | # step 3 91 | Z = torch.cat([Z1, Z2, Z3, Z4], axis=1) 92 | Z = self.activation(self.batch_norm(Z)) 93 | if self.return_indices: 94 | return Z, indices 95 | else: 96 | return Z 97 | 98 | 99 | class InceptionBlock(nn.Module): 100 | def __init__(self, in_channels, n_filters=32, kernel_sizes=[9,19,39], bottleneck_channels=32, use_residual=True, activation=nn.ReLU(), return_indices=False): 101 | super(InceptionBlock, self).__init__() 102 | self.use_residual = use_residual 103 | self.return_indices = return_indices 104 | self.activation = activation 105 | self.inception_1 = Inception( 106 | in_channels=in_channels, 107 | n_filters=n_filters, 108 | kernel_sizes=kernel_sizes, 109 | bottleneck_channels=bottleneck_channels, 110 | activation=activation, 111 | return_indices=return_indices 112 | ) 113 | self.inception_2 = Inception( 114 | in_channels=4*n_filters, 115 | n_filters=n_filters, 116 | kernel_sizes=kernel_sizes, 117 | bottleneck_channels=bottleneck_channels, 118 | activation=activation, 119 | return_indices=return_indices 120 | ) 121 | self.inception_3 = Inception( 122 | in_channels=4*n_filters, 123 | n_filters=n_filters, 124 | kernel_sizes=kernel_sizes, 125 | bottleneck_channels=bottleneck_channels, 126 | activation=activation, 127 | return_indices=return_indices 128 | ) 129 | if self.use_residual: 130 | self.residual = nn.Sequential( 131 | nn.Conv1d( 132 | in_channels=in_channels, 133 | out_channels=4*n_filters, 134 | kernel_size=1, 135 | stride=1, 136 | padding=0 137 | ), 138 | nn.BatchNorm1d( 139 | num_features=4*n_filters 140 | ) 141 | ) 142 | 143 | def forward(self, X): 144 | if self.return_indices: 145 | Z, i1 = self.inception_1(X) 146 | Z, i2 = self.inception_2(Z) 147 | Z, i3 = self.inception_3(Z) 148 | else: 149 | Z = self.inception_1(X) 150 | Z = self.inception_2(Z) 151 | Z = self.inception_3(Z) 152 | if self.use_residual: 153 | Z = Z + self.residual(X) 154 | Z = self.activation(Z) 155 | if self.return_indices: 156 | return Z,[i1, i2, i3] 157 | else: 158 | return Z 159 | 160 | 161 | 162 | class InceptionTranspose(nn.Module): 163 | def __init__(self, in_channels, out_channels, kernel_sizes=[9, 19, 39], bottleneck_channels=32, activation=nn.ReLU()): 164 | """ 165 | : param in_channels Number of input channels (input features) 166 | : param n_filters Number of filters per convolution layer => out_channels = 4*n_filters 167 | : param kernel_sizes List of kernel sizes for each convolution. 168 | Each kernel size must be odd number that meets -> "kernel_size % 2 !=0". 169 | This is nessesery because of padding size. 170 | For correction of kernel_sizes use function "correct_sizes". 171 | : param bottleneck_channels Number of output channels in bottleneck. 172 | Bottleneck wont be used if nuber of in_channels is equal to 1. 173 | : param activation Activation function for output tensor (nn.ReLU()). 174 | """ 175 | super(InceptionTranspose, self).__init__() 176 | self.activation = activation 177 | self.conv_to_bottleneck_1 = nn.ConvTranspose1d( 178 | in_channels=in_channels, 179 | out_channels=bottleneck_channels, 180 | kernel_size=kernel_sizes[0], 181 | stride=1, 182 | padding=kernel_sizes[0]//2, 183 | bias=False 184 | ) 185 | self.conv_to_bottleneck_2 = nn.ConvTranspose1d( 186 | in_channels=in_channels, 187 | out_channels=bottleneck_channels, 188 | kernel_size=kernel_sizes[1], 189 | stride=1, 190 | padding=kernel_sizes[1]//2, 191 | bias=False 192 | ) 193 | self.conv_to_bottleneck_3 = nn.ConvTranspose1d( 194 | in_channels=in_channels, 195 | out_channels=bottleneck_channels, 196 | kernel_size=kernel_sizes[2], 197 | stride=1, 198 | padding=kernel_sizes[2]//2, 199 | bias=False 200 | ) 201 | self.conv_to_maxpool = nn.Conv1d( 202 | in_channels=in_channels, 203 | out_channels=out_channels, 204 | kernel_size=1, 205 | stride=1, 206 | padding=0, 207 | bias=False 208 | ) 209 | self.max_unpool = nn.MaxUnpool1d(kernel_size=3, stride=1, padding=1) 210 | self.bottleneck = nn.Conv1d( 211 | in_channels=3*bottleneck_channels, 212 | out_channels=out_channels, 213 | kernel_size=1, 214 | stride=1, 215 | bias=False 216 | ) 217 | self.batch_norm = nn.BatchNorm1d(num_features=out_channels) 218 | 219 | def forward(self, X, indices): 220 | Z1 = self.conv_to_bottleneck_1(X) 221 | Z2 = self.conv_to_bottleneck_2(X) 222 | Z3 = self.conv_to_bottleneck_3(X) 223 | Z4 = self.conv_to_maxpool(X) 224 | 225 | Z = torch.cat([Z1, Z2, Z3], axis=1) 226 | MUP = self.max_unpool(Z4, indices) 227 | BN = self.bottleneck(Z) 228 | # another possibility insted of sum BN and MUP is adding 2nd bottleneck transposed convolution 229 | 230 | return self.activation(self.batch_norm(BN + MUP)) 231 | 232 | 233 | class InceptionTransposeBlock(nn.Module): 234 | def __init__(self, in_channels, out_channels=32, kernel_sizes=[9,19,39], bottleneck_channels=32, use_residual=True, activation=nn.ReLU()): 235 | super(InceptionTransposeBlock, self).__init__() 236 | self.use_residual = use_residual 237 | self.activation = activation 238 | self.inception_1 = InceptionTranspose( 239 | in_channels=in_channels, 240 | out_channels=in_channels, 241 | kernel_sizes=kernel_sizes, 242 | bottleneck_channels=bottleneck_channels, 243 | activation=activation 244 | ) 245 | self.inception_2 = InceptionTranspose( 246 | in_channels=in_channels, 247 | out_channels=in_channels, 248 | kernel_sizes=kernel_sizes, 249 | bottleneck_channels=bottleneck_channels, 250 | activation=activation 251 | ) 252 | self.inception_3 = InceptionTranspose( 253 | in_channels=in_channels, 254 | out_channels=out_channels, 255 | kernel_sizes=kernel_sizes, 256 | bottleneck_channels=bottleneck_channels, 257 | activation=activation 258 | ) 259 | if self.use_residual: 260 | self.residual = nn.Sequential( 261 | nn.ConvTranspose1d( 262 | in_channels=in_channels, 263 | out_channels=out_channels, 264 | kernel_size=1, 265 | stride=1, 266 | padding=0 267 | ), 268 | nn.BatchNorm1d( 269 | num_features=out_channels 270 | ) 271 | ) 272 | 273 | def forward(self, X, indices): 274 | assert len(indices)==3 275 | Z = self.inception_1(X, indices[2]) 276 | Z = self.inception_2(Z, indices[1]) 277 | Z = self.inception_3(Z, indices[0]) 278 | if self.use_residual: 279 | Z = Z + self.residual(X) 280 | Z = self.activation(Z) 281 | return Z --------------------------------------------------------------------------------