├── .gitignore
├── LA-Transformer Testing.html
├── LA-Transformer Testing.ipynb
├── LA-Transformer Training.html
├── LA-Transformer Training.ipynb
├── LATransformer
├── metrics.py
├── model.py
└── utils.py
├── LICENSE
└── Readme.md
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LA-Transformer Testing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from __future__ import print_function\n",
10 | "\n",
11 | "import os\n",
12 | "import time\n",
13 | "import glob\n",
14 | "import random\n",
15 | "import zipfile\n",
16 | "from itertools import chain\n",
17 | "\n",
18 | "import timm\n",
19 | "import numpy as np\n",
20 | "import pandas as pd\n",
21 | "from PIL import Image\n",
22 | "from tqdm.notebook import tqdm\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "from collections import OrderedDict\n",
25 | "from sklearn.model_selection import train_test_split\n",
26 | "\n",
27 | "import torch\n",
28 | "import torch.nn as nn\n",
29 | "from torch.nn import init\n",
30 | "import torch.optim as optim\n",
31 | "from torchvision import models\n",
32 | "import torch.nn.functional as F\n",
33 | "from torch.autograd import Variable\n",
34 | "from torch.optim.lr_scheduler import StepLR\n",
35 | "from torchvision import datasets, transforms\n",
36 | "from torch.utils.data import DataLoader, Dataset\n",
37 | "\n",
38 | "from LATransformer.model import ClassBlock, LATransformer, LATransformerTest\n",
39 | "from LATransformer.utils import save_network, update_summary, get_id\n",
40 | "from LATransformer.metrics import rank1, rank5, rank10, calc_map\n",
41 | "\n",
42 | "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
43 | "device = \"cuda\""
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "## Config Parameters"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": 2,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "batch_size = 8\n",
60 | "gamma = 0.7\n",
61 | "seed = 42"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "metadata": {},
67 | "source": [
68 | "## Load Model"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": 3,
74 | "metadata": {},
75 | "outputs": [
76 | {
77 | "data": {
78 | "text/plain": [
79 | "LATransformerTest(\n",
80 | " (model): VisionTransformer(\n",
81 | " (patch_embed): PatchEmbed(\n",
82 | " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
83 | " )\n",
84 | " (pos_drop): Dropout(p=0.0, inplace=False)\n",
85 | " (blocks): ModuleList(\n",
86 | " (0): Block(\n",
87 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
88 | " (attn): Attention(\n",
89 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
90 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
91 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
92 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
93 | " )\n",
94 | " (drop_path): Identity()\n",
95 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
96 | " (mlp): Mlp(\n",
97 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
98 | " (act): GELU()\n",
99 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
100 | " (drop): Dropout(p=0.0, inplace=False)\n",
101 | " )\n",
102 | " )\n",
103 | " (1): Block(\n",
104 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
105 | " (attn): Attention(\n",
106 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
107 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
108 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
109 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
110 | " )\n",
111 | " (drop_path): Identity()\n",
112 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
113 | " (mlp): Mlp(\n",
114 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
115 | " (act): GELU()\n",
116 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
117 | " (drop): Dropout(p=0.0, inplace=False)\n",
118 | " )\n",
119 | " )\n",
120 | " (2): Block(\n",
121 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
122 | " (attn): Attention(\n",
123 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
124 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
125 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
126 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
127 | " )\n",
128 | " (drop_path): Identity()\n",
129 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
130 | " (mlp): Mlp(\n",
131 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
132 | " (act): GELU()\n",
133 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
134 | " (drop): Dropout(p=0.0, inplace=False)\n",
135 | " )\n",
136 | " )\n",
137 | " (3): Block(\n",
138 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
139 | " (attn): Attention(\n",
140 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
141 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
142 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
143 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
144 | " )\n",
145 | " (drop_path): Identity()\n",
146 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
147 | " (mlp): Mlp(\n",
148 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
149 | " (act): GELU()\n",
150 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
151 | " (drop): Dropout(p=0.0, inplace=False)\n",
152 | " )\n",
153 | " )\n",
154 | " (4): Block(\n",
155 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
156 | " (attn): Attention(\n",
157 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
158 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
159 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
160 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
161 | " )\n",
162 | " (drop_path): Identity()\n",
163 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
164 | " (mlp): Mlp(\n",
165 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
166 | " (act): GELU()\n",
167 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
168 | " (drop): Dropout(p=0.0, inplace=False)\n",
169 | " )\n",
170 | " )\n",
171 | " (5): Block(\n",
172 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
173 | " (attn): Attention(\n",
174 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
175 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
176 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
177 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
178 | " )\n",
179 | " (drop_path): Identity()\n",
180 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
181 | " (mlp): Mlp(\n",
182 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
183 | " (act): GELU()\n",
184 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
185 | " (drop): Dropout(p=0.0, inplace=False)\n",
186 | " )\n",
187 | " )\n",
188 | " (6): Block(\n",
189 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
190 | " (attn): Attention(\n",
191 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
192 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
193 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
194 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
195 | " )\n",
196 | " (drop_path): Identity()\n",
197 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
198 | " (mlp): Mlp(\n",
199 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
200 | " (act): GELU()\n",
201 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
202 | " (drop): Dropout(p=0.0, inplace=False)\n",
203 | " )\n",
204 | " )\n",
205 | " (7): Block(\n",
206 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
207 | " (attn): Attention(\n",
208 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
209 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
210 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
211 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
212 | " )\n",
213 | " (drop_path): Identity()\n",
214 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
215 | " (mlp): Mlp(\n",
216 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
217 | " (act): GELU()\n",
218 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
219 | " (drop): Dropout(p=0.0, inplace=False)\n",
220 | " )\n",
221 | " )\n",
222 | " (8): Block(\n",
223 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
224 | " (attn): Attention(\n",
225 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
226 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
227 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
228 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
229 | " )\n",
230 | " (drop_path): Identity()\n",
231 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
232 | " (mlp): Mlp(\n",
233 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
234 | " (act): GELU()\n",
235 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
236 | " (drop): Dropout(p=0.0, inplace=False)\n",
237 | " )\n",
238 | " )\n",
239 | " (9): Block(\n",
240 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
241 | " (attn): Attention(\n",
242 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
243 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
244 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
245 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
246 | " )\n",
247 | " (drop_path): Identity()\n",
248 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
249 | " (mlp): Mlp(\n",
250 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
251 | " (act): GELU()\n",
252 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
253 | " (drop): Dropout(p=0.0, inplace=False)\n",
254 | " )\n",
255 | " )\n",
256 | " (10): Block(\n",
257 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
258 | " (attn): Attention(\n",
259 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
260 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
261 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
262 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
263 | " )\n",
264 | " (drop_path): Identity()\n",
265 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
266 | " (mlp): Mlp(\n",
267 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
268 | " (act): GELU()\n",
269 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
270 | " (drop): Dropout(p=0.0, inplace=False)\n",
271 | " )\n",
272 | " )\n",
273 | " (11): Block(\n",
274 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
275 | " (attn): Attention(\n",
276 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
277 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
278 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
279 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
280 | " )\n",
281 | " (drop_path): Identity()\n",
282 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
283 | " (mlp): Mlp(\n",
284 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
285 | " (act): GELU()\n",
286 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
287 | " (drop): Dropout(p=0.0, inplace=False)\n",
288 | " )\n",
289 | " )\n",
290 | " )\n",
291 | " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
292 | " (head): Linear(in_features=768, out_features=751, bias=True)\n",
293 | " )\n",
294 | " (avgpool): AdaptiveAvgPool2d(output_size=(14, 768))\n",
295 | " (dropout): Dropout(p=0.5, inplace=False)\n",
296 | ")"
297 | ]
298 | },
299 | "execution_count": 3,
300 | "metadata": {},
301 | "output_type": "execute_result"
302 | }
303 | ],
304 | "source": [
305 | "# Load ViT\n",
306 | "vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)\n",
307 | "vit_base= vit_base.to(device)\n",
308 | "\n",
309 | "# Create La-Transformer\n",
310 | "model = LATransformerTest(vit_base, lmbd=8).to(device)\n",
311 | "\n",
312 | "# Load LA-Transformer\n",
313 | "name = \"la_with_lmbd_8\"\n",
314 | "save_path = os.path.join('./model',name,'net_best.pth')\n",
315 | "model.load_state_dict(torch.load(save_path), strict=False)\n",
316 | "model.eval()"
317 | ]
318 | },
319 | {
320 | "cell_type": "markdown",
321 | "metadata": {},
322 | "source": [
323 | "\n",
324 | "\n",
325 | "### DataLoader"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": 4,
331 | "metadata": {},
332 | "outputs": [],
333 | "source": [
334 | "transform_query_list = [\n",
335 | " transforms.Resize((224,224), interpolation=3),\n",
336 | " transforms.RandomHorizontalFlip(),\n",
337 | " transforms.ToTensor(),\n",
338 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
339 | " ]\n",
340 | "transform_gallery_list = [\n",
341 | " transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC\n",
342 | " transforms.ToTensor(),\n",
343 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
344 | " ]\n",
345 | "data_transforms = {\n",
346 | "'query': transforms.Compose( transform_query_list ),\n",
347 | "'gallery': transforms.Compose(transform_gallery_list),\n",
348 | "}"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": 5,
354 | "metadata": {},
355 | "outputs": [
356 | {
357 | "name": "stdout",
358 | "output_type": "stream",
359 | "text": [
360 | "750\n"
361 | ]
362 | }
363 | ],
364 | "source": [
365 | "image_datasets = {}\n",
366 | "data_dir = \"data/Market-Pytorch/Market/\"\n",
367 | "\n",
368 | "image_datasets['query'] = datasets.ImageFolder(os.path.join(data_dir, 'query'),\n",
369 | " data_transforms['query'])\n",
370 | "image_datasets['gallery'] = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),\n",
371 | " data_transforms['gallery'])\n",
372 | "query_loader = DataLoader(dataset = image_datasets['query'], batch_size=batch_size, shuffle=False )\n",
373 | "gallery_loader = DataLoader(dataset = image_datasets['gallery'], batch_size=batch_size, shuffle=False)\n",
374 | "\n",
375 | "class_names = image_datasets['query'].classes\n",
376 | "print(len(class_names))"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "metadata": {},
382 | "source": [
383 | "### Extract Features"
384 | ]
385 | },
386 | {
387 | "cell_type": "code",
388 | "execution_count": 6,
389 | "metadata": {},
390 | "outputs": [],
391 | "source": [
392 | "activation = {}\n",
393 | "def get_activation(name):\n",
394 | " def hook(model, input, output):\n",
395 | " activation[name] = output.detach()\n",
396 | " return hook"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": 7,
402 | "metadata": {},
403 | "outputs": [],
404 | "source": [
405 | "def extract_feature(model,dataloaders):\n",
406 | " \n",
407 | " features = torch.FloatTensor()\n",
408 | " count = 0\n",
409 | " idx = 0\n",
410 | " for data in tqdm(dataloaders):\n",
411 | " img, label = data\n",
412 | " img, label = img.to(device), label.to(device)\n",
413 | "\n",
414 | " output = model(img)\n",
415 | "\n",
416 | " n, c, h, w = img.size()\n",
417 | " \n",
418 | " count += n\n",
419 | " features = torch.cat((features, output.detach().cpu()), 0)\n",
420 | " idx += 1\n",
421 | " return features"
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": 8,
427 | "metadata": {
428 | "scrolled": true
429 | },
430 | "outputs": [
431 | {
432 | "data": {
433 | "application/vnd.jupyter.widget-view+json": {
434 | "model_id": "febb2a07ac2f42178b9fdec40350e415",
435 | "version_major": 2,
436 | "version_minor": 0
437 | },
438 | "text/plain": [
439 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=421.0), HTML(value='')))"
440 | ]
441 | },
442 | "metadata": {},
443 | "output_type": "display_data"
444 | },
445 | {
446 | "name": "stdout",
447 | "output_type": "stream",
448 | "text": [
449 | "\n"
450 | ]
451 | },
452 | {
453 | "data": {
454 | "application/vnd.jupyter.widget-view+json": {
455 | "model_id": "7a1342f8f990420e90d234818e474955",
456 | "version_major": 2,
457 | "version_minor": 0
458 | },
459 | "text/plain": [
460 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2467.0), HTML(value='')))"
461 | ]
462 | },
463 | "metadata": {},
464 | "output_type": "display_data"
465 | },
466 | {
467 | "name": "stdout",
468 | "output_type": "stream",
469 | "text": [
470 | "\n"
471 | ]
472 | }
473 | ],
474 | "source": [
475 | "# Extract Query Features\n",
476 | "query_feature= extract_feature(model, query_loader)\n",
477 | "\n",
478 | "# Extract Gallery Features\n",
479 | "gallery_feature = extract_feature(model, gallery_loader)"
480 | ]
481 | },
482 | {
483 | "cell_type": "code",
484 | "execution_count": 9,
485 | "metadata": {},
486 | "outputs": [],
487 | "source": [
488 | "# Retrieve labels\n",
489 | "gallery_path = image_datasets['gallery'].imgs\n",
490 | "query_path = image_datasets['query'].imgs\n",
491 | "\n",
492 | "gallery_cam,gallery_label = get_id(gallery_path)\n",
493 | "query_cam,query_label = get_id(query_path)"
494 | ]
495 | },
496 | {
497 | "cell_type": "markdown",
498 | "metadata": {},
499 | "source": [
500 | "## Concat Averaged GELTs"
501 | ]
502 | },
503 | {
504 | "cell_type": "code",
505 | "execution_count": 10,
506 | "metadata": {},
507 | "outputs": [
508 | {
509 | "data": {
510 | "application/vnd.jupyter.widget-view+json": {
511 | "model_id": "07fbd13be6e943e7ab65d3a7354c5d87",
512 | "version_major": 2,
513 | "version_minor": 0
514 | },
515 | "text/plain": [
516 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3368.0), HTML(value='')))"
517 | ]
518 | },
519 | "metadata": {},
520 | "output_type": "display_data"
521 | },
522 | {
523 | "name": "stdout",
524 | "output_type": "stream",
525 | "text": [
526 | "\n"
527 | ]
528 | },
529 | {
530 | "data": {
531 | "application/vnd.jupyter.widget-view+json": {
532 | "model_id": "1afc1486a57544ec9a2365265fab4281",
533 | "version_major": 2,
534 | "version_minor": 0
535 | },
536 | "text/plain": [
537 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=19732.0), HTML(value='')))"
538 | ]
539 | },
540 | "metadata": {},
541 | "output_type": "display_data"
542 | },
543 | {
544 | "name": "stdout",
545 | "output_type": "stream",
546 | "text": [
547 | "\n"
548 | ]
549 | }
550 | ],
551 | "source": [
552 | "concatenated_query_vectors = []\n",
553 | "for query in tqdm(query_feature):\n",
554 | " \n",
555 | " fnorm = torch.norm(query, p=2, dim=1, keepdim=True)*np.sqrt(14)\n",
556 | " \n",
557 | " query_norm = query.div(fnorm.expand_as(query))\n",
558 | " \n",
559 | " concatenated_query_vectors.append(query_norm.view((-1))) # 14*768 -> 10752\n",
560 | "\n",
561 | "concatenated_gallery_vectors = []\n",
562 | "for gallery in tqdm(gallery_feature):\n",
563 | " \n",
564 | " fnorm = torch.norm(gallery, p=2, dim=1, keepdim=True) *np.sqrt(14)\n",
565 | " \n",
566 | " gallery_norm = gallery.div(fnorm.expand_as(gallery))\n",
567 | " \n",
568 | " concatenated_gallery_vectors.append(gallery_norm.view((-1))) # 14*768 -> 10752\n",
569 | " "
570 | ]
571 | },
572 | {
573 | "cell_type": "markdown",
574 | "metadata": {},
575 | "source": [
576 | "## Calculate Similarity using FAISS"
577 | ]
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": 11,
582 | "metadata": {},
583 | "outputs": [],
584 | "source": [
585 | "import faiss\n",
586 | "import numpy as np\n",
587 | "\n",
588 | "\n",
589 | "index = faiss.IndexIDMap(faiss.IndexFlatIP(10752))\n",
590 | "\n",
591 | "index.add_with_ids(np.array([t.numpy() for t in concatenated_gallery_vectors]),np.array(gallery_label))\n",
592 | "\n",
593 | "# xb = np.array([t.numpy() for t in concatenated_gallery_vectors]).astype(dtype=np.float32)\n",
594 | "# index = faiss.IndexFlatL2(10752) \n",
595 | "# ids = np.array(gallery_label, dtype=np.float32)\n",
596 | "# index2 = faiss.IndexIDMap(index)\n",
597 | "# index2.add_with_ids(xb, ids)\n",
598 | "\n",
599 | "\n",
600 | "def search(query: str, k=1):\n",
601 | " encoded_query = query.unsqueeze(dim=0).numpy()\n",
602 | " top_k = index.search(encoded_query, k)\n",
603 | " return top_k"
604 | ]
605 | },
606 | {
607 | "cell_type": "code",
608 | "execution_count": 12,
609 | "metadata": {},
610 | "outputs": [
611 | {
612 | "name": "stdout",
613 | "output_type": "stream",
614 | "text": [
615 | "Rank1: 0.9833729216152018, Rank5: 0.9973277909738717, Rank10: 0.9982185273159145, mAP: 0.9279050887119389\n"
616 | ]
617 | }
618 | ],
619 | "source": [
620 | "rank1_score = 0\n",
621 | "rank5_score = 0\n",
622 | "rank10_score = 0\n",
623 | "ap = 0\n",
624 | "count = 0\n",
625 | "for query, label in zip(concatenated_query_vectors, query_label):\n",
626 | " count += 1\n",
627 | " label = label\n",
628 | " output = search(query, k=10)\n",
629 | "# print(output)\n",
630 | " rank1_score += rank1(label, output) \n",
631 | " rank5_score += rank5(label, output) \n",
632 | " rank10_score += rank10(label, output) \n",
633 | " print(\"Correct: {}, Total: {}, Incorrect: {}\".format(rank1_score, count, count-rank1_score), end=\"\\r\")\n",
634 | " ap += calc_map(label, output)\n",
635 | "\n",
636 | "print(\"Rank1: {}, Rank5: {}, Rank10: {}, mAP: {}\".format(rank1_score/len(query_feature), \n",
637 | " rank5_score/len(query_feature), \n",
638 | " rank10_score/len(query_feature), ap/len(query_feature))) "
639 | ]
640 | },
641 | {
642 | "cell_type": "code",
643 | "execution_count": null,
644 | "metadata": {},
645 | "outputs": [],
646 | "source": []
647 | }
648 | ],
649 | "metadata": {
650 | "kernelspec": {
651 | "display_name": "Python 3",
652 | "language": "python",
653 | "name": "python3"
654 | },
655 | "language_info": {
656 | "codemirror_mode": {
657 | "name": "ipython",
658 | "version": 3
659 | },
660 | "file_extension": ".py",
661 | "mimetype": "text/x-python",
662 | "name": "python",
663 | "nbconvert_exporter": "python",
664 | "pygments_lexer": "ipython3",
665 | "version": "3.7.4"
666 | }
667 | },
668 | "nbformat": 4,
669 | "nbformat_minor": 4
670 | }
671 |
--------------------------------------------------------------------------------
/LA-Transformer Training.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Import Libraries"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "from __future__ import print_function\n",
17 | "\n",
18 | "import os\n",
19 | "import time\n",
20 | "import random\n",
21 | "import zipfile\n",
22 | "from itertools import chain\n",
23 | "\n",
24 | "import timm\n",
25 | "import numpy as np\n",
26 | "from PIL import Image\n",
27 | "from tqdm.notebook import tqdm\n",
28 | "from collections import OrderedDict\n",
29 | "\n",
30 | "import torch\n",
31 | "import torch.nn as nn\n",
32 | "from torch.nn import init\n",
33 | "import torch.optim as optim\n",
34 | "from torchvision import models\n",
35 | "import torch.nn.functional as F\n",
36 | "from torch.autograd import Variable\n",
37 | "from torch.optim.lr_scheduler import StepLR\n",
38 | "from torchvision import datasets, transforms\n",
39 | "from torch.utils.data import DataLoader, Dataset\n",
40 | "\n",
41 | "from LATransformer.model import ClassBlock, LATransformer\n",
42 | "from LATransformer.utils import save_network, update_summary\n",
43 | "\n",
44 | "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
45 | "device = \"cuda\""
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "### Set Config Parameters"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 2,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "batch_size = 32\n",
62 | "num_epochs = 30\n",
63 | "lr = 3e-4\n",
64 | "gamma = 0.7\n",
65 | "unfreeze_after=2\n",
66 | "lr_decay=.8\n",
67 | "lmbd = 8"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "## Load Data"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 3,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "transform_train_list = [\n",
84 | " transforms.Resize((224,224), interpolation=3),\n",
85 | " transforms.RandomHorizontalFlip(),\n",
86 | " transforms.ToTensor(),\n",
87 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
88 | " ]\n",
89 | "transform_val_list = [\n",
90 | " transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC\n",
91 | " transforms.ToTensor(),\n",
92 | " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
93 | " ]\n",
94 | "data_transforms = {\n",
95 | "'train': transforms.Compose( transform_train_list ),\n",
96 | "'val': transforms.Compose(transform_val_list),\n",
97 | "}"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 4,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "name": "stdout",
107 | "output_type": "stream",
108 | "text": [
109 | "751\n"
110 | ]
111 | }
112 | ],
113 | "source": [
114 | "image_datasets = {}\n",
115 | "data_dir = \"data/Market-Pytorch/Market/\"\n",
116 | "\n",
117 | "image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),\n",
118 | " data_transforms['train'])\n",
119 | "image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),\n",
120 | " data_transforms['val'])\n",
121 | "train_loader = DataLoader(dataset = image_datasets['train'], batch_size=batch_size, shuffle=True )\n",
122 | "valid_loader = DataLoader(dataset = image_datasets['val'], batch_size=batch_size, shuffle=True)\n",
123 | "# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,\n",
124 | "# shuffle=True, num_workers=8, pin_memory=True) # 8 workers may work faster\n",
125 | "# for x in ['train', 'val']}\n",
126 | "# dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}\n",
127 | "class_names = image_datasets['train'].classes\n",
128 | "print(len(class_names))"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "## Load Model"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 5,
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "data": {
145 | "text/plain": [
146 | "VisionTransformer(\n",
147 | " (patch_embed): PatchEmbed(\n",
148 | " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
149 | " )\n",
150 | " (pos_drop): Dropout(p=0.0, inplace=False)\n",
151 | " (blocks): ModuleList(\n",
152 | " (0): Block(\n",
153 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
154 | " (attn): Attention(\n",
155 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
156 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
157 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
158 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
159 | " )\n",
160 | " (drop_path): Identity()\n",
161 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
162 | " (mlp): Mlp(\n",
163 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
164 | " (act): GELU()\n",
165 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
166 | " (drop): Dropout(p=0.0, inplace=False)\n",
167 | " )\n",
168 | " )\n",
169 | " (1): Block(\n",
170 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
171 | " (attn): Attention(\n",
172 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
173 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
174 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
175 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
176 | " )\n",
177 | " (drop_path): Identity()\n",
178 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
179 | " (mlp): Mlp(\n",
180 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
181 | " (act): GELU()\n",
182 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
183 | " (drop): Dropout(p=0.0, inplace=False)\n",
184 | " )\n",
185 | " )\n",
186 | " (2): Block(\n",
187 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
188 | " (attn): Attention(\n",
189 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
190 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
191 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
192 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
193 | " )\n",
194 | " (drop_path): Identity()\n",
195 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
196 | " (mlp): Mlp(\n",
197 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
198 | " (act): GELU()\n",
199 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
200 | " (drop): Dropout(p=0.0, inplace=False)\n",
201 | " )\n",
202 | " )\n",
203 | " (3): Block(\n",
204 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
205 | " (attn): Attention(\n",
206 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
207 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
208 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
209 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
210 | " )\n",
211 | " (drop_path): Identity()\n",
212 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
213 | " (mlp): Mlp(\n",
214 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
215 | " (act): GELU()\n",
216 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
217 | " (drop): Dropout(p=0.0, inplace=False)\n",
218 | " )\n",
219 | " )\n",
220 | " (4): Block(\n",
221 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
222 | " (attn): Attention(\n",
223 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
224 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
225 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
226 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
227 | " )\n",
228 | " (drop_path): Identity()\n",
229 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
230 | " (mlp): Mlp(\n",
231 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
232 | " (act): GELU()\n",
233 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
234 | " (drop): Dropout(p=0.0, inplace=False)\n",
235 | " )\n",
236 | " )\n",
237 | " (5): Block(\n",
238 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
239 | " (attn): Attention(\n",
240 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
241 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
242 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
243 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
244 | " )\n",
245 | " (drop_path): Identity()\n",
246 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
247 | " (mlp): Mlp(\n",
248 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
249 | " (act): GELU()\n",
250 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
251 | " (drop): Dropout(p=0.0, inplace=False)\n",
252 | " )\n",
253 | " )\n",
254 | " (6): Block(\n",
255 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
256 | " (attn): Attention(\n",
257 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
258 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
259 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
260 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
261 | " )\n",
262 | " (drop_path): Identity()\n",
263 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
264 | " (mlp): Mlp(\n",
265 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
266 | " (act): GELU()\n",
267 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
268 | " (drop): Dropout(p=0.0, inplace=False)\n",
269 | " )\n",
270 | " )\n",
271 | " (7): Block(\n",
272 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
273 | " (attn): Attention(\n",
274 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
275 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
276 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
277 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
278 | " )\n",
279 | " (drop_path): Identity()\n",
280 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
281 | " (mlp): Mlp(\n",
282 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
283 | " (act): GELU()\n",
284 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
285 | " (drop): Dropout(p=0.0, inplace=False)\n",
286 | " )\n",
287 | " )\n",
288 | " (8): Block(\n",
289 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
290 | " (attn): Attention(\n",
291 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
292 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
293 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
294 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
295 | " )\n",
296 | " (drop_path): Identity()\n",
297 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
298 | " (mlp): Mlp(\n",
299 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
300 | " (act): GELU()\n",
301 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
302 | " (drop): Dropout(p=0.0, inplace=False)\n",
303 | " )\n",
304 | " )\n",
305 | " (9): Block(\n",
306 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
307 | " (attn): Attention(\n",
308 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
309 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
310 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
311 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
312 | " )\n",
313 | " (drop_path): Identity()\n",
314 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
315 | " (mlp): Mlp(\n",
316 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
317 | " (act): GELU()\n",
318 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
319 | " (drop): Dropout(p=0.0, inplace=False)\n",
320 | " )\n",
321 | " )\n",
322 | " (10): Block(\n",
323 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
324 | " (attn): Attention(\n",
325 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
326 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
327 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
328 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
329 | " )\n",
330 | " (drop_path): Identity()\n",
331 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
332 | " (mlp): Mlp(\n",
333 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
334 | " (act): GELU()\n",
335 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
336 | " (drop): Dropout(p=0.0, inplace=False)\n",
337 | " )\n",
338 | " )\n",
339 | " (11): Block(\n",
340 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
341 | " (attn): Attention(\n",
342 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
343 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
344 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
345 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
346 | " )\n",
347 | " (drop_path): Identity()\n",
348 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
349 | " (mlp): Mlp(\n",
350 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
351 | " (act): GELU()\n",
352 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
353 | " (drop): Dropout(p=0.0, inplace=False)\n",
354 | " )\n",
355 | " )\n",
356 | " )\n",
357 | " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
358 | " (head): Linear(in_features=768, out_features=751, bias=True)\n",
359 | ")"
360 | ]
361 | },
362 | "execution_count": 5,
363 | "metadata": {},
364 | "output_type": "execute_result"
365 | }
366 | ],
367 | "source": [
368 | "# Load pre-trained ViT\n",
369 | "vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)\n",
370 | "vit_base= vit_base.to(device)\n",
371 | "vit_base.eval()"
372 | ]
373 | },
374 | {
375 | "cell_type": "markdown",
376 | "metadata": {},
377 | "source": [
378 | "\n",
379 | "\n",
380 | "### Train"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": 6,
386 | "metadata": {},
387 | "outputs": [],
388 | "source": [
389 | "class AverageMeter:\n",
390 | " \"\"\"Computes and stores the average and current value\"\"\"\n",
391 | " def __init__(self):\n",
392 | " self.reset()\n",
393 | "\n",
394 | " def reset(self):\n",
395 | " self.val = 0\n",
396 | " self.avg = 0\n",
397 | " self.sum = 0\n",
398 | " self.count = 0\n",
399 | "\n",
400 | " def update(self, val, n=1):\n",
401 | " self.val = val\n",
402 | " self.sum += val * n\n",
403 | " self.count += n\n",
404 | " self.avg = self.sum / self.count"
405 | ]
406 | },
407 | {
408 | "cell_type": "code",
409 | "execution_count": 7,
410 | "metadata": {},
411 | "outputs": [],
412 | "source": [
413 | "def validate(model, loader, loss_fn):\n",
414 | " batch_time_m = AverageMeter()\n",
415 | " losses_m = AverageMeter()\n",
416 | " top1_m = AverageMeter()\n",
417 | " top5_m = AverageMeter()\n",
418 | "\n",
419 | " model.eval()\n",
420 | " epoch_accuracy = 0\n",
421 | " epoch_loss = 0\n",
422 | " end = time.time()\n",
423 | " last_idx = len(loader) - 1\n",
424 | " \n",
425 | " running_loss = 0.0\n",
426 | " running_corrects = 0.0\n",
427 | "\n",
428 | " with torch.no_grad():\n",
429 | " for input, target in tqdm(loader):\n",
430 | "\n",
431 | " input, target = input.to(device), target.to(device)\n",
432 | " \n",
433 | " output = model(input)\n",
434 | " \n",
435 | " score = 0.0\n",
436 | " sm = nn.Softmax(dim=1)\n",
437 | " for k, v in output.items():\n",
438 | " score += sm(output[k])\n",
439 | " _, preds = torch.max(score.data, 1)\n",
440 | "\n",
441 | " loss = 0.0\n",
442 | " for k,v in output.items():\n",
443 | " loss += loss_fn(output[k], target)\n",
444 | "\n",
445 | "\n",
446 | " batch_time_m.update(time.time() - end)\n",
447 | " acc = (preds == target.data).float().mean()\n",
448 | " epoch_loss += loss/len(loader)\n",
449 | " epoch_accuracy += acc / len(loader)\n",
450 | " \n",
451 | " print(f\"Epoch : {epoch+1} - val_loss : {epoch_loss:.4f} - val_acc: {epoch_accuracy:.4f}\", end=\"\\r\")\n",
452 | " print() \n",
453 | " metrics = OrderedDict([('val_loss', epoch_loss.data.item()), (\"val_accuracy\", epoch_accuracy.data.item())])\n",
454 | "\n",
455 | "\n",
456 | " return metrics"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 8,
462 | "metadata": {},
463 | "outputs": [],
464 | "source": [
465 | "def train_one_epoch(\n",
466 | " epoch, model, loader, optimizer, loss_fn,\n",
467 | " lr_scheduler=None, saver=None, output_dir='', \n",
468 | " loss_scaler=None, model_ema=None, mixup_fn=None):\n",
469 | "\n",
470 | " \n",
471 | "\n",
472 | " \n",
473 | " batch_time_m = AverageMeter()\n",
474 | " data_time_m = AverageMeter()\n",
475 | " losses_m = AverageMeter()\n",
476 | "\n",
477 | " model.train()\n",
478 | " epoch_accuracy = 0\n",
479 | " epoch_loss = 0\n",
480 | " end = time.time()\n",
481 | " last_idx = len(loader) - 1\n",
482 | " num_updates = epoch * len(loader)\n",
483 | " running_loss = 0.0\n",
484 | " running_corrects = 0.0\n",
485 | "\n",
486 | " for data, target in tqdm(loader):\n",
487 | " data, target = data.to(device), target.to(device)\n",
488 | "\n",
489 | " \n",
490 | " data_time_m.update(time.time() - end)\n",
491 | "\n",
492 | " optimizer.zero_grad()\n",
493 | " output = model(data)\n",
494 | " score = 0.0\n",
495 | " sm = nn.Softmax(dim=1)\n",
496 | " for k, v in output.items():\n",
497 | " score += sm(output[k])\n",
498 | " _, preds = torch.max(score.data, 1)\n",
499 | " \n",
500 | " loss = 0.0\n",
501 | " for k,v in output.items():\n",
502 | " loss += loss_fn(output[k], target)\n",
503 | " loss.backward()\n",
504 | "\n",
505 | " optimizer.step()\n",
506 | "\n",
507 | " batch_time_m.update(time.time() - end)\n",
508 | " \n",
509 | "# print(preds, target.data)\n",
510 | " acc = (preds == target.data).float().mean()\n",
511 | " \n",
512 | "# print(acc)\n",
513 | " epoch_loss += loss/len(loader)\n",
514 | " epoch_accuracy += acc / len(loader)\n",
515 | "# if acc:\n",
516 | "# print(acc, epreds, target.data)\n",
517 | " print(\n",
518 | " f\"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}\"\n",
519 | ", end=\"\\r\")\n",
520 | "\n",
521 | " print()\n",
522 | "\n",
523 | " return OrderedDict([('train_loss', epoch_loss.data.item()), (\"train_accuracy\", epoch_accuracy.data.item())])\n"
524 | ]
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": 9,
529 | "metadata": {},
530 | "outputs": [],
531 | "source": [
532 | "def freeze_all_blocks(model):\n",
533 | " frozen_blocks = 12\n",
534 | " for block in model.model.blocks[:frozen_blocks]:\n",
535 | " for param in block.parameters():\n",
536 | " param.requires_grad=False\n",
537 | " "
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "execution_count": 10,
543 | "metadata": {},
544 | "outputs": [],
545 | "source": [
546 | "def unfreeze_blocks(model, amount= 1):\n",
547 | " \n",
548 | " for block in model.model.blocks[11-amount:]:\n",
549 | " for param in block.parameters():\n",
550 | " param.requires_grad=True\n",
551 | " return model"
552 | ]
553 | },
554 | {
555 | "cell_type": "markdown",
556 | "metadata": {},
557 | "source": [
558 | "## Training Loop"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": 11,
564 | "metadata": {
565 | "scrolled": true
566 | },
567 | "outputs": [
568 | {
569 | "name": "stdout",
570 | "output_type": "stream",
571 | "text": [
572 | "LATransformer(\n",
573 | " (model): VisionTransformer(\n",
574 | " (patch_embed): PatchEmbed(\n",
575 | " (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
576 | " )\n",
577 | " (pos_drop): Dropout(p=0.0, inplace=False)\n",
578 | " (blocks): ModuleList(\n",
579 | " (0): Block(\n",
580 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
581 | " (attn): Attention(\n",
582 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
583 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
584 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
585 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
586 | " )\n",
587 | " (drop_path): Identity()\n",
588 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
589 | " (mlp): Mlp(\n",
590 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
591 | " (act): GELU()\n",
592 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
593 | " (drop): Dropout(p=0.0, inplace=False)\n",
594 | " )\n",
595 | " )\n",
596 | " (1): Block(\n",
597 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
598 | " (attn): Attention(\n",
599 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
600 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
601 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
602 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
603 | " )\n",
604 | " (drop_path): Identity()\n",
605 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
606 | " (mlp): Mlp(\n",
607 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
608 | " (act): GELU()\n",
609 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
610 | " (drop): Dropout(p=0.0, inplace=False)\n",
611 | " )\n",
612 | " )\n",
613 | " (2): Block(\n",
614 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
615 | " (attn): Attention(\n",
616 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
617 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
618 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
619 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
620 | " )\n",
621 | " (drop_path): Identity()\n",
622 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
623 | " (mlp): Mlp(\n",
624 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
625 | " (act): GELU()\n",
626 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
627 | " (drop): Dropout(p=0.0, inplace=False)\n",
628 | " )\n",
629 | " )\n",
630 | " (3): Block(\n",
631 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
632 | " (attn): Attention(\n",
633 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
634 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
635 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
636 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
637 | " )\n",
638 | " (drop_path): Identity()\n",
639 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
640 | " (mlp): Mlp(\n",
641 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
642 | " (act): GELU()\n",
643 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
644 | " (drop): Dropout(p=0.0, inplace=False)\n",
645 | " )\n",
646 | " )\n",
647 | " (4): Block(\n",
648 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
649 | " (attn): Attention(\n",
650 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
651 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
652 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
653 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
654 | " )\n",
655 | " (drop_path): Identity()\n",
656 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
657 | " (mlp): Mlp(\n",
658 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
659 | " (act): GELU()\n",
660 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
661 | " (drop): Dropout(p=0.0, inplace=False)\n",
662 | " )\n",
663 | " )\n",
664 | " (5): Block(\n",
665 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
666 | " (attn): Attention(\n",
667 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
668 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
669 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
670 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
671 | " )\n",
672 | " (drop_path): Identity()\n",
673 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
674 | " (mlp): Mlp(\n",
675 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
676 | " (act): GELU()\n",
677 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
678 | " (drop): Dropout(p=0.0, inplace=False)\n",
679 | " )\n",
680 | " )\n",
681 | " (6): Block(\n",
682 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
683 | " (attn): Attention(\n",
684 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
685 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
686 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
687 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
688 | " )\n",
689 | " (drop_path): Identity()\n",
690 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
691 | " (mlp): Mlp(\n",
692 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
693 | " (act): GELU()\n",
694 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
695 | " (drop): Dropout(p=0.0, inplace=False)\n",
696 | " )\n",
697 | " )\n",
698 | " (7): Block(\n",
699 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
700 | " (attn): Attention(\n",
701 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
702 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
703 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
704 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
705 | " )\n",
706 | " (drop_path): Identity()\n",
707 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
708 | " (mlp): Mlp(\n",
709 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
710 | " (act): GELU()\n",
711 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
712 | " (drop): Dropout(p=0.0, inplace=False)\n",
713 | " )\n",
714 | " )\n",
715 | " (8): Block(\n",
716 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
717 | " (attn): Attention(\n",
718 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
719 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
720 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
721 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
722 | " )\n",
723 | " (drop_path): Identity()\n",
724 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
725 | " (mlp): Mlp(\n",
726 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
727 | " (act): GELU()\n",
728 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
729 | " (drop): Dropout(p=0.0, inplace=False)\n",
730 | " )\n",
731 | " )\n",
732 | " (9): Block(\n",
733 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
734 | " (attn): Attention(\n",
735 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
736 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
737 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
738 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
739 | " )\n",
740 | " (drop_path): Identity()\n",
741 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
742 | " (mlp): Mlp(\n",
743 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
744 | " (act): GELU()\n",
745 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
746 | " (drop): Dropout(p=0.0, inplace=False)\n",
747 | " )\n",
748 | " )\n",
749 | " (10): Block(\n",
750 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
751 | " (attn): Attention(\n",
752 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
753 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
754 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
755 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
756 | " )\n",
757 | " (drop_path): Identity()\n",
758 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
759 | " (mlp): Mlp(\n",
760 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
761 | " (act): GELU()\n",
762 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
763 | " (drop): Dropout(p=0.0, inplace=False)\n",
764 | " )\n",
765 | " )\n",
766 | " (11): Block(\n",
767 | " (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
768 | " (attn): Attention(\n",
769 | " (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
770 | " (attn_drop): Dropout(p=0.0, inplace=False)\n",
771 | " (proj): Linear(in_features=768, out_features=768, bias=True)\n",
772 | " (proj_drop): Dropout(p=0.0, inplace=False)\n",
773 | " )\n",
774 | " (drop_path): Identity()\n",
775 | " (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
776 | " (mlp): Mlp(\n",
777 | " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
778 | " (act): GELU()\n",
779 | " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
780 | " (drop): Dropout(p=0.0, inplace=False)\n",
781 | " )\n",
782 | " )\n",
783 | " )\n",
784 | " (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
785 | " (head): Linear(in_features=768, out_features=751, bias=True)\n",
786 | " )\n",
787 | " (avgpool): AdaptiveAvgPool2d(output_size=(14, 768))\n",
788 | " (dropout): Dropout(p=0.5, inplace=False)\n",
789 | " (classifier0): ClassBlock(\n",
790 | " (add_block): Sequential(\n",
791 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
792 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
793 | " (2): Dropout(p=0.5, inplace=False)\n",
794 | " )\n",
795 | " (classifier): Sequential(\n",
796 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
797 | " )\n",
798 | " )\n",
799 | " (classifier1): ClassBlock(\n",
800 | " (add_block): Sequential(\n",
801 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
802 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
803 | " (2): Dropout(p=0.5, inplace=False)\n",
804 | " )\n",
805 | " (classifier): Sequential(\n",
806 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
807 | " )\n",
808 | " )\n",
809 | " (classifier2): ClassBlock(\n",
810 | " (add_block): Sequential(\n",
811 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
812 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
813 | " (2): Dropout(p=0.5, inplace=False)\n",
814 | " )\n",
815 | " (classifier): Sequential(\n",
816 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
817 | " )\n",
818 | " )\n",
819 | " (classifier3): ClassBlock(\n",
820 | " (add_block): Sequential(\n",
821 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
822 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
823 | " (2): Dropout(p=0.5, inplace=False)\n",
824 | " )\n",
825 | " (classifier): Sequential(\n",
826 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
827 | " )\n",
828 | " )\n",
829 | " (classifier4): ClassBlock(\n",
830 | " (add_block): Sequential(\n",
831 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
832 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
833 | " (2): Dropout(p=0.5, inplace=False)\n",
834 | " )\n",
835 | " (classifier): Sequential(\n",
836 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
837 | " )\n",
838 | " )\n",
839 | " (classifier5): ClassBlock(\n",
840 | " (add_block): Sequential(\n",
841 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
842 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
843 | " (2): Dropout(p=0.5, inplace=False)\n",
844 | " )\n",
845 | " (classifier): Sequential(\n",
846 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
847 | " )\n",
848 | " )\n",
849 | " (classifier6): ClassBlock(\n",
850 | " (add_block): Sequential(\n",
851 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
852 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
853 | " (2): Dropout(p=0.5, inplace=False)\n",
854 | " )\n",
855 | " (classifier): Sequential(\n",
856 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
857 | " )\n",
858 | " )\n",
859 | " (classifier7): ClassBlock(\n",
860 | " (add_block): Sequential(\n",
861 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
862 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
863 | " (2): Dropout(p=0.5, inplace=False)\n",
864 | " )\n",
865 | " (classifier): Sequential(\n",
866 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
867 | " )\n",
868 | " )\n",
869 | " (classifier8): ClassBlock(\n",
870 | " (add_block): Sequential(\n",
871 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
872 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
873 | " (2): Dropout(p=0.5, inplace=False)\n",
874 | " )\n",
875 | " (classifier): Sequential(\n",
876 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
877 | " )\n",
878 | " )\n",
879 | " (classifier9): ClassBlock(\n",
880 | " (add_block): Sequential(\n",
881 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
882 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
883 | " (2): Dropout(p=0.5, inplace=False)\n",
884 | " )\n",
885 | " (classifier): Sequential(\n",
886 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
887 | " )\n",
888 | " )\n",
889 | " (classifier10): ClassBlock(\n",
890 | " (add_block): Sequential(\n",
891 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
892 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
893 | " (2): Dropout(p=0.5, inplace=False)\n",
894 | " )\n",
895 | " (classifier): Sequential(\n",
896 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
897 | " )\n",
898 | " )\n",
899 | " (classifier11): ClassBlock(\n",
900 | " (add_block): Sequential(\n",
901 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
902 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
903 | " (2): Dropout(p=0.5, inplace=False)\n",
904 | " )\n",
905 | " (classifier): Sequential(\n",
906 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
907 | " )\n",
908 | " )\n",
909 | " (classifier12): ClassBlock(\n",
910 | " (add_block): Sequential(\n",
911 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
912 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
913 | " (2): Dropout(p=0.5, inplace=False)\n",
914 | " )\n",
915 | " (classifier): Sequential(\n",
916 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
917 | " )\n",
918 | " )\n",
919 | " (classifier13): ClassBlock(\n",
920 | " (add_block): Sequential(\n",
921 | " (0): Linear(in_features=768, out_features=256, bias=True)\n",
922 | " (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
923 | " (2): Dropout(p=0.5, inplace=False)\n",
924 | " )\n",
925 | " (classifier): Sequential(\n",
926 | " (0): Linear(in_features=256, out_features=751, bias=True)\n",
927 | " )\n",
928 | " )\n",
929 | ")\n"
930 | ]
931 | }
932 | ],
933 | "source": [
934 | "# Create LA Transformer\n",
935 | "model = LATransformer(vit_base, lmbd).to(device)\n",
936 | "print(model.eval())\n",
937 | "\n",
938 | "# loss function\n",
939 | "criterion = nn.CrossEntropyLoss()\n",
940 | "\n",
941 | "# optimizer\n",
942 | "optimizer = optim.Adam(model.parameters(),weight_decay=5e-4, lr=lr)\n",
943 | "\n",
944 | "# scheduler\n",
945 | "scheduler = StepLR(optimizer, step_size=1, gamma=gamma)\n",
946 | "freeze_all_blocks(model)"
947 | ]
948 | },
949 | {
950 | "cell_type": "code",
951 | "execution_count": null,
952 | "metadata": {
953 | "scrolled": true
954 | },
955 | "outputs": [
956 | {
957 | "name": "stdout",
958 | "output_type": "stream",
959 | "text": [
960 | "training...\n",
961 | "Unfrozen Blocks: 1, Current lr: 0.00023999999999999998, Trainable Params: 20962817\n"
962 | ]
963 | },
964 | {
965 | "data": {
966 | "application/vnd.jupyter.widget-view+json": {
967 | "model_id": "f767a92c4eac4258b98b3c9676782209",
968 | "version_major": 2,
969 | "version_minor": 0
970 | },
971 | "text/plain": [
972 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
973 | ]
974 | },
975 | "metadata": {},
976 | "output_type": "display_data"
977 | },
978 | {
979 | "name": "stdout",
980 | "output_type": "stream",
981 | "text": [
982 | "Epoch : 1 - loss : 82.7351 - acc: 0.0880\n",
983 | "\n"
984 | ]
985 | },
986 | {
987 | "data": {
988 | "application/vnd.jupyter.widget-view+json": {
989 | "model_id": "1baf1a35386941478d54f8ee048653bb",
990 | "version_major": 2,
991 | "version_minor": 0
992 | },
993 | "text/plain": [
994 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
995 | ]
996 | },
997 | "metadata": {},
998 | "output_type": "display_data"
999 | },
1000 | {
1001 | "name": "stdout",
1002 | "output_type": "stream",
1003 | "text": [
1004 | "Epoch : 1 - val_loss : 77.1901 - val_acc: 0.0497\n",
1005 | "\n",
1006 | "SAVED!\n"
1007 | ]
1008 | },
1009 | {
1010 | "data": {
1011 | "application/vnd.jupyter.widget-view+json": {
1012 | "model_id": "c349d24f58ee4194b48328d96f9cb2b9",
1013 | "version_major": 2,
1014 | "version_minor": 0
1015 | },
1016 | "text/plain": [
1017 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1018 | ]
1019 | },
1020 | "metadata": {},
1021 | "output_type": "display_data"
1022 | },
1023 | {
1024 | "name": "stdout",
1025 | "output_type": "stream",
1026 | "text": [
1027 | "Epoch : 2 - loss : 59.0334 - acc: 0.2364\n",
1028 | "\n"
1029 | ]
1030 | },
1031 | {
1032 | "data": {
1033 | "application/vnd.jupyter.widget-view+json": {
1034 | "model_id": "030f496ac22f4d08b2fe5b2f13df306c",
1035 | "version_major": 2,
1036 | "version_minor": 0
1037 | },
1038 | "text/plain": [
1039 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1040 | ]
1041 | },
1042 | "metadata": {},
1043 | "output_type": "display_data"
1044 | },
1045 | {
1046 | "name": "stdout",
1047 | "output_type": "stream",
1048 | "text": [
1049 | "Epoch : 2 - val_loss : 58.8111 - val_acc: 0.1918\n",
1050 | "\n",
1051 | "SAVED!\n",
1052 | "Unfrozen Blocks: 2, Current lr: 0.000192, Trainable Params: 28050689\n"
1053 | ]
1054 | },
1055 | {
1056 | "data": {
1057 | "application/vnd.jupyter.widget-view+json": {
1058 | "model_id": "c7284d3fd841487e996202add386669c",
1059 | "version_major": 2,
1060 | "version_minor": 0
1061 | },
1062 | "text/plain": [
1063 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1064 | ]
1065 | },
1066 | "metadata": {},
1067 | "output_type": "display_data"
1068 | },
1069 | {
1070 | "name": "stdout",
1071 | "output_type": "stream",
1072 | "text": [
1073 | "Epoch : 3 - loss : 41.1694 - acc: 0.4632\n",
1074 | "\n"
1075 | ]
1076 | },
1077 | {
1078 | "data": {
1079 | "application/vnd.jupyter.widget-view+json": {
1080 | "model_id": "d77e806e33104a8bbfb85a04483c94f4",
1081 | "version_major": 2,
1082 | "version_minor": 0
1083 | },
1084 | "text/plain": [
1085 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1086 | ]
1087 | },
1088 | "metadata": {},
1089 | "output_type": "display_data"
1090 | },
1091 | {
1092 | "name": "stdout",
1093 | "output_type": "stream",
1094 | "text": [
1095 | "Epoch : 3 - val_loss : 47.2650 - val_acc: 0.3353\n",
1096 | "\n",
1097 | "SAVED!\n"
1098 | ]
1099 | },
1100 | {
1101 | "data": {
1102 | "application/vnd.jupyter.widget-view+json": {
1103 | "model_id": "e082fac2ed5e4952938c6fd3fb5d4f1a",
1104 | "version_major": 2,
1105 | "version_minor": 0
1106 | },
1107 | "text/plain": [
1108 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1109 | ]
1110 | },
1111 | "metadata": {},
1112 | "output_type": "display_data"
1113 | },
1114 | {
1115 | "name": "stdout",
1116 | "output_type": "stream",
1117 | "text": [
1118 | "Epoch : 4 - loss : 28.3517 - acc: 0.6674\n",
1119 | "\n"
1120 | ]
1121 | },
1122 | {
1123 | "data": {
1124 | "application/vnd.jupyter.widget-view+json": {
1125 | "model_id": "57552c299cc54398aa4416c3689fb7b5",
1126 | "version_major": 2,
1127 | "version_minor": 0
1128 | },
1129 | "text/plain": [
1130 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1131 | ]
1132 | },
1133 | "metadata": {},
1134 | "output_type": "display_data"
1135 | },
1136 | {
1137 | "name": "stdout",
1138 | "output_type": "stream",
1139 | "text": [
1140 | "Epoch : 4 - val_loss : 33.9487 - val_acc: 0.5391\n",
1141 | "\n",
1142 | "SAVED!\n",
1143 | "Unfrozen Blocks: 3, Current lr: 0.00015360000000000002, Trainable Params: 35138561\n"
1144 | ]
1145 | },
1146 | {
1147 | "data": {
1148 | "application/vnd.jupyter.widget-view+json": {
1149 | "model_id": "66d3571123f849ccb1d30a186c4fee9d",
1150 | "version_major": 2,
1151 | "version_minor": 0
1152 | },
1153 | "text/plain": [
1154 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1155 | ]
1156 | },
1157 | "metadata": {},
1158 | "output_type": "display_data"
1159 | },
1160 | {
1161 | "name": "stdout",
1162 | "output_type": "stream",
1163 | "text": [
1164 | "Epoch : 5 - loss : 18.7140 - acc: 0.8141\n",
1165 | "\n"
1166 | ]
1167 | },
1168 | {
1169 | "data": {
1170 | "application/vnd.jupyter.widget-view+json": {
1171 | "model_id": "df0b76c5867f4efea8c5f27c4b6b416a",
1172 | "version_major": 2,
1173 | "version_minor": 0
1174 | },
1175 | "text/plain": [
1176 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1177 | ]
1178 | },
1179 | "metadata": {},
1180 | "output_type": "display_data"
1181 | },
1182 | {
1183 | "name": "stdout",
1184 | "output_type": "stream",
1185 | "text": [
1186 | "Epoch : 5 - val_loss : 25.3060 - val_acc: 0.6617\n",
1187 | "\n",
1188 | "SAVED!\n"
1189 | ]
1190 | },
1191 | {
1192 | "data": {
1193 | "application/vnd.jupyter.widget-view+json": {
1194 | "model_id": "32b7d1fa8e024640849763fbbbb5df8d",
1195 | "version_major": 2,
1196 | "version_minor": 0
1197 | },
1198 | "text/plain": [
1199 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1200 | ]
1201 | },
1202 | "metadata": {},
1203 | "output_type": "display_data"
1204 | },
1205 | {
1206 | "name": "stdout",
1207 | "output_type": "stream",
1208 | "text": [
1209 | "Epoch : 6 - loss : 12.2253 - acc: 0.9050\n",
1210 | "\n"
1211 | ]
1212 | },
1213 | {
1214 | "data": {
1215 | "application/vnd.jupyter.widget-view+json": {
1216 | "model_id": "7530e23081ff40838c5c340324fb0bee",
1217 | "version_major": 2,
1218 | "version_minor": 0
1219 | },
1220 | "text/plain": [
1221 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1222 | ]
1223 | },
1224 | "metadata": {},
1225 | "output_type": "display_data"
1226 | },
1227 | {
1228 | "name": "stdout",
1229 | "output_type": "stream",
1230 | "text": [
1231 | "Epoch : 6 - val_loss : 19.0367 - val_acc: 0.7506\n",
1232 | "\n",
1233 | "SAVED!\n",
1234 | "Unfrozen Blocks: 4, Current lr: 0.00012288000000000002, Trainable Params: 42226433\n"
1235 | ]
1236 | },
1237 | {
1238 | "data": {
1239 | "application/vnd.jupyter.widget-view+json": {
1240 | "model_id": "5d2820792b3d4c078a4687f08bd4f92d",
1241 | "version_major": 2,
1242 | "version_minor": 0
1243 | },
1244 | "text/plain": [
1245 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1246 | ]
1247 | },
1248 | "metadata": {},
1249 | "output_type": "display_data"
1250 | },
1251 | {
1252 | "name": "stdout",
1253 | "output_type": "stream",
1254 | "text": [
1255 | "Epoch : 7 - loss : 8.0031 - acc: 0.9542\n",
1256 | "\n"
1257 | ]
1258 | },
1259 | {
1260 | "data": {
1261 | "application/vnd.jupyter.widget-view+json": {
1262 | "model_id": "166716159015465d8430b22cbb0a937f",
1263 | "version_major": 2,
1264 | "version_minor": 0
1265 | },
1266 | "text/plain": [
1267 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1268 | ]
1269 | },
1270 | "metadata": {},
1271 | "output_type": "display_data"
1272 | },
1273 | {
1274 | "name": "stdout",
1275 | "output_type": "stream",
1276 | "text": [
1277 | "Epoch : 7 - val_loss : 14.0309 - val_acc: 0.8325\n",
1278 | "\n",
1279 | "SAVED!\n"
1280 | ]
1281 | },
1282 | {
1283 | "data": {
1284 | "application/vnd.jupyter.widget-view+json": {
1285 | "model_id": "e43fe49292c04b0f83dacbc610d49e58",
1286 | "version_major": 2,
1287 | "version_minor": 0
1288 | },
1289 | "text/plain": [
1290 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1291 | ]
1292 | },
1293 | "metadata": {},
1294 | "output_type": "display_data"
1295 | },
1296 | {
1297 | "name": "stdout",
1298 | "output_type": "stream",
1299 | "text": [
1300 | "Epoch : 8 - loss : 5.4122 - acc: 0.9771\n",
1301 | "\n"
1302 | ]
1303 | },
1304 | {
1305 | "data": {
1306 | "application/vnd.jupyter.widget-view+json": {
1307 | "model_id": "af46f3dff0a548f98b97b3868ce2a783",
1308 | "version_major": 2,
1309 | "version_minor": 0
1310 | },
1311 | "text/plain": [
1312 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1313 | ]
1314 | },
1315 | "metadata": {},
1316 | "output_type": "display_data"
1317 | },
1318 | {
1319 | "name": "stdout",
1320 | "output_type": "stream",
1321 | "text": [
1322 | "Epoch : 8 - val_loss : 11.0224 - val_acc: 0.8602\n",
1323 | "\n",
1324 | "SAVED!\n",
1325 | "Unfrozen Blocks: 5, Current lr: 9.830400000000001e-05, Trainable Params: 49314305\n"
1326 | ]
1327 | },
1328 | {
1329 | "data": {
1330 | "application/vnd.jupyter.widget-view+json": {
1331 | "model_id": "af2d255d73484d968d636523528a596c",
1332 | "version_major": 2,
1333 | "version_minor": 0
1334 | },
1335 | "text/plain": [
1336 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1337 | ]
1338 | },
1339 | "metadata": {},
1340 | "output_type": "display_data"
1341 | },
1342 | {
1343 | "name": "stdout",
1344 | "output_type": "stream",
1345 | "text": [
1346 | "Epoch : 9 - loss : 3.7149 - acc: 0.9906\n",
1347 | "\n"
1348 | ]
1349 | },
1350 | {
1351 | "data": {
1352 | "application/vnd.jupyter.widget-view+json": {
1353 | "model_id": "1b1001230f5842bfae334384b74040ae",
1354 | "version_major": 2,
1355 | "version_minor": 0
1356 | },
1357 | "text/plain": [
1358 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1359 | ]
1360 | },
1361 | "metadata": {},
1362 | "output_type": "display_data"
1363 | },
1364 | {
1365 | "name": "stdout",
1366 | "output_type": "stream",
1367 | "text": [
1368 | "Epoch : 9 - val_loss : 8.5832 - val_acc: 0.8944\n",
1369 | "\n",
1370 | "SAVED!\n"
1371 | ]
1372 | },
1373 | {
1374 | "data": {
1375 | "application/vnd.jupyter.widget-view+json": {
1376 | "model_id": "1e7691744d4142ee940783af59c753e5",
1377 | "version_major": 2,
1378 | "version_minor": 0
1379 | },
1380 | "text/plain": [
1381 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1382 | ]
1383 | },
1384 | "metadata": {},
1385 | "output_type": "display_data"
1386 | },
1387 | {
1388 | "name": "stdout",
1389 | "output_type": "stream",
1390 | "text": [
1391 | "Epoch : 10 - loss : 2.7142 - acc: 0.9950\n",
1392 | "\n"
1393 | ]
1394 | },
1395 | {
1396 | "data": {
1397 | "application/vnd.jupyter.widget-view+json": {
1398 | "model_id": "ceb48891601843519efddf0984ef2e18",
1399 | "version_major": 2,
1400 | "version_minor": 0
1401 | },
1402 | "text/plain": [
1403 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1404 | ]
1405 | },
1406 | "metadata": {},
1407 | "output_type": "display_data"
1408 | },
1409 | {
1410 | "name": "stdout",
1411 | "output_type": "stream",
1412 | "text": [
1413 | "Epoch : 10 - val_loss : 7.6481 - val_acc: 0.9033\n",
1414 | "\n",
1415 | "SAVED!\n",
1416 | "Unfrozen Blocks: 6, Current lr: 7.864320000000001e-05, Trainable Params: 56402177\n"
1417 | ]
1418 | },
1419 | {
1420 | "data": {
1421 | "application/vnd.jupyter.widget-view+json": {
1422 | "model_id": "3a5eff7b69b8422b9da1bfc8aa601193",
1423 | "version_major": 2,
1424 | "version_minor": 0
1425 | },
1426 | "text/plain": [
1427 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1428 | ]
1429 | },
1430 | "metadata": {},
1431 | "output_type": "display_data"
1432 | },
1433 | {
1434 | "name": "stdout",
1435 | "output_type": "stream",
1436 | "text": [
1437 | "Epoch : 11 - loss : 2.0092 - acc: 0.9965\n",
1438 | "\n"
1439 | ]
1440 | },
1441 | {
1442 | "data": {
1443 | "application/vnd.jupyter.widget-view+json": {
1444 | "model_id": "a691c2bbe71f434390741bb0071a3e42",
1445 | "version_major": 2,
1446 | "version_minor": 0
1447 | },
1448 | "text/plain": [
1449 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1450 | ]
1451 | },
1452 | "metadata": {},
1453 | "output_type": "display_data"
1454 | },
1455 | {
1456 | "name": "stdout",
1457 | "output_type": "stream",
1458 | "text": [
1459 | "Epoch : 11 - val_loss : 6.7372 - val_acc: 0.9137\n",
1460 | "\n",
1461 | "SAVED!\n"
1462 | ]
1463 | },
1464 | {
1465 | "data": {
1466 | "application/vnd.jupyter.widget-view+json": {
1467 | "model_id": "6b239464e5bf46908b445345dc0eb523",
1468 | "version_major": 2,
1469 | "version_minor": 0
1470 | },
1471 | "text/plain": [
1472 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1473 | ]
1474 | },
1475 | "metadata": {},
1476 | "output_type": "display_data"
1477 | },
1478 | {
1479 | "name": "stdout",
1480 | "output_type": "stream",
1481 | "text": [
1482 | "Epoch : 12 - loss : 1.5912 - acc: 0.9977\n",
1483 | "\n"
1484 | ]
1485 | },
1486 | {
1487 | "data": {
1488 | "application/vnd.jupyter.widget-view+json": {
1489 | "model_id": "e9828497a70f455188333f870c7eb5ff",
1490 | "version_major": 2,
1491 | "version_minor": 0
1492 | },
1493 | "text/plain": [
1494 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1495 | ]
1496 | },
1497 | "metadata": {},
1498 | "output_type": "display_data"
1499 | },
1500 | {
1501 | "name": "stdout",
1502 | "output_type": "stream",
1503 | "text": [
1504 | "Epoch : 12 - val_loss : 6.0404 - val_acc: 0.9189\n",
1505 | "\n",
1506 | "SAVED!\n",
1507 | "Unfrozen Blocks: 7, Current lr: 6.291456000000001e-05, Trainable Params: 63490049\n"
1508 | ]
1509 | },
1510 | {
1511 | "data": {
1512 | "application/vnd.jupyter.widget-view+json": {
1513 | "model_id": "e46c51355950451a98a10a10d833c441",
1514 | "version_major": 2,
1515 | "version_minor": 0
1516 | },
1517 | "text/plain": [
1518 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1519 | ]
1520 | },
1521 | "metadata": {},
1522 | "output_type": "display_data"
1523 | },
1524 | {
1525 | "name": "stdout",
1526 | "output_type": "stream",
1527 | "text": [
1528 | "Epoch : 13 - loss : 1.3100 - acc: 0.9984\n",
1529 | "\n"
1530 | ]
1531 | },
1532 | {
1533 | "data": {
1534 | "application/vnd.jupyter.widget-view+json": {
1535 | "model_id": "ca1050b552cc41289c74862a47398dab",
1536 | "version_major": 2,
1537 | "version_minor": 0
1538 | },
1539 | "text/plain": [
1540 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1541 | ]
1542 | },
1543 | "metadata": {},
1544 | "output_type": "display_data"
1545 | },
1546 | {
1547 | "name": "stdout",
1548 | "output_type": "stream",
1549 | "text": [
1550 | "Epoch : 13 - val_loss : 5.8097 - val_acc: 0.9230\n",
1551 | "\n",
1552 | "SAVED!\n"
1553 | ]
1554 | },
1555 | {
1556 | "data": {
1557 | "application/vnd.jupyter.widget-view+json": {
1558 | "model_id": "86c1b41c6a4c4f1da8e89f803c675709",
1559 | "version_major": 2,
1560 | "version_minor": 0
1561 | },
1562 | "text/plain": [
1563 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1564 | ]
1565 | },
1566 | "metadata": {},
1567 | "output_type": "display_data"
1568 | },
1569 | {
1570 | "name": "stdout",
1571 | "output_type": "stream",
1572 | "text": [
1573 | "Epoch : 14 - loss : 1.0894 - acc: 0.9991\n",
1574 | "\n"
1575 | ]
1576 | },
1577 | {
1578 | "data": {
1579 | "application/vnd.jupyter.widget-view+json": {
1580 | "model_id": "0b727d9790b64dab897630c639b05a6a",
1581 | "version_major": 2,
1582 | "version_minor": 0
1583 | },
1584 | "text/plain": [
1585 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1586 | ]
1587 | },
1588 | "metadata": {},
1589 | "output_type": "display_data"
1590 | },
1591 | {
1592 | "name": "stdout",
1593 | "output_type": "stream",
1594 | "text": [
1595 | "Epoch : 14 - val_loss : 5.1302 - val_acc: 0.9321\n",
1596 | "\n",
1597 | "SAVED!\n",
1598 | "Unfrozen Blocks: 8, Current lr: 5.0331648000000016e-05, Trainable Params: 70577921\n"
1599 | ]
1600 | },
1601 | {
1602 | "data": {
1603 | "application/vnd.jupyter.widget-view+json": {
1604 | "model_id": "ddf2a4c6d42d4298b84bca0cda2f78df",
1605 | "version_major": 2,
1606 | "version_minor": 0
1607 | },
1608 | "text/plain": [
1609 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1610 | ]
1611 | },
1612 | "metadata": {},
1613 | "output_type": "display_data"
1614 | },
1615 | {
1616 | "name": "stdout",
1617 | "output_type": "stream",
1618 | "text": [
1619 | "Epoch : 15 - loss : 0.9347 - acc: 0.9992\n",
1620 | "\n"
1621 | ]
1622 | },
1623 | {
1624 | "data": {
1625 | "application/vnd.jupyter.widget-view+json": {
1626 | "model_id": "6c69748f472b4d1d9c8c0c2ac0b893d5",
1627 | "version_major": 2,
1628 | "version_minor": 0
1629 | },
1630 | "text/plain": [
1631 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1632 | ]
1633 | },
1634 | "metadata": {},
1635 | "output_type": "display_data"
1636 | },
1637 | {
1638 | "name": "stdout",
1639 | "output_type": "stream",
1640 | "text": [
1641 | "Epoch : 15 - val_loss : 5.5233 - val_acc: 0.9217\n",
1642 | "\n"
1643 | ]
1644 | },
1645 | {
1646 | "data": {
1647 | "application/vnd.jupyter.widget-view+json": {
1648 | "model_id": "be86fea29a08418fa1fed8cb91aee99c",
1649 | "version_major": 2,
1650 | "version_minor": 0
1651 | },
1652 | "text/plain": [
1653 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1654 | ]
1655 | },
1656 | "metadata": {},
1657 | "output_type": "display_data"
1658 | },
1659 | {
1660 | "name": "stdout",
1661 | "output_type": "stream",
1662 | "text": [
1663 | "Epoch : 16 - loss : 0.9086 - acc: 0.9996\n",
1664 | "\n"
1665 | ]
1666 | },
1667 | {
1668 | "data": {
1669 | "application/vnd.jupyter.widget-view+json": {
1670 | "model_id": "1616dcc149cd46ac904968b474f6cfe3",
1671 | "version_major": 2,
1672 | "version_minor": 0
1673 | },
1674 | "text/plain": [
1675 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1676 | ]
1677 | },
1678 | "metadata": {},
1679 | "output_type": "display_data"
1680 | },
1681 | {
1682 | "name": "stdout",
1683 | "output_type": "stream",
1684 | "text": [
1685 | "Epoch : 16 - val_loss : 4.4655 - val_acc: 0.9362\n",
1686 | "\n",
1687 | "SAVED!\n",
1688 | "Unfrozen Blocks: 9, Current lr: 4.026531840000002e-05, Trainable Params: 77665793\n"
1689 | ]
1690 | },
1691 | {
1692 | "data": {
1693 | "application/vnd.jupyter.widget-view+json": {
1694 | "model_id": "3fb8fcd15c544ca4a23474ee01c35e91",
1695 | "version_major": 2,
1696 | "version_minor": 0
1697 | },
1698 | "text/plain": [
1699 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1700 | ]
1701 | },
1702 | "metadata": {},
1703 | "output_type": "display_data"
1704 | },
1705 | {
1706 | "name": "stdout",
1707 | "output_type": "stream",
1708 | "text": [
1709 | "Epoch : 17 - loss : 0.7159 - acc: 0.9999\n",
1710 | "\n"
1711 | ]
1712 | },
1713 | {
1714 | "data": {
1715 | "application/vnd.jupyter.widget-view+json": {
1716 | "model_id": "1af37990febc45fc9a8cb3e6921de68b",
1717 | "version_major": 2,
1718 | "version_minor": 0
1719 | },
1720 | "text/plain": [
1721 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1722 | ]
1723 | },
1724 | "metadata": {},
1725 | "output_type": "display_data"
1726 | },
1727 | {
1728 | "name": "stdout",
1729 | "output_type": "stream",
1730 | "text": [
1731 | "Epoch : 17 - val_loss : 4.2927 - val_acc: 0.9414\n",
1732 | "\n",
1733 | "SAVED!\n"
1734 | ]
1735 | },
1736 | {
1737 | "data": {
1738 | "application/vnd.jupyter.widget-view+json": {
1739 | "model_id": "277af575763647ddb9be0c49cbd7fb4f",
1740 | "version_major": 2,
1741 | "version_minor": 0
1742 | },
1743 | "text/plain": [
1744 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1745 | ]
1746 | },
1747 | "metadata": {},
1748 | "output_type": "display_data"
1749 | },
1750 | {
1751 | "name": "stdout",
1752 | "output_type": "stream",
1753 | "text": [
1754 | "Epoch : 18 - loss : 0.6362 - acc: 0.9998\n",
1755 | "\n"
1756 | ]
1757 | },
1758 | {
1759 | "data": {
1760 | "application/vnd.jupyter.widget-view+json": {
1761 | "model_id": "01e3ff670c0c4613b93381f273ee20a0",
1762 | "version_major": 2,
1763 | "version_minor": 0
1764 | },
1765 | "text/plain": [
1766 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1767 | ]
1768 | },
1769 | "metadata": {},
1770 | "output_type": "display_data"
1771 | },
1772 | {
1773 | "name": "stdout",
1774 | "output_type": "stream",
1775 | "text": [
1776 | "Epoch : 18 - val_loss : 4.2925 - val_acc: 0.9453\n",
1777 | "\n",
1778 | "SAVED!\n",
1779 | "Unfrozen Blocks: 10, Current lr: 3.221225472000002e-05, Trainable Params: 84753665\n"
1780 | ]
1781 | },
1782 | {
1783 | "data": {
1784 | "application/vnd.jupyter.widget-view+json": {
1785 | "model_id": "b98ea55bef3a4cf9acb4b5dee8d2d911",
1786 | "version_major": 2,
1787 | "version_minor": 0
1788 | },
1789 | "text/plain": [
1790 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1791 | ]
1792 | },
1793 | "metadata": {},
1794 | "output_type": "display_data"
1795 | },
1796 | {
1797 | "name": "stdout",
1798 | "output_type": "stream",
1799 | "text": [
1800 | "Epoch : 19 - loss : 0.6389 - acc: 0.9997\n",
1801 | "\n"
1802 | ]
1803 | },
1804 | {
1805 | "data": {
1806 | "application/vnd.jupyter.widget-view+json": {
1807 | "model_id": "f7ec952d59a84c618b9fd97cc14a6923",
1808 | "version_major": 2,
1809 | "version_minor": 0
1810 | },
1811 | "text/plain": [
1812 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1813 | ]
1814 | },
1815 | "metadata": {},
1816 | "output_type": "display_data"
1817 | },
1818 | {
1819 | "name": "stdout",
1820 | "output_type": "stream",
1821 | "text": [
1822 | "Epoch : 19 - val_loss : 4.5622 - val_acc: 0.9319\n",
1823 | "\n"
1824 | ]
1825 | },
1826 | {
1827 | "data": {
1828 | "application/vnd.jupyter.widget-view+json": {
1829 | "model_id": "44f3727d7f5b43b9ae6fd96fd1f66032",
1830 | "version_major": 2,
1831 | "version_minor": 0
1832 | },
1833 | "text/plain": [
1834 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1835 | ]
1836 | },
1837 | "metadata": {},
1838 | "output_type": "display_data"
1839 | },
1840 | {
1841 | "name": "stdout",
1842 | "output_type": "stream",
1843 | "text": [
1844 | "Epoch : 20 - loss : 0.5667 - acc: 0.9998\n",
1845 | "\n"
1846 | ]
1847 | },
1848 | {
1849 | "data": {
1850 | "application/vnd.jupyter.widget-view+json": {
1851 | "model_id": "30bea7568e694900876835ca696d7753",
1852 | "version_major": 2,
1853 | "version_minor": 0
1854 | },
1855 | "text/plain": [
1856 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1857 | ]
1858 | },
1859 | "metadata": {},
1860 | "output_type": "display_data"
1861 | },
1862 | {
1863 | "name": "stdout",
1864 | "output_type": "stream",
1865 | "text": [
1866 | "Epoch : 20 - val_loss : 4.6590 - val_acc: 0.9254\n",
1867 | "\n",
1868 | "Unfrozen Blocks: 11, Current lr: 2.5769803776000016e-05, Trainable Params: 91841537\n"
1869 | ]
1870 | },
1871 | {
1872 | "data": {
1873 | "application/vnd.jupyter.widget-view+json": {
1874 | "model_id": "178dec4505884c3f9dd0716e2e22bef2",
1875 | "version_major": 2,
1876 | "version_minor": 0
1877 | },
1878 | "text/plain": [
1879 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1880 | ]
1881 | },
1882 | "metadata": {},
1883 | "output_type": "display_data"
1884 | },
1885 | {
1886 | "name": "stdout",
1887 | "output_type": "stream",
1888 | "text": [
1889 | "Epoch : 21 - loss : 0.5401 - acc: 0.9998\n",
1890 | "\n"
1891 | ]
1892 | },
1893 | {
1894 | "data": {
1895 | "application/vnd.jupyter.widget-view+json": {
1896 | "model_id": "4d43c07645e94afaa6e63d6142ddf2f1",
1897 | "version_major": 2,
1898 | "version_minor": 0
1899 | },
1900 | "text/plain": [
1901 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1902 | ]
1903 | },
1904 | "metadata": {},
1905 | "output_type": "display_data"
1906 | },
1907 | {
1908 | "name": "stdout",
1909 | "output_type": "stream",
1910 | "text": [
1911 | "Epoch : 21 - val_loss : 3.8805 - val_acc: 0.9401\n",
1912 | "\n"
1913 | ]
1914 | },
1915 | {
1916 | "data": {
1917 | "application/vnd.jupyter.widget-view+json": {
1918 | "model_id": "47d5ad2ffad44135866f5d204f692c65",
1919 | "version_major": 2,
1920 | "version_minor": 0
1921 | },
1922 | "text/plain": [
1923 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1924 | ]
1925 | },
1926 | "metadata": {},
1927 | "output_type": "display_data"
1928 | },
1929 | {
1930 | "name": "stdout",
1931 | "output_type": "stream",
1932 | "text": [
1933 | "Epoch : 22 - loss : 0.6303 - acc: 0.9991\n",
1934 | "\n"
1935 | ]
1936 | },
1937 | {
1938 | "data": {
1939 | "application/vnd.jupyter.widget-view+json": {
1940 | "model_id": "d353c31d9edc454083f96fdca3e3baa4",
1941 | "version_major": 2,
1942 | "version_minor": 0
1943 | },
1944 | "text/plain": [
1945 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1946 | ]
1947 | },
1948 | "metadata": {},
1949 | "output_type": "display_data"
1950 | },
1951 | {
1952 | "name": "stdout",
1953 | "output_type": "stream",
1954 | "text": [
1955 | "Epoch : 22 - val_loss : 4.4941 - val_acc: 0.9375\n",
1956 | "\n",
1957 | "Unfrozen Blocks: 12, Current lr: 2.0615843020800013e-05, Trainable Params: 91841537\n"
1958 | ]
1959 | },
1960 | {
1961 | "data": {
1962 | "application/vnd.jupyter.widget-view+json": {
1963 | "model_id": "8a68e3064ab547538cc3d505c0122c3a",
1964 | "version_major": 2,
1965 | "version_minor": 0
1966 | },
1967 | "text/plain": [
1968 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
1969 | ]
1970 | },
1971 | "metadata": {},
1972 | "output_type": "display_data"
1973 | },
1974 | {
1975 | "name": "stdout",
1976 | "output_type": "stream",
1977 | "text": [
1978 | "Epoch : 23 - loss : 0.5186 - acc: 0.9997\n",
1979 | "\n"
1980 | ]
1981 | },
1982 | {
1983 | "data": {
1984 | "application/vnd.jupyter.widget-view+json": {
1985 | "model_id": "83722d7b0b7d406ca2bab64af3881225",
1986 | "version_major": 2,
1987 | "version_minor": 0
1988 | },
1989 | "text/plain": [
1990 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
1991 | ]
1992 | },
1993 | "metadata": {},
1994 | "output_type": "display_data"
1995 | },
1996 | {
1997 | "name": "stdout",
1998 | "output_type": "stream",
1999 | "text": [
2000 | "Epoch : 23 - val_loss : 4.0348 - val_acc: 0.9435\n",
2001 | "\n"
2002 | ]
2003 | },
2004 | {
2005 | "data": {
2006 | "application/vnd.jupyter.widget-view+json": {
2007 | "model_id": "4b535b57cc6443ef80194950ffc73705",
2008 | "version_major": 2,
2009 | "version_minor": 0
2010 | },
2011 | "text/plain": [
2012 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
2013 | ]
2014 | },
2015 | "metadata": {},
2016 | "output_type": "display_data"
2017 | },
2018 | {
2019 | "name": "stdout",
2020 | "output_type": "stream",
2021 | "text": [
2022 | "Epoch : 24 - loss : 0.4421 - acc: 0.9999\n",
2023 | "\n"
2024 | ]
2025 | },
2026 | {
2027 | "data": {
2028 | "application/vnd.jupyter.widget-view+json": {
2029 | "model_id": "15720d615707427b979d560257ae9c13",
2030 | "version_major": 2,
2031 | "version_minor": 0
2032 | },
2033 | "text/plain": [
2034 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
2035 | ]
2036 | },
2037 | "metadata": {},
2038 | "output_type": "display_data"
2039 | },
2040 | {
2041 | "name": "stdout",
2042 | "output_type": "stream",
2043 | "text": [
2044 | "Epoch : 24 - val_loss : 3.6783 - val_acc: 0.9464\n",
2045 | "\n",
2046 | "SAVED!\n",
2047 | "Unfrozen Blocks: 13, Current lr: 1.649267441664001e-05, Trainable Params: 91841537\n"
2048 | ]
2049 | },
2050 | {
2051 | "data": {
2052 | "application/vnd.jupyter.widget-view+json": {
2053 | "model_id": "d647ef490d5e4397885957e56c46c4db",
2054 | "version_major": 2,
2055 | "version_minor": 0
2056 | },
2057 | "text/plain": [
2058 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
2059 | ]
2060 | },
2061 | "metadata": {},
2062 | "output_type": "display_data"
2063 | },
2064 | {
2065 | "name": "stdout",
2066 | "output_type": "stream",
2067 | "text": [
2068 | "Epoch : 25 - loss : 0.4184 - acc: 1.0000\n",
2069 | "\n"
2070 | ]
2071 | },
2072 | {
2073 | "data": {
2074 | "application/vnd.jupyter.widget-view+json": {
2075 | "model_id": "3c5c37e9cea84f249edea413cad517b3",
2076 | "version_major": 2,
2077 | "version_minor": 0
2078 | },
2079 | "text/plain": [
2080 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
2081 | ]
2082 | },
2083 | "metadata": {},
2084 | "output_type": "display_data"
2085 | },
2086 | {
2087 | "name": "stdout",
2088 | "output_type": "stream",
2089 | "text": [
2090 | "Epoch : 25 - val_loss : 3.9668 - val_acc: 0.9425\n",
2091 | "\n"
2092 | ]
2093 | },
2094 | {
2095 | "data": {
2096 | "application/vnd.jupyter.widget-view+json": {
2097 | "model_id": "aa957165fe3d460c85a58fe67701842e",
2098 | "version_major": 2,
2099 | "version_minor": 0
2100 | },
2101 | "text/plain": [
2102 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
2103 | ]
2104 | },
2105 | "metadata": {},
2106 | "output_type": "display_data"
2107 | },
2108 | {
2109 | "name": "stdout",
2110 | "output_type": "stream",
2111 | "text": [
2112 | "Epoch : 26 - loss : 0.4113 - acc: 1.0000\n",
2113 | "\n"
2114 | ]
2115 | },
2116 | {
2117 | "data": {
2118 | "application/vnd.jupyter.widget-view+json": {
2119 | "model_id": "df1f36d9d95749a0bc03bba147d001ad",
2120 | "version_major": 2,
2121 | "version_minor": 0
2122 | },
2123 | "text/plain": [
2124 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
2125 | ]
2126 | },
2127 | "metadata": {},
2128 | "output_type": "display_data"
2129 | },
2130 | {
2131 | "name": "stdout",
2132 | "output_type": "stream",
2133 | "text": [
2134 | "Epoch : 26 - val_loss : 3.9590 - val_acc: 0.9398\n",
2135 | "\n",
2136 | "Unfrozen Blocks: 14, Current lr: 1.319413953331201e-05, Trainable Params: 91841537\n"
2137 | ]
2138 | },
2139 | {
2140 | "data": {
2141 | "application/vnd.jupyter.widget-view+json": {
2142 | "model_id": "b7a42e21863d4061a7616e79d0d62aeb",
2143 | "version_major": 2,
2144 | "version_minor": 0
2145 | },
2146 | "text/plain": [
2147 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
2148 | ]
2149 | },
2150 | "metadata": {},
2151 | "output_type": "display_data"
2152 | },
2153 | {
2154 | "name": "stdout",
2155 | "output_type": "stream",
2156 | "text": [
2157 | "Epoch : 27 - loss : 0.3976 - acc: 1.0000\n",
2158 | "\n"
2159 | ]
2160 | },
2161 | {
2162 | "data": {
2163 | "application/vnd.jupyter.widget-view+json": {
2164 | "model_id": "27e232c16fbd46fe9c6d093b9f2683bc",
2165 | "version_major": 2,
2166 | "version_minor": 0
2167 | },
2168 | "text/plain": [
2169 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
2170 | ]
2171 | },
2172 | "metadata": {},
2173 | "output_type": "display_data"
2174 | },
2175 | {
2176 | "name": "stdout",
2177 | "output_type": "stream",
2178 | "text": [
2179 | "Epoch : 27 - val_loss : 3.8370 - val_acc: 0.9414\n",
2180 | "\n"
2181 | ]
2182 | },
2183 | {
2184 | "data": {
2185 | "application/vnd.jupyter.widget-view+json": {
2186 | "model_id": "36e13b0ec23445d0ad79294a9772e49e",
2187 | "version_major": 2,
2188 | "version_minor": 0
2189 | },
2190 | "text/plain": [
2191 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
2192 | ]
2193 | },
2194 | "metadata": {},
2195 | "output_type": "display_data"
2196 | },
2197 | {
2198 | "name": "stdout",
2199 | "output_type": "stream",
2200 | "text": [
2201 | "Epoch : 28 - loss : 0.3917 - acc: 1.0000\n",
2202 | "\n"
2203 | ]
2204 | },
2205 | {
2206 | "data": {
2207 | "application/vnd.jupyter.widget-view+json": {
2208 | "model_id": "2ae58768db8a4bc98c8d2586bb563499",
2209 | "version_major": 2,
2210 | "version_minor": 0
2211 | },
2212 | "text/plain": [
2213 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
2214 | ]
2215 | },
2216 | "metadata": {},
2217 | "output_type": "display_data"
2218 | },
2219 | {
2220 | "name": "stdout",
2221 | "output_type": "stream",
2222 | "text": [
2223 | "Epoch : 28 - val_loss : 3.8097 - val_acc: 0.9422\n",
2224 | "\n",
2225 | "Unfrozen Blocks: 15, Current lr: 1.0555311626649608e-05, Trainable Params: 91841537\n"
2226 | ]
2227 | },
2228 | {
2229 | "data": {
2230 | "application/vnd.jupyter.widget-view+json": {
2231 | "model_id": "ead97d19cb7640fa878a0f98ecc59db2",
2232 | "version_major": 2,
2233 | "version_minor": 0
2234 | },
2235 | "text/plain": [
2236 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=381.0), HTML(value='')))"
2237 | ]
2238 | },
2239 | "metadata": {},
2240 | "output_type": "display_data"
2241 | },
2242 | {
2243 | "name": "stdout",
2244 | "output_type": "stream",
2245 | "text": [
2246 | "Epoch : 29 - loss : 0.3875 - acc: 1.0000\n",
2247 | "\n"
2248 | ]
2249 | },
2250 | {
2251 | "data": {
2252 | "application/vnd.jupyter.widget-view+json": {
2253 | "model_id": "64008a2d743a4aa88363b78b330385ca",
2254 | "version_major": 2,
2255 | "version_minor": 0
2256 | },
2257 | "text/plain": [
2258 | "HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))"
2259 | ]
2260 | },
2261 | "metadata": {},
2262 | "output_type": "display_data"
2263 | },
2264 | {
2265 | "name": "stdout",
2266 | "output_type": "stream",
2267 | "text": [
2268 | "Epoch : 29 - val_loss : 0.5044 - val_acc: 0.1576\r"
2269 | ]
2270 | }
2271 | ],
2272 | "source": [
2273 | "best_acc = 0.0\n",
2274 | "y_loss = {} # loss history\n",
2275 | "y_loss['train'] = []\n",
2276 | "y_loss['val'] = []\n",
2277 | "y_err = {}\n",
2278 | "y_err['train'] = []\n",
2279 | "y_err['val'] = []\n",
2280 | "print(\"training...\")\n",
2281 | "output_dir = \"\"\n",
2282 | "best_acc = 0\n",
2283 | "name = \"la_with_lmbd_{}\".format(lmbd)\n",
2284 | "\n",
2285 | "try:\n",
2286 | " os.mkdir(\"model/\" + name)\n",
2287 | "\n",
2288 | "except:\n",
2289 | " pass\n",
2290 | "output_dir = \"model/\" + name\n",
2291 | "unfrozen_blocks = 0\n",
2292 | "\n",
2293 | "for epoch in range(num_epochs):\n",
2294 | "\n",
2295 | " if epoch%unfreeze_after==0:\n",
2296 | " unfrozen_blocks += 1\n",
2297 | " model = unfreeze_blocks(model, unfrozen_blocks)\n",
2298 | " optimizer.param_groups[0]['lr'] *= lr_decay \n",
2299 | " trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
2300 | " print(\"Unfrozen Blocks: {}, Current lr: {}, Trainable Params: {}\".format(unfrozen_blocks, \n",
2301 | " optimizer.param_groups[0]['lr'], \n",
2302 | " trainable_params))\n",
2303 | "\n",
2304 | " train_metrics = train_one_epoch(\n",
2305 | " epoch, model, train_loader, optimizer, criterion,\n",
2306 | " lr_scheduler=None, saver=None)\n",
2307 | "\n",
2308 | " eval_metrics = validate(model, valid_loader, criterion)\n",
2309 | "\n",
2310 | "\n",
2311 | " # update summary\n",
2312 | " update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),\n",
2313 | " write_header=True)\n",
2314 | "\n",
2315 | " # deep copy the model\n",
2316 | " last_model_wts = model.state_dict()\n",
2317 | " if eval_metrics['val_accuracy'] > best_acc:\n",
2318 | " best_acc = eval_metrics['val_accuracy']\n",
2319 | " save_network(model, epoch,name)\n",
2320 | " print(\"SAVED!\")"
2321 | ]
2322 | },
2323 | {
2324 | "cell_type": "code",
2325 | "execution_count": null,
2326 | "metadata": {},
2327 | "outputs": [],
2328 | "source": []
2329 | },
2330 | {
2331 | "cell_type": "code",
2332 | "execution_count": null,
2333 | "metadata": {},
2334 | "outputs": [],
2335 | "source": []
2336 | }
2337 | ],
2338 | "metadata": {
2339 | "kernelspec": {
2340 | "display_name": "Python 3",
2341 | "language": "python",
2342 | "name": "python3"
2343 | },
2344 | "language_info": {
2345 | "codemirror_mode": {
2346 | "name": "ipython",
2347 | "version": 3
2348 | },
2349 | "file_extension": ".py",
2350 | "mimetype": "text/x-python",
2351 | "name": "python",
2352 | "nbconvert_exporter": "python",
2353 | "pygments_lexer": "ipython3",
2354 | "version": "3.7.4"
2355 | }
2356 | },
2357 | "nbformat": 4,
2358 | "nbformat_minor": 4
2359 | }
2360 |
--------------------------------------------------------------------------------
/LATransformer/metrics.py:
--------------------------------------------------------------------------------
1 | def rank1(label, output):
2 | if label==output[1][0][0]:
3 | return True
4 | return False
5 |
6 | def rank5(label, output):
7 | if label in output[1][0][:5]:
8 | return True
9 | return False
10 |
11 | def rank10(label, output):
12 | if label in output[1][0][:10]:
13 | return True
14 | return False
15 |
16 | def calc_map(label, output):
17 | count = 0
18 | score = 0
19 | good = 0
20 | for out in output[1][0]:
21 | count += 1
22 | if out==label:
23 | good += 1
24 | score += (good/count)
25 | if good==0:
26 | return 0
27 | return score/good
--------------------------------------------------------------------------------
/LATransformer/model.py:
--------------------------------------------------------------------------------
1 | import timm
2 | import numpy as np
3 | import pandas as pd
4 | from PIL import Image
5 | from tqdm.notebook import tqdm
6 | import matplotlib.pyplot as plt
7 | from collections import OrderedDict
8 | from sklearn.model_selection import train_test_split
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn import init
13 | import torch.optim as optim
14 | from torchvision import models
15 | import torch.nn.functional as F
16 | from torch.autograd import Variable
17 | from torch.optim.lr_scheduler import StepLR
18 | from torchvision import datasets, transforms
19 | from torch.utils.data import DataLoader, Dataset
20 |
21 |
22 | # weights initialization
23 | def weights_init_kaiming(m):
24 | classname = m.__class__.__name__
25 | # print(classname)
26 | if classname.find('Conv') != -1:
27 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal.
28 | elif classname.find('Linear') != -1:
29 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
30 | init.constant_(m.bias.data, 0.0)
31 | elif classname.find('BatchNorm1d') != -1:
32 | init.normal_(m.weight.data, 1.0, 0.02)
33 | init.constant_(m.bias.data, 0.0)
34 |
35 | def weights_init_classifier(m):
36 | classname = m.__class__.__name__
37 | if classname.find('Linear') != -1:
38 | init.normal_(m.weight.data, std=0.001)
39 | init.constant_(m.bias.data, 0.0)
40 |
41 | class ClassBlock(nn.Module):
42 | def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f = False):
43 | super(ClassBlock, self).__init__()
44 | self.return_f = return_f
45 | add_block = []
46 | if linear:
47 | add_block += [nn.Linear(input_dim, num_bottleneck)]
48 | else:
49 | num_bottleneck = input_dim
50 | if bnorm:
51 | add_block += [nn.BatchNorm1d(num_bottleneck)]
52 | if relu:
53 | add_block += [nn.LeakyReLU(0.1)]
54 | if droprate>0:
55 | add_block += [nn.Dropout(p=droprate)]
56 | add_block = nn.Sequential(*add_block)
57 | add_block.apply(weights_init_kaiming)
58 |
59 | classifier = []
60 | classifier += [nn.Linear(num_bottleneck, class_num)]
61 | classifier = nn.Sequential(*classifier)
62 | classifier.apply(weights_init_classifier)
63 |
64 | self.add_block = add_block
65 | self.classifier = classifier
66 | def forward(self, x):
67 | x = self.add_block(x)
68 | if self.return_f:
69 | f = x
70 | x = self.classifier(x)
71 | return [x,f]
72 | else:
73 | x = self.classifier(x)
74 | return x
75 |
76 | class LATransformer(nn.Module):
77 | def __init__(self, model, lmbd ):
78 | super(LATransformer, self).__init__()
79 |
80 | self.class_num = 751
81 | self.part = 14 # We cut the pool5 to sqrt(N) parts
82 | self.num_blocks = 12
83 | self.model = model
84 | self.model.head.requires_grad_ = False
85 | self.cls_token = self.model.cls_token
86 | self.pos_embed = self.model.pos_embed
87 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,768))
88 | self.dropout = nn.Dropout(p=0.5)
89 | self.lmbd = lmbd
90 | for i in range(self.part):
91 | name = 'classifier'+str(i)
92 | setattr(self, name, ClassBlock(768, self.class_num, droprate=0.5, relu=False, bnorm=True, num_bottleneck=256))
93 |
94 |
95 |
96 | def forward(self,x):
97 |
98 | # Divide input image into patch embeddings and add position embeddings
99 | x = self.model.patch_embed(x)
100 | cls_token = self.cls_token.expand(x.shape[0], -1, -1)
101 | x = torch.cat((cls_token, x), dim=1)
102 | x = self.model.pos_drop(x + self.pos_embed)
103 |
104 | # Feed forward through transformer blocks
105 | for i in range(self.num_blocks):
106 | x = self.model.blocks[i](x)
107 | x = self.model.norm(x)
108 |
109 | # extract the cls token
110 | cls_token_out = x[:, 0].unsqueeze(1)
111 |
112 | # Average pool
113 | x = self.avgpool(x[:, 1:])
114 |
115 | # Add global cls token to each local token
116 | for i in range(self.part):
117 | out = torch.mul(x[:, i, :], self.lmbd)
118 | x[:,i,:] = torch.div(torch.add(cls_token_out.squeeze(),out), 1+self.lmbd)
119 |
120 | # Locally aware network
121 | part = {}
122 | predict = {}
123 | for i in range(self.part):
124 | part[i] = x[:,i,:]
125 | name = 'classifier'+str(i)
126 | c = getattr(self,name)
127 | predict[i] = c(part[i])
128 | return predict
129 |
130 | class LATransformerTest(nn.Module):
131 | def __init__(self, model, lmbd ):
132 | super(LATransformerTest, self).__init__()
133 |
134 | self.class_num = 751
135 | self.part = 14 # We cut the pool5 to sqrt(N) parts
136 | self.num_blocks = 12
137 | self.model = model
138 | self.model.head.requires_grad_ = False
139 | self.cls_token = self.model.cls_token
140 | self.pos_embed = self.model.pos_embed
141 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,768))
142 | self.dropout = nn.Dropout(p=0.5)
143 | self.lmbd = lmbd
144 | # for i in range(self.part):
145 | # name = 'classifier'+str(i)
146 | # setattr(self, name, ClassBlock(768, self.class_num, droprate=0.5, relu=False, bnorm=True, num_bottleneck=256))
147 |
148 |
149 |
150 | def forward(self,x):
151 |
152 | # Divide input image into patch embeddings and add position embeddings
153 | x = self.model.patch_embed(x)
154 | cls_token = self.cls_token.expand(x.shape[0], -1, -1)
155 | x = torch.cat((cls_token, x), dim=1)
156 | x = self.model.pos_drop(x + self.pos_embed)
157 |
158 | # Feed forward through transformer blocks
159 | for i in range(self.num_blocks):
160 | x = self.model.blocks[i](x)
161 | x = self.model.norm(x)
162 |
163 | # extract the cls token
164 | cls_token_out = x[:, 0].unsqueeze(1)
165 |
166 | # Average pool
167 | x = self.avgpool(x[:, 1:])
168 |
169 | # Add global cls token to each local token
170 | # for i in range(self.part):
171 | # out = torch.mul(x[:, i, :], self.lmbd)
172 | # x[:,i,:] = torch.div(torch.add(cls_token_out.squeeze(),out), 1+self.lmbd)
173 |
174 | return x.cpu()
--------------------------------------------------------------------------------
/LATransformer/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import torch
4 | from collections import OrderedDict
5 |
6 | def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
7 | rowd = OrderedDict(epoch=epoch)
8 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
9 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
10 | with open(filename, mode='a') as cf:
11 | dw = csv.DictWriter(cf, fieldnames=rowd.keys())
12 | if write_header: # first iteration (epoch == 1 can't be used)
13 | dw.writeheader()
14 | dw.writerow(rowd)
15 |
16 | def save_network(network, epoch_label, name):
17 | save_filename = 'net_%s.pth'% "best"
18 | save_path = os.path.join('./model',name,save_filename)
19 | torch.save(network.cpu().state_dict(), save_path)
20 |
21 | if torch.cuda.is_available():
22 | network.cuda()
23 |
24 | def get_id(img_path):
25 | camera_id = []
26 | labels = []
27 | for path, v in img_path:
28 | #filename = path.split('/')[-1]
29 | filename = os.path.basename(path)
30 | label = filename[0:4]
31 | camera = filename.split('c')[1]
32 | if label[0:2]=='-1':
33 | labels.append(-1)
34 | else:
35 | labels.append(int(label))
36 | camera_id.append(int(camera[0]))
37 | return camera_id, labels
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Siddhant Kapil
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # Person Re-Identification with a Locally Aware Transformer
2 |
3 |
4 | This code is inspired from:
5 |
6 |
7 | 1) PCB - https://github.com/layumi/Person_reID_baseline_pytorch
8 | 2) Vit - https://github.com/lucidrains/vit-pytorch/tree/main/examples
9 | 3) Pre-trained models: https://github.com/rwightman/pytorch-image-models
10 |
11 | ## Release 7/5/21
12 | Demonstrates the working and performance of the LA-Transformer using two jupyter notebooks.
13 |
14 | 1) LA-Transformer Training: Demonstrates the training process. We have included cell outputs in the juyter notebook. In the
15 | last cell, training results are shown. One can also refer to model/{name}/summary.csv if the cell outputs are not clear. To
16 | run the jupyter notebook, install the requirements, download dataset using the link provided and extract it in data folder.
17 |
18 | 2) LA-Transformer Testing: Demonstrates the testing process. You can download the weights using the link below or train
19 | LA-transformer using the Training notebook. To use pre-trained weights, download them using the gdrive link below, extract
20 | them into model/{name} folder and run the Testing notebook. Performance metrics can be found in the last cell of the notebook.
21 |
22 | ## Requirements:
23 |
24 | - Torch==1.8.1 & torchvision==0.8.2: [Link](https://pytorch.org/)
25 | - timm==0.3.2: [Link](https://github.com/rwightman/pytorch-image-models)
26 | - faiss==1.6.3: [Link](https://github.com/facebookresearch/faiss)
27 | - tqdm==4.54.0
28 | - numpy==1.19.5
29 |
30 | ## Read-Only Versions:
31 | LA-Transformer Training.html and LA-Transformer Testing.html are the read-only versions containing outputs to quickly verfiy the working of LA-Transformer.
32 |
33 | ## Google Drive:
34 |
35 | Pretrained weights and dataset can be found on [this](https://drive.google.com/drive/folders/1CRkfn9iLEItaYur1WGf2abvpd2vT7nRB?usp=sharing) google drive. To remain anonymous we created a temporary gmail account to host weights and datasets. It will be changed to official account later.
36 |
--------------------------------------------------------------------------------