├── .gitignore ├── 01_InpaintingImageWang ├── 01_Inpainting_ImageWang.ipynb ├── 02_DownstreamTask_ImageWang.ipynb ├── 03_ImageWang_Leadboard_128.ipynb ├── 03_ImageWang_Leadboard_192.ipynb ├── 04_ActivationStats.ipynb ├── 04_InvestigateProblemsWithLargerImages.ipynb ├── README.md ├── RandomCutout.py └── config.py ├── 02_InpaintingVaryDatasetSize ├── 01_InpaintingWithVariedDatasetSize.ipynb ├── README.md ├── RandomCutout.py └── config.py ├── 03_PretextTrainingTime ├── 03_PretextTrainingTime.ipynb ├── RandomCutout.py └── config.py ├── 04_ImprovingPretextTask.ipynb ├── LICENSE ├── README.md └── misc ├── 00_BaselineApproaches.ipynb ├── 02_Inpainting_ImageWang.ipynb └── 03_DownstreamTask_Pascal.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # PyCharm settings 10 | .idea/ 11 | 12 | #Model checkpoints 13 | *.pth 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /01_InpaintingImageWang/README.md: -------------------------------------------------------------------------------- 1 | # Image Inpainting 2 | 3 | **Hypothesis**: By training a network on the task of image inpainting, we are left with a set of weights that outperform randomly initialized weights on a downstream task. 4 | 5 | **Result**: True 6 | 7 | - Random weights baseline: **54.0%** accuracy 8 | - Pretext weights with head fine-tuning: **56.3%** accuracy 9 | 10 | 11 | **Methodology**: 12 | 13 | We train a U-Net with an `xresnet34` backbone on the task of image inpainting in which it is tasked with "filling in" missing patches that have been cutout from an image. 14 | 15 | ![](https://joshvarty.files.wordpress.com/2020/02/inpainting-3.png) 16 | 17 | We take the `xresnet34` network, add a `torch.nn.Linear()` to it and train/validate it on the ImageWang dataset. 18 | 19 | -------------------------------------------------------------------------------- /01_InpaintingImageWang/RandomCutout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastai2.vision.all import PILImage, Image 3 | from fastai2.vision.augment import RandTransform 4 | 5 | 6 | # We create this dummy class in order to create a transform that ONLY operates on images of this type 7 | # We will use it to create all input images 8 | class PILImageInput(PILImage): 9 | pass 10 | 11 | 12 | class RandomCutout(RandTransform): 13 | "Picks a random scaled crop of an image and resize it to `size`" 14 | split_idx = None 15 | 16 | def __init__(self, min_n_holes=5, max_n_holes=10, min_length=5, max_length=50, **kwargs): 17 | super().__init__(**kwargs) 18 | self.min_n_holes = min_n_holes 19 | self.max_n_holes = max_n_holes 20 | self.min_length = min_length 21 | self.max_length = max_length 22 | 23 | def encodes(self, x: PILImageInput): 24 | """ 25 | Note that we're accepting our dummy PILImageInput class 26 | fastai2 will only pass images of this type to our encoder. 27 | This means that our transform will only be applied to input images and won't 28 | be run against output images. 29 | """ 30 | 31 | n_holes = np.random.randint(self.min_n_holes, self.max_n_holes) 32 | pixels = np.array(x) # Convert to mutable numpy array. FeelsBadMan 33 | h, w = pixels.shape[:2] 34 | 35 | for n in range(n_holes): 36 | h_length = np.random.randint(self.min_length, self.max_length) 37 | w_length = np.random.randint(self.min_length, self.max_length) 38 | h_y = np.random.randint(0, h) 39 | h_x = np.random.randint(0, w) 40 | y1 = int(np.clip(h_y - h_length / 2, 0, h)) 41 | y2 = int(np.clip(h_y + h_length / 2, 0, h)) 42 | x1 = int(np.clip(h_x - w_length / 2, 0, w)) 43 | x2 = int(np.clip(h_x + w_length / 2, 0, w)) 44 | 45 | pixels[y1:y2, x1:x2, :] = 0 46 | 47 | return Image.fromarray(pixels, mode='RGB') -------------------------------------------------------------------------------- /01_InpaintingImageWang/config.py: -------------------------------------------------------------------------------- 1 | from fastai2.layers import Mish, MaxPool 2 | from fastai2.vision.models.xresnet import xresnet34 3 | 4 | config = { 5 | 'lr': 8e-3, 6 | 'size': 128, 7 | 'sqrmom': 0.99, 8 | 'mom': 0.9, 9 | 'eps': 1e-6, 10 | 'epochs': 15, 11 | 'bs': 64, 12 | 'opt': 'ranger', 13 | 'sh': 0., 14 | 'sa': 0, 15 | 'sym': 0, 16 | 'beta': 0., 17 | 'act_fn': Mish, 18 | 'fp16': 0, 19 | 'pool': MaxPool, 20 | 'runs': 1, 21 | 'model': xresnet34 22 | } 23 | -------------------------------------------------------------------------------- /02_InpaintingVaryDatasetSize/01_InpaintingWithVariedDatasetSize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Inpainting with Variable Dataset Size" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Recall that the Image网 dataset consists of:\n", 15 | "\n", 16 | "1. A `/val` folder with 10 classes.\n", 17 | "2. A `/train` folder with 20 classes. \n", 18 | " - There are ~125 images in each class that exists in `/val`. There are \n", 19 | " - There are ~1,300 images in each class that does not exist in `/val`\n", 20 | "3. An `/unsup` folder with 7,750 unlabelled images." 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "The question we would like to answer with this notebook is:\n", 28 | "\n", 29 | "> What is the effect of dataset size during pretext training on downstream task performance?\n", 30 | "\n", 31 | "To answer this question we will consider four different datasets, each built from ImageWang.\n", 32 | "\n", 33 | "They are:\n", 34 | "\n", 35 | "1. All data in `/train`, `/unsup` and `/val`\n", 36 | "2. All data in `/train`, `/unsup`\n", 37 | "3. All data in `/train`\n", 38 | "4. Only Data in `/train` that has a corresponding class in `/val`" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 1, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import gc\n", 48 | "import json\n", 49 | "import torch\n", 50 | "import numpy as np\n", 51 | "\n", 52 | "from config import config\n", 53 | "from RandomCutout import RandomCutout, PILImageInput\n", 54 | "\n", 55 | "from fastai2.basics import *\n", 56 | "from fastai2.vision.all import *\n", 57 | "\n", 58 | "from torch.nn import MSELoss\n", 59 | "from functools import partial" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 2, 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "# Default parameters\n", 69 | "lr=config['lr']\n", 70 | "size=config['size']\n", 71 | "sqrmom=config['sqrmom']\n", 72 | "mom=config['mom']\n", 73 | "eps=config['eps']\n", 74 | "epochs=config['epochs']\n", 75 | "bs=config['bs']\n", 76 | "opt=config['opt']\n", 77 | "sh=config['sh']\n", 78 | "sa=config['sa']\n", 79 | "sym=config['sym']\n", 80 | "beta=config['beta']\n", 81 | "act_fn=config['act_fn']\n", 82 | "fp16=config['fp16']\n", 83 | "pool=config['pool']\n", 84 | "runs=config['runs']\n", 85 | "\n", 86 | "model = config['model']\n", 87 | "\n", 88 | "if opt=='adam' : opt_func = partial(Adam, mom=mom, sqr_mom=sqrmom, eps=eps)\n", 89 | "elif opt=='rms' : opt_func = partial(RMSProp, sqr_mom=sqrmom)\n", 90 | "elif opt=='sgd' : opt_func = partial(SGD, mom=mom)\n", 91 | "elif opt=='ranger': opt_func = partial(ranger, mom=mom, sqr_mom=sqrmom, eps=eps, beta=beta)\n", 92 | " \n", 93 | "size = 128\n", 94 | "bs = 64\n", 95 | "runs=3" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "lr 0.008\n", 108 | "size 128\n", 109 | "sqrmom 0.99\n", 110 | "mom 0.9\n", 111 | "eps 1e-06\n", 112 | "epochs 15\n", 113 | "bs 64\n", 114 | "opt ranger\n", 115 | "sh 0.0\n", 116 | "sa 0\n", 117 | "sym 0\n", 118 | "beta 0.0\n", 119 | "act_fn \n", 120 | "fp16 0\n", 121 | "pool \n", 122 | "runs 3\n", 123 | "model \n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "print(\"lr\", lr)\n", 129 | "print(\"size\", size)\n", 130 | "print(\"sqrmom\", sqrmom)\n", 131 | "print(\"mom\", mom)\n", 132 | "print(\"eps\", eps)\n", 133 | "print(\"epochs\", epochs)\n", 134 | "print(\"bs\", bs)\n", 135 | "print(\"opt\", opt)\n", 136 | "print(\"sh\", sh)\n", 137 | "print(\"sa\", sa)\n", 138 | "print(\"sym\", sym)\n", 139 | "print(\"beta\", beta)\n", 140 | "print(\"act_fn\", act_fn)\n", 141 | "print(\"fp16\", fp16)\n", 142 | "print(\"pool\", pool)\n", 143 | "print(\"runs\", runs)\n", 144 | "print(\"model\", model)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "## Get Items From Folder" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "So before we do anything, let's create some helper methods that will give us only the training sets that we would like." 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 4, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "source = untar_data(URLs.IMAGEWANG_160)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 5, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "# transforms are the same for each experiment\n", 177 | "item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5), RandomCutout]\n", 178 | "batch_tfms=RandomErasing(p=0.9, max_count=3, sh=sh) if sh else None" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 6, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "name": "stdout", 188 | "output_type": "stream", 189 | "text": [ 190 | "All Files:\t26348\n", 191 | "Train Files:\t14669\n", 192 | "Unsup Files:\t7750\n", 193 | "Valid Files:\t3929\n", 194 | "\n", 195 | "Train+Unsup Files: 22419\n", 196 | "Train(in validation set)+Unsup Files: 1275\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "def get_all_items(path):\n", 202 | " return get_files(path, extensions='.JPEG', recurse=True)\n", 203 | "\n", 204 | "def get_train_items(path):\n", 205 | " return get_files(path/'train', extensions='.JPEG', recurse=True)\n", 206 | "\n", 207 | "def get_unsup_items(path):\n", 208 | " return get_files(path/'unsup', extensions='.JPEG', recurse=True)\n", 209 | "\n", 210 | "def get_valid_items(path):\n", 211 | " return get_files(path/'val', extensions='.JPEG', recurse=True)\n", 212 | "\n", 213 | "def get_train_and_unsup(path):\n", 214 | " return get_train_items(path) + get_unsup_items(path)\n", 215 | "\n", 216 | "def get_train_items_that_are_present_in_val(path):\n", 217 | " \"\"\"\n", 218 | " We first get a list of all classes in /val\n", 219 | " Then we use that list to get all the examples of each class from /train\n", 220 | " \"\"\"\n", 221 | " val = source/'val'\n", 222 | " validation_classes = [path.name for path in val.iterdir()]\n", 223 | " \n", 224 | " train_files = L()\n", 225 | " for class_name in validation_classes:\n", 226 | " items = get_files(path/'train'/class_name, extensions='.JPEG', recurse=True)\n", 227 | " train_files = train_files + items\n", 228 | " \n", 229 | " return train_files\n", 230 | "\n", 231 | "all_items = get_all_items(untar_data(URLs.IMAGEWANG_160))\n", 232 | "train_items = get_train_items(untar_data(URLs.IMAGEWANG_160))\n", 233 | "unsup_items = get_unsup_items(untar_data(URLs.IMAGEWANG_160))\n", 234 | "valid_items = get_valid_items(untar_data(URLs.IMAGEWANG_160))\n", 235 | "\n", 236 | "print(\"All Files:\\t{}\".format(len(all_items)))\n", 237 | "print(\"Train Files:\\t{}\".format(len(train_items)))\n", 238 | "print(\"Unsup Files:\\t{}\".format(len(unsup_items)))\n", 239 | "print(\"Valid Files:\\t{}\".format(len(valid_items)))\n", 240 | "print()\n", 241 | "\n", 242 | "train_and_unsup_items = get_train_and_unsup(untar_data(URLs.IMAGEWANG_160))\n", 243 | "print(\"Train+Unsup Files: {}\".format(len(train_and_unsup_items)))\n", 244 | "train_in_valid_items = get_train_items_that_are_present_in_val(untar_data(URLs.IMAGEWANG_160))\n", 245 | "print(\"Train(in validation set)+Unsup Files: {}\".format(len(train_in_valid_items)))" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "## Train with all data in `/train`, `/unsup` and `/val`" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 7, 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "name": "stdout", 262 | "output_type": "stream", 263 | "text": [ 264 | "Training Size: 26348\n", 265 | "Validation Size: 0\n" 266 | ] 267 | } 268 | ], 269 | "source": [ 270 | "dblock = DataBlock(blocks=(ImageBlock(cls=PILImageInput), ImageBlock),\n", 271 | " splitter=RandomSplitter(valid_pct=0),\n", 272 | " get_items=get_all_items, \n", 273 | " get_y=lambda o: o,\n", 274 | " item_tfms=item_tfms,\n", 275 | " batch_tfms=batch_tfms)\n", 276 | "\n", 277 | "dbunch = dblock.dataloaders(source, path=source, bs=bs)\n", 278 | "\n", 279 | "#CHANGE: We're predicting pixel values, so we're just going to predict an output for each RGB channel\n", 280 | "dbunch.vocab = ['R', 'G', 'B']\n", 281 | "\n", 282 | "print(\"Training Size:\", len(dbunch.train_ds))\n", 283 | "print(\"Validation Size:\", len(dbunch.valid_ds))" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 8, 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "data": { 293 | "text/html": [ 294 | "\n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | "
epochtrain_lossvalid_losstime
00.016735None02:34
10.005480None02:30
20.004872None02:30
30.004522None02:30
40.004380None02:30
50.004298None02:30
60.004284None02:30
70.004252None02:30
80.004163None02:30
90.004135None02:30
100.004204None02:30
110.004075None02:30
120.003972None02:30
130.003911None02:30
140.003798None02:30
" 396 | ], 397 | "text/plain": [ 398 | "" 399 | ] 400 | }, 401 | "metadata": {}, 402 | "output_type": "display_data" 403 | }, 404 | { 405 | "name": "stderr", 406 | "output_type": "stream", 407 | "text": [ 408 | "/home/josh/anaconda3/envs/fastai2/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.\n", 409 | " warn(\"Your generator is empty.\")\n" 410 | ] 411 | }, 412 | { 413 | "data": { 414 | "text/html": [ 415 | "\n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | "
epochtrain_lossvalid_losstime
00.016650None02:30
10.005430None02:30
20.004920None02:30
30.004689None02:30
40.004529None02:30
50.004477None02:30
60.004304None02:30
70.004278None02:30
80.004193None02:30
90.004143None02:30
100.004176None02:30
110.004125None02:30
120.003944None02:30
130.003806None02:30
140.003859None02:30
" 517 | ], 518 | "text/plain": [ 519 | "" 520 | ] 521 | }, 522 | "metadata": {}, 523 | "output_type": "display_data" 524 | }, 525 | { 526 | "data": { 527 | "text/html": [ 528 | "\n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | "
epochtrain_lossvalid_losstime
00.017128None02:30
10.005302None02:30
20.004838None02:30
30.004739None02:30
40.004560None02:30
50.004536None02:30
60.004400None02:30
70.004351None02:30
80.004241None02:30
90.004143None02:30
100.004150None02:30
110.004077None02:30
120.004096None02:30
130.003935None02:30
140.003805None02:30
" 630 | ], 631 | "text/plain": [ 632 | "" 633 | ] 634 | }, 635 | "metadata": {}, 636 | "output_type": "display_data" 637 | } 638 | ], 639 | "source": [ 640 | "for run in range(runs):\n", 641 | " learn = unet_learner(dbunch, model, pretrained=False, opt_func=opt_func, metrics=[], loss_func=MSELoss())\n", 642 | "\n", 643 | " if fp16: learn = learn.to_fp16()\n", 644 | " cbs = []\n", 645 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 646 | "\n", 647 | " #Save model backbone\n", 648 | " torch.save(learn.model[0].state_dict(), 'all_train_unsup_val_pretext_{}.pth'.format(run))\n", 649 | " \n", 650 | " del learn\n", 651 | " torch.cuda.empty_cache() \n", 652 | " gc.collect() " 653 | ] 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "metadata": {}, 658 | "source": [ 659 | "## Train with all data in `/train` and `/unsup`" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": 9, 665 | "metadata": {}, 666 | "outputs": [ 667 | { 668 | "name": "stdout", 669 | "output_type": "stream", 670 | "text": [ 671 | "Training Size: 22419\n", 672 | "Validation Size: 0\n" 673 | ] 674 | } 675 | ], 676 | "source": [ 677 | "dblock = DataBlock(blocks=(ImageBlock(cls=PILImageInput), ImageBlock),\n", 678 | " splitter=RandomSplitter(valid_pct=0),\n", 679 | " get_items=get_train_and_unsup, \n", 680 | " get_y=lambda o: o,\n", 681 | " item_tfms=item_tfms, \n", 682 | " batch_tfms=batch_tfms)\n", 683 | "\n", 684 | "dbunch = dblock.dataloaders(source, path=source, bs=bs)\n", 685 | "\n", 686 | "#CHANGE: We're predicting pixel values, so we're just going to predict an output for each RGB channel\n", 687 | "dbunch.vocab = ['R', 'G', 'B']\n", 688 | "\n", 689 | "print(\"Training Size:\", len(dbunch.train_ds))\n", 690 | "print(\"Validation Size:\", len(dbunch.valid_ds))" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": 10, 696 | "metadata": {}, 697 | "outputs": [ 698 | { 699 | "data": { 700 | "text/html": [ 701 | "\n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | "
epochtrain_lossvalid_losstime
00.039221None02:08
10.005747None02:08
20.005353None02:08
30.005068None02:08
40.004822None02:08
50.004581None02:08
60.004599None02:08
70.004477None02:08
80.004373None02:08
90.004322None02:08
100.004392None02:08
110.004314None02:08
120.004105None02:08
130.003903None02:08
140.003880None02:08
" 803 | ], 804 | "text/plain": [ 805 | "" 806 | ] 807 | }, 808 | "metadata": {}, 809 | "output_type": "display_data" 810 | }, 811 | { 812 | "data": { 813 | "text/html": [ 814 | "\n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | "
epochtrain_lossvalid_losstime
00.038902None02:08
10.005771None02:08
20.005489None02:08
30.005071None02:08
40.004872None02:08
50.004596None02:08
60.004538None02:08
70.004442None02:08
80.004372None02:08
90.004342None02:08
100.004376None02:08
110.004244None02:08
120.004197None02:08
130.003948None02:08
140.003870None02:08
" 916 | ], 917 | "text/plain": [ 918 | "" 919 | ] 920 | }, 921 | "metadata": {}, 922 | "output_type": "display_data" 923 | }, 924 | { 925 | "data": { 926 | "text/html": [ 927 | "\n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | " \n", 1001 | " \n", 1002 | " \n", 1003 | " \n", 1004 | " \n", 1005 | " \n", 1006 | " \n", 1007 | " \n", 1008 | " \n", 1009 | " \n", 1010 | " \n", 1011 | " \n", 1012 | " \n", 1013 | " \n", 1014 | " \n", 1015 | " \n", 1016 | " \n", 1017 | " \n", 1018 | " \n", 1019 | " \n", 1020 | " \n", 1021 | " \n", 1022 | " \n", 1023 | " \n", 1024 | " \n", 1025 | " \n", 1026 | " \n", 1027 | " \n", 1028 | "
epochtrain_lossvalid_losstime
00.039105None02:08
10.005674None02:08
20.005298None02:08
30.004939None02:08
40.004763None02:08
50.004596None02:08
60.004536None02:08
70.004641None02:08
80.004362None02:08
90.004354None02:08
100.004228None02:08
110.004231None02:08
120.004023None02:08
130.004008None02:08
140.003962None02:08
" 1029 | ], 1030 | "text/plain": [ 1031 | "" 1032 | ] 1033 | }, 1034 | "metadata": {}, 1035 | "output_type": "display_data" 1036 | } 1037 | ], 1038 | "source": [ 1039 | "for run in range(runs):\n", 1040 | " learn = unet_learner(dbunch, model, pretrained=False, opt_func=opt_func, metrics=[], loss_func=MSELoss())\n", 1041 | "\n", 1042 | " if fp16: learn = learn.to_fp16()\n", 1043 | " cbs = []\n", 1044 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 1045 | "\n", 1046 | " #Save model backbone\n", 1047 | " torch.save(learn.model[0].state_dict(), 'all_train_unsup_pretext_{}.pth'.format(run))\n", 1048 | " \n", 1049 | " del learn\n", 1050 | " torch.cuda.empty_cache() \n", 1051 | " gc.collect() " 1052 | ] 1053 | }, 1054 | { 1055 | "cell_type": "markdown", 1056 | "metadata": {}, 1057 | "source": [ 1058 | "## Train with all data in `/train`" 1059 | ] 1060 | }, 1061 | { 1062 | "cell_type": "code", 1063 | "execution_count": 11, 1064 | "metadata": {}, 1065 | "outputs": [ 1066 | { 1067 | "name": "stdout", 1068 | "output_type": "stream", 1069 | "text": [ 1070 | "Training Size: 14669\n", 1071 | "Validation Size: 0\n" 1072 | ] 1073 | } 1074 | ], 1075 | "source": [ 1076 | "dblock = DataBlock(blocks=(ImageBlock(cls=PILImageInput), ImageBlock),\n", 1077 | " splitter=RandomSplitter(valid_pct=0),\n", 1078 | " get_items=get_train_items, \n", 1079 | " get_y=lambda o: o,\n", 1080 | " item_tfms=item_tfms, \n", 1081 | " batch_tfms=batch_tfms)\n", 1082 | "\n", 1083 | "dbunch = dblock.dataloaders(source, path=source, bs=bs)\n", 1084 | "\n", 1085 | "#CHANGE: We're predicting pixel values, so we're just going to predict an output for each RGB channel\n", 1086 | "dbunch.vocab = ['R', 'G', 'B']\n", 1087 | "\n", 1088 | "print(\"Training Size:\", len(dbunch.train_ds))\n", 1089 | "print(\"Validation Size:\", len(dbunch.valid_ds))" 1090 | ] 1091 | }, 1092 | { 1093 | "cell_type": "code", 1094 | "execution_count": 12, 1095 | "metadata": {}, 1096 | "outputs": [ 1097 | { 1098 | "data": { 1099 | "text/html": [ 1100 | "\n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | "
epochtrain_lossvalid_losstime
00.185480None01:24
10.010443None01:24
20.005977None01:24
30.006075None01:24
40.005620None01:24
50.005190None01:24
60.005272None01:24
70.005040None01:24
80.004982None01:24
90.004757None01:24
100.004697None01:24
110.004739None01:24
120.004572None01:24
130.004483None01:24
140.004370None01:24
" 1202 | ], 1203 | "text/plain": [ 1204 | "" 1205 | ] 1206 | }, 1207 | "metadata": {}, 1208 | "output_type": "display_data" 1209 | }, 1210 | { 1211 | "data": { 1212 | "text/html": [ 1213 | "\n", 1214 | " \n", 1215 | " \n", 1216 | " \n", 1217 | " \n", 1218 | " \n", 1219 | " \n", 1220 | " \n", 1221 | " \n", 1222 | " \n", 1223 | " \n", 1224 | " \n", 1225 | " \n", 1226 | " \n", 1227 | " \n", 1228 | " \n", 1229 | " \n", 1230 | " \n", 1231 | " \n", 1232 | " \n", 1233 | " \n", 1234 | " \n", 1235 | " \n", 1236 | " \n", 1237 | " \n", 1238 | " \n", 1239 | " \n", 1240 | " \n", 1241 | " \n", 1242 | " \n", 1243 | " \n", 1244 | " \n", 1245 | " \n", 1246 | " \n", 1247 | " \n", 1248 | " \n", 1249 | " \n", 1250 | " \n", 1251 | " \n", 1252 | " \n", 1253 | " \n", 1254 | " \n", 1255 | " \n", 1256 | " \n", 1257 | " \n", 1258 | " \n", 1259 | " \n", 1260 | " \n", 1261 | " \n", 1262 | " \n", 1263 | " \n", 1264 | " \n", 1265 | " \n", 1266 | " \n", 1267 | " \n", 1268 | " \n", 1269 | " \n", 1270 | " \n", 1271 | " \n", 1272 | " \n", 1273 | " \n", 1274 | " \n", 1275 | " \n", 1276 | " \n", 1277 | " \n", 1278 | " \n", 1279 | " \n", 1280 | " \n", 1281 | " \n", 1282 | " \n", 1283 | " \n", 1284 | " \n", 1285 | " \n", 1286 | " \n", 1287 | " \n", 1288 | " \n", 1289 | " \n", 1290 | " \n", 1291 | " \n", 1292 | " \n", 1293 | " \n", 1294 | " \n", 1295 | " \n", 1296 | " \n", 1297 | " \n", 1298 | " \n", 1299 | " \n", 1300 | " \n", 1301 | " \n", 1302 | " \n", 1303 | " \n", 1304 | " \n", 1305 | " \n", 1306 | " \n", 1307 | " \n", 1308 | " \n", 1309 | " \n", 1310 | " \n", 1311 | " \n", 1312 | " \n", 1313 | " \n", 1314 | "
epochtrain_lossvalid_losstime
00.184486None01:24
10.010486None01:24
20.006268None01:24
30.005693None01:24
40.005413None01:24
50.005430None01:24
60.005210None01:24
70.005036None01:24
80.005011None01:24
90.004929None01:24
100.004888None01:24
110.004755None01:24
120.004522None01:24
130.004333None01:24
140.004323None01:24
" 1315 | ], 1316 | "text/plain": [ 1317 | "" 1318 | ] 1319 | }, 1320 | "metadata": {}, 1321 | "output_type": "display_data" 1322 | }, 1323 | { 1324 | "data": { 1325 | "text/html": [ 1326 | "\n", 1327 | " \n", 1328 | " \n", 1329 | " \n", 1330 | " \n", 1331 | " \n", 1332 | " \n", 1333 | " \n", 1334 | " \n", 1335 | " \n", 1336 | " \n", 1337 | " \n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | " \n", 1362 | " \n", 1363 | " \n", 1364 | " \n", 1365 | " \n", 1366 | " \n", 1367 | " \n", 1368 | " \n", 1369 | " \n", 1370 | " \n", 1371 | " \n", 1372 | " \n", 1373 | " \n", 1374 | " \n", 1375 | " \n", 1376 | " \n", 1377 | " \n", 1378 | " \n", 1379 | " \n", 1380 | " \n", 1381 | " \n", 1382 | " \n", 1383 | " \n", 1384 | " \n", 1385 | " \n", 1386 | " \n", 1387 | " \n", 1388 | " \n", 1389 | " \n", 1390 | " \n", 1391 | " \n", 1392 | " \n", 1393 | " \n", 1394 | " \n", 1395 | " \n", 1396 | " \n", 1397 | " \n", 1398 | " \n", 1399 | " \n", 1400 | " \n", 1401 | " \n", 1402 | " \n", 1403 | " \n", 1404 | " \n", 1405 | " \n", 1406 | " \n", 1407 | " \n", 1408 | " \n", 1409 | " \n", 1410 | " \n", 1411 | " \n", 1412 | " \n", 1413 | " \n", 1414 | " \n", 1415 | " \n", 1416 | " \n", 1417 | " \n", 1418 | " \n", 1419 | " \n", 1420 | " \n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1425 | " \n", 1426 | " \n", 1427 | "
epochtrain_lossvalid_losstime
00.184661None01:24
10.010222None01:24
20.006018None01:24
30.005746None01:24
40.005596None01:24
50.005327None01:24
60.005002None01:24
70.004956None01:24
80.004934None01:24
90.004931None01:24
100.004784None01:24
110.004607None01:24
120.004510None01:24
130.004388None01:24
140.004330None01:24
" 1428 | ], 1429 | "text/plain": [ 1430 | "" 1431 | ] 1432 | }, 1433 | "metadata": {}, 1434 | "output_type": "display_data" 1435 | } 1436 | ], 1437 | "source": [ 1438 | "for run in range(runs):\n", 1439 | " learn = unet_learner(dbunch, model, pretrained=False, opt_func=opt_func, metrics=[], loss_func=MSELoss())\n", 1440 | "\n", 1441 | " if fp16: learn = learn.to_fp16()\n", 1442 | " cbs = []\n", 1443 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 1444 | "\n", 1445 | " #Save model backbone\n", 1446 | " torch.save(learn.model[0].state_dict(), 'all_train_pretext_{}.pth'.format(run))\n", 1447 | " \n", 1448 | " del learn\n", 1449 | " torch.cuda.empty_cache() \n", 1450 | " gc.collect() " 1451 | ] 1452 | }, 1453 | { 1454 | "cell_type": "markdown", 1455 | "metadata": {}, 1456 | "source": [ 1457 | "## Train with partial data from `/train`" 1458 | ] 1459 | }, 1460 | { 1461 | "cell_type": "code", 1462 | "execution_count": 10, 1463 | "metadata": {}, 1464 | "outputs": [ 1465 | { 1466 | "name": "stdout", 1467 | "output_type": "stream", 1468 | "text": [ 1469 | "Training Size: 1275\n", 1470 | "Validation Size: 0\n" 1471 | ] 1472 | } 1473 | ], 1474 | "source": [ 1475 | "dblock = DataBlock(blocks=(ImageBlock(cls=PILImageInput), ImageBlock),\n", 1476 | " splitter=RandomSplitter(valid_pct=0),\n", 1477 | " get_items=get_train_items_that_are_present_in_val, \n", 1478 | " get_y=lambda o: o,\n", 1479 | " item_tfms=item_tfms, \n", 1480 | " batch_tfms=batch_tfms)\n", 1481 | "\n", 1482 | "dbunch = dblock.dataloaders(source, path=source, bs=bs)\n", 1483 | "\n", 1484 | "#CHANGE: We're predicting pixel values, so we're just going to predict an output for each RGB channel\n", 1485 | "dbunch.vocab = ['R', 'G', 'B']\n", 1486 | "\n", 1487 | "print(\"Training Size:\", len(dbunch.train_ds))\n", 1488 | "print(\"Validation Size:\", len(dbunch.valid_ds))" 1489 | ] 1490 | }, 1491 | { 1492 | "cell_type": "code", 1493 | "execution_count": 11, 1494 | "metadata": { 1495 | "scrolled": false 1496 | }, 1497 | "outputs": [ 1498 | { 1499 | "data": { 1500 | "text/html": [ 1501 | "\n", 1502 | " \n", 1503 | " \n", 1504 | " \n", 1505 | " \n", 1506 | " \n", 1507 | " \n", 1508 | " \n", 1509 | " \n", 1510 | " \n", 1511 | " \n", 1512 | " \n", 1513 | " \n", 1514 | " \n", 1515 | " \n", 1516 | " \n", 1517 | " \n", 1518 | " \n", 1519 | " \n", 1520 | " \n", 1521 | " \n", 1522 | " \n", 1523 | " \n", 1524 | " \n", 1525 | " \n", 1526 | " \n", 1527 | " \n", 1528 | " \n", 1529 | " \n", 1530 | " \n", 1531 | " \n", 1532 | " \n", 1533 | " \n", 1534 | " \n", 1535 | " \n", 1536 | " \n", 1537 | " \n", 1538 | " \n", 1539 | " \n", 1540 | " \n", 1541 | " \n", 1542 | " \n", 1543 | " \n", 1544 | " \n", 1545 | " \n", 1546 | " \n", 1547 | " \n", 1548 | " \n", 1549 | " \n", 1550 | " \n", 1551 | " \n", 1552 | " \n", 1553 | " \n", 1554 | " \n", 1555 | " \n", 1556 | " \n", 1557 | " \n", 1558 | " \n", 1559 | " \n", 1560 | " \n", 1561 | " \n", 1562 | " \n", 1563 | " \n", 1564 | " \n", 1565 | " \n", 1566 | " \n", 1567 | " \n", 1568 | " \n", 1569 | " \n", 1570 | " \n", 1571 | " \n", 1572 | " \n", 1573 | " \n", 1574 | " \n", 1575 | " \n", 1576 | " \n", 1577 | " \n", 1578 | " \n", 1579 | " \n", 1580 | " \n", 1581 | " \n", 1582 | " \n", 1583 | " \n", 1584 | " \n", 1585 | " \n", 1586 | " \n", 1587 | " \n", 1588 | " \n", 1589 | " \n", 1590 | " \n", 1591 | " \n", 1592 | " \n", 1593 | " \n", 1594 | " \n", 1595 | " \n", 1596 | " \n", 1597 | " \n", 1598 | " \n", 1599 | " \n", 1600 | " \n", 1601 | " \n", 1602 | "
epochtrain_lossvalid_losstime
00.928153None00:07
10.813562None00:07
20.741672None00:07
30.675520None00:07
40.610155None00:07
50.543690None00:07
60.477957None00:07
70.414091None00:07
80.354247None00:07
90.298382None00:07
100.247550None00:07
110.202207None00:07
120.164292None00:07
130.134360None00:07
140.113019None00:07
" 1603 | ], 1604 | "text/plain": [ 1605 | "" 1606 | ] 1607 | }, 1608 | "metadata": {}, 1609 | "output_type": "display_data" 1610 | }, 1611 | { 1612 | "data": { 1613 | "text/html": [ 1614 | "\n", 1615 | " \n", 1616 | " \n", 1617 | " \n", 1618 | " \n", 1619 | " \n", 1620 | " \n", 1621 | " \n", 1622 | " \n", 1623 | " \n", 1624 | " \n", 1625 | " \n", 1626 | " \n", 1627 | " \n", 1628 | " \n", 1629 | " \n", 1630 | " \n", 1631 | " \n", 1632 | " \n", 1633 | " \n", 1634 | " \n", 1635 | " \n", 1636 | " \n", 1637 | " \n", 1638 | " \n", 1639 | " \n", 1640 | " \n", 1641 | " \n", 1642 | " \n", 1643 | " \n", 1644 | " \n", 1645 | " \n", 1646 | " \n", 1647 | " \n", 1648 | " \n", 1649 | " \n", 1650 | " \n", 1651 | " \n", 1652 | " \n", 1653 | " \n", 1654 | " \n", 1655 | " \n", 1656 | " \n", 1657 | " \n", 1658 | " \n", 1659 | " \n", 1660 | " \n", 1661 | " \n", 1662 | " \n", 1663 | " \n", 1664 | " \n", 1665 | " \n", 1666 | " \n", 1667 | " \n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | "
epochtrain_lossvalid_losstime
00.925984None00:07
10.812690None00:07
20.739564None00:07
30.674601None00:07
40.608644None00:07
50.542074None00:07
60.477865None00:08
70.414007None00:07
80.353446None00:07
90.297540None00:07
100.246756None00:07
110.202051None00:07
120.164051None00:07
130.134142None00:08
140.112815None00:08
" 1716 | ], 1717 | "text/plain": [ 1718 | "" 1719 | ] 1720 | }, 1721 | "metadata": {}, 1722 | "output_type": "display_data" 1723 | }, 1724 | { 1725 | "data": { 1726 | "text/html": [ 1727 | "\n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | " \n", 1733 | " \n", 1734 | " \n", 1735 | " \n", 1736 | " \n", 1737 | " \n", 1738 | " \n", 1739 | " \n", 1740 | " \n", 1741 | " \n", 1742 | " \n", 1743 | " \n", 1744 | " \n", 1745 | " \n", 1746 | " \n", 1747 | " \n", 1748 | " \n", 1749 | " \n", 1750 | " \n", 1751 | " \n", 1752 | " \n", 1753 | " \n", 1754 | " \n", 1755 | " \n", 1756 | " \n", 1757 | " \n", 1758 | " \n", 1759 | " \n", 1760 | " \n", 1761 | " \n", 1762 | " \n", 1763 | " \n", 1764 | " \n", 1765 | " \n", 1766 | " \n", 1767 | " \n", 1768 | " \n", 1769 | " \n", 1770 | " \n", 1771 | " \n", 1772 | " \n", 1773 | " \n", 1774 | " \n", 1775 | " \n", 1776 | " \n", 1777 | " \n", 1778 | " \n", 1779 | " \n", 1780 | " \n", 1781 | " \n", 1782 | " \n", 1783 | " \n", 1784 | " \n", 1785 | " \n", 1786 | " \n", 1787 | " \n", 1788 | " \n", 1789 | " \n", 1790 | " \n", 1791 | " \n", 1792 | " \n", 1793 | " \n", 1794 | " \n", 1795 | " \n", 1796 | " \n", 1797 | " \n", 1798 | " \n", 1799 | " \n", 1800 | " \n", 1801 | " \n", 1802 | " \n", 1803 | " \n", 1804 | " \n", 1805 | " \n", 1806 | " \n", 1807 | " \n", 1808 | " \n", 1809 | " \n", 1810 | " \n", 1811 | " \n", 1812 | " \n", 1813 | " \n", 1814 | " \n", 1815 | " \n", 1816 | " \n", 1817 | " \n", 1818 | " \n", 1819 | " \n", 1820 | " \n", 1821 | " \n", 1822 | " \n", 1823 | " \n", 1824 | " \n", 1825 | " \n", 1826 | " \n", 1827 | " \n", 1828 | "
epochtrain_lossvalid_losstime
00.921843None00:08
10.812477None00:07
20.741308None00:07
30.675774None00:08
40.610134None00:08
50.543442None00:07
60.477867None00:07
70.414159None00:07
80.353886None00:07
90.298228None00:07
100.247207None00:08
110.202228None00:07
120.164021None00:07
130.134206None00:07
140.112829None00:07
" 1829 | ], 1830 | "text/plain": [ 1831 | "" 1832 | ] 1833 | }, 1834 | "metadata": {}, 1835 | "output_type": "display_data" 1836 | } 1837 | ], 1838 | "source": [ 1839 | "for run in range(runs):\n", 1840 | " learn = unet_learner(dbunch, model, pretrained=False, opt_func=opt_func, metrics=[], loss_func=MSELoss())\n", 1841 | "\n", 1842 | " if fp16: learn = learn.to_fp16()\n", 1843 | " cbs = []\n", 1844 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 1845 | "\n", 1846 | " #Save model backbone\n", 1847 | " torch.save(learn.model[0].state_dict(), 'partial_train_pretext_{}.pth'.format(run))\n", 1848 | " \n", 1849 | " del learn\n", 1850 | " torch.cuda.empty_cache() \n", 1851 | " gc.collect()" 1852 | ] 1853 | }, 1854 | { 1855 | "cell_type": "markdown", 1856 | "metadata": {}, 1857 | "source": [ 1858 | "# Downstream Task: Image网" 1859 | ] 1860 | }, 1861 | { 1862 | "cell_type": "markdown", 1863 | "metadata": {}, 1864 | "source": [ 1865 | "Now that we've trained models on our pretext tasks, let's compare the performance of each model against one another." 1866 | ] 1867 | }, 1868 | { 1869 | "cell_type": "code", 1870 | "execution_count": 12, 1871 | "metadata": {}, 1872 | "outputs": [], 1873 | "source": [ 1874 | "def get_dbunch(size, bs, sh=0., workers=None):\n", 1875 | " if size<=224: \n", 1876 | " path = URLs.IMAGEWANG_160\n", 1877 | " else: \n", 1878 | " path = URLs.IMAGEWANG\n", 1879 | " source = untar_data(path)\n", 1880 | " if workers is None: workers = min(8, num_cpus())\n", 1881 | " item_tfms=[RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]\n", 1882 | " batch_tfms=RandomErasing(p=0.9, max_count=3, sh=sh) if sh else None\n", 1883 | " \n", 1884 | " dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),\n", 1885 | " splitter=GrandparentSplitter(valid_name='val'),\n", 1886 | " get_items=get_image_files, get_y=parent_label,\n", 1887 | " item_tfms=item_tfms, batch_tfms=batch_tfms)\n", 1888 | " \n", 1889 | " return dblock.dataloaders(source, path=source, bs=bs, num_workers=workers)" 1890 | ] 1891 | }, 1892 | { 1893 | "cell_type": "code", 1894 | "execution_count": 14, 1895 | "metadata": {}, 1896 | "outputs": [], 1897 | "source": [ 1898 | "dbunch = get_dbunch(size, bs)" 1899 | ] 1900 | }, 1901 | { 1902 | "cell_type": "markdown", 1903 | "metadata": {}, 1904 | "source": [ 1905 | "## Random Baseline" 1906 | ] 1907 | }, 1908 | { 1909 | "cell_type": "code", 1910 | "execution_count": 15, 1911 | "metadata": {}, 1912 | "outputs": [ 1913 | { 1914 | "name": "stdout", 1915 | "output_type": "stream", 1916 | "text": [ 1917 | "Run: 0\n" 1918 | ] 1919 | }, 1920 | { 1921 | "data": { 1922 | "text/html": [ 1923 | "\n", 1924 | " \n", 1925 | " \n", 1926 | " \n", 1927 | " \n", 1928 | " \n", 1929 | " \n", 1930 | " \n", 1931 | " \n", 1932 | " \n", 1933 | " \n", 1934 | " \n", 1935 | " \n", 1936 | " \n", 1937 | " \n", 1938 | " \n", 1939 | " \n", 1940 | " \n", 1941 | " \n", 1942 | " \n", 1943 | " \n", 1944 | " \n", 1945 | " \n", 1946 | " \n", 1947 | " \n", 1948 | " \n", 1949 | " \n", 1950 | " \n", 1951 | " \n", 1952 | " \n", 1953 | " \n", 1954 | " \n", 1955 | " \n", 1956 | " \n", 1957 | " \n", 1958 | " \n", 1959 | " \n", 1960 | " \n", 1961 | " \n", 1962 | " \n", 1963 | " \n", 1964 | " \n", 1965 | " \n", 1966 | " \n", 1967 | " \n", 1968 | " \n", 1969 | " \n", 1970 | " \n", 1971 | " \n", 1972 | " \n", 1973 | " \n", 1974 | " \n", 1975 | " \n", 1976 | " \n", 1977 | " \n", 1978 | " \n", 1979 | " \n", 1980 | " \n", 1981 | " \n", 1982 | " \n", 1983 | " \n", 1984 | " \n", 1985 | " \n", 1986 | " \n", 1987 | " \n", 1988 | " \n", 1989 | " \n", 1990 | " \n", 1991 | " \n", 1992 | " \n", 1993 | " \n", 1994 | " \n", 1995 | " \n", 1996 | " \n", 1997 | " \n", 1998 | " \n", 1999 | " \n", 2000 | " \n", 2001 | " \n", 2002 | " \n", 2003 | " \n", 2004 | " \n", 2005 | " \n", 2006 | " \n", 2007 | " \n", 2008 | " \n", 2009 | " \n", 2010 | " \n", 2011 | " \n", 2012 | " \n", 2013 | " \n", 2014 | " \n", 2015 | " \n", 2016 | " \n", 2017 | " \n", 2018 | " \n", 2019 | " \n", 2020 | " \n", 2021 | " \n", 2022 | " \n", 2023 | " \n", 2024 | " \n", 2025 | " \n", 2026 | " \n", 2027 | " \n", 2028 | " \n", 2029 | " \n", 2030 | " \n", 2031 | " \n", 2032 | " \n", 2033 | " \n", 2034 | " \n", 2035 | " \n", 2036 | " \n", 2037 | " \n", 2038 | " \n", 2039 | " \n", 2040 | " \n", 2041 | " \n", 2042 | " \n", 2043 | " \n", 2044 | " \n", 2045 | " \n", 2046 | " \n", 2047 | " \n", 2048 | " \n", 2049 | " \n", 2050 | " \n", 2051 | " \n", 2052 | " \n", 2053 | " \n", 2054 | " \n", 2055 | " \n", 2056 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.7966873.6364460.0122170.29676800:18
11.5500673.1812290.0567570.46780400:18
21.4267394.1021860.0053450.19165200:18
31.3246693.0962110.1074060.50012700:18
41.2615783.2087630.0837360.47340300:18
51.2049702.5676710.2524820.73988300:18
61.1511862.3903530.3059300.77144300:18
71.1187672.6366420.2425550.69330600:18
81.1008042.6979140.2578260.69254300:18
91.0483022.6738720.2657160.69407000:18
101.0232412.5096430.3441080.75082700:18
111.0033342.1479380.4219900.84779800:18
120.9304201.9942790.4718760.87503200:18
130.8642671.9156070.5128530.87859500:18
140.8074401.8744880.5237970.88648500:19
" 2057 | ], 2058 | "text/plain": [ 2059 | "" 2060 | ] 2061 | }, 2062 | "metadata": {}, 2063 | "output_type": "display_data" 2064 | }, 2065 | { 2066 | "name": "stdout", 2067 | "output_type": "stream", 2068 | "text": [ 2069 | "Run: 1\n" 2070 | ] 2071 | }, 2072 | { 2073 | "data": { 2074 | "text/html": [ 2075 | "\n", 2076 | " \n", 2077 | " \n", 2078 | " \n", 2079 | " \n", 2080 | " \n", 2081 | " \n", 2082 | " \n", 2083 | " \n", 2084 | " \n", 2085 | " \n", 2086 | " \n", 2087 | " \n", 2088 | " \n", 2089 | " \n", 2090 | " \n", 2091 | " \n", 2092 | " \n", 2093 | " \n", 2094 | " \n", 2095 | " \n", 2096 | " \n", 2097 | " \n", 2098 | " \n", 2099 | " \n", 2100 | " \n", 2101 | " \n", 2102 | " \n", 2103 | " \n", 2104 | " \n", 2105 | " \n", 2106 | " \n", 2107 | " \n", 2108 | " \n", 2109 | " \n", 2110 | " \n", 2111 | " \n", 2112 | " \n", 2113 | " \n", 2114 | " \n", 2115 | " \n", 2116 | " \n", 2117 | " \n", 2118 | " \n", 2119 | " \n", 2120 | " \n", 2121 | " \n", 2122 | " \n", 2123 | " \n", 2124 | " \n", 2125 | " \n", 2126 | " \n", 2127 | " \n", 2128 | " \n", 2129 | " \n", 2130 | " \n", 2131 | " \n", 2132 | " \n", 2133 | " \n", 2134 | " \n", 2135 | " \n", 2136 | " \n", 2137 | " \n", 2138 | " \n", 2139 | " \n", 2140 | " \n", 2141 | " \n", 2142 | " \n", 2143 | " \n", 2144 | " \n", 2145 | " \n", 2146 | " \n", 2147 | " \n", 2148 | " \n", 2149 | " \n", 2150 | " \n", 2151 | " \n", 2152 | " \n", 2153 | " \n", 2154 | " \n", 2155 | " \n", 2156 | " \n", 2157 | " \n", 2158 | " \n", 2159 | " \n", 2160 | " \n", 2161 | " \n", 2162 | " \n", 2163 | " \n", 2164 | " \n", 2165 | " \n", 2166 | " \n", 2167 | " \n", 2168 | " \n", 2169 | " \n", 2170 | " \n", 2171 | " \n", 2172 | " \n", 2173 | " \n", 2174 | " \n", 2175 | " \n", 2176 | " \n", 2177 | " \n", 2178 | " \n", 2179 | " \n", 2180 | " \n", 2181 | " \n", 2182 | " \n", 2183 | " \n", 2184 | " \n", 2185 | " \n", 2186 | " \n", 2187 | " \n", 2188 | " \n", 2189 | " \n", 2190 | " \n", 2191 | " \n", 2192 | " \n", 2193 | " \n", 2194 | " \n", 2195 | " \n", 2196 | " \n", 2197 | " \n", 2198 | " \n", 2199 | " \n", 2200 | " \n", 2201 | " \n", 2202 | " \n", 2203 | " \n", 2204 | " \n", 2205 | " \n", 2206 | " \n", 2207 | " \n", 2208 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.7858603.4575670.0302880.33774500:18
11.5373493.5157960.0208700.39297500:18
21.4180702.8667190.1030800.63018600:19
31.3275883.0352410.1142780.56630200:19
41.2406703.1187590.1323490.65640100:18
51.2390332.4982540.2435730.76609800:19
61.1617902.4398760.2873500.78060600:19
71.1112432.6155220.2550270.70272300:18
81.0718622.7761830.2361920.68948800:18
91.0478922.6403350.2669890.68999700:18
101.0252562.8889090.2031050.70272300:18
111.0068022.1472260.4158820.84423500:19
120.9444832.0186850.4558410.85568800:17
130.8623981.9393110.4993640.88012200:18
140.8175261.8983630.5148890.88750300:18
" 2209 | ], 2210 | "text/plain": [ 2211 | "" 2212 | ] 2213 | }, 2214 | "metadata": {}, 2215 | "output_type": "display_data" 2216 | }, 2217 | { 2218 | "name": "stdout", 2219 | "output_type": "stream", 2220 | "text": [ 2221 | "Run: 2\n" 2222 | ] 2223 | }, 2224 | { 2225 | "data": { 2226 | "text/html": [ 2227 | "\n", 2228 | " \n", 2229 | " \n", 2230 | " \n", 2231 | " \n", 2232 | " \n", 2233 | " \n", 2234 | " \n", 2235 | " \n", 2236 | " \n", 2237 | " \n", 2238 | " \n", 2239 | " \n", 2240 | " \n", 2241 | " \n", 2242 | " \n", 2243 | " \n", 2244 | " \n", 2245 | " \n", 2246 | " \n", 2247 | " \n", 2248 | " \n", 2249 | " \n", 2250 | " \n", 2251 | " \n", 2252 | " \n", 2253 | " \n", 2254 | " \n", 2255 | " \n", 2256 | " \n", 2257 | " \n", 2258 | " \n", 2259 | " \n", 2260 | " \n", 2261 | " \n", 2262 | " \n", 2263 | " \n", 2264 | " \n", 2265 | " \n", 2266 | " \n", 2267 | " \n", 2268 | " \n", 2269 | " \n", 2270 | " \n", 2271 | " \n", 2272 | " \n", 2273 | " \n", 2274 | " \n", 2275 | " \n", 2276 | " \n", 2277 | " \n", 2278 | " \n", 2279 | " \n", 2280 | " \n", 2281 | " \n", 2282 | " \n", 2283 | " \n", 2284 | " \n", 2285 | " \n", 2286 | " \n", 2287 | " \n", 2288 | " \n", 2289 | " \n", 2290 | " \n", 2291 | " \n", 2292 | " \n", 2293 | " \n", 2294 | " \n", 2295 | " \n", 2296 | " \n", 2297 | " \n", 2298 | " \n", 2299 | " \n", 2300 | " \n", 2301 | " \n", 2302 | " \n", 2303 | " \n", 2304 | " \n", 2305 | " \n", 2306 | " \n", 2307 | " \n", 2308 | " \n", 2309 | " \n", 2310 | " \n", 2311 | " \n", 2312 | " \n", 2313 | " \n", 2314 | " \n", 2315 | " \n", 2316 | " \n", 2317 | " \n", 2318 | " \n", 2319 | " \n", 2320 | " \n", 2321 | " \n", 2322 | " \n", 2323 | " \n", 2324 | " \n", 2325 | " \n", 2326 | " \n", 2327 | " \n", 2328 | " \n", 2329 | " \n", 2330 | " \n", 2331 | " \n", 2332 | " \n", 2333 | " \n", 2334 | " \n", 2335 | " \n", 2336 | " \n", 2337 | " \n", 2338 | " \n", 2339 | " \n", 2340 | " \n", 2341 | " \n", 2342 | " \n", 2343 | " \n", 2344 | " \n", 2345 | " \n", 2346 | " \n", 2347 | " \n", 2348 | " \n", 2349 | " \n", 2350 | " \n", 2351 | " \n", 2352 | " \n", 2353 | " \n", 2354 | " \n", 2355 | " \n", 2356 | " \n", 2357 | " \n", 2358 | " \n", 2359 | " \n", 2360 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.7877733.5167800.0076360.36625100:18
11.5679643.5884760.0292700.30643900:17
21.4301033.2784020.0761010.52634300:18
31.3580133.4610020.0534490.43318900:18
41.2653563.0489520.1193690.54161400:18
51.1959572.2952390.3405450.79155000:17
61.1588752.4544800.2865870.75744500:17
71.1281012.1982370.3738860.82565500:18
81.0662332.8452310.2191400.64978400:19
91.0521862.8751700.2402650.62305900:18
101.0259883.0820220.1969970.59099000:18
110.9981132.1816390.4079920.82031100:18
120.9404522.1459260.4224990.84347200:19
130.8640491.9390760.5016540.87630400:18
140.8219201.8712610.5296510.88750300:18
" 2361 | ], 2362 | "text/plain": [ 2363 | "" 2364 | ] 2365 | }, 2366 | "metadata": {}, 2367 | "output_type": "display_data" 2368 | } 2369 | ], 2370 | "source": [ 2371 | "for run in range(runs):\n", 2372 | " print(f'Run: {run}')\n", 2373 | " learn = Learner(dbunch, model(c_out=20, pretrained=False, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \\\n", 2374 | " metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())\n", 2375 | "\n", 2376 | " if fp16: learn = learn.to_fp16()\n", 2377 | " cbs = []\n", 2378 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 2379 | " \n", 2380 | " del learn\n", 2381 | " torch.cuda.empty_cache() \n", 2382 | " gc.collect() " 2383 | ] 2384 | }, 2385 | { 2386 | "cell_type": "markdown", 2387 | "metadata": {}, 2388 | "source": [ 2389 | "Results:\n", 2390 | "- Run 1: 0.523797\n", 2391 | "- Run 2: 0.514889\n", 2392 | "- Run 3: 0.529651\n", 2393 | "\n", 2394 | "Average: **52.3%**\n" 2395 | ] 2396 | }, 2397 | { 2398 | "cell_type": "markdown", 2399 | "metadata": {}, 2400 | "source": [ 2401 | "## All data in `/train`, `/unsup` and `/val`" 2402 | ] 2403 | }, 2404 | { 2405 | "cell_type": "code", 2406 | "execution_count": 16, 2407 | "metadata": {}, 2408 | "outputs": [ 2409 | { 2410 | "name": "stdout", 2411 | "output_type": "stream", 2412 | "text": [ 2413 | "Run: 0\n" 2414 | ] 2415 | }, 2416 | { 2417 | "data": { 2418 | "text/html": [ 2419 | "\n", 2420 | " \n", 2421 | " \n", 2422 | " \n", 2423 | " \n", 2424 | " \n", 2425 | " \n", 2426 | " \n", 2427 | " \n", 2428 | " \n", 2429 | " \n", 2430 | " \n", 2431 | " \n", 2432 | " \n", 2433 | " \n", 2434 | " \n", 2435 | " \n", 2436 | " \n", 2437 | " \n", 2438 | " \n", 2439 | " \n", 2440 | " \n", 2441 | " \n", 2442 | " \n", 2443 | " \n", 2444 | " \n", 2445 | " \n", 2446 | " \n", 2447 | " \n", 2448 | " \n", 2449 | " \n", 2450 | " \n", 2451 | " \n", 2452 | " \n", 2453 | " \n", 2454 | " \n", 2455 | " \n", 2456 | " \n", 2457 | " \n", 2458 | " \n", 2459 | " \n", 2460 | " \n", 2461 | " \n", 2462 | " \n", 2463 | " \n", 2464 | " \n", 2465 | " \n", 2466 | " \n", 2467 | " \n", 2468 | " \n", 2469 | " \n", 2470 | " \n", 2471 | " \n", 2472 | " \n", 2473 | " \n", 2474 | " \n", 2475 | " \n", 2476 | " \n", 2477 | " \n", 2478 | " \n", 2479 | " \n", 2480 | " \n", 2481 | " \n", 2482 | " \n", 2483 | " \n", 2484 | " \n", 2485 | " \n", 2486 | " \n", 2487 | " \n", 2488 | " \n", 2489 | " \n", 2490 | " \n", 2491 | " \n", 2492 | " \n", 2493 | " \n", 2494 | " \n", 2495 | " \n", 2496 | " \n", 2497 | " \n", 2498 | " \n", 2499 | " \n", 2500 | " \n", 2501 | " \n", 2502 | " \n", 2503 | " \n", 2504 | " \n", 2505 | " \n", 2506 | " \n", 2507 | " \n", 2508 | " \n", 2509 | " \n", 2510 | " \n", 2511 | " \n", 2512 | " \n", 2513 | " \n", 2514 | " \n", 2515 | " \n", 2516 | " \n", 2517 | " \n", 2518 | " \n", 2519 | " \n", 2520 | " \n", 2521 | " \n", 2522 | " \n", 2523 | " \n", 2524 | " \n", 2525 | " \n", 2526 | " \n", 2527 | " \n", 2528 | " \n", 2529 | " \n", 2530 | " \n", 2531 | " \n", 2532 | " \n", 2533 | " \n", 2534 | " \n", 2535 | " \n", 2536 | " \n", 2537 | " \n", 2538 | " \n", 2539 | " \n", 2540 | " \n", 2541 | " \n", 2542 | " \n", 2543 | " \n", 2544 | " \n", 2545 | " \n", 2546 | " \n", 2547 | " \n", 2548 | " \n", 2549 | " \n", 2550 | " \n", 2551 | " \n", 2552 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.6257743.0959020.0786460.48969200:17
11.4442912.8378940.1277680.62865900:18
21.3233503.9419190.0422500.54288600:18
31.2675783.1000820.1644180.58844500:18
41.1908202.6585770.2270300.71417700:18
51.1464192.3361390.3285820.79511300:19
61.1188362.8502870.1891070.71214100:19
71.0784082.1454680.4072280.84703500:19
81.0480372.5268600.2766610.74191900:19
91.0513192.3740880.3158560.79740400:19
100.9996482.1334270.4062100.82972800:19
110.9781071.8973810.4840930.88953900:19
120.9066502.0318210.4614410.87146900:18
130.8366701.8646370.5237970.88674000:18
140.8008431.7441350.5670650.90633700:19
" 2553 | ], 2554 | "text/plain": [ 2555 | "" 2556 | ] 2557 | }, 2558 | "metadata": {}, 2559 | "output_type": "display_data" 2560 | }, 2561 | { 2562 | "name": "stdout", 2563 | "output_type": "stream", 2564 | "text": [ 2565 | "Run: 1\n" 2566 | ] 2567 | }, 2568 | { 2569 | "data": { 2570 | "text/html": [ 2571 | "\n", 2572 | " \n", 2573 | " \n", 2574 | " \n", 2575 | " \n", 2576 | " \n", 2577 | " \n", 2578 | " \n", 2579 | " \n", 2580 | " \n", 2581 | " \n", 2582 | " \n", 2583 | " \n", 2584 | " \n", 2585 | " \n", 2586 | " \n", 2587 | " \n", 2588 | " \n", 2589 | " \n", 2590 | " \n", 2591 | " \n", 2592 | " \n", 2593 | " \n", 2594 | " \n", 2595 | " \n", 2596 | " \n", 2597 | " \n", 2598 | " \n", 2599 | " \n", 2600 | " \n", 2601 | " \n", 2602 | " \n", 2603 | " \n", 2604 | " \n", 2605 | " \n", 2606 | " \n", 2607 | " \n", 2608 | " \n", 2609 | " \n", 2610 | " \n", 2611 | " \n", 2612 | " \n", 2613 | " \n", 2614 | " \n", 2615 | " \n", 2616 | " \n", 2617 | " \n", 2618 | " \n", 2619 | " \n", 2620 | " \n", 2621 | " \n", 2622 | " \n", 2623 | " \n", 2624 | " \n", 2625 | " \n", 2626 | " \n", 2627 | " \n", 2628 | " \n", 2629 | " \n", 2630 | " \n", 2631 | " \n", 2632 | " \n", 2633 | " \n", 2634 | " \n", 2635 | " \n", 2636 | " \n", 2637 | " \n", 2638 | " \n", 2639 | " \n", 2640 | " \n", 2641 | " \n", 2642 | " \n", 2643 | " \n", 2644 | " \n", 2645 | " \n", 2646 | " \n", 2647 | " \n", 2648 | " \n", 2649 | " \n", 2650 | " \n", 2651 | " \n", 2652 | " \n", 2653 | " \n", 2654 | " \n", 2655 | " \n", 2656 | " \n", 2657 | " \n", 2658 | " \n", 2659 | " \n", 2660 | " \n", 2661 | " \n", 2662 | " \n", 2663 | " \n", 2664 | " \n", 2665 | " \n", 2666 | " \n", 2667 | " \n", 2668 | " \n", 2669 | " \n", 2670 | " \n", 2671 | " \n", 2672 | " \n", 2673 | " \n", 2674 | " \n", 2675 | " \n", 2676 | " \n", 2677 | " \n", 2678 | " \n", 2679 | " \n", 2680 | " \n", 2681 | " \n", 2682 | " \n", 2683 | " \n", 2684 | " \n", 2685 | " \n", 2686 | " \n", 2687 | " \n", 2688 | " \n", 2689 | " \n", 2690 | " \n", 2691 | " \n", 2692 | " \n", 2693 | " \n", 2694 | " \n", 2695 | " \n", 2696 | " \n", 2697 | " \n", 2698 | " \n", 2699 | " \n", 2700 | " \n", 2701 | " \n", 2702 | " \n", 2703 | " \n", 2704 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.5689263.2014400.0389410.48103800:19
11.4301293.0791700.0834820.52048900:18
21.3149563.4553480.0478490.50114500:18
31.2378333.6308590.0628660.41002800:18
41.2023483.7065860.1188600.51998000:18
51.1471392.6493550.2377200.72741200:19
61.1089602.3394260.3107660.81343900:18
71.0679042.2633190.3723590.82871000:18
81.0421592.0217850.4441330.87503200:19
91.0150932.6711520.2631710.72333900:19
100.9974222.6215750.2545180.75769900:18
110.9794422.1908450.4344620.83787200:18
120.9123221.9776280.4907100.86205100:18
130.8338661.7819700.5576480.90099300:19
140.7896751.7604210.5673200.90124700:18
" 2705 | ], 2706 | "text/plain": [ 2707 | "" 2708 | ] 2709 | }, 2710 | "metadata": {}, 2711 | "output_type": "display_data" 2712 | }, 2713 | { 2714 | "name": "stdout", 2715 | "output_type": "stream", 2716 | "text": [ 2717 | "Run: 2\n" 2718 | ] 2719 | }, 2720 | { 2721 | "data": { 2722 | "text/html": [ 2723 | "\n", 2724 | " \n", 2725 | " \n", 2726 | " \n", 2727 | " \n", 2728 | " \n", 2729 | " \n", 2730 | " \n", 2731 | " \n", 2732 | " \n", 2733 | " \n", 2734 | " \n", 2735 | " \n", 2736 | " \n", 2737 | " \n", 2738 | " \n", 2739 | " \n", 2740 | " \n", 2741 | " \n", 2742 | " \n", 2743 | " \n", 2744 | " \n", 2745 | " \n", 2746 | " \n", 2747 | " \n", 2748 | " \n", 2749 | " \n", 2750 | " \n", 2751 | " \n", 2752 | " \n", 2753 | " \n", 2754 | " \n", 2755 | " \n", 2756 | " \n", 2757 | " \n", 2758 | " \n", 2759 | " \n", 2760 | " \n", 2761 | " \n", 2762 | " \n", 2763 | " \n", 2764 | " \n", 2765 | " \n", 2766 | " \n", 2767 | " \n", 2768 | " \n", 2769 | " \n", 2770 | " \n", 2771 | " \n", 2772 | " \n", 2773 | " \n", 2774 | " \n", 2775 | " \n", 2776 | " \n", 2777 | " \n", 2778 | " \n", 2779 | " \n", 2780 | " \n", 2781 | " \n", 2782 | " \n", 2783 | " \n", 2784 | " \n", 2785 | " \n", 2786 | " \n", 2787 | " \n", 2788 | " \n", 2789 | " \n", 2790 | " \n", 2791 | " \n", 2792 | " \n", 2793 | " \n", 2794 | " \n", 2795 | " \n", 2796 | " \n", 2797 | " \n", 2798 | " \n", 2799 | " \n", 2800 | " \n", 2801 | " \n", 2802 | " \n", 2803 | " \n", 2804 | " \n", 2805 | " \n", 2806 | " \n", 2807 | " \n", 2808 | " \n", 2809 | " \n", 2810 | " \n", 2811 | " \n", 2812 | " \n", 2813 | " \n", 2814 | " \n", 2815 | " \n", 2816 | " \n", 2817 | " \n", 2818 | " \n", 2819 | " \n", 2820 | " \n", 2821 | " \n", 2822 | " \n", 2823 | " \n", 2824 | " \n", 2825 | " \n", 2826 | " \n", 2827 | " \n", 2828 | " \n", 2829 | " \n", 2830 | " \n", 2831 | " \n", 2832 | " \n", 2833 | " \n", 2834 | " \n", 2835 | " \n", 2836 | " \n", 2837 | " \n", 2838 | " \n", 2839 | " \n", 2840 | " \n", 2841 | " \n", 2842 | " \n", 2843 | " \n", 2844 | " \n", 2845 | " \n", 2846 | " \n", 2847 | " \n", 2848 | " \n", 2849 | " \n", 2850 | " \n", 2851 | " \n", 2852 | " \n", 2853 | " \n", 2854 | " \n", 2855 | " \n", 2856 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.5815553.0890380.0470860.46525800:18
11.4251983.2322370.0587940.47849300:18
21.3392472.7510870.1682360.62051400:18
31.2722822.6902380.2132860.66480000:18
41.2006653.4455630.0783910.57317400:18
51.1429992.4438230.2954950.75286300:18
61.1153492.2371380.3690510.80936600:18
71.0673072.6067950.2669890.72003100:19
81.0532322.2960610.3553070.79282300:18
91.0290872.1126650.4186820.85594300:18
100.9996952.2006570.3985750.79536800:18
110.9706002.0903510.4459150.83838100:18
120.9074202.0112560.4848560.84958000:18
130.8484901.8280820.5408500.89335700:18
140.7955421.7854250.5556120.89361200:18
" 2857 | ], 2858 | "text/plain": [ 2859 | "" 2860 | ] 2861 | }, 2862 | "metadata": {}, 2863 | "output_type": "display_data" 2864 | } 2865 | ], 2866 | "source": [ 2867 | "for run in range(runs):\n", 2868 | " print(f'Run: {run}')\n", 2869 | " learn = Learner(dbunch, model(c_out=20, pretrained=False, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \\\n", 2870 | " metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())\n", 2871 | "\n", 2872 | " if fp16: learn = learn.to_fp16()\n", 2873 | " cbs = []\n", 2874 | "\n", 2875 | " # Load weights generated from training on our pretext task\n", 2876 | " model_path = 'all_train_unsup_val_pretext_' + str(run) + '.pth'\n", 2877 | " state_dict = torch.load(model_path)\n", 2878 | " # HACK: If we don't have all of the parameters for our learner, we get an error\n", 2879 | " linear_layer = learn.model[-1]\n", 2880 | " state_dict['11.weight'] = linear_layer.weight\n", 2881 | " state_dict['11.bias'] = linear_layer.bias\n", 2882 | "\n", 2883 | " learn.model.load_state_dict(state_dict)\n", 2884 | "\n", 2885 | " learn.freeze()\n", 2886 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 2887 | "\n", 2888 | " del learn\n", 2889 | " torch.cuda.empty_cache() \n", 2890 | " gc.collect() " 2891 | ] 2892 | }, 2893 | { 2894 | "cell_type": "markdown", 2895 | "metadata": {}, 2896 | "source": [ 2897 | "Results:\n", 2898 | "- Run 1: 0.567065\n", 2899 | "- Run 2: 0.567320\n", 2900 | "- Run 3: 0.555612\n", 2901 | "\n", 2902 | "Average: **56.3%**\n" 2903 | ] 2904 | }, 2905 | { 2906 | "cell_type": "markdown", 2907 | "metadata": {}, 2908 | "source": [ 2909 | "## All data in `/train` and `/unsup`" 2910 | ] 2911 | }, 2912 | { 2913 | "cell_type": "code", 2914 | "execution_count": 17, 2915 | "metadata": {}, 2916 | "outputs": [ 2917 | { 2918 | "name": "stdout", 2919 | "output_type": "stream", 2920 | "text": [ 2921 | "Run: 0\n" 2922 | ] 2923 | }, 2924 | { 2925 | "data": { 2926 | "text/html": [ 2927 | "\n", 2928 | " \n", 2929 | " \n", 2930 | " \n", 2931 | " \n", 2932 | " \n", 2933 | " \n", 2934 | " \n", 2935 | " \n", 2936 | " \n", 2937 | " \n", 2938 | " \n", 2939 | " \n", 2940 | " \n", 2941 | " \n", 2942 | " \n", 2943 | " \n", 2944 | " \n", 2945 | " \n", 2946 | " \n", 2947 | " \n", 2948 | " \n", 2949 | " \n", 2950 | " \n", 2951 | " \n", 2952 | " \n", 2953 | " \n", 2954 | " \n", 2955 | " \n", 2956 | " \n", 2957 | " \n", 2958 | " \n", 2959 | " \n", 2960 | " \n", 2961 | " \n", 2962 | " \n", 2963 | " \n", 2964 | " \n", 2965 | " \n", 2966 | " \n", 2967 | " \n", 2968 | " \n", 2969 | " \n", 2970 | " \n", 2971 | " \n", 2972 | " \n", 2973 | " \n", 2974 | " \n", 2975 | " \n", 2976 | " \n", 2977 | " \n", 2978 | " \n", 2979 | " \n", 2980 | " \n", 2981 | " \n", 2982 | " \n", 2983 | " \n", 2984 | " \n", 2985 | " \n", 2986 | " \n", 2987 | " \n", 2988 | " \n", 2989 | " \n", 2990 | " \n", 2991 | " \n", 2992 | " \n", 2993 | " \n", 2994 | " \n", 2995 | " \n", 2996 | " \n", 2997 | " \n", 2998 | " \n", 2999 | " \n", 3000 | " \n", 3001 | " \n", 3002 | " \n", 3003 | " \n", 3004 | " \n", 3005 | " \n", 3006 | " \n", 3007 | " \n", 3008 | " \n", 3009 | " \n", 3010 | " \n", 3011 | " \n", 3012 | " \n", 3013 | " \n", 3014 | " \n", 3015 | " \n", 3016 | " \n", 3017 | " \n", 3018 | " \n", 3019 | " \n", 3020 | " \n", 3021 | " \n", 3022 | " \n", 3023 | " \n", 3024 | " \n", 3025 | " \n", 3026 | " \n", 3027 | " \n", 3028 | " \n", 3029 | " \n", 3030 | " \n", 3031 | " \n", 3032 | " \n", 3033 | " \n", 3034 | " \n", 3035 | " \n", 3036 | " \n", 3037 | " \n", 3038 | " \n", 3039 | " \n", 3040 | " \n", 3041 | " \n", 3042 | " \n", 3043 | " \n", 3044 | " \n", 3045 | " \n", 3046 | " \n", 3047 | " \n", 3048 | " \n", 3049 | " \n", 3050 | " \n", 3051 | " \n", 3052 | " \n", 3053 | " \n", 3054 | " \n", 3055 | " \n", 3056 | " \n", 3057 | " \n", 3058 | " \n", 3059 | " \n", 3060 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.6028503.2119280.0618480.52023400:17
11.4288632.7905000.1639090.61109700:18
21.3361363.6661060.0290150.35072500:17
31.2413022.8054840.1664550.61746000:18
41.1779332.8452490.1857980.64087600:18
51.1620772.4216940.2898960.76151700:18
61.1036012.2329950.3519980.80478500:18
71.0695772.2696750.3624330.84067200:18
81.0458622.4529790.3301090.73937400:18
91.0161412.5193020.3005850.73581100:18
101.0044042.6984580.2911680.69890600:18
110.9712251.9301250.4739120.87732200:18
120.9119821.8833060.5021630.88190400:18
130.8291621.7758150.5517940.89717500:18
140.7971531.7502550.5640110.90150200:18
" 3061 | ], 3062 | "text/plain": [ 3063 | "" 3064 | ] 3065 | }, 3066 | "metadata": {}, 3067 | "output_type": "display_data" 3068 | }, 3069 | { 3070 | "name": "stdout", 3071 | "output_type": "stream", 3072 | "text": [ 3073 | "Run: 1\n" 3074 | ] 3075 | }, 3076 | { 3077 | "data": { 3078 | "text/html": [ 3079 | "\n", 3080 | " \n", 3081 | " \n", 3082 | " \n", 3083 | " \n", 3084 | " \n", 3085 | " \n", 3086 | " \n", 3087 | " \n", 3088 | " \n", 3089 | " \n", 3090 | " \n", 3091 | " \n", 3092 | " \n", 3093 | " \n", 3094 | " \n", 3095 | " \n", 3096 | " \n", 3097 | " \n", 3098 | " \n", 3099 | " \n", 3100 | " \n", 3101 | " \n", 3102 | " \n", 3103 | " \n", 3104 | " \n", 3105 | " \n", 3106 | " \n", 3107 | " \n", 3108 | " \n", 3109 | " \n", 3110 | " \n", 3111 | " \n", 3112 | " \n", 3113 | " \n", 3114 | " \n", 3115 | " \n", 3116 | " \n", 3117 | " \n", 3118 | " \n", 3119 | " \n", 3120 | " \n", 3121 | " \n", 3122 | " \n", 3123 | " \n", 3124 | " \n", 3125 | " \n", 3126 | " \n", 3127 | " \n", 3128 | " \n", 3129 | " \n", 3130 | " \n", 3131 | " \n", 3132 | " \n", 3133 | " \n", 3134 | " \n", 3135 | " \n", 3136 | " \n", 3137 | " \n", 3138 | " \n", 3139 | " \n", 3140 | " \n", 3141 | " \n", 3142 | " \n", 3143 | " \n", 3144 | " \n", 3145 | " \n", 3146 | " \n", 3147 | " \n", 3148 | " \n", 3149 | " \n", 3150 | " \n", 3151 | " \n", 3152 | " \n", 3153 | " \n", 3154 | " \n", 3155 | " \n", 3156 | " \n", 3157 | " \n", 3158 | " \n", 3159 | " \n", 3160 | " \n", 3161 | " \n", 3162 | " \n", 3163 | " \n", 3164 | " \n", 3165 | " \n", 3166 | " \n", 3167 | " \n", 3168 | " \n", 3169 | " \n", 3170 | " \n", 3171 | " \n", 3172 | " \n", 3173 | " \n", 3174 | " \n", 3175 | " \n", 3176 | " \n", 3177 | " \n", 3178 | " \n", 3179 | " \n", 3180 | " \n", 3181 | " \n", 3182 | " \n", 3183 | " \n", 3184 | " \n", 3185 | " \n", 3186 | " \n", 3187 | " \n", 3188 | " \n", 3189 | " \n", 3190 | " \n", 3191 | " \n", 3192 | " \n", 3193 | " \n", 3194 | " \n", 3195 | " \n", 3196 | " \n", 3197 | " \n", 3198 | " \n", 3199 | " \n", 3200 | " \n", 3201 | " \n", 3202 | " \n", 3203 | " \n", 3204 | " \n", 3205 | " \n", 3206 | " \n", 3207 | " \n", 3208 | " \n", 3209 | " \n", 3210 | " \n", 3211 | " \n", 3212 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.6382953.4375570.0218890.34589000:18
11.4470402.9523830.1180960.65640100:18
21.3407242.8928930.1458390.61338800:18
31.2644873.1034020.1313310.50623600:18
41.2037614.0015340.0455590.27996900:17
51.1592792.3803090.3206920.78900500:18
61.1057782.5423870.2781880.76380800:18
71.0803902.5034950.3010940.74802800:19
81.0281602.7095620.2547720.67752600:18
91.0363372.1568380.3965390.82718200:19
101.0184572.1201520.4227540.82641900:19
110.9759782.1721450.4090100.82921900:19
120.8996631.9870480.4940190.85212500:18
130.8381591.8154430.5444130.90022900:18
140.7970201.7738570.5561210.90582800:18
" 3213 | ], 3214 | "text/plain": [ 3215 | "" 3216 | ] 3217 | }, 3218 | "metadata": {}, 3219 | "output_type": "display_data" 3220 | }, 3221 | { 3222 | "name": "stdout", 3223 | "output_type": "stream", 3224 | "text": [ 3225 | "Run: 2\n" 3226 | ] 3227 | }, 3228 | { 3229 | "data": { 3230 | "text/html": [ 3231 | "\n", 3232 | " \n", 3233 | " \n", 3234 | " \n", 3235 | " \n", 3236 | " \n", 3237 | " \n", 3238 | " \n", 3239 | " \n", 3240 | " \n", 3241 | " \n", 3242 | " \n", 3243 | " \n", 3244 | " \n", 3245 | " \n", 3246 | " \n", 3247 | " \n", 3248 | " \n", 3249 | " \n", 3250 | " \n", 3251 | " \n", 3252 | " \n", 3253 | " \n", 3254 | " \n", 3255 | " \n", 3256 | " \n", 3257 | " \n", 3258 | " \n", 3259 | " \n", 3260 | " \n", 3261 | " \n", 3262 | " \n", 3263 | " \n", 3264 | " \n", 3265 | " \n", 3266 | " \n", 3267 | " \n", 3268 | " \n", 3269 | " \n", 3270 | " \n", 3271 | " \n", 3272 | " \n", 3273 | " \n", 3274 | " \n", 3275 | " \n", 3276 | " \n", 3277 | " \n", 3278 | " \n", 3279 | " \n", 3280 | " \n", 3281 | " \n", 3282 | " \n", 3283 | " \n", 3284 | " \n", 3285 | " \n", 3286 | " \n", 3287 | " \n", 3288 | " \n", 3289 | " \n", 3290 | " \n", 3291 | " \n", 3292 | " \n", 3293 | " \n", 3294 | " \n", 3295 | " \n", 3296 | " \n", 3297 | " \n", 3298 | " \n", 3299 | " \n", 3300 | " \n", 3301 | " \n", 3302 | " \n", 3303 | " \n", 3304 | " \n", 3305 | " \n", 3306 | " \n", 3307 | " \n", 3308 | " \n", 3309 | " \n", 3310 | " \n", 3311 | " \n", 3312 | " \n", 3313 | " \n", 3314 | " \n", 3315 | " \n", 3316 | " \n", 3317 | " \n", 3318 | " \n", 3319 | " \n", 3320 | " \n", 3321 | " \n", 3322 | " \n", 3323 | " \n", 3324 | " \n", 3325 | " \n", 3326 | " \n", 3327 | " \n", 3328 | " \n", 3329 | " \n", 3330 | " \n", 3331 | " \n", 3332 | " \n", 3333 | " \n", 3334 | " \n", 3335 | " \n", 3336 | " \n", 3337 | " \n", 3338 | " \n", 3339 | " \n", 3340 | " \n", 3341 | " \n", 3342 | " \n", 3343 | " \n", 3344 | " \n", 3345 | " \n", 3346 | " \n", 3347 | " \n", 3348 | " \n", 3349 | " \n", 3350 | " \n", 3351 | " \n", 3352 | " \n", 3353 | " \n", 3354 | " \n", 3355 | " \n", 3356 | " \n", 3357 | " \n", 3358 | " \n", 3359 | " \n", 3360 | " \n", 3361 | " \n", 3362 | " \n", 3363 | " \n", 3364 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.6371823.3371580.0394500.44489700:18
11.4399492.8758550.1145330.58971800:19
21.3331742.8605750.1527110.58844500:19
31.2390652.4086610.2771700.76864300:18
41.2239652.9960970.1432930.57750100:18
51.1620422.4239230.2812420.78976800:18
61.1139722.2104070.3624330.83278200:18
71.0912372.0872660.4230080.84321700:18
81.0470782.5467930.2743700.71290400:18
91.0142502.4014780.3013490.81089300:18
101.0042902.9889670.2257570.64520200:18
110.9876002.0085820.4744210.85670700:18
120.9234781.9598920.4866380.86536000:18
130.8520611.8312910.5444130.88852100:18
140.8052991.7744600.5645200.89921100:18
" 3365 | ], 3366 | "text/plain": [ 3367 | "" 3368 | ] 3369 | }, 3370 | "metadata": {}, 3371 | "output_type": "display_data" 3372 | } 3373 | ], 3374 | "source": [ 3375 | "for run in range(runs):\n", 3376 | " print(f'Run: {run}')\n", 3377 | " learn = Learner(dbunch, model(c_out=20, pretrained=False, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \\\n", 3378 | " metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())\n", 3379 | "\n", 3380 | " if fp16: learn = learn.to_fp16()\n", 3381 | " cbs = []\n", 3382 | "\n", 3383 | " # Load weights generated from training on our pretext task\n", 3384 | " model_path = 'all_train_unsup_pretext_' + str(run) + '.pth'\n", 3385 | " state_dict = torch.load(model_path)\n", 3386 | " # HACK: If we don't have all of the parameters for our learner, we get an error\n", 3387 | " linear_layer = learn.model[-1]\n", 3388 | " state_dict['11.weight'] = linear_layer.weight\n", 3389 | " state_dict['11.bias'] = linear_layer.bias\n", 3390 | "\n", 3391 | " learn.model.load_state_dict(state_dict)\n", 3392 | "\n", 3393 | " learn.freeze()\n", 3394 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 3395 | " \n", 3396 | " del learn\n", 3397 | " torch.cuda.empty_cache() \n", 3398 | " gc.collect() " 3399 | ] 3400 | }, 3401 | { 3402 | "cell_type": "markdown", 3403 | "metadata": {}, 3404 | "source": [ 3405 | "Results:\n", 3406 | "- Run 1: 0.564011\n", 3407 | "- Run 2: 0.556121\n", 3408 | "- Run 3: 0.564520\n", 3409 | "\n", 3410 | "Average: **56.2%**\n" 3411 | ] 3412 | }, 3413 | { 3414 | "cell_type": "markdown", 3415 | "metadata": {}, 3416 | "source": [ 3417 | "## All data in `/train`" 3418 | ] 3419 | }, 3420 | { 3421 | "cell_type": "code", 3422 | "execution_count": 18, 3423 | "metadata": {}, 3424 | "outputs": [ 3425 | { 3426 | "name": "stdout", 3427 | "output_type": "stream", 3428 | "text": [ 3429 | "Run: 0\n" 3430 | ] 3431 | }, 3432 | { 3433 | "data": { 3434 | "text/html": [ 3435 | "\n", 3436 | " \n", 3437 | " \n", 3438 | " \n", 3439 | " \n", 3440 | " \n", 3441 | " \n", 3442 | " \n", 3443 | " \n", 3444 | " \n", 3445 | " \n", 3446 | " \n", 3447 | " \n", 3448 | " \n", 3449 | " \n", 3450 | " \n", 3451 | " \n", 3452 | " \n", 3453 | " \n", 3454 | " \n", 3455 | " \n", 3456 | " \n", 3457 | " \n", 3458 | " \n", 3459 | " \n", 3460 | " \n", 3461 | " \n", 3462 | " \n", 3463 | " \n", 3464 | " \n", 3465 | " \n", 3466 | " \n", 3467 | " \n", 3468 | " \n", 3469 | " \n", 3470 | " \n", 3471 | " \n", 3472 | " \n", 3473 | " \n", 3474 | " \n", 3475 | " \n", 3476 | " \n", 3477 | " \n", 3478 | " \n", 3479 | " \n", 3480 | " \n", 3481 | " \n", 3482 | " \n", 3483 | " \n", 3484 | " \n", 3485 | " \n", 3486 | " \n", 3487 | " \n", 3488 | " \n", 3489 | " \n", 3490 | " \n", 3491 | " \n", 3492 | " \n", 3493 | " \n", 3494 | " \n", 3495 | " \n", 3496 | " \n", 3497 | " \n", 3498 | " \n", 3499 | " \n", 3500 | " \n", 3501 | " \n", 3502 | " \n", 3503 | " \n", 3504 | " \n", 3505 | " \n", 3506 | " \n", 3507 | " \n", 3508 | " \n", 3509 | " \n", 3510 | " \n", 3511 | " \n", 3512 | " \n", 3513 | " \n", 3514 | " \n", 3515 | " \n", 3516 | " \n", 3517 | " \n", 3518 | " \n", 3519 | " \n", 3520 | " \n", 3521 | " \n", 3522 | " \n", 3523 | " \n", 3524 | " \n", 3525 | " \n", 3526 | " \n", 3527 | " \n", 3528 | " \n", 3529 | " \n", 3530 | " \n", 3531 | " \n", 3532 | " \n", 3533 | " \n", 3534 | " \n", 3535 | " \n", 3536 | " \n", 3537 | " \n", 3538 | " \n", 3539 | " \n", 3540 | " \n", 3541 | " \n", 3542 | " \n", 3543 | " \n", 3544 | " \n", 3545 | " \n", 3546 | " \n", 3547 | " \n", 3548 | " \n", 3549 | " \n", 3550 | " \n", 3551 | " \n", 3552 | " \n", 3553 | " \n", 3554 | " \n", 3555 | " \n", 3556 | " \n", 3557 | " \n", 3558 | " \n", 3559 | " \n", 3560 | " \n", 3561 | " \n", 3562 | " \n", 3563 | " \n", 3564 | " \n", 3565 | " \n", 3566 | " \n", 3567 | " \n", 3568 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.6277933.3627360.0226520.57062900:19
11.4519933.4746790.0481040.40061100:19
21.3441813.0088030.1163150.54059600:19
31.2649433.0611570.1074060.56477500:19
41.2123863.2329840.1150420.48180200:18
51.1549602.2059710.3654870.83761800:19
61.1156932.2661360.3257830.80249400:19
71.0804172.6507320.2618990.72664800:18
81.0531032.7639430.2127770.72410300:19
91.0218552.3846330.3407990.78213300:19
101.0002202.8185200.2303390.63553100:19
110.9766862.2114830.4146090.81572900:19
120.9191761.8338790.5319420.89004800:19
130.8338641.8004080.5520490.89539300:18
140.8140291.7835750.5571390.90277400:19
" 3569 | ], 3570 | "text/plain": [ 3571 | "" 3572 | ] 3573 | }, 3574 | "metadata": {}, 3575 | "output_type": "display_data" 3576 | }, 3577 | { 3578 | "name": "stdout", 3579 | "output_type": "stream", 3580 | "text": [ 3581 | "Run: 1\n" 3582 | ] 3583 | }, 3584 | { 3585 | "data": { 3586 | "text/html": [ 3587 | "\n", 3588 | " \n", 3589 | " \n", 3590 | " \n", 3591 | " \n", 3592 | " \n", 3593 | " \n", 3594 | " \n", 3595 | " \n", 3596 | " \n", 3597 | " \n", 3598 | " \n", 3599 | " \n", 3600 | " \n", 3601 | " \n", 3602 | " \n", 3603 | " \n", 3604 | " \n", 3605 | " \n", 3606 | " \n", 3607 | " \n", 3608 | " \n", 3609 | " \n", 3610 | " \n", 3611 | " \n", 3612 | " \n", 3613 | " \n", 3614 | " \n", 3615 | " \n", 3616 | " \n", 3617 | " \n", 3618 | " \n", 3619 | " \n", 3620 | " \n", 3621 | " \n", 3622 | " \n", 3623 | " \n", 3624 | " \n", 3625 | " \n", 3626 | " \n", 3627 | " \n", 3628 | " \n", 3629 | " \n", 3630 | " \n", 3631 | " \n", 3632 | " \n", 3633 | " \n", 3634 | " \n", 3635 | " \n", 3636 | " \n", 3637 | " \n", 3638 | " \n", 3639 | " \n", 3640 | " \n", 3641 | " \n", 3642 | " \n", 3643 | " \n", 3644 | " \n", 3645 | " \n", 3646 | " \n", 3647 | " \n", 3648 | " \n", 3649 | " \n", 3650 | " \n", 3651 | " \n", 3652 | " \n", 3653 | " \n", 3654 | " \n", 3655 | " \n", 3656 | " \n", 3657 | " \n", 3658 | " \n", 3659 | " \n", 3660 | " \n", 3661 | " \n", 3662 | " \n", 3663 | " \n", 3664 | " \n", 3665 | " \n", 3666 | " \n", 3667 | " \n", 3668 | " \n", 3669 | " \n", 3670 | " \n", 3671 | " \n", 3672 | " \n", 3673 | " \n", 3674 | " \n", 3675 | " \n", 3676 | " \n", 3677 | " \n", 3678 | " \n", 3679 | " \n", 3680 | " \n", 3681 | " \n", 3682 | " \n", 3683 | " \n", 3684 | " \n", 3685 | " \n", 3686 | " \n", 3687 | " \n", 3688 | " \n", 3689 | " \n", 3690 | " \n", 3691 | " \n", 3692 | " \n", 3693 | " \n", 3694 | " \n", 3695 | " \n", 3696 | " \n", 3697 | " \n", 3698 | " \n", 3699 | " \n", 3700 | " \n", 3701 | " \n", 3702 | " \n", 3703 | " \n", 3704 | " \n", 3705 | " \n", 3706 | " \n", 3707 | " \n", 3708 | " \n", 3709 | " \n", 3710 | " \n", 3711 | " \n", 3712 | " \n", 3713 | " \n", 3714 | " \n", 3715 | " \n", 3716 | " \n", 3717 | " \n", 3718 | " \n", 3719 | " \n", 3720 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.6288273.3182130.0435230.44769700:18
11.4476262.7678230.1476200.63756700:19
21.3554822.8410800.1560190.59658900:18
31.2712042.9492380.1641640.57953700:19
41.2180512.3507540.3229830.75617200:19
51.1669772.3915980.3178930.76838900:19
61.1364292.3957360.3242560.76762500:19
71.0777982.5959340.2621530.70475900:19
81.0543192.5768590.2733520.70374100:19
91.0221322.6587960.2606260.69636000:18
101.0087742.5497930.3168750.73708300:18
110.9826862.0039190.4764570.86459700:18
120.9144501.8949430.5179440.87452300:19
130.8412791.7774620.5614660.89793800:19
140.8029851.7923790.5545940.89564800:19
" 3721 | ], 3722 | "text/plain": [ 3723 | "" 3724 | ] 3725 | }, 3726 | "metadata": {}, 3727 | "output_type": "display_data" 3728 | }, 3729 | { 3730 | "name": "stdout", 3731 | "output_type": "stream", 3732 | "text": [ 3733 | "Run: 2\n" 3734 | ] 3735 | }, 3736 | { 3737 | "data": { 3738 | "text/html": [ 3739 | "\n", 3740 | " \n", 3741 | " \n", 3742 | " \n", 3743 | " \n", 3744 | " \n", 3745 | " \n", 3746 | " \n", 3747 | " \n", 3748 | " \n", 3749 | " \n", 3750 | " \n", 3751 | " \n", 3752 | " \n", 3753 | " \n", 3754 | " \n", 3755 | " \n", 3756 | " \n", 3757 | " \n", 3758 | " \n", 3759 | " \n", 3760 | " \n", 3761 | " \n", 3762 | " \n", 3763 | " \n", 3764 | " \n", 3765 | " \n", 3766 | " \n", 3767 | " \n", 3768 | " \n", 3769 | " \n", 3770 | " \n", 3771 | " \n", 3772 | " \n", 3773 | " \n", 3774 | " \n", 3775 | " \n", 3776 | " \n", 3777 | " \n", 3778 | " \n", 3779 | " \n", 3780 | " \n", 3781 | " \n", 3782 | " \n", 3783 | " \n", 3784 | " \n", 3785 | " \n", 3786 | " \n", 3787 | " \n", 3788 | " \n", 3789 | " \n", 3790 | " \n", 3791 | " \n", 3792 | " \n", 3793 | " \n", 3794 | " \n", 3795 | " \n", 3796 | " \n", 3797 | " \n", 3798 | " \n", 3799 | " \n", 3800 | " \n", 3801 | " \n", 3802 | " \n", 3803 | " \n", 3804 | " \n", 3805 | " \n", 3806 | " \n", 3807 | " \n", 3808 | " \n", 3809 | " \n", 3810 | " \n", 3811 | " \n", 3812 | " \n", 3813 | " \n", 3814 | " \n", 3815 | " \n", 3816 | " \n", 3817 | " \n", 3818 | " \n", 3819 | " \n", 3820 | " \n", 3821 | " \n", 3822 | " \n", 3823 | " \n", 3824 | " \n", 3825 | " \n", 3826 | " \n", 3827 | " \n", 3828 | " \n", 3829 | " \n", 3830 | " \n", 3831 | " \n", 3832 | " \n", 3833 | " \n", 3834 | " \n", 3835 | " \n", 3836 | " \n", 3837 | " \n", 3838 | " \n", 3839 | " \n", 3840 | " \n", 3841 | " \n", 3842 | " \n", 3843 | " \n", 3844 | " \n", 3845 | " \n", 3846 | " \n", 3847 | " \n", 3848 | " \n", 3849 | " \n", 3850 | " \n", 3851 | " \n", 3852 | " \n", 3853 | " \n", 3854 | " \n", 3855 | " \n", 3856 | " \n", 3857 | " \n", 3858 | " \n", 3859 | " \n", 3860 | " \n", 3861 | " \n", 3862 | " \n", 3863 | " \n", 3864 | " \n", 3865 | " \n", 3866 | " \n", 3867 | " \n", 3868 | " \n", 3869 | " \n", 3870 | " \n", 3871 | " \n", 3872 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.6202823.3674760.0320690.37566800:18
11.4419243.0959980.0712650.54034100:19
21.3530133.0000890.1606010.54746800:19
31.2667112.8980400.1524560.67243600:19
41.2045632.6733550.2137950.65792800:18
51.1550932.3524020.3176380.78060600:18
61.1137022.0990770.3863580.85848800:18
71.0985392.1003490.4197000.84805300:19
81.0302182.4287870.3374900.75464500:19
91.0136812.9548270.2237210.67345400:18
100.9971922.6664050.2537540.70806800:18
110.9752692.0701580.4377700.85136200:18
120.9088411.9208640.5016540.87350500:18
130.8376991.8031320.5484860.89539300:18
140.7983811.7655940.5579030.90506500:18
" 3873 | ], 3874 | "text/plain": [ 3875 | "" 3876 | ] 3877 | }, 3878 | "metadata": {}, 3879 | "output_type": "display_data" 3880 | } 3881 | ], 3882 | "source": [ 3883 | "for run in range(runs):\n", 3884 | " print(f'Run: {run}')\n", 3885 | " learn = Learner(dbunch, model(c_out=20, pretrained=False, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \\\n", 3886 | " metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())\n", 3887 | "\n", 3888 | " if fp16: learn = learn.to_fp16()\n", 3889 | " cbs = []\n", 3890 | "\n", 3891 | " # Load weights generated from training on our pretext task\n", 3892 | " model_path = 'all_train_pretext_' + str(run) + '.pth'\n", 3893 | " state_dict = torch.load(model_path)\n", 3894 | " # HACK: If we don't have all of the parameters for our learner, we get an error\n", 3895 | " linear_layer = learn.model[-1]\n", 3896 | " state_dict['11.weight'] = linear_layer.weight\n", 3897 | " state_dict['11.bias'] = linear_layer.bias\n", 3898 | "\n", 3899 | " learn.model.load_state_dict(state_dict)\n", 3900 | "\n", 3901 | " learn.freeze()\n", 3902 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 3903 | "\n", 3904 | " del learn\n", 3905 | " torch.cuda.empty_cache() \n", 3906 | " gc.collect() \n" 3907 | ] 3908 | }, 3909 | { 3910 | "cell_type": "markdown", 3911 | "metadata": {}, 3912 | "source": [ 3913 | "Results:\n", 3914 | "- Run 1: 0.557139\n", 3915 | "- Run 2: 0.554594\n", 3916 | "- Run 3: 0.557903\n", 3917 | "\n", 3918 | "Average: **55.7%**\n" 3919 | ] 3920 | }, 3921 | { 3922 | "cell_type": "markdown", 3923 | "metadata": {}, 3924 | "source": [ 3925 | "## Partial data from `/train`" 3926 | ] 3927 | }, 3928 | { 3929 | "cell_type": "code", 3930 | "execution_count": 19, 3931 | "metadata": { 3932 | "scrolled": false 3933 | }, 3934 | "outputs": [ 3935 | { 3936 | "name": "stdout", 3937 | "output_type": "stream", 3938 | "text": [ 3939 | "Run: 0\n" 3940 | ] 3941 | }, 3942 | { 3943 | "data": { 3944 | "text/html": [ 3945 | "\n", 3946 | " \n", 3947 | " \n", 3948 | " \n", 3949 | " \n", 3950 | " \n", 3951 | " \n", 3952 | " \n", 3953 | " \n", 3954 | " \n", 3955 | " \n", 3956 | " \n", 3957 | " \n", 3958 | " \n", 3959 | " \n", 3960 | " \n", 3961 | " \n", 3962 | " \n", 3963 | " \n", 3964 | " \n", 3965 | " \n", 3966 | " \n", 3967 | " \n", 3968 | " \n", 3969 | " \n", 3970 | " \n", 3971 | " \n", 3972 | " \n", 3973 | " \n", 3974 | " \n", 3975 | " \n", 3976 | " \n", 3977 | " \n", 3978 | " \n", 3979 | " \n", 3980 | " \n", 3981 | " \n", 3982 | " \n", 3983 | " \n", 3984 | " \n", 3985 | " \n", 3986 | " \n", 3987 | " \n", 3988 | " \n", 3989 | " \n", 3990 | " \n", 3991 | " \n", 3992 | " \n", 3993 | " \n", 3994 | " \n", 3995 | " \n", 3996 | " \n", 3997 | " \n", 3998 | " \n", 3999 | " \n", 4000 | " \n", 4001 | " \n", 4002 | " \n", 4003 | " \n", 4004 | " \n", 4005 | " \n", 4006 | " \n", 4007 | " \n", 4008 | " \n", 4009 | " \n", 4010 | " \n", 4011 | " \n", 4012 | " \n", 4013 | " \n", 4014 | " \n", 4015 | " \n", 4016 | " \n", 4017 | " \n", 4018 | " \n", 4019 | " \n", 4020 | " \n", 4021 | " \n", 4022 | " \n", 4023 | " \n", 4024 | " \n", 4025 | " \n", 4026 | " \n", 4027 | " \n", 4028 | " \n", 4029 | " \n", 4030 | " \n", 4031 | " \n", 4032 | " \n", 4033 | " \n", 4034 | " \n", 4035 | " \n", 4036 | " \n", 4037 | " \n", 4038 | " \n", 4039 | " \n", 4040 | " \n", 4041 | " \n", 4042 | " \n", 4043 | " \n", 4044 | " \n", 4045 | " \n", 4046 | " \n", 4047 | " \n", 4048 | " \n", 4049 | " \n", 4050 | " \n", 4051 | " \n", 4052 | " \n", 4053 | " \n", 4054 | " \n", 4055 | " \n", 4056 | " \n", 4057 | " \n", 4058 | " \n", 4059 | " \n", 4060 | " \n", 4061 | " \n", 4062 | " \n", 4063 | " \n", 4064 | " \n", 4065 | " \n", 4066 | " \n", 4067 | " \n", 4068 | " \n", 4069 | " \n", 4070 | " \n", 4071 | " \n", 4072 | " \n", 4073 | " \n", 4074 | " \n", 4075 | " \n", 4076 | " \n", 4077 | " \n", 4078 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.7971303.7786690.0076360.34003600:18
11.5244883.3604740.0307970.46042200:18
21.4068903.5244800.0860270.56223000:19
31.3373263.1711070.0880630.44922400:19
41.2598282.5223740.2407740.71824900:19
51.2018982.2795330.3270550.80733000:18
61.1469012.4747580.2820060.78849600:18
71.1224962.5444700.2468820.73097500:18
81.0726123.0573130.1773990.55688500:18
91.0632442.6059190.2850600.75948100:18
101.0281912.5120520.2970220.75769900:18
111.0032042.1843870.3917030.82031100:18
120.9364902.0319550.4550780.85212500:18
130.8578101.8694830.5291420.88190400:18
140.8083241.8500180.5319420.89310300:18
" 4079 | ], 4080 | "text/plain": [ 4081 | "" 4082 | ] 4083 | }, 4084 | "metadata": {}, 4085 | "output_type": "display_data" 4086 | }, 4087 | { 4088 | "name": "stdout", 4089 | "output_type": "stream", 4090 | "text": [ 4091 | "Run: 1\n" 4092 | ] 4093 | }, 4094 | { 4095 | "data": { 4096 | "text/html": [ 4097 | "\n", 4098 | " \n", 4099 | " \n", 4100 | " \n", 4101 | " \n", 4102 | " \n", 4103 | " \n", 4104 | " \n", 4105 | " \n", 4106 | " \n", 4107 | " \n", 4108 | " \n", 4109 | " \n", 4110 | " \n", 4111 | " \n", 4112 | " \n", 4113 | " \n", 4114 | " \n", 4115 | " \n", 4116 | " \n", 4117 | " \n", 4118 | " \n", 4119 | " \n", 4120 | " \n", 4121 | " \n", 4122 | " \n", 4123 | " \n", 4124 | " \n", 4125 | " \n", 4126 | " \n", 4127 | " \n", 4128 | " \n", 4129 | " \n", 4130 | " \n", 4131 | " \n", 4132 | " \n", 4133 | " \n", 4134 | " \n", 4135 | " \n", 4136 | " \n", 4137 | " \n", 4138 | " \n", 4139 | " \n", 4140 | " \n", 4141 | " \n", 4142 | " \n", 4143 | " \n", 4144 | " \n", 4145 | " \n", 4146 | " \n", 4147 | " \n", 4148 | " \n", 4149 | " \n", 4150 | " \n", 4151 | " \n", 4152 | " \n", 4153 | " \n", 4154 | " \n", 4155 | " \n", 4156 | " \n", 4157 | " \n", 4158 | " \n", 4159 | " \n", 4160 | " \n", 4161 | " \n", 4162 | " \n", 4163 | " \n", 4164 | " \n", 4165 | " \n", 4166 | " \n", 4167 | " \n", 4168 | " \n", 4169 | " \n", 4170 | " \n", 4171 | " \n", 4172 | " \n", 4173 | " \n", 4174 | " \n", 4175 | " \n", 4176 | " \n", 4177 | " \n", 4178 | " \n", 4179 | " \n", 4180 | " \n", 4181 | " \n", 4182 | " \n", 4183 | " \n", 4184 | " \n", 4185 | " \n", 4186 | " \n", 4187 | " \n", 4188 | " \n", 4189 | " \n", 4190 | " \n", 4191 | " \n", 4192 | " \n", 4193 | " \n", 4194 | " \n", 4195 | " \n", 4196 | " \n", 4197 | " \n", 4198 | " \n", 4199 | " \n", 4200 | " \n", 4201 | " \n", 4202 | " \n", 4203 | " \n", 4204 | " \n", 4205 | " \n", 4206 | " \n", 4207 | " \n", 4208 | " \n", 4209 | " \n", 4210 | " \n", 4211 | " \n", 4212 | " \n", 4213 | " \n", 4214 | " \n", 4215 | " \n", 4216 | " \n", 4217 | " \n", 4218 | " \n", 4219 | " \n", 4220 | " \n", 4221 | " \n", 4222 | " \n", 4223 | " \n", 4224 | " \n", 4225 | " \n", 4226 | " \n", 4227 | " \n", 4228 | " \n", 4229 | " \n", 4230 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.7803233.4109280.0600660.55866600:18
11.5395623.1490610.0679560.47391200:18
21.4120053.0666400.0867910.50725400:18
31.3507273.6258880.0503950.33876300:19
41.2454073.1598990.1351490.53525100:18
51.2083712.6048920.2547720.75311800:18
61.1664242.3757580.3242560.76584400:18
71.1078682.2701390.3525070.79460400:18
81.0970992.4245990.3204380.75668100:18
91.0495742.2680820.3802490.81852900:18
101.0261682.4839290.3046580.75439000:18
110.9983692.1442630.4293710.83329100:18
120.9328112.0800140.4571140.84499900:19
130.8523011.9103140.5001270.87757700:18
140.8124251.8975080.5153980.88674000:18
" 4231 | ], 4232 | "text/plain": [ 4233 | "" 4234 | ] 4235 | }, 4236 | "metadata": {}, 4237 | "output_type": "display_data" 4238 | }, 4239 | { 4240 | "name": "stdout", 4241 | "output_type": "stream", 4242 | "text": [ 4243 | "Run: 2\n" 4244 | ] 4245 | }, 4246 | { 4247 | "data": { 4248 | "text/html": [ 4249 | "\n", 4250 | " \n", 4251 | " \n", 4252 | " \n", 4253 | " \n", 4254 | " \n", 4255 | " \n", 4256 | " \n", 4257 | " \n", 4258 | " \n", 4259 | " \n", 4260 | " \n", 4261 | " \n", 4262 | " \n", 4263 | " \n", 4264 | " \n", 4265 | " \n", 4266 | " \n", 4267 | " \n", 4268 | " \n", 4269 | " \n", 4270 | " \n", 4271 | " \n", 4272 | " \n", 4273 | " \n", 4274 | " \n", 4275 | " \n", 4276 | " \n", 4277 | " \n", 4278 | " \n", 4279 | " \n", 4280 | " \n", 4281 | " \n", 4282 | " \n", 4283 | " \n", 4284 | " \n", 4285 | " \n", 4286 | " \n", 4287 | " \n", 4288 | " \n", 4289 | " \n", 4290 | " \n", 4291 | " \n", 4292 | " \n", 4293 | " \n", 4294 | " \n", 4295 | " \n", 4296 | " \n", 4297 | " \n", 4298 | " \n", 4299 | " \n", 4300 | " \n", 4301 | " \n", 4302 | " \n", 4303 | " \n", 4304 | " \n", 4305 | " \n", 4306 | " \n", 4307 | " \n", 4308 | " \n", 4309 | " \n", 4310 | " \n", 4311 | " \n", 4312 | " \n", 4313 | " \n", 4314 | " \n", 4315 | " \n", 4316 | " \n", 4317 | " \n", 4318 | " \n", 4319 | " \n", 4320 | " \n", 4321 | " \n", 4322 | " \n", 4323 | " \n", 4324 | " \n", 4325 | " \n", 4326 | " \n", 4327 | " \n", 4328 | " \n", 4329 | " \n", 4330 | " \n", 4331 | " \n", 4332 | " \n", 4333 | " \n", 4334 | " \n", 4335 | " \n", 4336 | " \n", 4337 | " \n", 4338 | " \n", 4339 | " \n", 4340 | " \n", 4341 | " \n", 4342 | " \n", 4343 | " \n", 4344 | " \n", 4345 | " \n", 4346 | " \n", 4347 | " \n", 4348 | " \n", 4349 | " \n", 4350 | " \n", 4351 | " \n", 4352 | " \n", 4353 | " \n", 4354 | " \n", 4355 | " \n", 4356 | " \n", 4357 | " \n", 4358 | " \n", 4359 | " \n", 4360 | " \n", 4361 | " \n", 4362 | " \n", 4363 | " \n", 4364 | " \n", 4365 | " \n", 4366 | " \n", 4367 | " \n", 4368 | " \n", 4369 | " \n", 4370 | " \n", 4371 | " \n", 4372 | " \n", 4373 | " \n", 4374 | " \n", 4375 | " \n", 4376 | " \n", 4377 | " \n", 4378 | " \n", 4379 | " \n", 4380 | " \n", 4381 | " \n", 4382 | "
epochtrain_lossvalid_lossaccuracytop_k_accuracytime
01.7889613.7213980.0246880.39450200:19
11.5442013.6587490.0328330.46424000:18
21.4238483.3406560.0768640.48816500:18
31.3301472.8864020.1476200.59709800:18
41.2658483.3538010.1012980.59888000:19
51.1814822.5805870.2608810.75464500:19
61.1531282.5395720.2616440.77144300:19
71.1138992.3445620.3415630.78798700:18
81.0706942.7171440.2555360.68261600:18
91.0605043.0529060.1735810.63018600:18
101.0274052.4833550.3306180.75795400:18
110.9985072.2726110.4039200.79460400:18
120.9211362.1108250.4482060.84092600:18
130.8519841.8672430.5293970.88903000:18
140.8115031.8496990.5444130.89030300:18
" 4383 | ], 4384 | "text/plain": [ 4385 | "" 4386 | ] 4387 | }, 4388 | "metadata": {}, 4389 | "output_type": "display_data" 4390 | } 4391 | ], 4392 | "source": [ 4393 | "for run in range(runs):\n", 4394 | " print(f'Run: {run}')\n", 4395 | " learn = Learner(dbunch, model(c_out=20, pretrained=False, act_cls=torch.nn.ReLU, sa=sa, sym=sym, pool=pool), opt_func=opt_func, \\\n", 4396 | " metrics=[accuracy,top_k_accuracy], loss_func=LabelSmoothingCrossEntropy())\n", 4397 | "\n", 4398 | " if fp16: learn = learn.to_fp16()\n", 4399 | " cbs = []\n", 4400 | "\n", 4401 | " # Load weights generated from training on our pretext task\n", 4402 | " model_path = 'partial_train_pretext_' + str(run) + '.pth'\n", 4403 | " state_dict = torch.load(model_path)\n", 4404 | " # HACK: If we don't have all of the parameters for our learner, we get an error\n", 4405 | " linear_layer = learn.model[-1]\n", 4406 | " state_dict['11.weight'] = linear_layer.weight\n", 4407 | " state_dict['11.bias'] = linear_layer.bias\n", 4408 | "\n", 4409 | " learn.model.load_state_dict(state_dict)\n", 4410 | "\n", 4411 | " learn.freeze()\n", 4412 | " learn.fit_flat_cos(epochs, lr, wd=1e-2, cbs=cbs)\n", 4413 | "\n", 4414 | " del learn\n", 4415 | " torch.cuda.empty_cache() \n", 4416 | " gc.collect() " 4417 | ] 4418 | }, 4419 | { 4420 | "cell_type": "markdown", 4421 | "metadata": {}, 4422 | "source": [ 4423 | "Results:\n", 4424 | "- Run 1: 0.531942\n", 4425 | "- Run 2: 0.515398\n", 4426 | "- Run 3: 0.544413\n", 4427 | "\n", 4428 | "Average: **53.1%**\n" 4429 | ] 4430 | }, 4431 | { 4432 | "cell_type": "markdown", 4433 | "metadata": {}, 4434 | "source": [ 4435 | "## Results:" 4436 | ] 4437 | }, 4438 | { 4439 | "cell_type": "markdown", 4440 | "metadata": {}, 4441 | "source": [ 4442 | "- Random: **52.3%**\n", 4443 | "- Partial `/train`: **53.1%**\n", 4444 | "- All `/train`: **55.7%**\n", 4445 | "- All `/train` and `/unsup` : **56.2%**\n", 4446 | "- All `/train`,`/unsup` and `/val` : **56.3%**" 4447 | ] 4448 | }, 4449 | { 4450 | "cell_type": "code", 4451 | "execution_count": 26, 4452 | "metadata": {}, 4453 | "outputs": [ 4454 | { 4455 | "data": { 4456 | "text/plain": [ 4457 | "" 4458 | ] 4459 | }, 4460 | "execution_count": 26, 4461 | "metadata": {}, 4462 | "output_type": "execute_result" 4463 | }, 4464 | { 4465 | "data": { 4466 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3debxVdb3/8ddbBMUUcTh5ZVDUHHLK4WhpOZaiVkhqN9NKLDMroizt6m3yml2vcb0NVxv06nWolDQjNAuHzKE0OSSKQCRy9ccBUsQwNULBz++P73frYrPOYR08+4zv5+OxH3ut75o+a9jrs9f0XYoIzMzM6q3X3QGYmVnP5ARhZmalnCDMzKyUE4SZmZVygjAzs1JOEGZmVsoJYh1IOlnSbV08zU9KekrSC5K26Mpp9yaSrpJ0QXfHUSNpnKT7ujuO/kjSYEk3S3pO0g3dHU9v5ATRBknvkPT7vHE9K+l3kvYDiIgfR8SRXRjLQOC/gCMjYuOIWNrg6Y2SFJLW76TxPSHpXe10P1RSa2dMy9q2tvXQB50AbAVsERHv7+5geqNO2QH0NZKGALcAnwR+CgwCDgJWdFNIWwEbArM6Y2SS1o+IlZ0xLus7+tJ2IWkAsC3w53WZp760LF6XiPCn7gM0A8va6T4OuC83fxF4ofB5Gbgqd9sUuAJYDCwELgAGtDHODYBvA4vy59u5bCfgRSDy+H9TMuyo3P30POxi4AuF7ucBNwI/Av4GnEY6ejwHeBxYSkqEm+f+/19hei8AB+TyjwJzgL8CU4Ftc/mBwDPAyNz+FmAZsAtwLfAKsDyP64t1sb8hd3ulML1hwP7A/Xk8i4FLgEF5GAHfAp4GngMeAXbP3a4CLsjNmwB3Ad8FVDfdE4GWurIzgSm5+RhgNvB8XndnVdx2tgCm5OX8IPD12rZSWFbTctzTgANz+WHAzEJ/dwAPFtrvA8bm5ieAs/J8PwdMAjbM3bYk/blZBjwL3JvX9RrrobDdfCyv83vyON4G/D6P42Hg0EIcp+Zt4HlgPvCJQrdDgdY87qfzehubl+Wfczz/2s6yuwr4AXB7Hv/d5G0sd98ld3sWmAv8c92w3wduJf1efge8RPo9vpDncT3gy8CTOb5rgE3rfkOvLotC2anAAtJ2fwawX172y4BLCjHsAPyG9Ht6BvgxMLTQvc31lrsfC8wgbTuPA0d1dD/S6fvCrphIb/sAQ/JKvho4Gtisrvs4Cj/6QvlI0g76mNw+GfghaSf4RtIO4xNtTPN84IHcX1P+gX69buNdv41ha92vy9PaA1gCvCt3Py//UMbmH8lg4HN5eiNIieiHwHVtTS8POw94M+nI88vA7wvdv5F/HIPzD2B83Q/jXe0s70OB1rqyfUk7qvVzPHOAz+Vuo4HpwFBSsngzsHXudlX+AW2Rl/cFbUxzI9JOaMdC2TTgxNy8GDgoN28G7FNx27melGzfAOyef9C1PxObk3YyH87z9cHcvgXpCHE5aQe/PvCXvC1tkpfpctKpktryfJCUSDfPy+aM3O1C0k52YP4cRE6O9euhsJ6vyfEOBoaTtv1j8rZyRG5vysO8m7QjFHAI8PfassnrcSXw1Tztj5O2w5/k+dgN+AewfRvL7qq8Tg4mbZPfKSy7N5B20qfm5bMPaSe8W2HY54C357g3JG33PyqM/6OkbXh7YGPgJuDadpZFrewHeXxH5vgnk36nw0mJ5pA8jjfl5bUB6Td8D/Dtut9BW+tt/xz/ETn+4cAuHd2PdPq+sCsm0hs/pJ3OVaR/RCtJ/wq3yt3GUZcg8gY1HfiX3L4V6ZTU4EI/HwTuamN6j5MTS24fDTxRt/GuLUHsUij7JnBFbj6P/O+w0H0O8M5C+9akJLJ+2fSAXwEfK7SvR9o5bJvbB+b5nwn8msI/dtYhQZT08zng57n5cNI/0rcB69X1dxVwJfAocPZaxvkj4Ku5eUfSzmmj3P7/gE8AQzqwzQzIy7C4Hv6d13ZyH6ZwVJDL7gfG5eZ7gePyfN1GSjRHkY4uHqlbnh+qW9c/yM3nA78A3lQS32rrobCety+U/Qt5p1komwqc0sY8TwY+W1iPy8n/bklJIYC3FvqfTj4SKhnXVcD1hfaNgVWkP14fAO6t6/+HwNcKw15T1/08Vk8QdwKfKrTvzJrbfHFZ1MqGF8qWAh8otP+M/MelZH7GAg9VXG8/BL5VMo4O7Uc6++OL1G2IiDkRMS4iRpD+CQ4jnfZpyxXA3Ii4KLdvS9ppLpa0TNIy0kbwxjaGH0Y69K15Mpd1xIJ2hl9Q1++2wM8Lsc0h/Ri3amPc2wLfKfT/LOlf5HCAiHiZ9CPdHbg48pa8riTtJOkWSX+R9DfSjnbLPK3fkE45XQo8JemyfN2o5t2khP2DtUzmJ6QfG8BJwOSI+HtuP570L/pJSXdLOqBC2E2knU39eqipX8e17sNz892knezBufm3pH/ph+T2or8Umv9O2pkCTCT9S75N0nxJ51SIuxjvtsD7a+s5r+t3kP5AIOloSQ/kGzeWkZbRloXhl0bEqty8PH8/Vei+vBBru7FExAuk7WxYjuutdXGdDPxTG/NRpuw3tj6rb/Nl46iPv3R+JL1R0vWSFuZt9kesvmyg7fU2kvQnsV5H9yOdygmigoj4E6/t/NaQf4Q7k85f1iwgZf4tI2Jo/gyJiN3amMwi0sZQs00u64iR7Qxfv8NeABxdiG1oRGwYEQtL+q31/4m6/gdHxO8BJA0Hvgb8L3CxpA3amXa9su7fB/5EOgU0BPhXUkJKA0R8NyL2JZ222Ak4uzDs5aSjmFslvaGd6d4GbClpL1Ki+Elh/NMi4ljSD3Ey6d/82iwhHW3Wr4ea+nVc674wN9cniLtpO0GUiojnI+ILEbE98F7g85LeWevc1mCF5gWkI4jien5DRPxHXqc/A/6TdDQ9lHTOXyXjXFevLjtJG5NOxSzKcd1dF9fGEfHJNuajTNlvbCWr7/Bfzx+bC/Pwe+Zt9kNUXzYLSKfuyso7sh/pVE4QJSTtIukLkkbk9pGkHcgDJf0eDUwgHTbX/jEREYtJO6CLJQ2RtJ6kHSQd0sZkrwO+LKlJ0pak87g/6mDoX5G0kaTdSOdqJ7XT7w+Ab0jaNs9Hk6Rjc7clpAua29f1f24eN5I2lfT+3CxSAr2ClCQXky7O1jxVN656TwFbSNq0ULYJ6WLdC5J2Id1RRp7efpLemm//fZF0XngVqxtPupB5i6TBZRONdJfKjaR/3ZuTLoAiaVB+1mXTfGT0t5Lxl41vFem89nl5PewKnFLo5VZgJ0knSVpf0geAXUkXlSFdd9qZdD76wYiYRf7nTDqfvVaS3iPpTXmd1OKuxb629QBpm3uvpNGSBkjaMN+GPIJ0N98G5ESYt/3Ovt37mHyL+SDSNvSHiFhAWkY7SfqwpIH5s5+kN3dg3NcBZ0raLieffwcmRefdrbQJ6YL4svyH6ey19F90BXCqpHfmfcVwSbusw36kUzlBlHue9KP8g6QXSYnhUeALJf1+gHRqYY7SQ2wvSKqd2vgI6Uc1m3Qx8kbyoXqJC4AW0gXemcAfc1lH3E06vXAn8J8R0d7DfN8hXVe5TdLzpHl8K0A+zfIN4Hf5sPZtEfFz4CLg+nz4/CjpAj6kBLkV8JV8aulU0sZ+UO5+ISn5LZN0Vn0g+QjtOmB+7mcY6W6Pk0jr4nJWT3ZDctlfSacJlpL+1RbHGaS7uhYAv5C0YRvL4SfAu4Ab6nYUHwaeyPN6BunfIJK2yet4mzVHBaTEtDHpVMJVpCOqWkxLgfeQtqOlpLt93hMRz+TuL5LW+6yIeCkPdj/wZEQ83cb06u1IugPqhTzs9yLit7lbu+shx7CAdDfNv5ISwQLSjm69iHietK5/Slr2J5G2oc70E9KR6LOkGxVOznE9T0pGJ5KOBP5C2h43KB9NqStJd3PdA/wf6Y/FZzorcODfSBfPnwN+SfqzUElEPEj63XwrD383rx3tdGQ/0qlqdzdYLyZpFGmDH9iJ/4bMupSkq0g3K3y5u2OxxEcQZmZWygnCzMxKNTRBSDpK0lxJ88put1OqyGyJpBn5c1qh2zclzZI0R9J380U3KxERT0SEfHrJerN8W7lPL/UgDauLSakulEtJTwa2AtMkTYmI2XW9ToqI8XXDHkh6InLPXHQf6Va/3zYqXjMzW10jK+vbH5gXEfMBJF1PujuiPkGUCdKj7YNI9xEPZPV7ldew5ZZbxqhRo15PvGZm/c706dOfiYimsm6NTBDDWf2pxFbybZR1jpd0MKnqhDMjYkFE3C/pLtL99CJViDWnfkBJp5NuZWSbbbahpaWls+fBzKxPk1T/dP+rGnkNouyaQf09tTcDoyJiT9K921cDSHoTqS6kEaREc3hOIquPLOKyiGiOiOamptIEaGZm66iRCaKV1ascGEFd1RERsTQiau9YuJz0YAzA+4AHIuKFXB/Lr0gVmJmZWRdpZIKYBuyYH2sfRHoCcrWnLiUVnwYcQ6owDlJNmofk6ggGki5Qr3GKyczMGqdh1yAiYqWk8aSqggcAV0bELEnnk17UMgWYIGkMqcKsZ0nVaEN6lPxwUpUTAfw6Im5uVKxmZramPlPVRnNzc/gitZlZx0iaHhHNZd38JLWZmZVygjAzs1JOEGZmVsoJwszMSjlBmJlZKScIMzMr5QRhZmalnCDMzKyUE4SZmZVygjAzs1JOEGZmVsoJwszMSjlBmJlZKScIMzMr5QRhZmalnCDMzKyUE4SZmZVygjAzs1JOEGZmVsoJwszMSjU0QUg6StJcSfMknVPSfZykJZJm5M9pufywQtkMSf+QNLaRsZqZ2erWb9SIJQ0ALgWOAFqBaZKmRMTsul4nRcT4YkFE3AXslcezOTAPuK1RsZqZ2ZoaeQSxPzAvIuZHxEvA9cCx6zCeE4BfRcTfOzU6MzNrVyMTxHBgQaG9NZfVO17SI5JulDSypPuJwHVlE5B0uqQWSS1Llix5/RGbmdmrGpkgVFIWde03A6MiYk/gDuDq1UYgbQ3sAUwtm0BEXBYRzRHR3NTU1Akhm5lZTSMTRCtQPCIYASwq9hARSyNiRW69HNi3bhz/DPw8Il5uWJRmZlaqkQliGrCjpO0kDSKdKppS7CEfIdSMAebUjeODtHF6yczMGqthdzFFxEpJ40mnhwYAV0bELEnnAy0RMQWYIGkMsBJ4FhhXG17SKNIRyN2NitHMzNqmiPrLAr1Tc3NztLS0dHcYZma9iqTpEdFc1s1PUpuZWSknCDMzK+UEYWZmpZwgzMyslBOEmZmVcoIwM7NSThBmZlbKCcLMzEo5QZiZWSknCDMzK+UEYWZmpZwgzMysVMNqczUzs8aa/NBCJk6dy6Jlyxk2dDBnj96ZsXuXvbhz3ThBmJn1QpMfWsi5N81k+curAFi4bDnn3jQToNOShE8xmZn1QhOnzn01OdQsf3kVE6fO7bRp+AjCzIzGn67pbIuWLe9Q+brwEYSZ9Xu10zULly0neO10zeSHFnZ3aG0aNnRwh8rXhROEmfV7XXG6prOdPXpnBg8csFrZ4IEDOHv0zp02DZ9iMrN+rytO13S22ukv38VkZtZAw4YOZmFJMujM0zWNMHbv4Q29TtLQU0ySjpI0V9I8SeeUdB8naYmkGflzWqHbNpJukzRH0mxJoxoZq5n1X11xuqY3atgRhKQBwKXAEUArME3SlIiYXdfrpIgYXzKKa4BvRMTtkjYGXmlUrGbWv3XF6ZreaK0JQtLPgCuBX0VER3bS+wPzImJ+Hs/1wLFAfYIom+auwPoRcTtARLzQgemamXVYo0/X9EZVTjF9HzgJeEzSf0japeK4hwMLCu2tuaze8ZIekXSjpJG5bCdgmaSbJD0kaWI+IlmNpNMltUhqWbJkScWwzMysirUmiIi4IyJOBvYBngBul/R7SadKGtjOoCobXV37zcCoiNgTuAO4OpevDxwEnAXsB2wPjCuJ7bKIaI6I5qamprXNipmZdUCli9SStiDtoE8DHgK+Q0oYt7czWCswstA+AlhU7CEilkbEitx6ObBvYdiHImJ+RKwEJufpmZlZF1lrgpB0E3AvsBHw3ogYExGTIuIzwMbtDDoN2FHSdpIGAScCU+rGvXWhdQwwpzDsZpJqhwWHU+HahZmZdZ4qdzFdEhG/KesQEc1tDRQRKyWNB6YCA4ArI2KWpPOBloiYAkyQNAZYCTxLPo0UEasknQXcKUnAdNIRhpmZdRFF1F8WKOlJ2h3YFdiwVhYR1zQwrg5rbm6OlpaW7g7DzKxXkTS9rT/7VW5z/RpwKClB3AocDdxHek7BzBqst9Uyan1HlYvUJwDvBP4SEacCbwE2aGhUZgb0zlpGre+okiCW5wfkVkoaAjxNuu3UzBqsN9Yyan1HlYvULZKGki4STwdeAB5saFRmBvTOWkat72g3QeQ7iC6MiGXADyT9GhgSEY90SXRm/VxvrWXU+oZ2TzFFusVpcqH9CScHs67jWkatO1W5BvGApP0aHomZrWHs3sO58Lg9GD50MAKGDx3Mhcft4buYrEtUuQZxGPAJSU8CL5LqWIpcf5KZNZhrGbXuUiVBHN3wKMzMrMepkiDW/qi1mZn1OVUSxC9JSUKkqja2A+YCuzUwLjMz62ZrTRARsUexXdI+wCcaFpGZmfUIld4HURQRfyS9xMfMzPqwKpX1fb7Quh7pxT1+v6eZWR9X5RrEJoXmlaRrEj9rTDhmZtZTVLkG8W9dEYiZmfUsVV45enuurK/WvpmkqY0Ny8zMuluVi9RNubI+ACLir8AbGxeSmZn1BFUSxCpJ29RaJG2LH54zM+vzqlyk/hJwn6S7c/vBwOmNC8nMzHqCtR5BRMSvSbe2TgJ+CuwbEZWuQUg6StJcSfMknVPSfZykJZJm5M9phW6rCuVTqs+SmZl1hirPQbwP+E1E3JLbh0oaGxGT1zLcAOBS4AigFZgmaUpEzK7rdVJEjC8ZxfKI2KvSXJiZWaercg3iaxHxXK0lX7D+WoXh9gfmRcT8iHgJuB44dt3CNDOzrlYlQZT1U+XaxXBgQaG9NZfVO17SI5JulDSyUL6hpBZJD0gaW2F6ZmbWiaokiBZJ/yVpB0nbS/oWML3CcCopq7/76WZgVH750B3A1YVu20REM3AS8G1JO6wxAen0nERalixx7R9mZp2pSoL4DPAS6SL1DcA/gE9XGK4VKB4RjAAWFXuIiKURsSK3Xg7sW+i2KH/PB34L7F0/gYi4LCKaI6K5qampQkhmZlZVlao2XgTWuAOpgmnAjpK2AxYCJ5KOBl4laeuIWJxbxwBzcvlmwN8jYoWkLYG3A99chxjMzGwdVbmLqQn4IukFQRvWyiPi8PaGi4iVksYDU4EBwJURMUvS+UBLREwBJkgaQ6oE8FlgXB78zcAPJb1COsr5j5K7n8zMrIEU0f5D0ZJuI51eOgs4AzgFWBIR/9L48Kprbm6OlpaW7g7DzKxXkTQ9X+9dQ5VrEFtExBXAyxFxd0R8FHhbp0ZoZmY9TpXbVV/O34slvZt0oXlE40IyM7OeoEqCuEDSpsAXgP8GhgBnNjQqMzPrdlXuYrolNz4HHNbYcMzMrKeocg3CzMz6IScIMzMr5QRhZmalqjwoNxT4CDCq2H9ETGhcWGZm1t2q3MV0K/AAMBN4pbHhmJlZT1ElQWwYEZ9veCRmZtajVLkGca2kj0vaWtLmtU/DIzMzs25V5QjiJWAi8CVee59DANs3KigzM+t+VRLE54E3RcQzjQ7GzMx6jiqnmGYBf290IGZm1rNUOYJYBcyQdBdQe/ubb3M1M+vjqiSIyfljZmb9SJXK+q7uikDMzKxnqfIk9Y7AhcCurP7KUd/FZGbWh1W5SP2/wPdJ740+DLgGuLaRQZmZWferkiAGR8SdpPdXPxkR5wGHNzYsMzPrblUuUv9D0nrAY5LGAwuBNzY2LDMz625VjiA+B2wETAD2BT4EnFJl5JKOkjRX0jxJ55R0HydpiaQZ+XNaXfchkhZKuqTK9MzMrPNUuYtpGoCkiIhTq45Y0gDgUuAIoBWYJmlKRMyu63VSRIxvYzRfB+6uOk0zM+s8az2CkHSApNnAnNz+FknfqzDu/YF5ETE/Il4CrgeOrRqYpH2BrYDbqg5jZmadp8oppm8Do4GlABHxMHBwheGGAwsK7a25rN7xkh6RdKOkkQD5msfFwNntTUDS6ZJaJLUsWbKkQkhmZlZVpVeORsSCuqJVFQZT2ajq2m8GRkXEnsAdQO2hvE8Bt5ZMtz6uyyKiOSKam5qaKoRkZmZVVbmLaYGkA4GQNIh0sXpOheFagZGF9hHAomIPEbG00Ho5cFFuPgA4SNKngI2BQZJeiIg1LnSbmVljVEkQZwDfIZ0eaiVdE/h0heGmATtK2o50a+yJwEnFHiRtHRGLc+sYcuKJiJML/YwDmp0czMy6VrsJIt+J9OHiDruqiFiZn5uYCgwAroyIWZLOB1oiYgowQdIY0lPazwLjOjodMzNrDEXUXxao60H6bUQc2jXhrLvm5uZoaWnp7jDMzHoVSdMjormsW5VTTL/LD6pNAl6sFUbEHzspPjMz64GqJIgD8/f5hbLA9TGZmfVpVRLExyJifrFAkqv6NjPr46o8B3FjSdkNnR2ImZn1LG0eQUjaBdgN2FTScYVOQyi8OMjMzPqm9k4x7Qy8BxgKvLdQ/jzw8UYGZWZm3a/NBBERvwB+IemAiLi/C2MyM7MeoMo1iPfl9zIMlHSnpGckfajhkZmZWbeqkiCOjIi/kU43tQI7sZZaVs3MrPerkiAG5u9jgOsi4tkGxmNmZj1Elecgbpb0J2A58ClJTcA/GhuWmZl1t7UeQeRaVA8g1aj6Mqm6jcpvhjMzs96pyhEEwJuBUZKK/V/TgHjMzKyHWGuCkHQtsAMwg9feJBc4QZiZ9WlVjiCagV1jbfWCm5lZn1LlLqZHgX9qdCBmZtazVDmC2BKYLelBYEWtMCLGNCwqMzPrdlUSxHmNDqI3mvzQQiZOncuiZcsZNnQwZ4/embF7D+/usMzMOk2VBLEDcG9EPNboYHqLyQ8t5NybZrL85XTNfuGy5Zx700wAJwkz6zOqXIMYBfxQ0uOSfirpM5L2anBcPdrEqXNfTQ41y19excSpc7spIjOzzlflQbmvRsThwO7AfaR6mKZXGbmkoyTNlTRP0jkl3cdJWiJpRv6clsu3lTQ9l82SdEbHZquxFi1b3qFyM7PeqMpzEF8G3g5sDDwEnAXcW2G4AcClwBGkSv6mSZoSEbPrep0UEePryhYDB0bECkkbA4/mYRetdY66wLChg1lYkgyGDR3cDdGYmTVGlVNMxwFbAHcANwFTImJxheH2B+ZFxPyIeAm4nopVdETESxFRu2Nqg4pxdpmzR+/M4IEDVisbPHAAZ4/euZsiMjPrfFVOMe0DvBN4kHQ0MFPSfRXGPRxYUGhvzWX1jpf0iKQbJY2sFUoaKemRPI6Lyo4eJJ0uqUVSy5IlSyqE1DnG7j2cC4/bg+FDByNg+NDBXHjcHr5AbWZ9SpVTTLsDBwGHkJ6qXkCFU0yASsrqn8a+mVSF+Ip8neFq4HCAiFgA7ClpGDBZ0o0R8dRqI4u4DLgMoLm5uUuf9B6793AnBDPr06qcurkIGAJ8F3hzRBwWEV+tMFwrMLLQPgJY7SggIpYWTiVdDuxbP5J85DCLlKTMzKyLVDnF9G7gW8DfgJ0lDVzLIDXTgB0lbSdpEHAiMKXYg6StC61jgDm5fISkwbl5M9JFct9DambWhaqcYjqEVHPrE6TTRiMlnRIR97Q3XESslDQemAoMAK6MiFmSzgdaImIKMEHSGGAl8CwwLg/+ZuBiSZGn+Z8RMXNdZtDMzNaN1lZJq6TpwEkRMTe370S6brDG6aDu1NzcHC0tLd0dhplZryJpekQ0l3Wr9E7qWnIAiIg/89p7qs3MrI+qUhdTi6QrgGtz+8lUfJLazMx6ryoJ4pPAp4EJpOsB9wDfa2RQZmbW/daaIPIzCtcC10ZE1z2NZmZm3arNaxBKzpP0DPAnYG6uWK/KMxBmZtbLtXeR+nOk5w/2i4gtImJz4K3A2yWd2SXRmZlZt2kvQXwE+GBE/F+tICLmAx/K3czMrA9rL0EMjIhn6gvzdQjf5mpm1se1lyBeWsduZmbWB7R3F9NbJP2tpFzAhg2Kx8zMeog2E0REDGirm5mZ9X096k1tZmbWc1R5krpPm/zQQiZOncuiZcsZNnQwZ4/e2S8CMjOjnyeIyQ8t5NybZrL85VUALFy2nHNvSrWKO0mYWX/Xr08xTZw699XkULP85VVMnOp3E5mZ9esEsWjZ8g6Vm5n1J/06QQwbOrhD5WZm/Um/ThBnj96ZwQNXv5t38MABnD16526KyMys5+jXF6lrF6J9F5OZ2Zr6dYKAlCScEMzM1tTQU0ySjpI0V9I8SeeUdB+X3zExI39Oy+V7Sbpf0ixJj0j6QCPjNDOzNTXsCELSAOBS4AigFZgmaUpEzK7rdVJEjK8r+zvwkYh4TNIwYLqkqRGxrFHxmpnZ6hp5BLE/MC8i5kfES8D1wLFVBoyIP0fEY7l5EfA00NSwSM3MbA2NTBDDgQWF9tZcVu/4fBrpRkkj6ztK2h8YBDxe0u10SS2SWpYs8euyzcw6UyMThErKoq79ZmBUROwJ3AFcvdoIpK2Ba4FTI+KVNUYWcVlENEdEc1OTDzDMzDpTIxNEK1A8IhgBLCr2EBFLI2JFbr0c2LfWTdIQ4JfAlyPigQbGaWZmJRqZIKYBO0raTtIg4ERgSrGHfIRQMwaYk8sHAT8HromIGxoYo5mZtaFhdzFFxEpJ44GpwADgyoiYJel8oCUipgATJI0BVgLPAuPy4P8MHAxsIalWNi4iZjQqXjMzW50i6i8L9E7Nzc3R0tLS3WGYmfUqkqZHRHNZt35dF5OZmbXNCcLMzEo5QZiZWSknCDMzK+UEYWZmpZwgzMyslBOEmZmVcoIwM7NSThBmZlbKCcLMzEo5QZiZWSknCDMzK+UEYWZmpZwgzMyslBOEmZmVcoIwM7NSThBmZlbKCcLMzEo5QZiZWSknCDMzK9XQBCHpKElzJc2TdE5J93GSlkiakT+nFcxnyWQAAAqfSURBVLr9WtIySbc0MkYzMyu3fqNGLGkAcClwBNAKTJM0JSJm1/U6KSLGl4xiIrAR8IlGxWhmZm1r5BHE/sC8iJgfES8B1wPHVh04Iu4Enm9UcGZm1r5GJojhwIJCe2suq3e8pEck3ShpZAPjMTOzDmhkglBJWdS13wyMiog9gTuAqzs0Ael0SS2SWpYsWbKOYZqZWZlGJohWoHhEMAJYVOwhIpZGxIrcejmwb0cmEBGXRURzRDQ3NTW9rmDNzGx1jUwQ04AdJW0naRBwIjCl2IOkrQutY4A5DYzHzMw6oGF3MUXESknjganAAODKiJgl6XygJSKmABMkjQFWAs8C42rDS7oX2AXYWFIr8LGImNqoeM3MbHWKqL8s0Ds1NzdHS0tLd4dhZtarSJoeEc1l3fwktZmZlXKCMDOzUk4QZmZWygnCzMxKOUGYmVkpJwgzMyvlBGFmZqX6zHMQkpYAT76OUWwJPNNJ4fRU/WEewfPZ1/SH+ezOedw2IkrrKuozCeL1ktTS1sMifUV/mEfwfPY1/WE+e+o8+hSTmZmVcoIwM7NSThCvuay7A+gC/WEewfPZ1/SH+eyR8+hrEGZmVspHEGZmVsoJwszMSvX7BCHpKElzJc2TdE53x7MuJD0haaakGZJactnmkm6X9Fj+3iyXS9J38/w+ImmfwnhOyf0/JumU7pqfQjxXSnpa0qOFsk6bL0n75uU2Lw9b9h71hmpjHs+TtDCvzxmSjil0OzfHO1fS6EJ56Xac3+j4hzzvk/LbHbucpJGS7pI0R9IsSZ/N5X1mfbYzj713fUZEv/2Q3nT3OLA9MAh4GNi1u+Nah/l4AtiyruybwDm5+Rzgotx8DPArQMDbgD/k8s2B+fl7s9y8WTfP18HAPsCjjZgv4EHggDzMr4Cje8g8ngecVdLvrnkb3QDYLm+7A9rbjoGfAifm5h8An+ymdbk1sE9u3gT4c56fPrM+25nHXrs++/sRxP7AvIiYHxEvAdcDx3ZzTJ3lWODq3Hw1MLZQfk0kDwBDld4NPhq4PSKejYi/ArcDR3V10EURcQ/pVbRFnTJfuduQiLg/0q/tmsK4ukwb89iWY4HrI2JFRPwfMI+0DZdux/kf9OHAjXn44vLqUhGxOCL+mJufJ71/fjh9aH22M49t6fHrs78niOHAgkJ7K+2v0J4qgNskTZd0ei7bKiIWQ9pwgTfm8rbmubcsi86ar+G5ub68pxifT61cWTvtQsfncQtgWUSsrCvvVpJGAXsDf6CPrs+6eYReuj77e4IoO0fZG+/7fXtE7AMcDXxa0sHt9NvWPPf2ZdHR+erJ8/t9YAdgL2AxcHEu7/XzKGlj4GfA5yLib+31WlLWK+a1ZB577frs7wmiFRhZaB8BLOqmWNZZRCzK308DPycdoj6VD7vJ30/n3tua596yLDprvlpzc315t4uIpyJiVUS8AlxOWp/Q8Xl8hnRqZv268m4haSBpx/njiLgpF/ep9Vk2j715ffb3BDEN2DHfGTAIOBGY0s0xdYikN0japNYMHAk8SpqP2h0epwC/yM1TgI/ku0TeBjyXD+2nAkdK2iwfAh+Zy3qaTpmv3O15SW/L53Y/UhhXt6rtMLP3kdYnpHk8UdIGkrYDdiRdmC3djvO5+LuAE/LwxeXVpfIyvgKYExH/VejUZ9ZnW/PYq9dnI6+A94YP6W6JP5PuGvhSd8ezDvFvT7rL4WFgVm0eSOcr7wQey9+b53IBl+b5nQk0F8b1UdKFsnnAqT1g3q4jHZK/TPpX9bHOnC+gmfRjfRy4hFyzQA+Yx2vzPDxC2olsXej/SzneuRTu0mlrO87bx4N53m8ANuimdfkO0umQR4AZ+XNMX1qf7cxjr12frmrDzMxK9fdTTGZm1gYnCDMzK+UEYWZmpZwgzMyslBOEmZmVcoKwhpN0oaRDJY1VB2vMldSUa698SNJBdd3+R9KunRtt55DULOm7XTzNQyUd+DqGH9vW8myvW4XxjpN0ybrGZd3HCcK6wltJddIcAtzbwWHfCfwpIvaOiNWGjYjTImJ2J8XYqSKiJSImdPZ4C0/RljkUWOcEQar4ra0k0F4366OcIKxhJE2U9AiwH3A/cBrwfUlfLel3W0l35grN7pS0jaS9SNVBH5Pr0R9cN8xvJTXn5hckXZQrLLxD0v65+3xJY3I/oyTdK+mP+XNgLl9P0veU6vC/RdKtkk7I3faVdHce79RCtRATJM3O8V5fMj+HSrolN5+nVElbLZ7SxJHn4eIc252Smgrz+e+S7gY+m4+qfiZpWv68XalyuDOAM/OyOqisvzy+79bWgaTRku7Jy2IMMDEPv0MhrjW6Sfp4HufDeRob5X7fL+nRXH5PyTy+W9L9krYsWwbWw3THU5X+9J8Pqd6Z/wYGAr9rp7+bgVNy80eBybl5HHBJG8P8lvyELekJ1qNz88+B2/I03wLMyOUbARvm5h2Bltx8AnAr6Q/TPwF/zWUDgd8DTbm/DwBX5uZF5KdYgaElsR0K3JKbz8vj2QDYElgKDCwZJoCTc/NXa/Od5/N7hf5+ArwjN29DqtqhNp2zKvS3Eemp+8NIT/DukMuvAk5oY1mv1g3YotB8AfCZ3DwTGF5cLrV1SKpm4l66+T0j/lT/tHe4atYZ9iZVObAL0N7poAOA43LztaQjh454Cfh1bp4JrIiIlyXNBEbl8oHAJfnIZBWwUy5/B3BDpMrU/iLprly+M7A7cHuqZocBpGoxIFWb8GNJk4HJFeL7ZUSsAFZIehrYitWrpwZ4BZiUm38E3FToNqnQ/C5gV732wrQhyvVx1SntLyKel/Rx4B7gzIh4vEL89XaXdAEwFNiY1+rt+h1wlaSf1sV/GKkqjCOj/VpcrQdxgrCGyDvhq0g1Tj5D+tcqSTOAAyJi+VpG0dE6YF6O/HeVtKNdARARrxTO258JPEU6qlgP+Ect3LZmA5gVEQeUdHs36W1wY4CvSNotXqunv8yKQvMqqv32isvgxULzepQsQ635hs3S/rI9SEcywyrEUeYqYGxEPCxpHOmIiYg4Q9JbSctnRt4OIL35bXtSUm5Zx2laF/M1CGuIiJgREXvx2msXfwOMjoi92thh/Z5UayXAycB9DQhrU2BxPlL4MOmIgDyt4/O1iK3IOzvS6ZcmSQdAqspZ0m6S1gNGRsRdwBd57V/067Uer9XUeRJtL4PbgPG1lsJO+HnSqy7b7U/StsAXSEd3R+cdetnwRfXdNgEWK1VvfXJhGjtExB8i4qukPwa1aqufJB0hXiNptzamYT2ME4Q1TL7I+te8Q94l2r/jaAJwar6o/WHgsw0I6XvAKZIeIP2Trf0r/xnpdM+jwA9Jd1w9F+l1jycAF0l6mHSq7EBSYvlRPn31EPCtiFjWCfG9COwmaTrp1ZLnt9HfBKA5XyCfTbo4Dek6zvtqF6nL+pNerZL6rEjvEfkY8D+SNiS92vJspVuKd6ibZn23r5CW0+3Anwr9TZQ0U9KjpFNYD9c6RMRcUjK5oWT81gO5Nlcz0lvAIuIFSVuQqlN+e0T8pYtjeCEiOuNIxKxT+BqEWXKLpKHAIODrXZ0czHoiH0GYmVkpX4MwM7NSThBmZlbKCcLMzEo5QZiZWSknCDMzK/X/AT6Of0H5Nj+rAAAAAElFTkSuQmCC\n", 4467 | "text/plain": [ 4468 | "
" 4469 | ] 4470 | }, 4471 | "metadata": { 4472 | "needs_background": "light" 4473 | }, 4474 | "output_type": "display_data" 4475 | } 4476 | ], 4477 | "source": [ 4478 | "# No pretraining, train subset, all train, train + unsup, train + unsup + val\n", 4479 | "x = [0, 1275, 14669, 22419, 26348]\n", 4480 | "y = [0.523, 0.531, 0.557, 0.562, 0.563]\n", 4481 | "plt.title(\"Size of pretext task vs. downstream performance\")\n", 4482 | "plt.xlabel(\"# of images in pretext task\")\n", 4483 | "plt.ylabel(\"Downstream accuray\")\n", 4484 | "plt.scatter(x,y)" 4485 | ] 4486 | }, 4487 | { 4488 | "cell_type": "markdown", 4489 | "metadata": {}, 4490 | "source": [ 4491 | "In general adding more images to our pretext task seems to help, but the performance gains seem to be saturating." 4492 | ] 4493 | } 4494 | ], 4495 | "metadata": { 4496 | "kernelspec": { 4497 | "display_name": "Python (fastai2)", 4498 | "language": "python", 4499 | "name": "fastai2" 4500 | }, 4501 | "language_info": { 4502 | "codemirror_mode": { 4503 | "name": "ipython", 4504 | "version": 3 4505 | }, 4506 | "file_extension": ".py", 4507 | "mimetype": "text/x-python", 4508 | "name": "python", 4509 | "nbconvert_exporter": "python", 4510 | "pygments_lexer": "ipython3", 4511 | "version": "3.7.6" 4512 | } 4513 | }, 4514 | "nbformat": 4, 4515 | "nbformat_minor": 2 4516 | } 4517 | -------------------------------------------------------------------------------- /02_InpaintingVaryDatasetSize/README.md: -------------------------------------------------------------------------------- 1 | ## Effect of Pretext Dataset Size on Downstream Performance 2 | 3 | **Hypothesis:** By increasing the size of our pretext dataset we can improve downstream performance. 4 | 5 | **Result:** True, but perhaps with saturating improvements. 6 | 7 | **Methodology:** 8 | 9 | Image inpainting pretext task. 10 | 11 | Using four subsets of the ImageWang dataset: 12 | 13 | - `/train` that has a corresponding class in `/val` 14 | - `1,275` images 15 | - All `/train` data 16 | - `14,669` images 17 | - All `/train` data + all `/unsup` data 18 | - `22,419` images 19 | - All `/train` data + all `/unsup` data + all `/val` data 20 | - `26,348` images 21 | 22 | Results: 23 | 24 | Random: **52.3%** 25 | 26 | Partial /train: **53.1%** 27 | 28 | All /train: **55.7%** 29 | 30 | All /train and /unsup : **56.2%** 31 | 32 | All /train,/unsup and /val : **56.3%** 33 | 34 | 35 | ![test](https://i.imgur.com/ZuuAygJ.png) 36 | -------------------------------------------------------------------------------- /02_InpaintingVaryDatasetSize/RandomCutout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastai2.vision.all import PILImage, Image 3 | from fastai2.vision.augment import RandTransform 4 | 5 | 6 | # We create this dummy class in order to create a transform that ONLY operates on images of this type 7 | # We will use it to create all input images 8 | class PILImageInput(PILImage): 9 | pass 10 | 11 | 12 | class RandomCutout(RandTransform): 13 | "Picks a random scaled crop of an image and resize it to `size`" 14 | split_idx = None 15 | 16 | def __init__(self, min_n_holes=5, max_n_holes=10, min_length=5, max_length=50, **kwargs): 17 | super().__init__(**kwargs) 18 | self.min_n_holes = min_n_holes 19 | self.max_n_holes = max_n_holes 20 | self.min_length = min_length 21 | self.max_length = max_length 22 | 23 | def encodes(self, x: PILImageInput): 24 | """ 25 | Note that we're accepting our dummy PILImageInput class 26 | fastai2 will only pass images of this type to our encoder. 27 | This means that our transform will only be applied to input images and won't 28 | be run against output images. 29 | """ 30 | 31 | n_holes = np.random.randint(self.min_n_holes, self.max_n_holes) 32 | pixels = np.array(x) # Convert to mutable numpy array. FeelsBadMan 33 | h, w = pixels.shape[:2] 34 | 35 | for n in range(n_holes): 36 | h_length = np.random.randint(self.min_length, self.max_length) 37 | w_length = np.random.randint(self.min_length, self.max_length) 38 | h_y = np.random.randint(0, h) 39 | h_x = np.random.randint(0, w) 40 | y1 = int(np.clip(h_y - h_length / 2, 0, h)) 41 | y2 = int(np.clip(h_y + h_length / 2, 0, h)) 42 | x1 = int(np.clip(h_x - w_length / 2, 0, w)) 43 | x2 = int(np.clip(h_x + w_length / 2, 0, w)) 44 | 45 | pixels[y1:y2, x1:x2, :] = 0 46 | 47 | return Image.fromarray(pixels, mode='RGB') -------------------------------------------------------------------------------- /02_InpaintingVaryDatasetSize/config.py: -------------------------------------------------------------------------------- 1 | from fastai2.layers import Mish, MaxPool 2 | from fastai2.vision.models.xresnet import xresnet34 3 | 4 | config = { 5 | 'lr': 8e-3, 6 | 'size': 128, 7 | 'sqrmom': 0.99, 8 | 'mom': 0.9, 9 | 'eps': 1e-6, 10 | 'epochs': 15, 11 | 'bs': 64, 12 | 'opt': 'ranger', 13 | 'sh': 0., 14 | 'sa': 0, 15 | 'sym': 0, 16 | 'beta': 0., 17 | 'act_fn': Mish, 18 | 'fp16': 0, 19 | 'pool': MaxPool, 20 | 'runs': 1, 21 | 'model': xresnet34 22 | } 23 | -------------------------------------------------------------------------------- /03_PretextTrainingTime/RandomCutout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastai2.vision.all import PILImage, Image 3 | from fastai2.vision.augment import RandTransform 4 | 5 | 6 | # We create this dummy class in order to create a transform that ONLY operates on images of this type 7 | # We will use it to create all input images 8 | class PILImageInput(PILImage): 9 | pass 10 | 11 | 12 | class RandomCutout(RandTransform): 13 | "Picks a random scaled crop of an image and resize it to `size`" 14 | split_idx = None 15 | 16 | def __init__(self, min_n_holes=5, max_n_holes=10, min_length=5, max_length=50, **kwargs): 17 | super().__init__(**kwargs) 18 | self.min_n_holes = min_n_holes 19 | self.max_n_holes = max_n_holes 20 | self.min_length = min_length 21 | self.max_length = max_length 22 | 23 | def encodes(self, x: PILImageInput): 24 | """ 25 | Note that we're accepting our dummy PILImageInput class 26 | fastai2 will only pass images of this type to our encoder. 27 | This means that our transform will only be applied to input images and won't 28 | be run against output images. 29 | """ 30 | 31 | n_holes = np.random.randint(self.min_n_holes, self.max_n_holes) 32 | pixels = np.array(x) # Convert to mutable numpy array. FeelsBadMan 33 | h, w = pixels.shape[:2] 34 | 35 | for n in range(n_holes): 36 | h_length = np.random.randint(self.min_length, self.max_length) 37 | w_length = np.random.randint(self.min_length, self.max_length) 38 | h_y = np.random.randint(0, h) 39 | h_x = np.random.randint(0, w) 40 | y1 = int(np.clip(h_y - h_length / 2, 0, h)) 41 | y2 = int(np.clip(h_y + h_length / 2, 0, h)) 42 | x1 = int(np.clip(h_x - w_length / 2, 0, w)) 43 | x2 = int(np.clip(h_x + w_length / 2, 0, w)) 44 | 45 | pixels[y1:y2, x1:x2, :] = 0 46 | 47 | return Image.fromarray(pixels, mode='RGB') -------------------------------------------------------------------------------- /03_PretextTrainingTime/config.py: -------------------------------------------------------------------------------- 1 | from fastai2.layers import Mish, MaxPool 2 | from fastai2.vision.models.xresnet import xresnet34 3 | 4 | config = { 5 | 'lr': 8e-3, 6 | 'size': 128, 7 | 'sqrmom': 0.99, 8 | 'mom': 0.9, 9 | 'eps': 1e-6, 10 | 'epochs': 15, 11 | 'bs': 64, 12 | 'opt': 'ranger', 13 | 'sh': 0., 14 | 'sa': 0, 15 | 'sym': 0, 16 | 'beta': 0., 17 | 'act_fn': Mish, 18 | 'fp16': 0, 19 | 'pool': MaxPool, 20 | 'runs': 1, 21 | 'model': xresnet34 22 | } 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Josh Varty 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SelfSupervisedLearning 2 | Experiments with self-supervised learning 3 | 4 | ### Requirements 5 | 6 | You must have [`fastai2`](https://github.com/fastai/fastai2) installed in order to run these notebooks. 7 | 8 | 9 | ### Resources 10 | 11 | [Self-Supervised Learning: Part 1](https://joshvarty.com/2020/02/03/self-supervised-learning-part-1/) 12 | --------------------------------------------------------------------------------