├── .gitignore ├── LA-Transformer Testing.html ├── LA-Transformer Testing.ipynb ├── LA-Transformer Training.html ├── LA-Transformer Training.ipynb ├── LATransformer ├── metrics.py ├── model.py └── utils.py ├── LICENSE └── Readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LA-Transformer Testing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from __future__ import print_function\n", 10 | "\n", 11 | "import os\n", 12 | "import time\n", 13 | "import glob\n", 14 | "import random\n", 15 | "import zipfile\n", 16 | "from itertools import chain\n", 17 | "\n", 18 | "import timm\n", 19 | "import numpy as np\n", 20 | "import pandas as pd\n", 21 | "from PIL import Image\n", 22 | "from tqdm.notebook import tqdm\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "from collections import OrderedDict\n", 25 | "from sklearn.model_selection import train_test_split\n", 26 | "\n", 27 | "import torch\n", 28 | "import torch.nn as nn\n", 29 | "from torch.nn import init\n", 30 | "import torch.optim as optim\n", 31 | "from torchvision import models\n", 32 | "import torch.nn.functional as F\n", 33 | "from torch.autograd import Variable\n", 34 | "from torch.optim.lr_scheduler import StepLR\n", 35 | "from torchvision import datasets, transforms\n", 36 | "from torch.utils.data import DataLoader, Dataset\n", 37 | "\n", 38 | "from LATransformer.model import ClassBlock, LATransformer, LATransformerTest\n", 39 | "from LATransformer.utils import save_network, update_summary, get_id\n", 40 | "from LATransformer.metrics import rank1, rank5, rank10, calc_map\n", 41 | "\n", 42 | "os.environ['CUDA_VISIBLE_DEVICES']='1'\n", 43 | "device = \"cuda\"" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## Config Parameters" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "batch_size = 8\n", 60 | "gamma = 0.7\n", 61 | "seed = 42" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Load Model" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "LATransformerTest(\n", 80 | " (model): VisionTransformer(\n", 81 | " (patch_embed): PatchEmbed(\n", 82 | " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n", 83 | " )\n", 84 | " (pos_drop): Dropout(p=0.0, inplace=False)\n", 85 | " (blocks): ModuleList(\n", 86 | " (0): Block(\n", 87 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 88 | " (attn): Attention(\n", 89 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 90 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 91 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 92 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 93 | " )\n", 94 | " (drop_path): Identity()\n", 95 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 96 | " (mlp): Mlp(\n", 97 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 98 | " (act): GELU()\n", 99 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 100 | " (drop): Dropout(p=0.0, inplace=False)\n", 101 | " )\n", 102 | " )\n", 103 | " (1): Block(\n", 104 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 105 | " (attn): Attention(\n", 106 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 107 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 108 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 109 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 110 | " )\n", 111 | " (drop_path): Identity()\n", 112 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 113 | " (mlp): Mlp(\n", 114 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 115 | " (act): GELU()\n", 116 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 117 | " (drop): Dropout(p=0.0, inplace=False)\n", 118 | " )\n", 119 | " )\n", 120 | " (2): Block(\n", 121 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 122 | " (attn): Attention(\n", 123 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 124 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 125 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 126 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 127 | " )\n", 128 | " (drop_path): Identity()\n", 129 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 130 | " (mlp): Mlp(\n", 131 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 132 | " (act): GELU()\n", 133 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 134 | " (drop): Dropout(p=0.0, inplace=False)\n", 135 | " )\n", 136 | " )\n", 137 | " (3): Block(\n", 138 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 139 | " (attn): Attention(\n", 140 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 141 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 142 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 143 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 144 | " )\n", 145 | " (drop_path): Identity()\n", 146 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 147 | " (mlp): Mlp(\n", 148 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 149 | " (act): GELU()\n", 150 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 151 | " (drop): Dropout(p=0.0, inplace=False)\n", 152 | " )\n", 153 | " )\n", 154 | " (4): Block(\n", 155 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 156 | " (attn): Attention(\n", 157 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 158 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 159 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 160 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 161 | " )\n", 162 | " (drop_path): Identity()\n", 163 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 164 | " (mlp): Mlp(\n", 165 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 166 | " (act): GELU()\n", 167 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 168 | " (drop): Dropout(p=0.0, inplace=False)\n", 169 | " )\n", 170 | " )\n", 171 | " (5): Block(\n", 172 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 173 | " (attn): Attention(\n", 174 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 175 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 176 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 177 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 178 | " )\n", 179 | " (drop_path): Identity()\n", 180 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 181 | " (mlp): Mlp(\n", 182 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 183 | " (act): GELU()\n", 184 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 185 | " (drop): Dropout(p=0.0, inplace=False)\n", 186 | " )\n", 187 | " )\n", 188 | " (6): Block(\n", 189 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 190 | " (attn): Attention(\n", 191 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 192 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 193 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 194 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 195 | " )\n", 196 | " (drop_path): Identity()\n", 197 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 198 | " (mlp): Mlp(\n", 199 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 200 | " (act): GELU()\n", 201 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 202 | " (drop): Dropout(p=0.0, inplace=False)\n", 203 | " )\n", 204 | " )\n", 205 | " (7): Block(\n", 206 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 207 | " (attn): Attention(\n", 208 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 209 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 210 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 211 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 212 | " )\n", 213 | " (drop_path): Identity()\n", 214 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 215 | " (mlp): Mlp(\n", 216 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 217 | " (act): GELU()\n", 218 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 219 | " (drop): Dropout(p=0.0, inplace=False)\n", 220 | " )\n", 221 | " )\n", 222 | " (8): Block(\n", 223 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 224 | " (attn): Attention(\n", 225 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 226 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 227 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 228 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 229 | " )\n", 230 | " (drop_path): Identity()\n", 231 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 232 | " (mlp): Mlp(\n", 233 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 234 | " (act): GELU()\n", 235 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 236 | " (drop): Dropout(p=0.0, inplace=False)\n", 237 | " )\n", 238 | " )\n", 239 | " (9): Block(\n", 240 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 241 | " (attn): Attention(\n", 242 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 243 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 244 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 245 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 246 | " )\n", 247 | " (drop_path): Identity()\n", 248 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 249 | " (mlp): Mlp(\n", 250 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 251 | " (act): GELU()\n", 252 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 253 | " (drop): Dropout(p=0.0, inplace=False)\n", 254 | " )\n", 255 | " )\n", 256 | " (10): Block(\n", 257 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 258 | " (attn): Attention(\n", 259 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 260 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 261 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 262 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 263 | " )\n", 264 | " (drop_path): Identity()\n", 265 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 266 | " (mlp): Mlp(\n", 267 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 268 | " (act): GELU()\n", 269 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 270 | " (drop): Dropout(p=0.0, inplace=False)\n", 271 | " )\n", 272 | " )\n", 273 | " (11): Block(\n", 274 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 275 | " (attn): Attention(\n", 276 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 277 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 278 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 279 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 280 | " )\n", 281 | " (drop_path): Identity()\n", 282 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 283 | " (mlp): Mlp(\n", 284 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 285 | " (act): GELU()\n", 286 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 287 | " (drop): Dropout(p=0.0, inplace=False)\n", 288 | " )\n", 289 | " )\n", 290 | " )\n", 291 | " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 292 | " (head): Linear(in_features=768, out_features=751, bias=True)\n", 293 | " )\n", 294 | " (avgpool): AdaptiveAvgPool2d(output_size=(14, 768))\n", 295 | " (dropout): Dropout(p=0.5, inplace=False)\n", 296 | ")" 297 | ] 298 | }, 299 | "execution_count": 3, 300 | "metadata": {}, 301 | "output_type": "execute_result" 302 | } 303 | ], 304 | "source": [ 305 | "# Load ViT\n", 306 | "vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)\n", 307 | "vit_base= vit_base.to(device)\n", 308 | "\n", 309 | "# Create La-Transformer\n", 310 | "model = LATransformerTest(vit_base, lmbd=8).to(device)\n", 311 | "\n", 312 | "# Load LA-Transformer\n", 313 | "name = \"la_with_lmbd_8\"\n", 314 | "save_path = os.path.join('./model',name,'net_best.pth')\n", 315 | "model.load_state_dict(torch.load(save_path), strict=False)\n", 316 | "model.eval()" 317 | ] 318 | }, 319 | { 320 | "cell_type": "markdown", 321 | "metadata": {}, 322 | "source": [ 323 | "\n", 324 | "\n", 325 | "### DataLoader" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 4, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "transform_query_list = [\n", 335 | " transforms.Resize((224,224), interpolation=3),\n", 336 | " transforms.RandomHorizontalFlip(),\n", 337 | " transforms.ToTensor(),\n", 338 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 339 | " ]\n", 340 | "transform_gallery_list = [\n", 341 | " transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC\n", 342 | " transforms.ToTensor(),\n", 343 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 344 | " ]\n", 345 | "data_transforms = {\n", 346 | "'query': transforms.Compose( transform_query_list ),\n", 347 | "'gallery': transforms.Compose(transform_gallery_list),\n", 348 | "}" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 5, 354 | "metadata": {}, 355 | "outputs": [ 356 | { 357 | "name": "stdout", 358 | "output_type": "stream", 359 | "text": [ 360 | "750\n" 361 | ] 362 | } 363 | ], 364 | "source": [ 365 | "image_datasets = {}\n", 366 | "data_dir = \"data/Market-Pytorch/Market/\"\n", 367 | "\n", 368 | "image_datasets['query'] = datasets.ImageFolder(os.path.join(data_dir, 'query'),\n", 369 | " data_transforms['query'])\n", 370 | "image_datasets['gallery'] = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),\n", 371 | " data_transforms['gallery'])\n", 372 | "query_loader = DataLoader(dataset = image_datasets['query'], batch_size=batch_size, shuffle=False )\n", 373 | "gallery_loader = DataLoader(dataset = image_datasets['gallery'], batch_size=batch_size, shuffle=False)\n", 374 | "\n", 375 | "class_names = image_datasets['query'].classes\n", 376 | "print(len(class_names))" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": {}, 382 | "source": [ 383 | "### Extract Features" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 6, 389 | "metadata": {}, 390 | "outputs": [], 391 | "source": [ 392 | "activation = {}\n", 393 | "def get_activation(name):\n", 394 | " def hook(model, input, output):\n", 395 | " activation[name] = output.detach()\n", 396 | " return hook" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 7, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "def extract_feature(model,dataloaders):\n", 406 | " \n", 407 | " features = torch.FloatTensor()\n", 408 | " count = 0\n", 409 | " idx = 0\n", 410 | " for data in tqdm(dataloaders):\n", 411 | " img, label = data\n", 412 | " img, label = img.to(device), label.to(device)\n", 413 | "\n", 414 | " output = model(img)\n", 415 | "\n", 416 | " n, c, h, w = img.size()\n", 417 | " \n", 418 | " count += n\n", 419 | " features = torch.cat((features, output.detach().cpu()), 0)\n", 420 | " idx += 1\n", 421 | " return features" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 8, 427 | "metadata": { 428 | "scrolled": true 429 | }, 430 | "outputs": [ 431 | { 432 | "data": { 433 | "application/vnd.jupyter.widget-view+json": { 434 | "model_id": "febb2a07ac2f42178b9fdec40350e415", 435 | "version_major": 2, 436 | "version_minor": 0 437 | }, 438 | "text/plain": [ 439 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=421.0), HTML(value='')))" 440 | ] 441 | }, 442 | "metadata": {}, 443 | "output_type": "display_data" 444 | }, 445 | { 446 | "name": "stdout", 447 | "output_type": "stream", 448 | "text": [ 449 | "\n" 450 | ] 451 | }, 452 | { 453 | "data": { 454 | "application/vnd.jupyter.widget-view+json": { 455 | "model_id": "7a1342f8f990420e90d234818e474955", 456 | "version_major": 2, 457 | "version_minor": 0 458 | }, 459 | "text/plain": [ 460 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2467.0), HTML(value='')))" 461 | ] 462 | }, 463 | "metadata": {}, 464 | "output_type": "display_data" 465 | }, 466 | { 467 | "name": "stdout", 468 | "output_type": "stream", 469 | "text": [ 470 | "\n" 471 | ] 472 | } 473 | ], 474 | "source": [ 475 | "# Extract Query Features\n", 476 | "query_feature= extract_feature(model, query_loader)\n", 477 | "\n", 478 | "# Extract Gallery Features\n", 479 | "gallery_feature = extract_feature(model, gallery_loader)" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": 9, 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "# Retrieve labels\n", 489 | "gallery_path = image_datasets['gallery'].imgs\n", 490 | "query_path = image_datasets['query'].imgs\n", 491 | "\n", 492 | "gallery_cam,gallery_label = get_id(gallery_path)\n", 493 | "query_cam,query_label = get_id(query_path)" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": {}, 499 | "source": [ 500 | "## Concat Averaged GELTs" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": 10, 506 | "metadata": {}, 507 | "outputs": [ 508 | { 509 | "data": { 510 | "application/vnd.jupyter.widget-view+json": { 511 | "model_id": "07fbd13be6e943e7ab65d3a7354c5d87", 512 | "version_major": 2, 513 | "version_minor": 0 514 | }, 515 | "text/plain": [ 516 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3368.0), HTML(value='')))" 517 | ] 518 | }, 519 | "metadata": {}, 520 | "output_type": "display_data" 521 | }, 522 | { 523 | "name": "stdout", 524 | "output_type": "stream", 525 | "text": [ 526 | "\n" 527 | ] 528 | }, 529 | { 530 | "data": { 531 | "application/vnd.jupyter.widget-view+json": { 532 | "model_id": "1afc1486a57544ec9a2365265fab4281", 533 | "version_major": 2, 534 | "version_minor": 0 535 | }, 536 | "text/plain": [ 537 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=19732.0), HTML(value='')))" 538 | ] 539 | }, 540 | "metadata": {}, 541 | "output_type": "display_data" 542 | }, 543 | { 544 | "name": "stdout", 545 | "output_type": "stream", 546 | "text": [ 547 | "\n" 548 | ] 549 | } 550 | ], 551 | "source": [ 552 | "concatenated_query_vectors = []\n", 553 | "for query in tqdm(query_feature):\n", 554 | " \n", 555 | " fnorm = torch.norm(query, p=2, dim=1, keepdim=True)*np.sqrt(14)\n", 556 | " \n", 557 | " query_norm = query.div(fnorm.expand_as(query))\n", 558 | " \n", 559 | " concatenated_query_vectors.append(query_norm.view((-1))) # 14*768 -> 10752\n", 560 | "\n", 561 | "concatenated_gallery_vectors = []\n", 562 | "for gallery in tqdm(gallery_feature):\n", 563 | " \n", 564 | " fnorm = torch.norm(gallery, p=2, dim=1, keepdim=True) *np.sqrt(14)\n", 565 | " \n", 566 | " gallery_norm = gallery.div(fnorm.expand_as(gallery))\n", 567 | " \n", 568 | " concatenated_gallery_vectors.append(gallery_norm.view((-1))) # 14*768 -> 10752\n", 569 | " " 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "metadata": {}, 575 | "source": [ 576 | "## Calculate Similarity using FAISS" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 11, 582 | "metadata": {}, 583 | "outputs": [], 584 | "source": [ 585 | "import faiss\n", 586 | "import numpy as np\n", 587 | "\n", 588 | "\n", 589 | "index = faiss.IndexIDMap(faiss.IndexFlatIP(10752))\n", 590 | "\n", 591 | "index.add_with_ids(np.array([t.numpy() for t in concatenated_gallery_vectors]),np.array(gallery_label))\n", 592 | "\n", 593 | "# xb = np.array([t.numpy() for t in concatenated_gallery_vectors]).astype(dtype=np.float32)\n", 594 | "# index = faiss.IndexFlatL2(10752) \n", 595 | "# ids = np.array(gallery_label, dtype=np.float32)\n", 596 | "# index2 = faiss.IndexIDMap(index)\n", 597 | "# index2.add_with_ids(xb, ids)\n", 598 | "\n", 599 | "\n", 600 | "def search(query: str, k=1):\n", 601 | " encoded_query = query.unsqueeze(dim=0).numpy()\n", 602 | " top_k = index.search(encoded_query, k)\n", 603 | " return top_k" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": 12, 609 | "metadata": {}, 610 | "outputs": [ 611 | { 612 | "name": "stdout", 613 | "output_type": "stream", 614 | "text": [ 615 | "Rank1: 0.9833729216152018, Rank5: 0.9973277909738717, Rank10: 0.9982185273159145, mAP: 0.9279050887119389\n" 616 | ] 617 | } 618 | ], 619 | "source": [ 620 | "rank1_score = 0\n", 621 | "rank5_score = 0\n", 622 | "rank10_score = 0\n", 623 | "ap = 0\n", 624 | "count = 0\n", 625 | "for query, label in zip(concatenated_query_vectors, query_label):\n", 626 | " count += 1\n", 627 | " label = label\n", 628 | " output = search(query, k=10)\n", 629 | "# print(output)\n", 630 | " rank1_score += rank1(label, output) \n", 631 | " rank5_score += rank5(label, output) \n", 632 | " rank10_score += rank10(label, output) \n", 633 | " print(\"Correct: {}, Total: {}, Incorrect: {}\".format(rank1_score, count, count-rank1_score), end=\"\\r\")\n", 634 | " ap += calc_map(label, output)\n", 635 | "\n", 636 | "print(\"Rank1: {}, Rank5: {}, Rank10: {}, mAP: {}\".format(rank1_score/len(query_feature), \n", 637 | " rank5_score/len(query_feature), \n", 638 | " rank10_score/len(query_feature), ap/len(query_feature))) " 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": null, 644 | "metadata": {}, 645 | "outputs": [], 646 | "source": [] 647 | } 648 | ], 649 | "metadata": { 650 | "kernelspec": { 651 | "display_name": "Python 3", 652 | "language": "python", 653 | "name": "python3" 654 | }, 655 | "language_info": { 656 | "codemirror_mode": { 657 | "name": "ipython", 658 | "version": 3 659 | }, 660 | "file_extension": ".py", 661 | "mimetype": "text/x-python", 662 | "name": "python", 663 | "nbconvert_exporter": "python", 664 | "pygments_lexer": "ipython3", 665 | "version": "3.7.4" 666 | } 667 | }, 668 | "nbformat": 4, 669 | "nbformat_minor": 4 670 | } 671 | -------------------------------------------------------------------------------- /LA-Transformer Training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Import Libraries" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from __future__ import print_function\n", 17 | "\n", 18 | "import os\n", 19 | "import time\n", 20 | "import random\n", 21 | "import zipfile\n", 22 | "from itertools import chain\n", 23 | "\n", 24 | "import timm\n", 25 | "import numpy as np\n", 26 | "from PIL import Image\n", 27 | "from tqdm.notebook import tqdm\n", 28 | "from collections import OrderedDict\n", 29 | "\n", 30 | "import torch\n", 31 | "import torch.nn as nn\n", 32 | "from torch.nn import init\n", 33 | "import torch.optim as optim\n", 34 | "from torchvision import models\n", 35 | "import torch.nn.functional as F\n", 36 | "from torch.autograd import Variable\n", 37 | "from torch.optim.lr_scheduler import StepLR\n", 38 | "from torchvision import datasets, transforms\n", 39 | "from torch.utils.data import DataLoader, Dataset\n", 40 | "\n", 41 | "from LATransformer.model import ClassBlock, LATransformer\n", 42 | "from LATransformer.utils import save_network, update_summary\n", 43 | "\n", 44 | "os.environ['CUDA_VISIBLE_DEVICES']='1'\n", 45 | "device = \"cuda\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "### Set Config Parameters" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "batch_size = 32\n", 62 | "num_epochs = 30\n", 63 | "lr = 3e-4\n", 64 | "gamma = 0.7\n", 65 | "unfreeze_after=2\n", 66 | "lr_decay=.8\n", 67 | "lmbd = 8" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Load Data" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 3, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "transform_train_list = [\n", 84 | " transforms.Resize((224,224), interpolation=3),\n", 85 | " transforms.RandomHorizontalFlip(),\n", 86 | " transforms.ToTensor(),\n", 87 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 88 | " ]\n", 89 | "transform_val_list = [\n", 90 | " transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC\n", 91 | " transforms.ToTensor(),\n", 92 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", 93 | " ]\n", 94 | "data_transforms = {\n", 95 | "'train': transforms.Compose( transform_train_list ),\n", 96 | "'val': transforms.Compose(transform_val_list),\n", 97 | "}" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "751\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "image_datasets = {}\n", 115 | "data_dir = \"data/Market-Pytorch/Market/\"\n", 116 | "\n", 117 | "image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),\n", 118 | " data_transforms['train'])\n", 119 | "image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),\n", 120 | " data_transforms['val'])\n", 121 | "train_loader = DataLoader(dataset = image_datasets['train'], batch_size=batch_size, shuffle=True )\n", 122 | "valid_loader = DataLoader(dataset = image_datasets['val'], batch_size=batch_size, shuffle=True)\n", 123 | "# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,\n", 124 | "# shuffle=True, num_workers=8, pin_memory=True) # 8 workers may work faster\n", 125 | "# for x in ['train', 'val']}\n", 126 | "# dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n", 127 | "class_names = image_datasets['train'].classes\n", 128 | "print(len(class_names))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "## Load Model" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 5, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "text/plain": [ 146 | "VisionTransformer(\n", 147 | " (patch_embed): PatchEmbed(\n", 148 | " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n", 149 | " )\n", 150 | " (pos_drop): Dropout(p=0.0, inplace=False)\n", 151 | " (blocks): ModuleList(\n", 152 | " (0): Block(\n", 153 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 154 | " (attn): Attention(\n", 155 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 156 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 157 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 158 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 159 | " )\n", 160 | " (drop_path): Identity()\n", 161 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 162 | " (mlp): Mlp(\n", 163 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 164 | " (act): GELU()\n", 165 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 166 | " (drop): Dropout(p=0.0, inplace=False)\n", 167 | " )\n", 168 | " )\n", 169 | " (1): Block(\n", 170 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 171 | " (attn): Attention(\n", 172 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 173 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 174 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 175 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 176 | " )\n", 177 | " (drop_path): Identity()\n", 178 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 179 | " (mlp): Mlp(\n", 180 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 181 | " (act): GELU()\n", 182 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 183 | " (drop): Dropout(p=0.0, inplace=False)\n", 184 | " )\n", 185 | " )\n", 186 | " (2): Block(\n", 187 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 188 | " (attn): Attention(\n", 189 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 190 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 191 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 192 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 193 | " )\n", 194 | " (drop_path): Identity()\n", 195 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 196 | " (mlp): Mlp(\n", 197 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 198 | " (act): GELU()\n", 199 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 200 | " (drop): Dropout(p=0.0, inplace=False)\n", 201 | " )\n", 202 | " )\n", 203 | " (3): Block(\n", 204 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 205 | " (attn): Attention(\n", 206 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 207 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 208 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 209 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 210 | " )\n", 211 | " (drop_path): Identity()\n", 212 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 213 | " (mlp): Mlp(\n", 214 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 215 | " (act): GELU()\n", 216 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 217 | " (drop): Dropout(p=0.0, inplace=False)\n", 218 | " )\n", 219 | " )\n", 220 | " (4): Block(\n", 221 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 222 | " (attn): Attention(\n", 223 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 224 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 225 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 226 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 227 | " )\n", 228 | " (drop_path): Identity()\n", 229 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 230 | " (mlp): Mlp(\n", 231 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 232 | " (act): GELU()\n", 233 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 234 | " (drop): Dropout(p=0.0, inplace=False)\n", 235 | " )\n", 236 | " )\n", 237 | " (5): Block(\n", 238 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 239 | " (attn): Attention(\n", 240 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 241 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 242 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 243 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 244 | " )\n", 245 | " (drop_path): Identity()\n", 246 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 247 | " (mlp): Mlp(\n", 248 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 249 | " (act): GELU()\n", 250 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 251 | " (drop): Dropout(p=0.0, inplace=False)\n", 252 | " )\n", 253 | " )\n", 254 | " (6): Block(\n", 255 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 256 | " (attn): Attention(\n", 257 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 258 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 259 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 260 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 261 | " )\n", 262 | " (drop_path): Identity()\n", 263 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 264 | " (mlp): Mlp(\n", 265 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 266 | " (act): GELU()\n", 267 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 268 | " (drop): Dropout(p=0.0, inplace=False)\n", 269 | " )\n", 270 | " )\n", 271 | " (7): Block(\n", 272 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 273 | " (attn): Attention(\n", 274 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 275 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 276 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 277 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 278 | " )\n", 279 | " (drop_path): Identity()\n", 280 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 281 | " (mlp): Mlp(\n", 282 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 283 | " (act): GELU()\n", 284 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 285 | " (drop): Dropout(p=0.0, inplace=False)\n", 286 | " )\n", 287 | " )\n", 288 | " (8): Block(\n", 289 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 290 | " (attn): Attention(\n", 291 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 292 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 293 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 294 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 295 | " )\n", 296 | " (drop_path): Identity()\n", 297 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 298 | " (mlp): Mlp(\n", 299 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 300 | " (act): GELU()\n", 301 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 302 | " (drop): Dropout(p=0.0, inplace=False)\n", 303 | " )\n", 304 | " )\n", 305 | " (9): Block(\n", 306 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 307 | " (attn): Attention(\n", 308 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 309 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 310 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 311 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 312 | " )\n", 313 | " (drop_path): Identity()\n", 314 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 315 | " (mlp): Mlp(\n", 316 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 317 | " (act): GELU()\n", 318 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 319 | " (drop): Dropout(p=0.0, inplace=False)\n", 320 | " )\n", 321 | " )\n", 322 | " (10): Block(\n", 323 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 324 | " (attn): Attention(\n", 325 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 326 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 327 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 328 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 329 | " )\n", 330 | " (drop_path): Identity()\n", 331 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 332 | " (mlp): Mlp(\n", 333 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 334 | " (act): GELU()\n", 335 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 336 | " (drop): Dropout(p=0.0, inplace=False)\n", 337 | " )\n", 338 | " )\n", 339 | " (11): Block(\n", 340 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 341 | " (attn): Attention(\n", 342 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 343 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 344 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 345 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 346 | " )\n", 347 | " (drop_path): Identity()\n", 348 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 349 | " (mlp): Mlp(\n", 350 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 351 | " (act): GELU()\n", 352 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 353 | " (drop): Dropout(p=0.0, inplace=False)\n", 354 | " )\n", 355 | " )\n", 356 | " )\n", 357 | " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 358 | " (head): Linear(in_features=768, out_features=751, bias=True)\n", 359 | ")" 360 | ] 361 | }, 362 | "execution_count": 5, 363 | "metadata": {}, 364 | "output_type": "execute_result" 365 | } 366 | ], 367 | "source": [ 368 | "# Load pre-trained ViT\n", 369 | "vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)\n", 370 | "vit_base= vit_base.to(device)\n", 371 | "vit_base.eval()" 372 | ] 373 | }, 374 | { 375 | "cell_type": "markdown", 376 | "metadata": {}, 377 | "source": [ 378 | "\n", 379 | "\n", 380 | "### Train" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 6, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "class AverageMeter:\n", 390 | " \"\"\"Computes and stores the average and current value\"\"\"\n", 391 | " def __init__(self):\n", 392 | " self.reset()\n", 393 | "\n", 394 | " def reset(self):\n", 395 | " self.val = 0\n", 396 | " self.avg = 0\n", 397 | " self.sum = 0\n", 398 | " self.count = 0\n", 399 | "\n", 400 | " def update(self, val, n=1):\n", 401 | " self.val = val\n", 402 | " self.sum += val * n\n", 403 | " self.count += n\n", 404 | " self.avg = self.sum / self.count" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 7, 410 | "metadata": {}, 411 | "outputs": [], 412 | "source": [ 413 | "def validate(model, loader, loss_fn):\n", 414 | " batch_time_m = AverageMeter()\n", 415 | " losses_m = AverageMeter()\n", 416 | " top1_m = AverageMeter()\n", 417 | " top5_m = AverageMeter()\n", 418 | "\n", 419 | " model.eval()\n", 420 | " epoch_accuracy = 0\n", 421 | " epoch_loss = 0\n", 422 | " end = time.time()\n", 423 | " last_idx = len(loader) - 1\n", 424 | " \n", 425 | " running_loss = 0.0\n", 426 | " running_corrects = 0.0\n", 427 | "\n", 428 | " with torch.no_grad():\n", 429 | " for input, target in tqdm(loader):\n", 430 | "\n", 431 | " input, target = input.to(device), target.to(device)\n", 432 | " \n", 433 | " output = model(input)\n", 434 | " \n", 435 | " score = 0.0\n", 436 | " sm = nn.Softmax(dim=1)\n", 437 | " for k, v in output.items():\n", 438 | " score += sm(output[k])\n", 439 | " _, preds = torch.max(score.data, 1)\n", 440 | "\n", 441 | " loss = 0.0\n", 442 | " for k,v in output.items():\n", 443 | " loss += loss_fn(output[k], target)\n", 444 | "\n", 445 | "\n", 446 | " batch_time_m.update(time.time() - end)\n", 447 | " acc = (preds == target.data).float().mean()\n", 448 | " epoch_loss += loss/len(loader)\n", 449 | " epoch_accuracy += acc / len(loader)\n", 450 | " \n", 451 | " print(f\"Epoch : {epoch+1} - val_loss : {epoch_loss:.4f} - val_acc: {epoch_accuracy:.4f}\", end=\"\\r\")\n", 452 | " print() \n", 453 | " metrics = OrderedDict([('val_loss', epoch_loss.data.item()), (\"val_accuracy\", epoch_accuracy.data.item())])\n", 454 | "\n", 455 | "\n", 456 | " return metrics" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 8, 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [ 465 | "def train_one_epoch(\n", 466 | " epoch, model, loader, optimizer, loss_fn,\n", 467 | " lr_scheduler=None, saver=None, output_dir='', \n", 468 | " loss_scaler=None, model_ema=None, mixup_fn=None):\n", 469 | "\n", 470 | " \n", 471 | "\n", 472 | " \n", 473 | " batch_time_m = AverageMeter()\n", 474 | " data_time_m = AverageMeter()\n", 475 | " losses_m = AverageMeter()\n", 476 | "\n", 477 | " model.train()\n", 478 | " epoch_accuracy = 0\n", 479 | " epoch_loss = 0\n", 480 | " end = time.time()\n", 481 | " last_idx = len(loader) - 1\n", 482 | " num_updates = epoch * len(loader)\n", 483 | " running_loss = 0.0\n", 484 | " running_corrects = 0.0\n", 485 | "\n", 486 | " for data, target in tqdm(loader):\n", 487 | " data, target = data.to(device), target.to(device)\n", 488 | "\n", 489 | " \n", 490 | " data_time_m.update(time.time() - end)\n", 491 | "\n", 492 | " optimizer.zero_grad()\n", 493 | " output = model(data)\n", 494 | " score = 0.0\n", 495 | " sm = nn.Softmax(dim=1)\n", 496 | " for k, v in output.items():\n", 497 | " score += sm(output[k])\n", 498 | " _, preds = torch.max(score.data, 1)\n", 499 | " \n", 500 | " loss = 0.0\n", 501 | " for k,v in output.items():\n", 502 | " loss += loss_fn(output[k], target)\n", 503 | " loss.backward()\n", 504 | "\n", 505 | " optimizer.step()\n", 506 | "\n", 507 | " batch_time_m.update(time.time() - end)\n", 508 | " \n", 509 | "# print(preds, target.data)\n", 510 | " acc = (preds == target.data).float().mean()\n", 511 | " \n", 512 | "# print(acc)\n", 513 | " epoch_loss += loss/len(loader)\n", 514 | " epoch_accuracy += acc / len(loader)\n", 515 | "# if acc:\n", 516 | "# print(acc, epreds, target.data)\n", 517 | " print(\n", 518 | " f\"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}\"\n", 519 | ", end=\"\\r\")\n", 520 | "\n", 521 | " print()\n", 522 | "\n", 523 | " return OrderedDict([('train_loss', epoch_loss.data.item()), (\"train_accuracy\", epoch_accuracy.data.item())])\n" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 9, 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "def freeze_all_blocks(model):\n", 533 | " frozen_blocks = 12\n", 534 | " for block in model.model.blocks[:frozen_blocks]:\n", 535 | " for param in block.parameters():\n", 536 | " param.requires_grad=False\n", 537 | " " 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": 10, 543 | "metadata": {}, 544 | "outputs": [], 545 | "source": [ 546 | "def unfreeze_blocks(model, amount= 1):\n", 547 | " \n", 548 | " for block in model.model.blocks[11-amount:]:\n", 549 | " for param in block.parameters():\n", 550 | " param.requires_grad=True\n", 551 | " return model" 552 | ] 553 | }, 554 | { 555 | "cell_type": "markdown", 556 | "metadata": {}, 557 | "source": [ 558 | "## Training Loop" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 11, 564 | "metadata": { 565 | "scrolled": true 566 | }, 567 | "outputs": [ 568 | { 569 | "name": "stdout", 570 | "output_type": "stream", 571 | "text": [ 572 | "LATransformer(\n", 573 | " (model): VisionTransformer(\n", 574 | " (patch_embed): PatchEmbed(\n", 575 | " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n", 576 | " )\n", 577 | " (pos_drop): Dropout(p=0.0, inplace=False)\n", 578 | " (blocks): ModuleList(\n", 579 | " (0): Block(\n", 580 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 581 | " (attn): Attention(\n", 582 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 583 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 584 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 585 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 586 | " )\n", 587 | " (drop_path): Identity()\n", 588 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 589 | " (mlp): Mlp(\n", 590 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 591 | " (act): GELU()\n", 592 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 593 | " (drop): Dropout(p=0.0, inplace=False)\n", 594 | " )\n", 595 | " )\n", 596 | " (1): Block(\n", 597 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 598 | " (attn): Attention(\n", 599 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 600 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 601 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 602 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 603 | " )\n", 604 | " (drop_path): Identity()\n", 605 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 606 | " (mlp): Mlp(\n", 607 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 608 | " (act): GELU()\n", 609 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 610 | " (drop): Dropout(p=0.0, inplace=False)\n", 611 | " )\n", 612 | " )\n", 613 | " (2): Block(\n", 614 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 615 | " (attn): Attention(\n", 616 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 617 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 618 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 619 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 620 | " )\n", 621 | " (drop_path): Identity()\n", 622 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 623 | " (mlp): Mlp(\n", 624 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 625 | " (act): GELU()\n", 626 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 627 | " (drop): Dropout(p=0.0, inplace=False)\n", 628 | " )\n", 629 | " )\n", 630 | " (3): Block(\n", 631 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 632 | " (attn): Attention(\n", 633 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 634 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 635 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 636 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 637 | " )\n", 638 | " (drop_path): Identity()\n", 639 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 640 | " (mlp): Mlp(\n", 641 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 642 | " (act): GELU()\n", 643 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 644 | " (drop): Dropout(p=0.0, inplace=False)\n", 645 | " )\n", 646 | " )\n", 647 | " (4): Block(\n", 648 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 649 | " (attn): Attention(\n", 650 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 651 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 652 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 653 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 654 | " )\n", 655 | " (drop_path): Identity()\n", 656 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 657 | " (mlp): Mlp(\n", 658 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 659 | " (act): GELU()\n", 660 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 661 | " (drop): Dropout(p=0.0, inplace=False)\n", 662 | " )\n", 663 | " )\n", 664 | " (5): Block(\n", 665 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 666 | " (attn): Attention(\n", 667 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 668 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 669 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 670 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 671 | " )\n", 672 | " (drop_path): Identity()\n", 673 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 674 | " (mlp): Mlp(\n", 675 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 676 | " (act): GELU()\n", 677 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 678 | " (drop): Dropout(p=0.0, inplace=False)\n", 679 | " )\n", 680 | " )\n", 681 | " (6): Block(\n", 682 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 683 | " (attn): Attention(\n", 684 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 685 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 686 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 687 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 688 | " )\n", 689 | " (drop_path): Identity()\n", 690 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 691 | " (mlp): Mlp(\n", 692 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 693 | " (act): GELU()\n", 694 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 695 | " (drop): Dropout(p=0.0, inplace=False)\n", 696 | " )\n", 697 | " )\n", 698 | " (7): Block(\n", 699 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 700 | " (attn): Attention(\n", 701 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 702 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 703 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 704 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 705 | " )\n", 706 | " (drop_path): Identity()\n", 707 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 708 | " (mlp): Mlp(\n", 709 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 710 | " (act): GELU()\n", 711 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 712 | " (drop): Dropout(p=0.0, inplace=False)\n", 713 | " )\n", 714 | " )\n", 715 | " (8): Block(\n", 716 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 717 | " (attn): Attention(\n", 718 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 719 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 720 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 721 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 722 | " )\n", 723 | " (drop_path): Identity()\n", 724 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 725 | " (mlp): Mlp(\n", 726 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 727 | " (act): GELU()\n", 728 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 729 | " (drop): Dropout(p=0.0, inplace=False)\n", 730 | " )\n", 731 | " )\n", 732 | " (9): Block(\n", 733 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 734 | " (attn): Attention(\n", 735 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 736 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 737 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 738 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 739 | " )\n", 740 | " (drop_path): Identity()\n", 741 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 742 | " (mlp): Mlp(\n", 743 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 744 | " (act): GELU()\n", 745 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 746 | " (drop): Dropout(p=0.0, inplace=False)\n", 747 | " )\n", 748 | " )\n", 749 | " (10): Block(\n", 750 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 751 | " (attn): Attention(\n", 752 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 753 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 754 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 755 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 756 | " )\n", 757 | " (drop_path): Identity()\n", 758 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 759 | " (mlp): Mlp(\n", 760 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 761 | " (act): GELU()\n", 762 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 763 | " (drop): Dropout(p=0.0, inplace=False)\n", 764 | " )\n", 765 | " )\n", 766 | " (11): Block(\n", 767 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 768 | " (attn): Attention(\n", 769 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n", 770 | " (attn_drop): Dropout(p=0.0, inplace=False)\n", 771 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n", 772 | " (proj_drop): Dropout(p=0.0, inplace=False)\n", 773 | " )\n", 774 | " (drop_path): Identity()\n", 775 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 776 | " (mlp): Mlp(\n", 777 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", 778 | " (act): GELU()\n", 779 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", 780 | " (drop): Dropout(p=0.0, inplace=False)\n", 781 | " )\n", 782 | " )\n", 783 | " )\n", 784 | " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n", 785 | " (head): Linear(in_features=768, out_features=751, bias=True)\n", 786 | " )\n", 787 | " (avgpool): AdaptiveAvgPool2d(output_size=(14, 768))\n", 788 | " (dropout): Dropout(p=0.5, inplace=False)\n", 789 | " (classifier0): ClassBlock(\n", 790 | " (add_block): Sequential(\n", 791 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 792 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 793 | " (2): Dropout(p=0.5, inplace=False)\n", 794 | " )\n", 795 | " (classifier): Sequential(\n", 796 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 797 | " )\n", 798 | " )\n", 799 | " (classifier1): ClassBlock(\n", 800 | " (add_block): Sequential(\n", 801 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 802 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 803 | " (2): Dropout(p=0.5, inplace=False)\n", 804 | " )\n", 805 | " (classifier): Sequential(\n", 806 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 807 | " )\n", 808 | " )\n", 809 | " (classifier2): ClassBlock(\n", 810 | " (add_block): Sequential(\n", 811 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 812 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 813 | " (2): Dropout(p=0.5, inplace=False)\n", 814 | " )\n", 815 | " (classifier): Sequential(\n", 816 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 817 | " )\n", 818 | " )\n", 819 | " (classifier3): ClassBlock(\n", 820 | " (add_block): Sequential(\n", 821 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 822 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 823 | " (2): Dropout(p=0.5, inplace=False)\n", 824 | " )\n", 825 | " (classifier): Sequential(\n", 826 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 827 | " )\n", 828 | " )\n", 829 | " (classifier4): ClassBlock(\n", 830 | " (add_block): Sequential(\n", 831 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 832 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 833 | " (2): Dropout(p=0.5, inplace=False)\n", 834 | " )\n", 835 | " (classifier): Sequential(\n", 836 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 837 | " )\n", 838 | " )\n", 839 | " (classifier5): ClassBlock(\n", 840 | " (add_block): Sequential(\n", 841 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 842 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 843 | " (2): Dropout(p=0.5, inplace=False)\n", 844 | " )\n", 845 | " (classifier): Sequential(\n", 846 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 847 | " )\n", 848 | " )\n", 849 | " (classifier6): ClassBlock(\n", 850 | " (add_block): Sequential(\n", 851 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 852 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 853 | " (2): Dropout(p=0.5, inplace=False)\n", 854 | " )\n", 855 | " (classifier): Sequential(\n", 856 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 857 | " )\n", 858 | " )\n", 859 | " (classifier7): ClassBlock(\n", 860 | " (add_block): Sequential(\n", 861 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 862 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 863 | " (2): Dropout(p=0.5, inplace=False)\n", 864 | " )\n", 865 | " (classifier): Sequential(\n", 866 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 867 | " )\n", 868 | " )\n", 869 | " (classifier8): ClassBlock(\n", 870 | " (add_block): Sequential(\n", 871 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 872 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 873 | " (2): Dropout(p=0.5, inplace=False)\n", 874 | " )\n", 875 | " (classifier): Sequential(\n", 876 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 877 | " )\n", 878 | " )\n", 879 | " (classifier9): ClassBlock(\n", 880 | " (add_block): Sequential(\n", 881 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 882 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 883 | " (2): Dropout(p=0.5, inplace=False)\n", 884 | " )\n", 885 | " (classifier): Sequential(\n", 886 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 887 | " )\n", 888 | " )\n", 889 | " (classifier10): ClassBlock(\n", 890 | " (add_block): Sequential(\n", 891 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 892 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 893 | " (2): Dropout(p=0.5, inplace=False)\n", 894 | " )\n", 895 | " (classifier): Sequential(\n", 896 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 897 | " )\n", 898 | " )\n", 899 | " (classifier11): ClassBlock(\n", 900 | " (add_block): Sequential(\n", 901 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 902 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 903 | " (2): Dropout(p=0.5, inplace=False)\n", 904 | " )\n", 905 | " (classifier): Sequential(\n", 906 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 907 | " )\n", 908 | " )\n", 909 | " (classifier12): ClassBlock(\n", 910 | " (add_block): Sequential(\n", 911 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 912 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 913 | " (2): Dropout(p=0.5, inplace=False)\n", 914 | " )\n", 915 | " (classifier): Sequential(\n", 916 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 917 | " )\n", 918 | " )\n", 919 | " (classifier13): ClassBlock(\n", 920 | " (add_block): Sequential(\n", 921 | " (0): Linear(in_features=768, out_features=256, bias=True)\n", 922 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 923 | " (2): Dropout(p=0.5, inplace=False)\n", 924 | " )\n", 925 | " (classifier): Sequential(\n", 926 | " (0): Linear(in_features=256, out_features=751, bias=True)\n", 927 | " )\n", 928 | " )\n", 929 | ")\n" 930 | ] 931 | } 932 | ], 933 | "source": [ 934 | "# Create LA Transformer\n", 935 | "model = LATransformer(vit_base, lmbd).to(device)\n", 936 | "print(model.eval())\n", 937 | "\n", 938 | "# loss function\n", 939 | "criterion = nn.CrossEntropyLoss()\n", 940 | "\n", 941 | "# optimizer\n", 942 | "optimizer = optim.Adam(model.parameters(),weight_decay=5e-4, lr=lr)\n", 943 | "\n", 944 | "# scheduler\n", 945 | "scheduler = StepLR(optimizer, step_size=1, gamma=gamma)\n", 946 | "freeze_all_blocks(model)" 947 | ] 948 | }, 949 | { 950 | "cell_type": "code", 951 | "execution_count": null, 952 | "metadata": { 953 | "scrolled": true 954 | }, 955 | "outputs": [ 956 | { 957 | "name": "stdout", 958 | "output_type": "stream", 959 | "text": [ 960 | "training...\n", 961 | "Unfrozen Blocks: 1, Current lr: 0.00023999999999999998, Trainable Params: 20962817\n" 962 | ] 963 | }, 964 | { 965 | "data": { 966 | "application/vnd.jupyter.widget-view+json": { 967 | "model_id": "f767a92c4eac4258b98b3c9676782209", 968 | "version_major": 2, 969 | "version_minor": 0 970 | }, 971 | "text/plain": [ 972 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 973 | ] 974 | }, 975 | "metadata": {}, 976 | "output_type": "display_data" 977 | }, 978 | { 979 | "name": "stdout", 980 | "output_type": "stream", 981 | "text": [ 982 | "Epoch : 1 - loss : 82.7351 - acc: 0.0880\n", 983 | "\n" 984 | ] 985 | }, 986 | { 987 | "data": { 988 | "application/vnd.jupyter.widget-view+json": { 989 | "model_id": "1baf1a35386941478d54f8ee048653bb", 990 | "version_major": 2, 991 | "version_minor": 0 992 | }, 993 | "text/plain": [ 994 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 995 | ] 996 | }, 997 | "metadata": {}, 998 | "output_type": "display_data" 999 | }, 1000 | { 1001 | "name": "stdout", 1002 | "output_type": "stream", 1003 | "text": [ 1004 | "Epoch : 1 - val_loss : 77.1901 - val_acc: 0.0497\n", 1005 | "\n", 1006 | "SAVED!\n" 1007 | ] 1008 | }, 1009 | { 1010 | "data": { 1011 | "application/vnd.jupyter.widget-view+json": { 1012 | "model_id": "c349d24f58ee4194b48328d96f9cb2b9", 1013 | "version_major": 2, 1014 | "version_minor": 0 1015 | }, 1016 | "text/plain": [ 1017 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1018 | ] 1019 | }, 1020 | "metadata": {}, 1021 | "output_type": "display_data" 1022 | }, 1023 | { 1024 | "name": "stdout", 1025 | "output_type": "stream", 1026 | "text": [ 1027 | "Epoch : 2 - loss : 59.0334 - acc: 0.2364\n", 1028 | "\n" 1029 | ] 1030 | }, 1031 | { 1032 | "data": { 1033 | "application/vnd.jupyter.widget-view+json": { 1034 | "model_id": "030f496ac22f4d08b2fe5b2f13df306c", 1035 | "version_major": 2, 1036 | "version_minor": 0 1037 | }, 1038 | "text/plain": [ 1039 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1040 | ] 1041 | }, 1042 | "metadata": {}, 1043 | "output_type": "display_data" 1044 | }, 1045 | { 1046 | "name": "stdout", 1047 | "output_type": "stream", 1048 | "text": [ 1049 | "Epoch : 2 - val_loss : 58.8111 - val_acc: 0.1918\n", 1050 | "\n", 1051 | "SAVED!\n", 1052 | "Unfrozen Blocks: 2, Current lr: 0.000192, Trainable Params: 28050689\n" 1053 | ] 1054 | }, 1055 | { 1056 | "data": { 1057 | "application/vnd.jupyter.widget-view+json": { 1058 | "model_id": "c7284d3fd841487e996202add386669c", 1059 | "version_major": 2, 1060 | "version_minor": 0 1061 | }, 1062 | "text/plain": [ 1063 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1064 | ] 1065 | }, 1066 | "metadata": {}, 1067 | "output_type": "display_data" 1068 | }, 1069 | { 1070 | "name": "stdout", 1071 | "output_type": "stream", 1072 | "text": [ 1073 | "Epoch : 3 - loss : 41.1694 - acc: 0.4632\n", 1074 | "\n" 1075 | ] 1076 | }, 1077 | { 1078 | "data": { 1079 | "application/vnd.jupyter.widget-view+json": { 1080 | "model_id": "d77e806e33104a8bbfb85a04483c94f4", 1081 | "version_major": 2, 1082 | "version_minor": 0 1083 | }, 1084 | "text/plain": [ 1085 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1086 | ] 1087 | }, 1088 | "metadata": {}, 1089 | "output_type": "display_data" 1090 | }, 1091 | { 1092 | "name": "stdout", 1093 | "output_type": "stream", 1094 | "text": [ 1095 | "Epoch : 3 - val_loss : 47.2650 - val_acc: 0.3353\n", 1096 | "\n", 1097 | "SAVED!\n" 1098 | ] 1099 | }, 1100 | { 1101 | "data": { 1102 | "application/vnd.jupyter.widget-view+json": { 1103 | "model_id": "e082fac2ed5e4952938c6fd3fb5d4f1a", 1104 | "version_major": 2, 1105 | "version_minor": 0 1106 | }, 1107 | "text/plain": [ 1108 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1109 | ] 1110 | }, 1111 | "metadata": {}, 1112 | "output_type": "display_data" 1113 | }, 1114 | { 1115 | "name": "stdout", 1116 | "output_type": "stream", 1117 | "text": [ 1118 | "Epoch : 4 - loss : 28.3517 - acc: 0.6674\n", 1119 | "\n" 1120 | ] 1121 | }, 1122 | { 1123 | "data": { 1124 | "application/vnd.jupyter.widget-view+json": { 1125 | "model_id": "57552c299cc54398aa4416c3689fb7b5", 1126 | "version_major": 2, 1127 | "version_minor": 0 1128 | }, 1129 | "text/plain": [ 1130 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1131 | ] 1132 | }, 1133 | "metadata": {}, 1134 | "output_type": "display_data" 1135 | }, 1136 | { 1137 | "name": "stdout", 1138 | "output_type": "stream", 1139 | "text": [ 1140 | "Epoch : 4 - val_loss : 33.9487 - val_acc: 0.5391\n", 1141 | "\n", 1142 | "SAVED!\n", 1143 | "Unfrozen Blocks: 3, Current lr: 0.00015360000000000002, Trainable Params: 35138561\n" 1144 | ] 1145 | }, 1146 | { 1147 | "data": { 1148 | "application/vnd.jupyter.widget-view+json": { 1149 | "model_id": "66d3571123f849ccb1d30a186c4fee9d", 1150 | "version_major": 2, 1151 | "version_minor": 0 1152 | }, 1153 | "text/plain": [ 1154 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1155 | ] 1156 | }, 1157 | "metadata": {}, 1158 | "output_type": "display_data" 1159 | }, 1160 | { 1161 | "name": "stdout", 1162 | "output_type": "stream", 1163 | "text": [ 1164 | "Epoch : 5 - loss : 18.7140 - acc: 0.8141\n", 1165 | "\n" 1166 | ] 1167 | }, 1168 | { 1169 | "data": { 1170 | "application/vnd.jupyter.widget-view+json": { 1171 | "model_id": "df0b76c5867f4efea8c5f27c4b6b416a", 1172 | "version_major": 2, 1173 | "version_minor": 0 1174 | }, 1175 | "text/plain": [ 1176 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1177 | ] 1178 | }, 1179 | "metadata": {}, 1180 | "output_type": "display_data" 1181 | }, 1182 | { 1183 | "name": "stdout", 1184 | "output_type": "stream", 1185 | "text": [ 1186 | "Epoch : 5 - val_loss : 25.3060 - val_acc: 0.6617\n", 1187 | "\n", 1188 | "SAVED!\n" 1189 | ] 1190 | }, 1191 | { 1192 | "data": { 1193 | "application/vnd.jupyter.widget-view+json": { 1194 | "model_id": "32b7d1fa8e024640849763fbbbb5df8d", 1195 | "version_major": 2, 1196 | "version_minor": 0 1197 | }, 1198 | "text/plain": [ 1199 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1200 | ] 1201 | }, 1202 | "metadata": {}, 1203 | "output_type": "display_data" 1204 | }, 1205 | { 1206 | "name": "stdout", 1207 | "output_type": "stream", 1208 | "text": [ 1209 | "Epoch : 6 - loss : 12.2253 - acc: 0.9050\n", 1210 | "\n" 1211 | ] 1212 | }, 1213 | { 1214 | "data": { 1215 | "application/vnd.jupyter.widget-view+json": { 1216 | "model_id": "7530e23081ff40838c5c340324fb0bee", 1217 | "version_major": 2, 1218 | "version_minor": 0 1219 | }, 1220 | "text/plain": [ 1221 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1222 | ] 1223 | }, 1224 | "metadata": {}, 1225 | "output_type": "display_data" 1226 | }, 1227 | { 1228 | "name": "stdout", 1229 | "output_type": "stream", 1230 | "text": [ 1231 | "Epoch : 6 - val_loss : 19.0367 - val_acc: 0.7506\n", 1232 | "\n", 1233 | "SAVED!\n", 1234 | "Unfrozen Blocks: 4, Current lr: 0.00012288000000000002, Trainable Params: 42226433\n" 1235 | ] 1236 | }, 1237 | { 1238 | "data": { 1239 | "application/vnd.jupyter.widget-view+json": { 1240 | "model_id": "5d2820792b3d4c078a4687f08bd4f92d", 1241 | "version_major": 2, 1242 | "version_minor": 0 1243 | }, 1244 | "text/plain": [ 1245 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1246 | ] 1247 | }, 1248 | "metadata": {}, 1249 | "output_type": "display_data" 1250 | }, 1251 | { 1252 | "name": "stdout", 1253 | "output_type": "stream", 1254 | "text": [ 1255 | "Epoch : 7 - loss : 8.0031 - acc: 0.9542\n", 1256 | "\n" 1257 | ] 1258 | }, 1259 | { 1260 | "data": { 1261 | "application/vnd.jupyter.widget-view+json": { 1262 | "model_id": "166716159015465d8430b22cbb0a937f", 1263 | "version_major": 2, 1264 | "version_minor": 0 1265 | }, 1266 | "text/plain": [ 1267 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1268 | ] 1269 | }, 1270 | "metadata": {}, 1271 | "output_type": "display_data" 1272 | }, 1273 | { 1274 | "name": "stdout", 1275 | "output_type": "stream", 1276 | "text": [ 1277 | "Epoch : 7 - val_loss : 14.0309 - val_acc: 0.8325\n", 1278 | "\n", 1279 | "SAVED!\n" 1280 | ] 1281 | }, 1282 | { 1283 | "data": { 1284 | "application/vnd.jupyter.widget-view+json": { 1285 | "model_id": "e43fe49292c04b0f83dacbc610d49e58", 1286 | "version_major": 2, 1287 | "version_minor": 0 1288 | }, 1289 | "text/plain": [ 1290 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1291 | ] 1292 | }, 1293 | "metadata": {}, 1294 | "output_type": "display_data" 1295 | }, 1296 | { 1297 | "name": "stdout", 1298 | "output_type": "stream", 1299 | "text": [ 1300 | "Epoch : 8 - loss : 5.4122 - acc: 0.9771\n", 1301 | "\n" 1302 | ] 1303 | }, 1304 | { 1305 | "data": { 1306 | "application/vnd.jupyter.widget-view+json": { 1307 | "model_id": "af46f3dff0a548f98b97b3868ce2a783", 1308 | "version_major": 2, 1309 | "version_minor": 0 1310 | }, 1311 | "text/plain": [ 1312 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1313 | ] 1314 | }, 1315 | "metadata": {}, 1316 | "output_type": "display_data" 1317 | }, 1318 | { 1319 | "name": "stdout", 1320 | "output_type": "stream", 1321 | "text": [ 1322 | "Epoch : 8 - val_loss : 11.0224 - val_acc: 0.8602\n", 1323 | "\n", 1324 | "SAVED!\n", 1325 | "Unfrozen Blocks: 5, Current lr: 9.830400000000001e-05, Trainable Params: 49314305\n" 1326 | ] 1327 | }, 1328 | { 1329 | "data": { 1330 | "application/vnd.jupyter.widget-view+json": { 1331 | "model_id": "af2d255d73484d968d636523528a596c", 1332 | "version_major": 2, 1333 | "version_minor": 0 1334 | }, 1335 | "text/plain": [ 1336 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1337 | ] 1338 | }, 1339 | "metadata": {}, 1340 | "output_type": "display_data" 1341 | }, 1342 | { 1343 | "name": "stdout", 1344 | "output_type": "stream", 1345 | "text": [ 1346 | "Epoch : 9 - loss : 3.7149 - acc: 0.9906\n", 1347 | "\n" 1348 | ] 1349 | }, 1350 | { 1351 | "data": { 1352 | "application/vnd.jupyter.widget-view+json": { 1353 | "model_id": "1b1001230f5842bfae334384b74040ae", 1354 | "version_major": 2, 1355 | "version_minor": 0 1356 | }, 1357 | "text/plain": [ 1358 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1359 | ] 1360 | }, 1361 | "metadata": {}, 1362 | "output_type": "display_data" 1363 | }, 1364 | { 1365 | "name": "stdout", 1366 | "output_type": "stream", 1367 | "text": [ 1368 | "Epoch : 9 - val_loss : 8.5832 - val_acc: 0.8944\n", 1369 | "\n", 1370 | "SAVED!\n" 1371 | ] 1372 | }, 1373 | { 1374 | "data": { 1375 | "application/vnd.jupyter.widget-view+json": { 1376 | "model_id": "1e7691744d4142ee940783af59c753e5", 1377 | "version_major": 2, 1378 | "version_minor": 0 1379 | }, 1380 | "text/plain": [ 1381 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1382 | ] 1383 | }, 1384 | "metadata": {}, 1385 | "output_type": "display_data" 1386 | }, 1387 | { 1388 | "name": "stdout", 1389 | "output_type": "stream", 1390 | "text": [ 1391 | "Epoch : 10 - loss : 2.7142 - acc: 0.9950\n", 1392 | "\n" 1393 | ] 1394 | }, 1395 | { 1396 | "data": { 1397 | "application/vnd.jupyter.widget-view+json": { 1398 | "model_id": "ceb48891601843519efddf0984ef2e18", 1399 | "version_major": 2, 1400 | "version_minor": 0 1401 | }, 1402 | "text/plain": [ 1403 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1404 | ] 1405 | }, 1406 | "metadata": {}, 1407 | "output_type": "display_data" 1408 | }, 1409 | { 1410 | "name": "stdout", 1411 | "output_type": "stream", 1412 | "text": [ 1413 | "Epoch : 10 - val_loss : 7.6481 - val_acc: 0.9033\n", 1414 | "\n", 1415 | "SAVED!\n", 1416 | "Unfrozen Blocks: 6, Current lr: 7.864320000000001e-05, Trainable Params: 56402177\n" 1417 | ] 1418 | }, 1419 | { 1420 | "data": { 1421 | "application/vnd.jupyter.widget-view+json": { 1422 | "model_id": "3a5eff7b69b8422b9da1bfc8aa601193", 1423 | "version_major": 2, 1424 | "version_minor": 0 1425 | }, 1426 | "text/plain": [ 1427 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1428 | ] 1429 | }, 1430 | "metadata": {}, 1431 | "output_type": "display_data" 1432 | }, 1433 | { 1434 | "name": "stdout", 1435 | "output_type": "stream", 1436 | "text": [ 1437 | "Epoch : 11 - loss : 2.0092 - acc: 0.9965\n", 1438 | "\n" 1439 | ] 1440 | }, 1441 | { 1442 | "data": { 1443 | "application/vnd.jupyter.widget-view+json": { 1444 | "model_id": "a691c2bbe71f434390741bb0071a3e42", 1445 | "version_major": 2, 1446 | "version_minor": 0 1447 | }, 1448 | "text/plain": [ 1449 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1450 | ] 1451 | }, 1452 | "metadata": {}, 1453 | "output_type": "display_data" 1454 | }, 1455 | { 1456 | "name": "stdout", 1457 | "output_type": "stream", 1458 | "text": [ 1459 | "Epoch : 11 - val_loss : 6.7372 - val_acc: 0.9137\n", 1460 | "\n", 1461 | "SAVED!\n" 1462 | ] 1463 | }, 1464 | { 1465 | "data": { 1466 | "application/vnd.jupyter.widget-view+json": { 1467 | "model_id": "6b239464e5bf46908b445345dc0eb523", 1468 | "version_major": 2, 1469 | "version_minor": 0 1470 | }, 1471 | "text/plain": [ 1472 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1473 | ] 1474 | }, 1475 | "metadata": {}, 1476 | "output_type": "display_data" 1477 | }, 1478 | { 1479 | "name": "stdout", 1480 | "output_type": "stream", 1481 | "text": [ 1482 | "Epoch : 12 - loss : 1.5912 - acc: 0.9977\n", 1483 | "\n" 1484 | ] 1485 | }, 1486 | { 1487 | "data": { 1488 | "application/vnd.jupyter.widget-view+json": { 1489 | "model_id": "e9828497a70f455188333f870c7eb5ff", 1490 | "version_major": 2, 1491 | "version_minor": 0 1492 | }, 1493 | "text/plain": [ 1494 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1495 | ] 1496 | }, 1497 | "metadata": {}, 1498 | "output_type": "display_data" 1499 | }, 1500 | { 1501 | "name": "stdout", 1502 | "output_type": "stream", 1503 | "text": [ 1504 | "Epoch : 12 - val_loss : 6.0404 - val_acc: 0.9189\n", 1505 | "\n", 1506 | "SAVED!\n", 1507 | "Unfrozen Blocks: 7, Current lr: 6.291456000000001e-05, Trainable Params: 63490049\n" 1508 | ] 1509 | }, 1510 | { 1511 | "data": { 1512 | "application/vnd.jupyter.widget-view+json": { 1513 | "model_id": "e46c51355950451a98a10a10d833c441", 1514 | "version_major": 2, 1515 | "version_minor": 0 1516 | }, 1517 | "text/plain": [ 1518 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1519 | ] 1520 | }, 1521 | "metadata": {}, 1522 | "output_type": "display_data" 1523 | }, 1524 | { 1525 | "name": "stdout", 1526 | "output_type": "stream", 1527 | "text": [ 1528 | "Epoch : 13 - loss : 1.3100 - acc: 0.9984\n", 1529 | "\n" 1530 | ] 1531 | }, 1532 | { 1533 | "data": { 1534 | "application/vnd.jupyter.widget-view+json": { 1535 | "model_id": "ca1050b552cc41289c74862a47398dab", 1536 | "version_major": 2, 1537 | "version_minor": 0 1538 | }, 1539 | "text/plain": [ 1540 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1541 | ] 1542 | }, 1543 | "metadata": {}, 1544 | "output_type": "display_data" 1545 | }, 1546 | { 1547 | "name": "stdout", 1548 | "output_type": "stream", 1549 | "text": [ 1550 | "Epoch : 13 - val_loss : 5.8097 - val_acc: 0.9230\n", 1551 | "\n", 1552 | "SAVED!\n" 1553 | ] 1554 | }, 1555 | { 1556 | "data": { 1557 | "application/vnd.jupyter.widget-view+json": { 1558 | "model_id": "86c1b41c6a4c4f1da8e89f803c675709", 1559 | "version_major": 2, 1560 | "version_minor": 0 1561 | }, 1562 | "text/plain": [ 1563 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1564 | ] 1565 | }, 1566 | "metadata": {}, 1567 | "output_type": "display_data" 1568 | }, 1569 | { 1570 | "name": "stdout", 1571 | "output_type": "stream", 1572 | "text": [ 1573 | "Epoch : 14 - loss : 1.0894 - acc: 0.9991\n", 1574 | "\n" 1575 | ] 1576 | }, 1577 | { 1578 | "data": { 1579 | "application/vnd.jupyter.widget-view+json": { 1580 | "model_id": "0b727d9790b64dab897630c639b05a6a", 1581 | "version_major": 2, 1582 | "version_minor": 0 1583 | }, 1584 | "text/plain": [ 1585 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1586 | ] 1587 | }, 1588 | "metadata": {}, 1589 | "output_type": "display_data" 1590 | }, 1591 | { 1592 | "name": "stdout", 1593 | "output_type": "stream", 1594 | "text": [ 1595 | "Epoch : 14 - val_loss : 5.1302 - val_acc: 0.9321\n", 1596 | "\n", 1597 | "SAVED!\n", 1598 | "Unfrozen Blocks: 8, Current lr: 5.0331648000000016e-05, Trainable Params: 70577921\n" 1599 | ] 1600 | }, 1601 | { 1602 | "data": { 1603 | "application/vnd.jupyter.widget-view+json": { 1604 | "model_id": "ddf2a4c6d42d4298b84bca0cda2f78df", 1605 | "version_major": 2, 1606 | "version_minor": 0 1607 | }, 1608 | "text/plain": [ 1609 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1610 | ] 1611 | }, 1612 | "metadata": {}, 1613 | "output_type": "display_data" 1614 | }, 1615 | { 1616 | "name": "stdout", 1617 | "output_type": "stream", 1618 | "text": [ 1619 | "Epoch : 15 - loss : 0.9347 - acc: 0.9992\n", 1620 | "\n" 1621 | ] 1622 | }, 1623 | { 1624 | "data": { 1625 | "application/vnd.jupyter.widget-view+json": { 1626 | "model_id": "6c69748f472b4d1d9c8c0c2ac0b893d5", 1627 | "version_major": 2, 1628 | "version_minor": 0 1629 | }, 1630 | "text/plain": [ 1631 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1632 | ] 1633 | }, 1634 | "metadata": {}, 1635 | "output_type": "display_data" 1636 | }, 1637 | { 1638 | "name": "stdout", 1639 | "output_type": "stream", 1640 | "text": [ 1641 | "Epoch : 15 - val_loss : 5.5233 - val_acc: 0.9217\n", 1642 | "\n" 1643 | ] 1644 | }, 1645 | { 1646 | "data": { 1647 | "application/vnd.jupyter.widget-view+json": { 1648 | "model_id": "be86fea29a08418fa1fed8cb91aee99c", 1649 | "version_major": 2, 1650 | "version_minor": 0 1651 | }, 1652 | "text/plain": [ 1653 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1654 | ] 1655 | }, 1656 | "metadata": {}, 1657 | "output_type": "display_data" 1658 | }, 1659 | { 1660 | "name": "stdout", 1661 | "output_type": "stream", 1662 | "text": [ 1663 | "Epoch : 16 - loss : 0.9086 - acc: 0.9996\n", 1664 | "\n" 1665 | ] 1666 | }, 1667 | { 1668 | "data": { 1669 | "application/vnd.jupyter.widget-view+json": { 1670 | "model_id": "1616dcc149cd46ac904968b474f6cfe3", 1671 | "version_major": 2, 1672 | "version_minor": 0 1673 | }, 1674 | "text/plain": [ 1675 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1676 | ] 1677 | }, 1678 | "metadata": {}, 1679 | "output_type": "display_data" 1680 | }, 1681 | { 1682 | "name": "stdout", 1683 | "output_type": "stream", 1684 | "text": [ 1685 | "Epoch : 16 - val_loss : 4.4655 - val_acc: 0.9362\n", 1686 | "\n", 1687 | "SAVED!\n", 1688 | "Unfrozen Blocks: 9, Current lr: 4.026531840000002e-05, Trainable Params: 77665793\n" 1689 | ] 1690 | }, 1691 | { 1692 | "data": { 1693 | "application/vnd.jupyter.widget-view+json": { 1694 | "model_id": "3fb8fcd15c544ca4a23474ee01c35e91", 1695 | "version_major": 2, 1696 | "version_minor": 0 1697 | }, 1698 | "text/plain": [ 1699 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1700 | ] 1701 | }, 1702 | "metadata": {}, 1703 | "output_type": "display_data" 1704 | }, 1705 | { 1706 | "name": "stdout", 1707 | "output_type": "stream", 1708 | "text": [ 1709 | "Epoch : 17 - loss : 0.7159 - acc: 0.9999\n", 1710 | "\n" 1711 | ] 1712 | }, 1713 | { 1714 | "data": { 1715 | "application/vnd.jupyter.widget-view+json": { 1716 | "model_id": "1af37990febc45fc9a8cb3e6921de68b", 1717 | "version_major": 2, 1718 | "version_minor": 0 1719 | }, 1720 | "text/plain": [ 1721 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1722 | ] 1723 | }, 1724 | "metadata": {}, 1725 | "output_type": "display_data" 1726 | }, 1727 | { 1728 | "name": "stdout", 1729 | "output_type": "stream", 1730 | "text": [ 1731 | "Epoch : 17 - val_loss : 4.2927 - val_acc: 0.9414\n", 1732 | "\n", 1733 | "SAVED!\n" 1734 | ] 1735 | }, 1736 | { 1737 | "data": { 1738 | "application/vnd.jupyter.widget-view+json": { 1739 | "model_id": "277af575763647ddb9be0c49cbd7fb4f", 1740 | "version_major": 2, 1741 | "version_minor": 0 1742 | }, 1743 | "text/plain": [ 1744 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1745 | ] 1746 | }, 1747 | "metadata": {}, 1748 | "output_type": "display_data" 1749 | }, 1750 | { 1751 | "name": "stdout", 1752 | "output_type": "stream", 1753 | "text": [ 1754 | "Epoch : 18 - loss : 0.6362 - acc: 0.9998\n", 1755 | "\n" 1756 | ] 1757 | }, 1758 | { 1759 | "data": { 1760 | "application/vnd.jupyter.widget-view+json": { 1761 | "model_id": "01e3ff670c0c4613b93381f273ee20a0", 1762 | "version_major": 2, 1763 | "version_minor": 0 1764 | }, 1765 | "text/plain": [ 1766 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1767 | ] 1768 | }, 1769 | "metadata": {}, 1770 | "output_type": "display_data" 1771 | }, 1772 | { 1773 | "name": "stdout", 1774 | "output_type": "stream", 1775 | "text": [ 1776 | "Epoch : 18 - val_loss : 4.2925 - val_acc: 0.9453\n", 1777 | "\n", 1778 | "SAVED!\n", 1779 | "Unfrozen Blocks: 10, Current lr: 3.221225472000002e-05, Trainable Params: 84753665\n" 1780 | ] 1781 | }, 1782 | { 1783 | "data": { 1784 | "application/vnd.jupyter.widget-view+json": { 1785 | "model_id": "b98ea55bef3a4cf9acb4b5dee8d2d911", 1786 | "version_major": 2, 1787 | "version_minor": 0 1788 | }, 1789 | "text/plain": [ 1790 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1791 | ] 1792 | }, 1793 | "metadata": {}, 1794 | "output_type": "display_data" 1795 | }, 1796 | { 1797 | "name": "stdout", 1798 | "output_type": "stream", 1799 | "text": [ 1800 | "Epoch : 19 - loss : 0.6389 - acc: 0.9997\n", 1801 | "\n" 1802 | ] 1803 | }, 1804 | { 1805 | "data": { 1806 | "application/vnd.jupyter.widget-view+json": { 1807 | "model_id": "f7ec952d59a84c618b9fd97cc14a6923", 1808 | "version_major": 2, 1809 | "version_minor": 0 1810 | }, 1811 | "text/plain": [ 1812 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1813 | ] 1814 | }, 1815 | "metadata": {}, 1816 | "output_type": "display_data" 1817 | }, 1818 | { 1819 | "name": "stdout", 1820 | "output_type": "stream", 1821 | "text": [ 1822 | "Epoch : 19 - val_loss : 4.5622 - val_acc: 0.9319\n", 1823 | "\n" 1824 | ] 1825 | }, 1826 | { 1827 | "data": { 1828 | "application/vnd.jupyter.widget-view+json": { 1829 | "model_id": "44f3727d7f5b43b9ae6fd96fd1f66032", 1830 | "version_major": 2, 1831 | "version_minor": 0 1832 | }, 1833 | "text/plain": [ 1834 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1835 | ] 1836 | }, 1837 | "metadata": {}, 1838 | "output_type": "display_data" 1839 | }, 1840 | { 1841 | "name": "stdout", 1842 | "output_type": "stream", 1843 | "text": [ 1844 | "Epoch : 20 - loss : 0.5667 - acc: 0.9998\n", 1845 | "\n" 1846 | ] 1847 | }, 1848 | { 1849 | "data": { 1850 | "application/vnd.jupyter.widget-view+json": { 1851 | "model_id": "30bea7568e694900876835ca696d7753", 1852 | "version_major": 2, 1853 | "version_minor": 0 1854 | }, 1855 | "text/plain": [ 1856 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1857 | ] 1858 | }, 1859 | "metadata": {}, 1860 | "output_type": "display_data" 1861 | }, 1862 | { 1863 | "name": "stdout", 1864 | "output_type": "stream", 1865 | "text": [ 1866 | "Epoch : 20 - val_loss : 4.6590 - val_acc: 0.9254\n", 1867 | "\n", 1868 | "Unfrozen Blocks: 11, Current lr: 2.5769803776000016e-05, Trainable Params: 91841537\n" 1869 | ] 1870 | }, 1871 | { 1872 | "data": { 1873 | "application/vnd.jupyter.widget-view+json": { 1874 | "model_id": "178dec4505884c3f9dd0716e2e22bef2", 1875 | "version_major": 2, 1876 | "version_minor": 0 1877 | }, 1878 | "text/plain": [ 1879 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1880 | ] 1881 | }, 1882 | "metadata": {}, 1883 | "output_type": "display_data" 1884 | }, 1885 | { 1886 | "name": "stdout", 1887 | "output_type": "stream", 1888 | "text": [ 1889 | "Epoch : 21 - loss : 0.5401 - acc: 0.9998\n", 1890 | "\n" 1891 | ] 1892 | }, 1893 | { 1894 | "data": { 1895 | "application/vnd.jupyter.widget-view+json": { 1896 | "model_id": "4d43c07645e94afaa6e63d6142ddf2f1", 1897 | "version_major": 2, 1898 | "version_minor": 0 1899 | }, 1900 | "text/plain": [ 1901 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1902 | ] 1903 | }, 1904 | "metadata": {}, 1905 | "output_type": "display_data" 1906 | }, 1907 | { 1908 | "name": "stdout", 1909 | "output_type": "stream", 1910 | "text": [ 1911 | "Epoch : 21 - val_loss : 3.8805 - val_acc: 0.9401\n", 1912 | "\n" 1913 | ] 1914 | }, 1915 | { 1916 | "data": { 1917 | "application/vnd.jupyter.widget-view+json": { 1918 | "model_id": "47d5ad2ffad44135866f5d204f692c65", 1919 | "version_major": 2, 1920 | "version_minor": 0 1921 | }, 1922 | "text/plain": [ 1923 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1924 | ] 1925 | }, 1926 | "metadata": {}, 1927 | "output_type": "display_data" 1928 | }, 1929 | { 1930 | "name": "stdout", 1931 | "output_type": "stream", 1932 | "text": [ 1933 | "Epoch : 22 - loss : 0.6303 - acc: 0.9991\n", 1934 | "\n" 1935 | ] 1936 | }, 1937 | { 1938 | "data": { 1939 | "application/vnd.jupyter.widget-view+json": { 1940 | "model_id": "d353c31d9edc454083f96fdca3e3baa4", 1941 | "version_major": 2, 1942 | "version_minor": 0 1943 | }, 1944 | "text/plain": [ 1945 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1946 | ] 1947 | }, 1948 | "metadata": {}, 1949 | "output_type": "display_data" 1950 | }, 1951 | { 1952 | "name": "stdout", 1953 | "output_type": "stream", 1954 | "text": [ 1955 | "Epoch : 22 - val_loss : 4.4941 - val_acc: 0.9375\n", 1956 | "\n", 1957 | "Unfrozen Blocks: 12, Current lr: 2.0615843020800013e-05, Trainable Params: 91841537\n" 1958 | ] 1959 | }, 1960 | { 1961 | "data": { 1962 | "application/vnd.jupyter.widget-view+json": { 1963 | "model_id": "8a68e3064ab547538cc3d505c0122c3a", 1964 | "version_major": 2, 1965 | "version_minor": 0 1966 | }, 1967 | "text/plain": [ 1968 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 1969 | ] 1970 | }, 1971 | "metadata": {}, 1972 | "output_type": "display_data" 1973 | }, 1974 | { 1975 | "name": "stdout", 1976 | "output_type": "stream", 1977 | "text": [ 1978 | "Epoch : 23 - loss : 0.5186 - acc: 0.9997\n", 1979 | "\n" 1980 | ] 1981 | }, 1982 | { 1983 | "data": { 1984 | "application/vnd.jupyter.widget-view+json": { 1985 | "model_id": "83722d7b0b7d406ca2bab64af3881225", 1986 | "version_major": 2, 1987 | "version_minor": 0 1988 | }, 1989 | "text/plain": [ 1990 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 1991 | ] 1992 | }, 1993 | "metadata": {}, 1994 | "output_type": "display_data" 1995 | }, 1996 | { 1997 | "name": "stdout", 1998 | "output_type": "stream", 1999 | "text": [ 2000 | "Epoch : 23 - val_loss : 4.0348 - val_acc: 0.9435\n", 2001 | "\n" 2002 | ] 2003 | }, 2004 | { 2005 | "data": { 2006 | "application/vnd.jupyter.widget-view+json": { 2007 | "model_id": "4b535b57cc6443ef80194950ffc73705", 2008 | "version_major": 2, 2009 | "version_minor": 0 2010 | }, 2011 | "text/plain": [ 2012 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 2013 | ] 2014 | }, 2015 | "metadata": {}, 2016 | "output_type": "display_data" 2017 | }, 2018 | { 2019 | "name": "stdout", 2020 | "output_type": "stream", 2021 | "text": [ 2022 | "Epoch : 24 - loss : 0.4421 - acc: 0.9999\n", 2023 | "\n" 2024 | ] 2025 | }, 2026 | { 2027 | "data": { 2028 | "application/vnd.jupyter.widget-view+json": { 2029 | "model_id": "15720d615707427b979d560257ae9c13", 2030 | "version_major": 2, 2031 | "version_minor": 0 2032 | }, 2033 | "text/plain": [ 2034 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 2035 | ] 2036 | }, 2037 | "metadata": {}, 2038 | "output_type": "display_data" 2039 | }, 2040 | { 2041 | "name": "stdout", 2042 | "output_type": "stream", 2043 | "text": [ 2044 | "Epoch : 24 - val_loss : 3.6783 - val_acc: 0.9464\n", 2045 | "\n", 2046 | "SAVED!\n", 2047 | "Unfrozen Blocks: 13, Current lr: 1.649267441664001e-05, Trainable Params: 91841537\n" 2048 | ] 2049 | }, 2050 | { 2051 | "data": { 2052 | "application/vnd.jupyter.widget-view+json": { 2053 | "model_id": "d647ef490d5e4397885957e56c46c4db", 2054 | "version_major": 2, 2055 | "version_minor": 0 2056 | }, 2057 | "text/plain": [ 2058 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 2059 | ] 2060 | }, 2061 | "metadata": {}, 2062 | "output_type": "display_data" 2063 | }, 2064 | { 2065 | "name": "stdout", 2066 | "output_type": "stream", 2067 | "text": [ 2068 | "Epoch : 25 - loss : 0.4184 - acc: 1.0000\n", 2069 | "\n" 2070 | ] 2071 | }, 2072 | { 2073 | "data": { 2074 | "application/vnd.jupyter.widget-view+json": { 2075 | "model_id": "3c5c37e9cea84f249edea413cad517b3", 2076 | "version_major": 2, 2077 | "version_minor": 0 2078 | }, 2079 | "text/plain": [ 2080 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 2081 | ] 2082 | }, 2083 | "metadata": {}, 2084 | "output_type": "display_data" 2085 | }, 2086 | { 2087 | "name": "stdout", 2088 | "output_type": "stream", 2089 | "text": [ 2090 | "Epoch : 25 - val_loss : 3.9668 - val_acc: 0.9425\n", 2091 | "\n" 2092 | ] 2093 | }, 2094 | { 2095 | "data": { 2096 | "application/vnd.jupyter.widget-view+json": { 2097 | "model_id": "aa957165fe3d460c85a58fe67701842e", 2098 | "version_major": 2, 2099 | "version_minor": 0 2100 | }, 2101 | "text/plain": [ 2102 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 2103 | ] 2104 | }, 2105 | "metadata": {}, 2106 | "output_type": "display_data" 2107 | }, 2108 | { 2109 | "name": "stdout", 2110 | "output_type": "stream", 2111 | "text": [ 2112 | "Epoch : 26 - loss : 0.4113 - acc: 1.0000\n", 2113 | "\n" 2114 | ] 2115 | }, 2116 | { 2117 | "data": { 2118 | "application/vnd.jupyter.widget-view+json": { 2119 | "model_id": "df1f36d9d95749a0bc03bba147d001ad", 2120 | "version_major": 2, 2121 | "version_minor": 0 2122 | }, 2123 | "text/plain": [ 2124 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 2125 | ] 2126 | }, 2127 | "metadata": {}, 2128 | "output_type": "display_data" 2129 | }, 2130 | { 2131 | "name": "stdout", 2132 | "output_type": "stream", 2133 | "text": [ 2134 | "Epoch : 26 - val_loss : 3.9590 - val_acc: 0.9398\n", 2135 | "\n", 2136 | "Unfrozen Blocks: 14, Current lr: 1.319413953331201e-05, Trainable Params: 91841537\n" 2137 | ] 2138 | }, 2139 | { 2140 | "data": { 2141 | "application/vnd.jupyter.widget-view+json": { 2142 | "model_id": "b7a42e21863d4061a7616e79d0d62aeb", 2143 | "version_major": 2, 2144 | "version_minor": 0 2145 | }, 2146 | "text/plain": [ 2147 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 2148 | ] 2149 | }, 2150 | "metadata": {}, 2151 | "output_type": "display_data" 2152 | }, 2153 | { 2154 | "name": "stdout", 2155 | "output_type": "stream", 2156 | "text": [ 2157 | "Epoch : 27 - loss : 0.3976 - acc: 1.0000\n", 2158 | "\n" 2159 | ] 2160 | }, 2161 | { 2162 | "data": { 2163 | "application/vnd.jupyter.widget-view+json": { 2164 | "model_id": "27e232c16fbd46fe9c6d093b9f2683bc", 2165 | "version_major": 2, 2166 | "version_minor": 0 2167 | }, 2168 | "text/plain": [ 2169 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 2170 | ] 2171 | }, 2172 | "metadata": {}, 2173 | "output_type": "display_data" 2174 | }, 2175 | { 2176 | "name": "stdout", 2177 | "output_type": "stream", 2178 | "text": [ 2179 | "Epoch : 27 - val_loss : 3.8370 - val_acc: 0.9414\n", 2180 | "\n" 2181 | ] 2182 | }, 2183 | { 2184 | "data": { 2185 | "application/vnd.jupyter.widget-view+json": { 2186 | "model_id": "36e13b0ec23445d0ad79294a9772e49e", 2187 | "version_major": 2, 2188 | "version_minor": 0 2189 | }, 2190 | "text/plain": [ 2191 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 2192 | ] 2193 | }, 2194 | "metadata": {}, 2195 | "output_type": "display_data" 2196 | }, 2197 | { 2198 | "name": "stdout", 2199 | "output_type": "stream", 2200 | "text": [ 2201 | "Epoch : 28 - loss : 0.3917 - acc: 1.0000\n", 2202 | "\n" 2203 | ] 2204 | }, 2205 | { 2206 | "data": { 2207 | "application/vnd.jupyter.widget-view+json": { 2208 | "model_id": "2ae58768db8a4bc98c8d2586bb563499", 2209 | "version_major": 2, 2210 | "version_minor": 0 2211 | }, 2212 | "text/plain": [ 2213 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 2214 | ] 2215 | }, 2216 | "metadata": {}, 2217 | "output_type": "display_data" 2218 | }, 2219 | { 2220 | "name": "stdout", 2221 | "output_type": "stream", 2222 | "text": [ 2223 | "Epoch : 28 - val_loss : 3.8097 - val_acc: 0.9422\n", 2224 | "\n", 2225 | "Unfrozen Blocks: 15, Current lr: 1.0555311626649608e-05, Trainable Params: 91841537\n" 2226 | ] 2227 | }, 2228 | { 2229 | "data": { 2230 | "application/vnd.jupyter.widget-view+json": { 2231 | "model_id": "ead97d19cb7640fa878a0f98ecc59db2", 2232 | "version_major": 2, 2233 | "version_minor": 0 2234 | }, 2235 | "text/plain": [ 2236 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))" 2237 | ] 2238 | }, 2239 | "metadata": {}, 2240 | "output_type": "display_data" 2241 | }, 2242 | { 2243 | "name": "stdout", 2244 | "output_type": "stream", 2245 | "text": [ 2246 | "Epoch : 29 - loss : 0.3875 - acc: 1.0000\n", 2247 | "\n" 2248 | ] 2249 | }, 2250 | { 2251 | "data": { 2252 | "application/vnd.jupyter.widget-view+json": { 2253 | "model_id": "64008a2d743a4aa88363b78b330385ca", 2254 | "version_major": 2, 2255 | "version_minor": 0 2256 | }, 2257 | "text/plain": [ 2258 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))" 2259 | ] 2260 | }, 2261 | "metadata": {}, 2262 | "output_type": "display_data" 2263 | }, 2264 | { 2265 | "name": "stdout", 2266 | "output_type": "stream", 2267 | "text": [ 2268 | "Epoch : 29 - val_loss : 0.5044 - val_acc: 0.1576\r" 2269 | ] 2270 | } 2271 | ], 2272 | "source": [ 2273 | "best_acc = 0.0\n", 2274 | "y_loss = {} # loss history\n", 2275 | "y_loss['train'] = []\n", 2276 | "y_loss['val'] = []\n", 2277 | "y_err = {}\n", 2278 | "y_err['train'] = []\n", 2279 | "y_err['val'] = []\n", 2280 | "print(\"training...\")\n", 2281 | "output_dir = \"\"\n", 2282 | "best_acc = 0\n", 2283 | "name = \"la_with_lmbd_{}\".format(lmbd)\n", 2284 | "\n", 2285 | "try:\n", 2286 | " os.mkdir(\"model/\" + name)\n", 2287 | "\n", 2288 | "except:\n", 2289 | " pass\n", 2290 | "output_dir = \"model/\" + name\n", 2291 | "unfrozen_blocks = 0\n", 2292 | "\n", 2293 | "for epoch in range(num_epochs):\n", 2294 | "\n", 2295 | " if epoch%unfreeze_after==0:\n", 2296 | " unfrozen_blocks += 1\n", 2297 | " model = unfreeze_blocks(model, unfrozen_blocks)\n", 2298 | " optimizer.param_groups[0]['lr'] *= lr_decay \n", 2299 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 2300 | " print(\"Unfrozen Blocks: {}, Current lr: {}, Trainable Params: {}\".format(unfrozen_blocks, \n", 2301 | " optimizer.param_groups[0]['lr'], \n", 2302 | " trainable_params))\n", 2303 | "\n", 2304 | " train_metrics = train_one_epoch(\n", 2305 | " epoch, model, train_loader, optimizer, criterion,\n", 2306 | " lr_scheduler=None, saver=None)\n", 2307 | "\n", 2308 | " eval_metrics = validate(model, valid_loader, criterion)\n", 2309 | "\n", 2310 | "\n", 2311 | " # update summary\n", 2312 | " update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n", 2313 | " write_header=True)\n", 2314 | "\n", 2315 | " # deep copy the model\n", 2316 | " last_model_wts = model.state_dict()\n", 2317 | " if eval_metrics['val_accuracy'] > best_acc:\n", 2318 | " best_acc = eval_metrics['val_accuracy']\n", 2319 | " save_network(model, epoch,name)\n", 2320 | " print(\"SAVED!\")" 2321 | ] 2322 | }, 2323 | { 2324 | "cell_type": "code", 2325 | "execution_count": null, 2326 | "metadata": {}, 2327 | "outputs": [], 2328 | "source": [] 2329 | }, 2330 | { 2331 | "cell_type": "code", 2332 | "execution_count": null, 2333 | "metadata": {}, 2334 | "outputs": [], 2335 | "source": [] 2336 | } 2337 | ], 2338 | "metadata": { 2339 | "kernelspec": { 2340 | "display_name": "Python 3", 2341 | "language": "python", 2342 | "name": "python3" 2343 | }, 2344 | "language_info": { 2345 | "codemirror_mode": { 2346 | "name": "ipython", 2347 | "version": 3 2348 | }, 2349 | "file_extension": ".py", 2350 | "mimetype": "text/x-python", 2351 | "name": "python", 2352 | "nbconvert_exporter": "python", 2353 | "pygments_lexer": "ipython3", 2354 | "version": "3.7.4" 2355 | } 2356 | }, 2357 | "nbformat": 4, 2358 | "nbformat_minor": 4 2359 | } 2360 | -------------------------------------------------------------------------------- /LATransformer/metrics.py: -------------------------------------------------------------------------------- 1 | def rank1(label, output): 2 | if label==output[1][0][0]: 3 | return True 4 | return False 5 | 6 | def rank5(label, output): 7 | if label in output[1][0][:5]: 8 | return True 9 | return False 10 | 11 | def rank10(label, output): 12 | if label in output[1][0][:10]: 13 | return True 14 | return False 15 | 16 | def calc_map(label, output): 17 | count = 0 18 | score = 0 19 | good = 0 20 | for out in output[1][0]: 21 | count += 1 22 | if out==label: 23 | good += 1 24 | score += (good/count) 25 | if good==0: 26 | return 0 27 | return score/good -------------------------------------------------------------------------------- /LATransformer/model.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import numpy as np 3 | import pandas as pd 4 | from PIL import Image 5 | from tqdm.notebook import tqdm 6 | import matplotlib.pyplot as plt 7 | from collections import OrderedDict 8 | from sklearn.model_selection import train_test_split 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import init 13 | import torch.optim as optim 14 | from torchvision import models 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | from torch.optim.lr_scheduler import StepLR 18 | from torchvision import datasets, transforms 19 | from torch.utils.data import DataLoader, Dataset 20 | 21 | 22 | # weights initialization 23 | def weights_init_kaiming(m): 24 | classname = m.__class__.__name__ 25 | # print(classname) 26 | if classname.find('Conv') != -1: 27 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal. 28 | elif classname.find('Linear') != -1: 29 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 30 | init.constant_(m.bias.data, 0.0) 31 | elif classname.find('BatchNorm1d') != -1: 32 | init.normal_(m.weight.data, 1.0, 0.02) 33 | init.constant_(m.bias.data, 0.0) 34 | 35 | def weights_init_classifier(m): 36 | classname = m.__class__.__name__ 37 | if classname.find('Linear') != -1: 38 | init.normal_(m.weight.data, std=0.001) 39 | init.constant_(m.bias.data, 0.0) 40 | 41 | class ClassBlock(nn.Module): 42 | def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f = False): 43 | super(ClassBlock, self).__init__() 44 | self.return_f = return_f 45 | add_block = [] 46 | if linear: 47 | add_block += [nn.Linear(input_dim, num_bottleneck)] 48 | else: 49 | num_bottleneck = input_dim 50 | if bnorm: 51 | add_block += [nn.BatchNorm1d(num_bottleneck)] 52 | if relu: 53 | add_block += [nn.LeakyReLU(0.1)] 54 | if droprate>0: 55 | add_block += [nn.Dropout(p=droprate)] 56 | add_block = nn.Sequential(*add_block) 57 | add_block.apply(weights_init_kaiming) 58 | 59 | classifier = [] 60 | classifier += [nn.Linear(num_bottleneck, class_num)] 61 | classifier = nn.Sequential(*classifier) 62 | classifier.apply(weights_init_classifier) 63 | 64 | self.add_block = add_block 65 | self.classifier = classifier 66 | def forward(self, x): 67 | x = self.add_block(x) 68 | if self.return_f: 69 | f = x 70 | x = self.classifier(x) 71 | return [x,f] 72 | else: 73 | x = self.classifier(x) 74 | return x 75 | 76 | class LATransformer(nn.Module): 77 | def __init__(self, model, lmbd ): 78 | super(LATransformer, self).__init__() 79 | 80 | self.class_num = 751 81 | self.part = 14 # We cut the pool5 to sqrt(N) parts 82 | self.num_blocks = 12 83 | self.model = model 84 | self.model.head.requires_grad_ = False 85 | self.cls_token = self.model.cls_token 86 | self.pos_embed = self.model.pos_embed 87 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,768)) 88 | self.dropout = nn.Dropout(p=0.5) 89 | self.lmbd = lmbd 90 | for i in range(self.part): 91 | name = 'classifier'+str(i) 92 | setattr(self, name, ClassBlock(768, self.class_num, droprate=0.5, relu=False, bnorm=True, num_bottleneck=256)) 93 | 94 | 95 | 96 | def forward(self,x): 97 | 98 | # Divide input image into patch embeddings and add position embeddings 99 | x = self.model.patch_embed(x) 100 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 101 | x = torch.cat((cls_token, x), dim=1) 102 | x = self.model.pos_drop(x + self.pos_embed) 103 | 104 | # Feed forward through transformer blocks 105 | for i in range(self.num_blocks): 106 | x = self.model.blocks[i](x) 107 | x = self.model.norm(x) 108 | 109 | # extract the cls token 110 | cls_token_out = x[:, 0].unsqueeze(1) 111 | 112 | # Average pool 113 | x = self.avgpool(x[:, 1:]) 114 | 115 | # Add global cls token to each local token 116 | for i in range(self.part): 117 | out = torch.mul(x[:, i, :], self.lmbd) 118 | x[:,i,:] = torch.div(torch.add(cls_token_out.squeeze(),out), 1+self.lmbd) 119 | 120 | # Locally aware network 121 | part = {} 122 | predict = {} 123 | for i in range(self.part): 124 | part[i] = x[:,i,:] 125 | name = 'classifier'+str(i) 126 | c = getattr(self,name) 127 | predict[i] = c(part[i]) 128 | return predict 129 | 130 | class LATransformerTest(nn.Module): 131 | def __init__(self, model, lmbd ): 132 | super(LATransformerTest, self).__init__() 133 | 134 | self.class_num = 751 135 | self.part = 14 # We cut the pool5 to sqrt(N) parts 136 | self.num_blocks = 12 137 | self.model = model 138 | self.model.head.requires_grad_ = False 139 | self.cls_token = self.model.cls_token 140 | self.pos_embed = self.model.pos_embed 141 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,768)) 142 | self.dropout = nn.Dropout(p=0.5) 143 | self.lmbd = lmbd 144 | # for i in range(self.part): 145 | # name = 'classifier'+str(i) 146 | # setattr(self, name, ClassBlock(768, self.class_num, droprate=0.5, relu=False, bnorm=True, num_bottleneck=256)) 147 | 148 | 149 | 150 | def forward(self,x): 151 | 152 | # Divide input image into patch embeddings and add position embeddings 153 | x = self.model.patch_embed(x) 154 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 155 | x = torch.cat((cls_token, x), dim=1) 156 | x = self.model.pos_drop(x + self.pos_embed) 157 | 158 | # Feed forward through transformer blocks 159 | for i in range(self.num_blocks): 160 | x = self.model.blocks[i](x) 161 | x = self.model.norm(x) 162 | 163 | # extract the cls token 164 | cls_token_out = x[:, 0].unsqueeze(1) 165 | 166 | # Average pool 167 | x = self.avgpool(x[:, 1:]) 168 | 169 | # Add global cls token to each local token 170 | # for i in range(self.part): 171 | # out = torch.mul(x[:, i, :], self.lmbd) 172 | # x[:,i,:] = torch.div(torch.add(cls_token_out.squeeze(),out), 1+self.lmbd) 173 | 174 | return x.cpu() -------------------------------------------------------------------------------- /LATransformer/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | from collections import OrderedDict 5 | 6 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): 7 | rowd = OrderedDict(epoch=epoch) 8 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 9 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 10 | with open(filename, mode='a') as cf: 11 | dw = csv.DictWriter(cf, fieldnames=rowd.keys()) 12 | if write_header: # first iteration (epoch == 1 can't be used) 13 | dw.writeheader() 14 | dw.writerow(rowd) 15 | 16 | def save_network(network, epoch_label, name): 17 | save_filename = 'net_%s.pth'% "best" 18 | save_path = os.path.join('./model',name,save_filename) 19 | torch.save(network.cpu().state_dict(), save_path) 20 | 21 | if torch.cuda.is_available(): 22 | network.cuda() 23 | 24 | def get_id(img_path): 25 | camera_id = [] 26 | labels = [] 27 | for path, v in img_path: 28 | #filename = path.split('/')[-1] 29 | filename = os.path.basename(path) 30 | label = filename[0:4] 31 | camera = filename.split('c')[1] 32 | if label[0:2]=='-1': 33 | labels.append(-1) 34 | else: 35 | labels.append(int(label)) 36 | camera_id.append(int(camera[0])) 37 | return camera_id, labels -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Siddhant Kapil 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Person Re-Identification with a Locally Aware Transformer 2 | 3 | 4 | This code is inspired from: 5 | 6 | 7 | 1) PCB - https://github.com/layumi/Person_reID_baseline_pytorch 8 | 2) Vit - https://github.com/lucidrains/vit-pytorch/tree/main/examples 9 | 3) Pre-trained models: https://github.com/rwightman/pytorch-image-models 10 | 11 | ## Release 7/5/21 12 | Demonstrates the working and performance of the LA-Transformer using two jupyter notebooks. 13 | 14 | 1) LA-Transformer Training: Demonstrates the training process. We have included cell outputs in the juyter notebook. In the 15 | last cell, training results are shown. One can also refer to model/{name}/summary.csv if the cell outputs are not clear. To 16 | run the jupyter notebook, install the requirements, download dataset using the link provided and extract it in data folder. 17 | 18 | 2) LA-Transformer Testing: Demonstrates the testing process. You can download the weights using the link below or train 19 | LA-transformer using the Training notebook. To use pre-trained weights, download them using the gdrive link below, extract 20 | them into model/{name} folder and run the Testing notebook. Performance metrics can be found in the last cell of the notebook. 21 | 22 | ## Requirements: 23 | 24 | - Torch==1.8.1 & torchvision==0.8.2: [Link](https://pytorch.org/) 25 | - timm==0.3.2: [Link](https://github.com/rwightman/pytorch-image-models) 26 | - faiss==1.6.3: [Link](https://github.com/facebookresearch/faiss) 27 | - tqdm==4.54.0 28 | - numpy==1.19.5 29 | 30 | ## Read-Only Versions: 31 | LA-Transformer Training.html and LA-Transformer Testing.html are the read-only versions containing outputs to quickly verfiy the working of LA-Transformer. 32 | 33 | ## Google Drive: 34 | 35 | Pretrained weights and dataset can be found on [this](https://drive.google.com/drive/folders/1CRkfn9iLEItaYur1WGf2abvpd2vT7nRB?usp=sharing) google drive. To remain anonymous we created a temporary gmail account to host weights and datasets. It will be changed to official account later. 36 | --------------------------------------------------------------------------------