├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets └── screenshot.png ├── codebase ├── 01.02_PyTorch installation and setup.ipynb ├── 02.01 What is tensor and Type Conversions.ipynb ├── 02.02 Mathematical Operations.ipynb ├── 02.03 Indexing, Slicing, Concatenation and Reshaping Ops.ipynb ├── 03.01 Derivatives, Partial derivative, and Successive Differentiation.ipynb ├── 04.01 Simple ANN implementation.ipynb ├── 05.01_Custom data loading for structured.ipynb ├── 05.02_Custom data loading for Ustructured.ipynb ├── 06.01_CNN_create_data_loader.ipynb ├── 06.02_CNN_architecture.ipynb ├── 06.03_train_CNN.ipynb ├── 06.04_evaluate_CNN.ipynb ├── 06.05_predict_using_CNN.ipynb ├── 06_03_session_dir │ └── CNN_model.pth ├── 07.01_Transfer_learning_Download_data_create_dataloader.ipynb ├── 07.02_Transfer_learning_using_pretrained_model.ipynb ├── 07.03_Transfer_learning_Training_model.ipynb ├── 07.04_Transfer_learning_evaluate.ipynb ├── 07.05_Transfer_learning_prediction.ipynb └── Data │ ├── img_data │ ├── train │ │ ├── Cat │ │ │ ├── 0.jpg │ │ │ ├── 1.jpg │ │ │ ├── 10.jpg │ │ │ ├── 11.jpg │ │ │ ├── 12.jpg │ │ │ ├── 13.jpg │ │ │ ├── 14.jpg │ │ │ ├── 15.jpg │ │ │ ├── 16.jpg │ │ │ ├── 17.jpg │ │ │ ├── 18.jpg │ │ │ ├── 19.jpg │ │ │ ├── 2.jpg │ │ │ ├── 20.jpg │ │ │ ├── 21.jpg │ │ │ ├── 22.jpg │ │ │ ├── 24.jpg │ │ │ ├── 25.jpg │ │ │ ├── 26.jpg │ │ │ ├── 27.jpg │ │ │ ├── 28.jpg │ │ │ ├── 3.jpg │ │ │ ├── 4.jpg │ │ │ ├── 5.jpg │ │ │ ├── 6.jpg │ │ │ ├── 7.jpg │ │ │ ├── 8.jpg │ │ │ └── 9.jpg │ │ └── Dog │ │ │ ├── 0.jpg │ │ │ ├── 1.jpg │ │ │ ├── 10.jpg │ │ │ ├── 11.jpg │ │ │ ├── 12.jpg │ │ │ ├── 13.jpg │ │ │ ├── 14.jpg │ │ │ ├── 15.jpg │ │ │ ├── 16.jpg │ │ │ ├── 17.jpg │ │ │ ├── 18.jpg │ │ │ ├── 19.jpg │ │ │ ├── 2.jpg │ │ │ ├── 20.jpg │ │ │ ├── 21.jpg │ │ │ ├── 22.jpg │ │ │ ├── 23.jpg │ │ │ ├── 24.jpg │ │ │ ├── 25.jpg │ │ │ ├── 26.jpg │ │ │ ├── 27.jpg │ │ │ ├── 3.jpg │ │ │ ├── 4.jpg │ │ │ ├── 5.jpg │ │ │ ├── 6.jpg │ │ │ ├── 7.jpg │ │ │ ├── 8.jpg │ │ │ └── 9.jpg │ └── validation │ │ ├── Cat │ │ ├── 0.jpg │ │ ├── 1.jpg │ │ ├── 10.jpg │ │ ├── 11.jpg │ │ ├── 12.jpg │ │ ├── 13.jpg │ │ ├── 14.jpg │ │ ├── 15.jpg │ │ ├── 16.jpg │ │ ├── 17.jpg │ │ ├── 18.jpg │ │ ├── 19.jpg │ │ ├── 2.jpg │ │ ├── 20.jpg │ │ ├── 21.jpg │ │ ├── 22.jpg │ │ ├── 24.jpg │ │ ├── 25.jpg │ │ ├── 26.jpg │ │ ├── 27.jpg │ │ ├── 28.jpg │ │ ├── 3.jpg │ │ ├── 4.jpg │ │ ├── 5.jpg │ │ ├── 6.jpg │ │ ├── 7.jpg │ │ ├── 8.jpg │ │ └── 9.jpg │ │ └── Dog │ │ ├── 0.jpg │ │ ├── 1.jpg │ │ ├── 10.jpg │ │ ├── 11.jpg │ │ ├── 12.jpg │ │ ├── 13.jpg │ │ ├── 14.jpg │ │ ├── 15.jpg │ │ ├── 16.jpg │ │ ├── 17.jpg │ │ ├── 18.jpg │ │ ├── 19.jpg │ │ ├── 2.jpg │ │ ├── 20.jpg │ │ ├── 21.jpg │ │ ├── 22.jpg │ │ ├── 23.jpg │ │ ├── 24.jpg │ │ ├── 25.jpg │ │ ├── 26.jpg │ │ ├── 27.jpg │ │ ├── 3.jpg │ │ ├── 4.jpg │ │ ├── 5.jpg │ │ ├── 6.jpg │ │ ├── 7.jpg │ │ ├── 8.jpg │ │ └── 9.jpg │ └── iris.csv ├── docs ├── Section_001_PyTorch_Introduction.md ├── Section_002_PyTorch_Tensors_and_Operations.md ├── Section_003_AutoGrad.md ├── Section_004_PyTorch_First_NN.md ├── Section_005_Custom_data_loading.md ├── Section_006_CNN.md ├── Section_007_Transfer_learning.md ├── img │ └── .gitkeep ├── index.md └── references.md ├── mkdocs.yml └── requirements.txt /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-python@v2 13 | with: 14 | python-version: 3.x 15 | - run: pip install mkdocs-material 16 | - run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | Rough* 132 | temp* 133 | MNIST/ 134 | # img_data/ 135 | IMG_DATA/ 136 | Fashion* 137 | codebase/hymenoptera_data/ 138 | codebase/Session_07_models/Trans_model.pth 139 | codebase/07.01_main_Transfer_learning_.ipynb 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Basics 2 | 3 | [![Pytorch-basics](https://socialify.git.ci/c17hawke/Pytorch-basics/image?forks=1&issues=1&language=1&name=1&owner=1&pattern=Brick%20Wall&pulls=1&stargazers=1&theme=Dark)](https://c17hawke.github.io/Pytorch-basics/) 4 | 6 | 7 | * ## This repository contains introduction to Pytorch. 8 | * ## [Click here](https://c17hawke.github.io/Pytorch-basics/) Refer the documentation associated with this repository -------------------------------------------------------------------------------- /assets/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/assets/screenshot.png -------------------------------------------------------------------------------- /codebase/01.02_PyTorch installation and setup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1fe917a6", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "d2e01c8b", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "1.10.2+cu113\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "print(torch.__version__)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "id": "397f7cb1", 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "text/plain": [ 40 | "True" 41 | ] 42 | }, 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "output_type": "execute_result" 46 | } 47 | ], 48 | "source": [ 49 | "torch.cuda.is_available()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "3367ae36", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [] 59 | } 60 | ], 61 | "metadata": { 62 | "kernelspec": { 63 | "display_name": "Python 3 (ipykernel)", 64 | "language": "python", 65 | "name": "python3" 66 | }, 67 | "language_info": { 68 | "codemirror_mode": { 69 | "name": "ipython", 70 | "version": 3 71 | }, 72 | "file_extension": ".py", 73 | "mimetype": "text/x-python", 74 | "name": "python", 75 | "nbconvert_exporter": "python", 76 | "pygments_lexer": "ipython3", 77 | "version": "3.7.11" 78 | } 79 | }, 80 | "nbformat": 4, 81 | "nbformat_minor": 5 82 | } 83 | -------------------------------------------------------------------------------- /codebase/02.01 What is tensor and Type Conversions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8a538038", 6 | "metadata": {}, 7 | "source": [ 8 | "## Tensors" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "e6deae41", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import numpy as np\n", 20 | "import os" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "id": "a2a83b54", 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "data": { 31 | "text/plain": [ 32 | "'cuda'" 33 | ] 34 | }, 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "output_type": "execute_result" 38 | } 39 | ], 40 | "source": [ 41 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 42 | "\n", 43 | "device" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "id": "99744240", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "tensor([[ 1, 2, 3],\n", 56 | " [11, 22, 33]])" 57 | ] 58 | }, 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "output_type": "execute_result" 62 | } 63 | ], 64 | "source": [ 65 | "basic_tensor = torch.tensor([[1,2,3],[11,22,33]])\n", 66 | "\n", 67 | "basic_tensor" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "id": "48af8f4c", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "data": { 78 | "text/plain": [ 79 | "torch.int64" 80 | ] 81 | }, 82 | "execution_count": 4, 83 | "metadata": {}, 84 | "output_type": "execute_result" 85 | } 86 | ], 87 | "source": [ 88 | "basic_tensor.dtype" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 6, 94 | "id": "89efc402", 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "data": { 99 | "text/plain": [ 100 | "device(type='cpu')" 101 | ] 102 | }, 103 | "execution_count": 6, 104 | "metadata": {}, 105 | "output_type": "execute_result" 106 | } 107 | ], 108 | "source": [ 109 | "basic_tensor.device" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 7, 115 | "id": "ec6aa14f", 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "data": { 120 | "text/plain": [ 121 | "torch.Size([2, 3])" 122 | ] 123 | }, 124 | "execution_count": 7, 125 | "metadata": {}, 126 | "output_type": "execute_result" 127 | } 128 | ], 129 | "source": [ 130 | "basic_tensor.shape" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 8, 136 | "id": "46f469dc", 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "False" 143 | ] 144 | }, 145 | "execution_count": 8, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | } 149 | ], 150 | "source": [ 151 | "basic_tensor.requires_grad" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 9, 157 | "id": "5aa932a9", 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "tensor([[ 1., 2., 3.],\n", 164 | " [11., 22., 33.]], device='cuda:0', requires_grad=True)" 165 | ] 166 | }, 167 | "execution_count": 9, 168 | "metadata": {}, 169 | "output_type": "execute_result" 170 | } 171 | ], 172 | "source": [ 173 | "tensor = torch.tensor([[1,2,3],[11,22,33]],\n", 174 | " dtype=torch.float,\n", 175 | " device=device,\n", 176 | " requires_grad=True)\n", 177 | "\n", 178 | "tensor" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 10, 184 | "id": "d1fa9d57", 185 | "metadata": {}, 186 | "outputs": [ 187 | { 188 | "data": { 189 | "text/plain": [ 190 | "torch.float32" 191 | ] 192 | }, 193 | "execution_count": 10, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "tensor.dtype" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 11, 205 | "id": "9df2cce8", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "text/plain": [ 211 | "device(type='cuda', index=0)" 212 | ] 213 | }, 214 | "execution_count": 11, 215 | "metadata": {}, 216 | "output_type": "execute_result" 217 | } 218 | ], 219 | "source": [ 220 | "tensor.device" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 12, 226 | "id": "530d1b7e", 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "data": { 231 | "text/plain": [ 232 | "True" 233 | ] 234 | }, 235 | "execution_count": 12, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "tensor.requires_grad" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "id": "cff827f1", 247 | "metadata": {}, 248 | "source": [ 249 | "### Other commonly used tensors" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 14, 255 | "id": "a5917c97", 256 | "metadata": {}, 257 | "outputs": [ 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "tensor([[0., 0., 0.],\n", 262 | " [0., 0., 0.],\n", 263 | " [0., 0., 0.]])" 264 | ] 265 | }, 266 | "execution_count": 14, 267 | "metadata": {}, 268 | "output_type": "execute_result" 269 | } 270 | ], 271 | "source": [ 272 | "x = torch.empty(size=(3,3))\n", 273 | "x" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 15, 279 | "id": "55e8a6af", 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "data": { 284 | "text/plain": [ 285 | "tensor([[0., 0., 0.],\n", 286 | " [0., 0., 0.],\n", 287 | " [0., 0., 0.]])" 288 | ] 289 | }, 290 | "execution_count": 15, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "x = torch.zeros(size=(3,3))\n", 297 | "x" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 16, 303 | "id": "c388b231", 304 | "metadata": {}, 305 | "outputs": [ 306 | { 307 | "data": { 308 | "text/plain": [ 309 | "tensor([[1., 1., 1.],\n", 310 | " [1., 1., 1.],\n", 311 | " [1., 1., 1.]])" 312 | ] 313 | }, 314 | "execution_count": 16, 315 | "metadata": {}, 316 | "output_type": "execute_result" 317 | } 318 | ], 319 | "source": [ 320 | "x = torch.ones(size=(3,3))\n", 321 | "x" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": 29, 327 | "id": "702c42b9", 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "data": { 332 | "text/plain": [ 333 | "tensor([[0.7332, 0.3748, 0.0849, 0.9105],\n", 334 | " [0.2788, 0.3333, 0.6220, 0.6664],\n", 335 | " [0.3703, 0.9297, 0.6921, 0.1396]])" 336 | ] 337 | }, 338 | "execution_count": 29, 339 | "metadata": {}, 340 | "output_type": "execute_result" 341 | } 342 | ], 343 | "source": [ 344 | "x = torch.rand(size=(3,4))\n", 345 | "x" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 28, 351 | "id": "e9d64de8", 352 | "metadata": {}, 353 | "outputs": [ 354 | { 355 | "data": { 356 | "text/plain": [ 357 | "tensor([[1., 0., 0.],\n", 358 | " [0., 1., 0.],\n", 359 | " [0., 0., 1.]])" 360 | ] 361 | }, 362 | "execution_count": 28, 363 | "metadata": {}, 364 | "output_type": "execute_result" 365 | } 366 | ], 367 | "source": [ 368 | "x = torch.eye(3)\n", 369 | "x" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 37, 375 | "id": "55c51827", 376 | "metadata": {}, 377 | "outputs": [ 378 | { 379 | "data": { 380 | "text/plain": [ 381 | "tensor([0, 2, 4, 6])" 382 | ] 383 | }, 384 | "execution_count": 37, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | } 388 | ], 389 | "source": [ 390 | "x = torch.arange(start=0, end=7, step=2)\n", 391 | "x" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": 41, 397 | "id": "4cf79a41", 398 | "metadata": {}, 399 | "outputs": [ 400 | { 401 | "data": { 402 | "text/plain": [ 403 | "tensor([0.0000, 0.7778, 1.5556, 2.3333, 3.1111, 3.8889, 4.6667, 5.4444, 6.2222,\n", 404 | " 7.0000])" 405 | ] 406 | }, 407 | "execution_count": 41, 408 | "metadata": {}, 409 | "output_type": "execute_result" 410 | } 411 | ], 412 | "source": [ 413 | "x = torch.linspace(start=0, end=7, steps=10)\n", 414 | "x" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 42, 420 | "id": "15499249", 421 | "metadata": {}, 422 | "outputs": [ 423 | { 424 | "data": { 425 | "text/plain": [ 426 | "tensor([[ 1.0478, -1.0514, 0.5596, -1.2438],\n", 427 | " [ 0.5222, 2.4026, 0.6896, 1.0098],\n", 428 | " [-1.0985, 0.5391, 1.9458, -1.8787]])" 429 | ] 430 | }, 431 | "execution_count": 42, 432 | "metadata": {}, 433 | "output_type": "execute_result" 434 | } 435 | ], 436 | "source": [ 437 | "x = torch.rand(size=(3,4)).normal_(mean=0, std=1)\n", 438 | "x" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": 46, 444 | "id": "9662201e", 445 | "metadata": {}, 446 | "outputs": [ 447 | { 448 | "data": { 449 | "text/plain": [ 450 | "tensor([[4.5163, 5.3036, 4.8373, 5.9569],\n", 451 | " [4.8600, 5.1942, 5.3013, 5.2837],\n", 452 | " [5.7229, 5.4198, 5.7625, 4.3776]])" 453 | ] 454 | }, 455 | "execution_count": 46, 456 | "metadata": {}, 457 | "output_type": "execute_result" 458 | } 459 | ], 460 | "source": [ 461 | "x = torch.rand(size=(3,4)).uniform_(3, 6)\n", 462 | "x" 463 | ] 464 | }, 465 | { 466 | "cell_type": "code", 467 | "execution_count": 47, 468 | "id": "a128483f", 469 | "metadata": {}, 470 | "outputs": [ 471 | { 472 | "data": { 473 | "text/plain": [ 474 | "tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 475 | " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 476 | " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", 477 | " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n", 478 | " [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n", 479 | " [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", 480 | " [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n", 481 | " [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n", 482 | " [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n", 483 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])" 484 | ] 485 | }, 486 | "execution_count": 47, 487 | "metadata": {}, 488 | "output_type": "execute_result" 489 | } 490 | ], 491 | "source": [ 492 | "x = torch.diag(torch.ones(10))\n", 493 | "x" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 48, 499 | "id": "3a53ca87", 500 | "metadata": {}, 501 | "outputs": [ 502 | { 503 | "data": { 504 | "text/plain": [ 505 | "torch.Size([10, 10])" 506 | ] 507 | }, 508 | "execution_count": 48, 509 | "metadata": {}, 510 | "output_type": "execute_result" 511 | } 512 | ], 513 | "source": [ 514 | "x.shape" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": 49, 520 | "id": "cf6939d9", 521 | "metadata": {}, 522 | "outputs": [ 523 | { 524 | "data": { 525 | "text/plain": [ 526 | "tensor([[5., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 527 | " [0., 5., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 528 | " [0., 0., 5., 0., 0., 0., 0., 0., 0., 0.],\n", 529 | " [0., 0., 0., 5., 0., 0., 0., 0., 0., 0.],\n", 530 | " [0., 0., 0., 0., 5., 0., 0., 0., 0., 0.],\n", 531 | " [0., 0., 0., 0., 0., 5., 0., 0., 0., 0.],\n", 532 | " [0., 0., 0., 0., 0., 0., 5., 0., 0., 0.],\n", 533 | " [0., 0., 0., 0., 0., 0., 0., 5., 0., 0.],\n", 534 | " [0., 0., 0., 0., 0., 0., 0., 0., 5., 0.],\n", 535 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 5.]])" 536 | ] 537 | }, 538 | "execution_count": 49, 539 | "metadata": {}, 540 | "output_type": "execute_result" 541 | } 542 | ], 543 | "source": [ 544 | "x = torch.diag(5*torch.ones(10))\n", 545 | "x" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "id": "ad0d7b60", 551 | "metadata": {}, 552 | "source": [ 553 | "## Conversions" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": 51, 559 | "id": "bb29f1ad", 560 | "metadata": {}, 561 | "outputs": [ 562 | { 563 | "data": { 564 | "text/plain": [ 565 | "tensor([0, 1, 2, 3])" 566 | ] 567 | }, 568 | "execution_count": 51, 569 | "metadata": {}, 570 | "output_type": "execute_result" 571 | } 572 | ], 573 | "source": [ 574 | "x = torch.arange(4)\n", 575 | "x" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": 52, 581 | "id": "6f0e6073", 582 | "metadata": {}, 583 | "outputs": [ 584 | { 585 | "data": { 586 | "text/plain": [ 587 | "tensor([False, True, True, True])" 588 | ] 589 | }, 590 | "execution_count": 52, 591 | "metadata": {}, 592 | "output_type": "execute_result" 593 | } 594 | ], 595 | "source": [ 596 | "x.bool()" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": 53, 602 | "id": "f9b3e5ae", 603 | "metadata": {}, 604 | "outputs": [ 605 | { 606 | "data": { 607 | "text/plain": [ 608 | "tensor([0, 1, 2, 3], dtype=torch.int32)" 609 | ] 610 | }, 611 | "execution_count": 53, 612 | "metadata": {}, 613 | "output_type": "execute_result" 614 | } 615 | ], 616 | "source": [ 617 | "x.int()" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": 54, 623 | "id": "66d7a8bd", 624 | "metadata": {}, 625 | "outputs": [ 626 | { 627 | "data": { 628 | "text/plain": [ 629 | "tensor([0, 1, 2, 3], dtype=torch.int16)" 630 | ] 631 | }, 632 | "execution_count": 54, 633 | "metadata": {}, 634 | "output_type": "execute_result" 635 | } 636 | ], 637 | "source": [ 638 | "x.short() # int16" 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": 55, 644 | "id": "66209493", 645 | "metadata": {}, 646 | "outputs": [ 647 | { 648 | "data": { 649 | "text/plain": [ 650 | "tensor([0, 1, 2, 3])" 651 | ] 652 | }, 653 | "execution_count": 55, 654 | "metadata": {}, 655 | "output_type": "execute_result" 656 | } 657 | ], 658 | "source": [ 659 | "x.long() # int64" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": 56, 665 | "id": "5e15f749", 666 | "metadata": {}, 667 | "outputs": [ 668 | { 669 | "data": { 670 | "text/plain": [ 671 | "tensor([0., 1., 2., 3.], dtype=torch.float16)" 672 | ] 673 | }, 674 | "execution_count": 56, 675 | "metadata": {}, 676 | "output_type": "execute_result" 677 | } 678 | ], 679 | "source": [ 680 | "x.half() # float16" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 58, 686 | "id": "facc1875", 687 | "metadata": {}, 688 | "outputs": [ 689 | { 690 | "data": { 691 | "text/plain": [ 692 | "tensor([0., 1., 2., 3.])" 693 | ] 694 | }, 695 | "execution_count": 58, 696 | "metadata": {}, 697 | "output_type": "execute_result" 698 | } 699 | ], 700 | "source": [ 701 | "x.float() # float32" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": 59, 707 | "id": "a78b5c15", 708 | "metadata": {}, 709 | "outputs": [ 710 | { 711 | "data": { 712 | "text/plain": [ 713 | "tensor([0., 1., 2., 3.], dtype=torch.float64)" 714 | ] 715 | }, 716 | "execution_count": 59, 717 | "metadata": {}, 718 | "output_type": "execute_result" 719 | } 720 | ], 721 | "source": [ 722 | "x.double() # float64" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 60, 728 | "id": "a9f46b70", 729 | "metadata": {}, 730 | "outputs": [ 731 | { 732 | "data": { 733 | "text/plain": [ 734 | "array([[1, 2, 3],\n", 735 | " [1, 2, 3]])" 736 | ] 737 | }, 738 | "execution_count": 60, 739 | "metadata": {}, 740 | "output_type": "execute_result" 741 | } 742 | ], 743 | "source": [ 744 | "np_array = np.array([[1,2,3], [1,2,3]])\n", 745 | "\n", 746 | "np_array" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": 63, 752 | "id": "a3200135", 753 | "metadata": {}, 754 | "outputs": [ 755 | { 756 | "data": { 757 | "text/plain": [ 758 | "tensor([[1, 2, 3],\n", 759 | " [1, 2, 3]], dtype=torch.int32)" 760 | ] 761 | }, 762 | "execution_count": 63, 763 | "metadata": {}, 764 | "output_type": "execute_result" 765 | } 766 | ], 767 | "source": [ 768 | "tensor = torch.from_numpy(np_array)\n", 769 | "tensor" 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": 64, 775 | "id": "0eceda01", 776 | "metadata": {}, 777 | "outputs": [ 778 | { 779 | "data": { 780 | "text/plain": [ 781 | "array([[1, 2, 3],\n", 782 | " [1, 2, 3]])" 783 | ] 784 | }, 785 | "execution_count": 64, 786 | "metadata": {}, 787 | "output_type": "execute_result" 788 | } 789 | ], 790 | "source": [ 791 | "tensor.numpy()" 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": null, 797 | "id": "797dbfe8", 798 | "metadata": {}, 799 | "outputs": [], 800 | "source": [] 801 | } 802 | ], 803 | "metadata": { 804 | "kernelspec": { 805 | "display_name": "Python 3 (ipykernel)", 806 | "language": "python", 807 | "name": "python3" 808 | }, 809 | "language_info": { 810 | "codemirror_mode": { 811 | "name": "ipython", 812 | "version": 3 813 | }, 814 | "file_extension": ".py", 815 | "mimetype": "text/x-python", 816 | "name": "python", 817 | "nbconvert_exporter": "python", 818 | "pygments_lexer": "ipython3", 819 | "version": "3.7.11" 820 | } 821 | }, 822 | "nbformat": 4, 823 | "nbformat_minor": 5 824 | } 825 | -------------------------------------------------------------------------------- /codebase/02.03 Indexing, Slicing, Concatenation and Reshaping Ops.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "c1913a37", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "f32723b4", 17 | "metadata": {}, 18 | "source": [ 19 | "## Indexing, Slicing" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 9, 25 | "id": "b5ca9cc3", 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "data": { 30 | "text/plain": [ 31 | "tensor([[0.4662, 0.9302, 0.3527, 0.4858],\n", 32 | " [0.4701, 0.0616, 0.0107, 0.8433],\n", 33 | " [0.4257, 0.2782, 0.6458, 0.5032]])" 34 | ] 35 | }, 36 | "execution_count": 9, 37 | "metadata": {}, 38 | "output_type": "execute_result" 39 | } 40 | ], 41 | "source": [ 42 | "x = torch.rand(size=(3,4))\n", 43 | "x" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "777f7bf7", 49 | "metadata": {}, 50 | "source": [ 51 | "x[row, cols]\n", 52 | "\n", 53 | "row -> start:stop:step\n", 54 | "\n", 55 | "col -> start:stop:step " 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 10, 61 | "id": "16bc3400", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "data": { 66 | "text/plain": [ 67 | "tensor([0.4662, 0.9302, 0.3527, 0.4858])" 68 | ] 69 | }, 70 | "execution_count": 10, 71 | "metadata": {}, 72 | "output_type": "execute_result" 73 | } 74 | ], 75 | "source": [ 76 | "x[0,:]" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 11, 82 | "id": "eea6a943", 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "data": { 87 | "text/plain": [ 88 | "tensor([0.4662, 0.4701, 0.4257])" 89 | ] 90 | }, 91 | "execution_count": 11, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | } 95 | ], 96 | "source": [ 97 | "x[:,0]" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 14, 103 | "id": "5a9c221d", 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "tensor([[0.4662, 0.9302, 0.3527],\n", 110 | " [0.4701, 0.0616, 0.0107]])" 111 | ] 112 | }, 113 | "execution_count": 14, 114 | "metadata": {}, 115 | "output_type": "execute_result" 116 | } 117 | ], 118 | "source": [ 119 | "x[0:2, 0:3]" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 15, 125 | "id": "76fbc871", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "tensor([[0.4662, 0.3527],\n", 132 | " [0.4701, 0.0107]])" 133 | ] 134 | }, 135 | "execution_count": 15, 136 | "metadata": {}, 137 | "output_type": "execute_result" 138 | } 139 | ], 140 | "source": [ 141 | "x[0:2, 0:3:2]" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 16, 147 | "id": "a13bc486", 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "data": { 152 | "text/plain": [ 153 | "tensor(0.4662)" 154 | ] 155 | }, 156 | "execution_count": 16, 157 | "metadata": {}, 158 | "output_type": "execute_result" 159 | } 160 | ], 161 | "source": [ 162 | "x[0,0]" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 19, 168 | "id": "06f1c7dd", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "x[0,0] = 11" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 20, 178 | "id": "1d6c5fda", 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "text/plain": [ 184 | "tensor([[1.1000e+01, 9.3021e-01, 3.5273e-01, 4.8576e-01],\n", 185 | " [4.7009e-01, 6.1568e-02, 1.0703e-02, 8.4334e-01],\n", 186 | " [4.2567e-01, 2.7815e-01, 6.4583e-01, 5.0318e-01]])" 187 | ] 188 | }, 189 | "execution_count": 20, 190 | "metadata": {}, 191 | "output_type": "execute_result" 192 | } 193 | ], 194 | "source": [ 195 | "x" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 21, 201 | "id": "eac166e9", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "x = torch.rand(size=(3,4))\n", 206 | "y = torch.rand(size=(3,4))\n" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "id": "f25062ff", 212 | "metadata": {}, 213 | "source": [ 214 | "## Concatination" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 22, 220 | "id": "f820c7ac", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "tensor([[0.8759, 0.6237, 0.0273, 0.6163, 0.4174, 0.9742, 0.8510, 0.7403],\n", 227 | " [0.9567, 0.2370, 0.8376, 0.2125, 0.5866, 0.5525, 0.3482, 0.1966],\n", 228 | " [0.1307, 0.5164, 0.8591, 0.0752, 0.0532, 0.3344, 0.0439, 0.7195]])" 229 | ] 230 | }, 231 | "execution_count": 22, 232 | "metadata": {}, 233 | "output_type": "execute_result" 234 | } 235 | ], 236 | "source": [ 237 | "torch.cat((x, y), dim=1)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 23, 243 | "id": "167773f4", 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "data": { 248 | "text/plain": [ 249 | "tensor([[0.8759, 0.6237, 0.0273, 0.6163],\n", 250 | " [0.9567, 0.2370, 0.8376, 0.2125],\n", 251 | " [0.1307, 0.5164, 0.8591, 0.0752],\n", 252 | " [0.4174, 0.9742, 0.8510, 0.7403],\n", 253 | " [0.5866, 0.5525, 0.3482, 0.1966],\n", 254 | " [0.0532, 0.3344, 0.0439, 0.7195]])" 255 | ] 256 | }, 257 | "execution_count": 23, 258 | "metadata": {}, 259 | "output_type": "execute_result" 260 | } 261 | ], 262 | "source": [ 263 | "torch.cat((x, y), dim=0)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 24, 269 | "id": "21234b0d", 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "data": { 274 | "text/plain": [ 275 | "torch.Size([3, 4])" 276 | ] 277 | }, 278 | "execution_count": 24, 279 | "metadata": {}, 280 | "output_type": "execute_result" 281 | } 282 | ], 283 | "source": [ 284 | "x.shape" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "id": "031a4225", 290 | "metadata": {}, 291 | "source": [ 292 | "## Reshaping" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 25, 298 | "id": "c5d1b634", 299 | "metadata": {}, 300 | "outputs": [ 301 | { 302 | "data": { 303 | "text/plain": [ 304 | "tensor([[0.8759, 0.6237, 0.0273, 0.6163, 0.9567, 0.2370],\n", 305 | " [0.8376, 0.2125, 0.1307, 0.5164, 0.8591, 0.0752]])" 306 | ] 307 | }, 308 | "execution_count": 25, 309 | "metadata": {}, 310 | "output_type": "execute_result" 311 | } 312 | ], 313 | "source": [ 314 | "x.reshape(2,6)" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 26, 320 | "id": "eba402b9", 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "ename": "RuntimeError", 325 | "evalue": "shape '[2, 3]' is invalid for input of size 12", 326 | "output_type": "error", 327 | "traceback": [ 328 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 329 | "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", 330 | "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_17316\\2983344640.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", 331 | "\u001b[1;31mRuntimeError\u001b[0m: shape '[2, 3]' is invalid for input of size 12" 332 | ] 333 | } 334 | ], 335 | "source": [ 336 | "x.reshape(2,3) # no. of elements should match" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 27, 342 | "id": "c5217cbe", 343 | "metadata": {}, 344 | "outputs": [ 345 | { 346 | "data": { 347 | "text/plain": [ 348 | "tensor([[0.8759, 0.6237, 0.0273, 0.6163, 0.9567, 0.2370, 0.8376, 0.2125, 0.1307,\n", 349 | " 0.5164, 0.8591, 0.0752]])" 350 | ] 351 | }, 352 | "execution_count": 27, 353 | "metadata": {}, 354 | "output_type": "execute_result" 355 | } 356 | ], 357 | "source": [ 358 | "x.reshape(1,12)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": 28, 364 | "id": "f31de262", 365 | "metadata": {}, 366 | "outputs": [ 367 | { 368 | "data": { 369 | "text/plain": [ 370 | "tensor([[0.8759],\n", 371 | " [0.6237],\n", 372 | " [0.0273],\n", 373 | " [0.6163],\n", 374 | " [0.9567],\n", 375 | " [0.2370],\n", 376 | " [0.8376],\n", 377 | " [0.2125],\n", 378 | " [0.1307],\n", 379 | " [0.5164],\n", 380 | " [0.8591],\n", 381 | " [0.0752]])" 382 | ] 383 | }, 384 | "execution_count": 28, 385 | "metadata": {}, 386 | "output_type": "execute_result" 387 | } 388 | ], 389 | "source": [ 390 | "x.reshape(12,1)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": 31, 396 | "id": "03441687", 397 | "metadata": {}, 398 | "outputs": [ 399 | { 400 | "data": { 401 | "text/plain": [ 402 | "torch.Size([3, 4])" 403 | ] 404 | }, 405 | "execution_count": 31, 406 | "metadata": {}, 407 | "output_type": "execute_result" 408 | } 409 | ], 410 | "source": [ 411 | "x.shape" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 32, 417 | "id": "5a8dd9ab", 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "data": { 422 | "text/plain": [ 423 | "tensor([0.8759, 0.6237, 0.0273, 0.6163, 0.9567, 0.2370, 0.8376, 0.2125, 0.1307,\n", 424 | " 0.5164, 0.8591, 0.0752])" 425 | ] 426 | }, 427 | "execution_count": 32, 428 | "metadata": {}, 429 | "output_type": "execute_result" 430 | } 431 | ], 432 | "source": [ 433 | "x.view(-1) # Flattening" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 33, 439 | "id": "9ee987ea", 440 | "metadata": {}, 441 | "outputs": [ 442 | { 443 | "data": { 444 | "text/plain": [ 445 | "torch.Size([12])" 446 | ] 447 | }, 448 | "execution_count": 33, 449 | "metadata": {}, 450 | "output_type": "execute_result" 451 | } 452 | ], 453 | "source": [ 454 | "x.view(-1).shape # flattening a tensor" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 34, 460 | "id": "4ccccefe", 461 | "metadata": {}, 462 | "outputs": [ 463 | { 464 | "data": { 465 | "text/plain": [ 466 | "tensor([[0.5898, 0.8485, 0.6827, 0.4938, 0.5183, 0.9858, 0.5632, 0.4546, 0.6970,\n", 467 | " 0.9603],\n", 468 | " [0.0669, 0.4628, 0.9424, 0.0092, 0.0878, 0.9861, 0.9675, 0.8619, 0.2528,\n", 469 | " 0.1191],\n", 470 | " [0.6260, 0.0551, 0.4902, 0.4019, 0.2533, 0.5351, 0.0208, 0.0644, 0.4734,\n", 471 | " 0.1473],\n", 472 | " [0.5578, 0.7696, 0.9104, 0.2832, 0.4618, 0.6799, 0.0138, 0.6154, 0.2280,\n", 473 | " 0.6078],\n", 474 | " [0.4576, 0.9747, 0.4234, 0.8570, 0.5523, 0.9100, 0.3069, 0.8513, 0.7162,\n", 475 | " 0.6048],\n", 476 | " [0.6103, 0.4901, 0.7255, 0.4219, 0.4437, 0.9195, 0.0497, 0.0241, 0.9140,\n", 477 | " 0.4657],\n", 478 | " [0.1688, 0.0681, 0.9507, 0.0024, 0.7920, 0.7637, 0.1157, 0.6318, 0.4604,\n", 479 | " 0.1113],\n", 480 | " [0.9648, 0.4391, 0.2547, 0.3421, 0.2013, 0.1122, 0.9653, 0.2397, 0.7509,\n", 481 | " 0.2061],\n", 482 | " [0.2421, 0.7438, 0.0903, 0.2956, 0.1293, 0.6893, 0.8242, 0.3162, 0.8461,\n", 483 | " 0.1161],\n", 484 | " [0.3650, 0.5304, 0.2917, 0.3597, 0.2389, 0.0235, 0.7417, 0.5334, 0.4782,\n", 485 | " 0.8095],\n", 486 | " [0.4757, 0.8335, 0.0165, 0.4111, 0.8916, 0.5094, 0.3302, 0.0553, 0.0875,\n", 487 | " 0.4874],\n", 488 | " [0.5583, 0.1159, 0.2653, 0.1260, 0.3124, 0.9583, 0.0977, 0.3883, 0.6024,\n", 489 | " 0.3513],\n", 490 | " [0.2737, 0.8206, 0.8189, 0.1035, 0.6452, 0.4267, 0.7023, 0.5829, 0.1668,\n", 491 | " 0.5769],\n", 492 | " [0.3219, 0.7873, 0.1554, 0.4559, 0.4183, 0.5106, 0.8702, 0.8929, 0.7462,\n", 493 | " 0.6766],\n", 494 | " [0.6762, 0.3115, 0.6406, 0.0345, 0.5024, 0.5300, 0.1036, 0.5054, 0.7688,\n", 495 | " 0.6348],\n", 496 | " [0.4287, 0.6326, 0.5973, 0.7432, 0.5106, 0.8773, 0.8123, 0.1154, 0.7038,\n", 497 | " 0.1973]])" 498 | ] 499 | }, 500 | "execution_count": 34, 501 | "metadata": {}, 502 | "output_type": "execute_result" 503 | } 504 | ], 505 | "source": [ 506 | "# batch flattening ops\n", 507 | "batch = 16\n", 508 | "\n", 509 | "torch.rand((batch, 2,5)).view((batch, -1))" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "id": "4ca5e840", 516 | "metadata": {}, 517 | "outputs": [], 518 | "source": [] 519 | } 520 | ], 521 | "metadata": { 522 | "kernelspec": { 523 | "display_name": "Python 3 (ipykernel)", 524 | "language": "python", 525 | "name": "python3" 526 | }, 527 | "language_info": { 528 | "codemirror_mode": { 529 | "name": "ipython", 530 | "version": 3 531 | }, 532 | "file_extension": ".py", 533 | "mimetype": "text/x-python", 534 | "name": "python", 535 | "nbconvert_exporter": "python", 536 | "pygments_lexer": "ipython3", 537 | "version": "3.7.11" 538 | } 539 | }, 540 | "nbformat": 4, 541 | "nbformat_minor": 5 542 | } 543 | -------------------------------------------------------------------------------- /codebase/03.01 Derivatives, Partial derivative, and Successive Differentiation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "7a365fdd", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "282b233b", 17 | "metadata": {}, 18 | "source": [ 19 | "## Derivatives" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 9, 25 | "id": "77f3c29f", 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "data": { 30 | "text/plain": [ 31 | "tensor(5., requires_grad=True)" 32 | ] 33 | }, 34 | "execution_count": 9, 35 | "metadata": {}, 36 | "output_type": "execute_result" 37 | } 38 | ], 39 | "source": [ 40 | "x = torch.tensor(5.0, requires_grad=True)\n", 41 | "x" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "6ae84c1f", 47 | "metadata": {}, 48 | "source": [ 49 | "$x = 5.0$\n", 50 | "\n", 51 | "$y = x^2 => f(x) = x^2$\n", 52 | "\n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 10, 58 | "id": "d40fbe0d", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/plain": [ 64 | "tensor(25., grad_fn=)" 65 | ] 66 | }, 67 | "execution_count": 10, 68 | "metadata": {}, 69 | "output_type": "execute_result" 70 | } 71 | ], 72 | "source": [ 73 | "y = x ** 2\n", 74 | "y" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "id": "6f474ae0", 80 | "metadata": {}, 81 | "source": [ 82 | "$\\frac{dy}{dx} = 2x$\n", 83 | "\n", 84 | "$f'(x=5.0) = 2 * 5.0 = 10$" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 11, 90 | "id": "dd8db750", 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [ 96 | "tensor(10.)" 97 | ] 98 | }, 99 | "execution_count": 11, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "y.backward()\n", 106 | "\n", 107 | "x.grad" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "27d0695a", 113 | "metadata": {}, 114 | "source": [ 115 | "## Partial derivative" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 12, 121 | "id": "6abe730e", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "x = torch.tensor(5.0, requires_grad=True)\n", 126 | "y = torch.tensor(5.0, requires_grad=True)\n" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 13, 132 | "id": "93926cf7", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "f = x**2 + y**2" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 14, 142 | "id": "4f329c02", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "f.backward()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 15, 152 | "id": "d1cb5deb", 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "text/plain": [ 158 | "" 159 | ] 160 | }, 161 | "execution_count": 15, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | } 165 | ], 166 | "source": [ 167 | "f.grad_fn" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 16, 173 | "id": "d80038fb", 174 | "metadata": {}, 175 | "outputs": [ 176 | { 177 | "data": { 178 | "text/plain": [ 179 | "tensor(10.)" 180 | ] 181 | }, 182 | "execution_count": 16, 183 | "metadata": {}, 184 | "output_type": "execute_result" 185 | } 186 | ], 187 | "source": [ 188 | "x.grad # partial derivative wrt x at x = 5 and y =5" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 17, 194 | "id": "6ab20700", 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "data": { 199 | "text/plain": [ 200 | "tensor(10.)" 201 | ] 202 | }, 203 | "execution_count": 17, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | } 207 | ], 208 | "source": [ 209 | "y.grad # partial derivative wrt y at x = 5 and y =5" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 22, 215 | "id": "0d44cfc1", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "x = torch.tensor(5.0, requires_grad=True)\n", 220 | "y = torch.tensor(5.0, requires_grad=True)\n" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 23, 226 | "id": "9396e8bf", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "f2 = x**2 * y**2" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "id": "fdad968c", 236 | "metadata": {}, 237 | "source": [ 238 | "$f2(x, y) = x^2 . y^2$\n", 239 | "\n", 240 | "$\\frac{\\partial f2}{\\partial x} = 2x.y^2$\n", 241 | "\n", 242 | "$\\frac{\\partial f2}{\\partial y} = x^2.2y$" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 24, 248 | "id": "4bd23358", 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "f2.backward()" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 25, 258 | "id": "6b101352", 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "data": { 263 | "text/plain": [ 264 | "" 265 | ] 266 | }, 267 | "execution_count": 25, 268 | "metadata": {}, 269 | "output_type": "execute_result" 270 | } 271 | ], 272 | "source": [ 273 | "f2.grad_fn" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": 26, 279 | "id": "4731a225", 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "data": { 284 | "text/plain": [ 285 | "tensor(250.)" 286 | ] 287 | }, 288 | "execution_count": 26, 289 | "metadata": {}, 290 | "output_type": "execute_result" 291 | } 292 | ], 293 | "source": [ 294 | "x.grad" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "id": "33fba02c", 300 | "metadata": {}, 301 | "source": [ 302 | "## Successive Differentiation" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 27, 308 | "id": "a3f4c644", 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "from torch.autograd import grad\n", 313 | "\n", 314 | "def nth_derivative(f, wrt, n=2):\n", 315 | " \n", 316 | " for i in range(n):\n", 317 | " grads = grad(f, wrt, create_graph=True)[0]\n", 318 | " f = grads.sum()\n", 319 | " \n", 320 | " return grads" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 28, 326 | "id": "50612387", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "x = torch.tensor(5.0, requires_grad=True)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "id": "064bbc40", 336 | "metadata": {}, 337 | "source": [ 338 | "$f(x) = x^2 + x^3$\n", 339 | "\n", 340 | "$f'(x) = 2x + 3x^2$\n", 341 | "\n", 342 | "$f''(x) = 2 + 6x$\n", 343 | "\n", 344 | "$f''(x=5) = 2 + 6*5 = 32$" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 29, 350 | "id": "4bd61626", 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "data": { 355 | "text/plain": [ 356 | "tensor(32., grad_fn=)" 357 | ] 358 | }, 359 | "execution_count": 29, 360 | "metadata": {}, 361 | "output_type": "execute_result" 362 | } 363 | ], 364 | "source": [ 365 | "f = x**2 + x**3\n", 366 | "\n", 367 | "# double derivative\n", 368 | "nth_derivative(f, x)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "id": "99b5e49d", 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [] 378 | } 379 | ], 380 | "metadata": { 381 | "kernelspec": { 382 | "display_name": "Python 3 (ipykernel)", 383 | "language": "python", 384 | "name": "python3" 385 | }, 386 | "language_info": { 387 | "codemirror_mode": { 388 | "name": "ipython", 389 | "version": 3 390 | }, 391 | "file_extension": ".py", 392 | "mimetype": "text/x-python", 393 | "name": "python", 394 | "nbconvert_exporter": "python", 395 | "pygments_lexer": "ipython3", 396 | "version": "3.7.11" 397 | } 398 | }, 399 | "nbformat": 4, 400 | "nbformat_minor": 5 401 | } 402 | -------------------------------------------------------------------------------- /codebase/05.01_Custom data loading for structured.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 9, 6 | "id": "dc72eb3f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch \n", 11 | "from torch.utils.data import DataLoader, Dataset\n", 12 | "import numpy as np\n", 13 | "import pandas as pd" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "id": "800db13a", 19 | "metadata": {}, 20 | "source": [ 21 | "# Loading structured dataset" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 10, 27 | "id": "db81ed75", 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "data": { 32 | "text/html": [ 33 | "
\n", 34 | "\n", 47 | "\n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | "
sepal.lengthsepal.widthpetal.lengthpetal.widthspecies
244.83.41.90.2Setosa
686.22.24.51.5Versicolor
145.84.01.20.2Setosa
1366.33.45.62.4Virginica
586.62.94.61.3Versicolor
\n", 101 | "
" 102 | ], 103 | "text/plain": [ 104 | " sepal.length sepal.width petal.length petal.width species\n", 105 | "24 4.8 3.4 1.9 0.2 Setosa\n", 106 | "68 6.2 2.2 4.5 1.5 Versicolor\n", 107 | "14 5.8 4.0 1.2 0.2 Setosa\n", 108 | "136 6.3 3.4 5.6 2.4 Virginica\n", 109 | "58 6.6 2.9 4.6 1.3 Versicolor" 110 | ] 111 | }, 112 | "execution_count": 10, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "df = pd.read_csv(\"Data/iris.csv\")\n", 119 | "df.sample(5)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 11, 125 | "id": "de3be177", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/plain": [ 131 | "array(['Setosa', 'Versicolor', 'Virginica'], dtype=object)" 132 | ] 133 | }, 134 | "execution_count": 11, 135 | "metadata": {}, 136 | "output_type": "execute_result" 137 | } 138 | ], 139 | "source": [ 140 | "df[\"species\"].unique()" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 14, 146 | "id": "48d87531", 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "data": { 151 | "text/plain": [ 152 | "{'Setosa': 0, 'Versicolor': 1, 'Virginica': 2}" 153 | ] 154 | }, 155 | "execution_count": 14, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "{val: ind for ind, val in enumerate(df[\"species\"].unique())}" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 15, 167 | "id": "59f61df1", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "class Iris(Dataset):\n", 172 | " def __init__(self, target_col_name=\"species\"):\n", 173 | " self.df = pd.read_csv(\"Data/iris.csv\")\n", 174 | " x = self.df.drop(target_col_name, axis=1).to_numpy()\n", 175 | " self.x = torch.from_numpy(x)\n", 176 | " \n", 177 | " replacement_dict = {'Setosa': 0, 'Versicolor': 1, 'Virginica': 2}\n", 178 | " y = self.df[target_col_name].replace(replacement_dict).to_numpy()\n", 179 | " self.y = torch.from_numpy(y)\n", 180 | "\n", 181 | " def __getitem__(self, index):\n", 182 | " return self.x[index], self.y[index]\n", 183 | "\n", 184 | " def __len__(self):\n", 185 | " return self.df.shape[0]" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 17, 191 | "id": "7d8f32ee", 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "data": { 196 | "text/plain": [ 197 | "(150, 5)" 198 | ] 199 | }, 200 | "execution_count": 17, 201 | "metadata": {}, 202 | "output_type": "execute_result" 203 | } 204 | ], 205 | "source": [ 206 | "df.shape" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 16, 212 | "id": "c477c964", 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "data": { 217 | "text/plain": [ 218 | "150" 219 | ] 220 | }, 221 | "execution_count": 16, 222 | "metadata": {}, 223 | "output_type": "execute_result" 224 | } 225 | ], 226 | "source": [ 227 | "iris_data = Iris()\n", 228 | "len(iris_data)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 18, 234 | "id": "1bd4c953", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "iris_data_loader = DataLoader(iris_data, batch_size=8)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 20, 244 | "id": "bb4b7e85", 245 | "metadata": {}, 246 | "outputs": [ 247 | { 248 | "name": "stdout", 249 | "output_type": "stream", 250 | "text": [ 251 | "independent col data: tensor([[5.1000, 3.5000, 1.4000, 0.2000],\n", 252 | " [4.9000, 3.0000, 1.4000, 0.2000],\n", 253 | " [4.7000, 3.2000, 1.3000, 0.2000],\n", 254 | " [4.6000, 3.1000, 1.5000, 0.2000],\n", 255 | " [5.0000, 3.6000, 1.4000, 0.2000],\n", 256 | " [5.4000, 3.9000, 1.7000, 0.4000],\n", 257 | " [4.6000, 3.4000, 1.4000, 0.3000],\n", 258 | " [5.0000, 3.4000, 1.5000, 0.2000]], dtype=torch.float64), \n", 259 | "taget_col: tensor([0, 0, 0, 0, 0, 0, 0, 0])\n" 260 | ] 261 | } 262 | ], 263 | "source": [ 264 | "for data in iris_data_loader:\n", 265 | " x, y = data\n", 266 | " print(f\"independent col data: {x}, \\ntaget_col: {y}\")\n", 267 | " break" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 21, 273 | "id": "ac6e22af", 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "iris_data_loader = DataLoader(iris_data, batch_size=8, shuffle=True)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 22, 283 | "id": "37f39384", 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "independent col data: tensor([[4.8000, 3.4000, 1.6000, 0.2000],\n", 291 | " [5.7000, 2.8000, 4.5000, 1.3000],\n", 292 | " [5.8000, 2.7000, 5.1000, 1.9000],\n", 293 | " [6.3000, 2.8000, 5.1000, 1.5000],\n", 294 | " [7.3000, 2.9000, 6.3000, 1.8000],\n", 295 | " [6.9000, 3.1000, 4.9000, 1.5000],\n", 296 | " [4.7000, 3.2000, 1.6000, 0.2000],\n", 297 | " [6.5000, 2.8000, 4.6000, 1.5000]], dtype=torch.float64), \n", 298 | "taget_col: tensor([0, 1, 2, 2, 2, 1, 0, 1])\n" 299 | ] 300 | } 301 | ], 302 | "source": [ 303 | "for data in iris_data_loader:\n", 304 | " x, y = data\n", 305 | " print(f\"independent col data: {x}, \\ntaget_col: {y}\")\n", 306 | " break" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "id": "08f1211a", 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [] 316 | } 317 | ], 318 | "metadata": { 319 | "kernelspec": { 320 | "display_name": "Python 3 (ipykernel)", 321 | "language": "python", 322 | "name": "python3" 323 | }, 324 | "language_info": { 325 | "codemirror_mode": { 326 | "name": "ipython", 327 | "version": 3 328 | }, 329 | "file_extension": ".py", 330 | "mimetype": "text/x-python", 331 | "name": "python", 332 | "nbconvert_exporter": "python", 333 | "pygments_lexer": "ipython3", 334 | "version": "3.7.11" 335 | } 336 | }, 337 | "nbformat": 4, 338 | "nbformat_minor": 5 339 | } 340 | -------------------------------------------------------------------------------- /codebase/06.01_CNN_create_data_loader.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "09726bcc-9410-4397-8306-6e4eff4b9b5e", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "D:\\oneNeuron\\Pytorch\\Pytorch-basics\\env\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 14 | " from .autonotebook import tqdm as notebook_tqdm\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import os\n", 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import pandas as pd\n", 23 | "import seaborn as sns\n", 24 | "import torch\n", 25 | "import torch.nn as nn\n", 26 | "from torch.utils.data import DataLoader\n", 27 | "from torchvision import transforms, datasets\n", 28 | "import torch.nn.functional as F\n", 29 | "from sklearn.metrics import confusion_matrix\n", 30 | "from tqdm import tqdm" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "445fafab-6fa0-4279-b9de-1dcf0483be2a", 36 | "metadata": {}, 37 | "source": [ 38 | "## Download dataset" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "id": "60f60c74-9ca3-4514-a42f-b50389b70e27", 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "name": "stdout", 49 | "output_type": "stream", 50 | "text": [ 51 | "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n", 52 | "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to FashionMNISTDir\\FashionMNIST\\raw\\train-images-idx3-ubyte.gz\n" 53 | ] 54 | }, 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "26422272it [00:04, 6030366.11it/s] \n" 60 | ] 61 | }, 62 | { 63 | "name": "stdout", 64 | "output_type": "stream", 65 | "text": [ 66 | "Extracting FashionMNISTDir\\FashionMNIST\\raw\\train-images-idx3-ubyte.gz to FashionMNISTDir\\FashionMNIST\\raw\n", 67 | "\n", 68 | "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n", 69 | "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to FashionMNISTDir\\FashionMNIST\\raw\\train-labels-idx1-ubyte.gz\n" 70 | ] 71 | }, 72 | { 73 | "name": "stderr", 74 | "output_type": "stream", 75 | "text": [ 76 | "29696it [00:00, 179147.96it/s] \n" 77 | ] 78 | }, 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "Extracting FashionMNISTDir\\FashionMNIST\\raw\\train-labels-idx1-ubyte.gz to FashionMNISTDir\\FashionMNIST\\raw\n", 84 | "\n", 85 | "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n", 86 | "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to FashionMNISTDir\\FashionMNIST\\raw\\t10k-images-idx3-ubyte.gz\n" 87 | ] 88 | }, 89 | { 90 | "name": "stderr", 91 | "output_type": "stream", 92 | "text": [ 93 | "4422656it [00:23, 184390.73it/s] \n" 94 | ] 95 | }, 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "Extracting FashionMNISTDir\\FashionMNIST\\raw\\t10k-images-idx3-ubyte.gz to FashionMNISTDir\\FashionMNIST\\raw\n", 101 | "\n", 102 | "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n", 103 | "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to FashionMNISTDir\\FashionMNIST\\raw\\t10k-labels-idx1-ubyte.gz\n" 104 | ] 105 | }, 106 | { 107 | "name": "stderr", 108 | "output_type": "stream", 109 | "text": [ 110 | "6144it [00:00, ?it/s] " 111 | ] 112 | }, 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "Extracting FashionMNISTDir\\FashionMNIST\\raw\\t10k-labels-idx1-ubyte.gz to FashionMNISTDir\\FashionMNIST\\raw\n", 118 | "\n" 119 | ] 120 | }, 121 | { 122 | "name": "stderr", 123 | "output_type": "stream", 124 | "text": [ 125 | "\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "ROOT_DATA_DIR = \"FashionMNISTDir\"\n", 131 | "\n", 132 | "train_data = datasets.FashionMNIST(\n", 133 | " root = ROOT_DATA_DIR,\n", 134 | " train = True,\n", 135 | " download = True,\n", 136 | " transform = transforms.ToTensor()\n", 137 | " )\n", 138 | "\n", 139 | "\n", 140 | "test_data = datasets.FashionMNIST(\n", 141 | " root = ROOT_DATA_DIR,\n", 142 | " train = False, ## <<< Test data\n", 143 | " download = True,\n", 144 | " transform = transforms.ToTensor()\n", 145 | " )" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 3, 151 | "id": "c2f04852-f194-4812-a6fe-8557fb145ee1", 152 | "metadata": {}, 153 | "outputs": [ 154 | { 155 | "data": { 156 | "text/plain": [ 157 | "torch.Size([60000, 28, 28])" 158 | ] 159 | }, 160 | "execution_count": 3, 161 | "metadata": {}, 162 | "output_type": "execute_result" 163 | } 164 | ], 165 | "source": [ 166 | "train_data.data.shape" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 4, 172 | "id": "adf58954-e425-47a6-97de-3674a2e209da", 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "data": { 177 | "text/plain": [ 178 | "torch.Size([10000, 28, 28])" 179 | ] 180 | }, 181 | "execution_count": 4, 182 | "metadata": {}, 183 | "output_type": "execute_result" 184 | } 185 | ], 186 | "source": [ 187 | "test_data.data.shape" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 5, 193 | "id": "50d627a6-55c8-445f-86ce-2f2202bcdf3e", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "label_map = {\n", 198 | " 0: 'T-shirt/top',\n", 199 | " 1: 'Trouser',\n", 200 | " 2: 'Pullover',\n", 201 | " 3:' Dress',\n", 202 | " 4: 'Coat',\n", 203 | " 5: 'Sandal',\n", 204 | " 6: 'Shirt',\n", 205 | " 7: 'Sneaker',\n", 206 | " 8: 'Bag',\n", 207 | " 9: 'Ankle boot',\n", 208 | " }" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "id": "a5ee85fb-9e21-4b3c-a5c6-2a088e80f2e7", 214 | "metadata": {}, 215 | "source": [ 216 | "## Visualize one sample" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 7, 222 | "id": "1cc1b28d-a533-4566-81ae-a3f37471bcfd", 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "def view_sample_img(data, index, label_map):\n", 227 | " plt.imshow(data.data[index], cmap=\"gray\")\n", 228 | " plt.title(f\"data label: {label_map[data.targets[index].item()]}\")\n", 229 | " plt.axis(\"off\")" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 10, 235 | "id": "76adb871-bff6-4292-b94c-94e47e8f07b1", 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAD3CAYAAADmIkO7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPRUlEQVR4nO3dfYxc5XXH8d+xvWt718FrTAHbrbFqwJZrAQXUmigukZuCkga1hDSmJUj9wxLhD6xWoUWt1BZVcZMKKiWVFZWqEm/uS2RSqBoRUJFqlYogp7Ud4VBiagm8YHDtrdfvxmt4+sdcK8N27znLXo/3zOT7kVbs+sxz7527+9s7O4fnuVZKEYB8Zkz3AQCYGOEEkiKcQFKEE0iKcAJJEU4gqZ4Pp5k9ZmZfybYvM9tmZhumuJ8pj53i/h40sy1O/Ydm9skLdTw/KXo+nB/Fhf6h7zQz+yszO159nDGzsbavv3u+9lNK+blSyjbnOCYMt5n1m9khM5vXa+f+fCCcPayU8qVSyrxSyjxJfybpW+e+LqV8+kIcg5nNcsq/JGlXKeX4hTiWbtNz4TSznzezHWZ2zMy+JWlOW22BmX3HzA6a2eHq85+uapskrZW0ubqybK7+/RtmNmxmR83sP81s7SSPo3ZfbZab2fZq2/9kZhe3jV9jZi+Z2aiZ/eBCvGw0swfM7O3q3P3IzH65rdxvZk9UtR+a2Y1t494ws09Vnz9oZk+Z2RYzOyrpS5L+UNL66rz+oG2bn5H0rHPuP25m3zezI9V/P962z21m9tW689cTSik98yGpX9Kbkn5XUp+kz0sak/SVqr5Q0h2SBiR9TNJWSc+0jd8macO4bX6xGjdL0pclvStpTs3+H/uI+3pb0mpJg5K+LWlLVVsiaUStH94Zkn6l+vqnxh+npKWSRiUtDc7Ng+e2X1NfIWlY0uLq62WSlreNPV0dz0xJX5X0ctvYNyR9qu2xY5J+vTr2uXX7lvSapBUTnXtJF0s6LOnu6tz/ZvX1wuj89cpHr10516gVyq+XUsZKKU9J+v65YillpJTy7VLKyVLKMUmbJN3sbbCUsqUad7aU8heSZqv1g+ya5L6eLKXsLqWckPRHkr5gZjPV+oXwbCnl2VLKB6WUf5H0H2qFY/x+9pVShkop+6JjCrxfPbdVZtZXSnmjlLK3rf7v1fG8L+lJSdc62/peKeWZ6thPTfQAM1suaVYp5Uc12/hVSa+XUp6szv3fqxXm29oeU3f+ekKvhXOxpLdL9au18ua5T8xswMweMbM3q5dc/yZpyPuGmtn9ZvZf1UurUUnzJV0SHcgk9zU87jj7qm1fIek3qpe0o9V+PyFpUbTfyTKz77a9OXRXKeW/Jf2OWle5/zGzfzCzxW1D3m37/KSkOc7fk8M1/97uM5K8N6UWq+17V3lTrVcVE+2n/fz1hF4L5zuSlpiZtf3b0rbPv6zWVe8XSykXqfWGhCSde/yHpuhUf1/+vqQvSFpQShmSdKTt8Z5oX5L0M+OOc0zSIbV+6J6srojnPgZLKV+bxH4npZTy6fLjN4f+tvq3vyulfEKtXw5F0p9PdfPB11L196bzmP3VcbRbqtZL2XPqzl9P6LVwfk/SWUkbzazPzD4n6Rfa6h+TdErSaPXmwZ+MG39A0s+Oe/xZSQclzTKzP5Z00SSPJdqXJH3RzFaZ2YCkP5X0VPWycYuk28zsVjObaWZzzOyTE7yhdN6Y2QozW2dms9X6+/KUpA/O0+YPSFpmZjOqfQ2o9X3513GPaT/3z0q62sx+y8xmmdl6SaskfaftMXXnryf0VDhLKWckfU7Sb0v6X0nrJf1j20O+rtYbFIckvSzpuXGb+Iakz1fvrv6lpOerx+xR62XTaU3uJdtk9iW1/nZ7TNWbTJI2Vs9jWNKvqfUu58Fqn7+nCb5fZra0emm6dHztI5ot6WvV8b4r6VJJf9Bwm+dsrf47YmY7JK1T6+/S022P+dC5L6WMSPqsWq9ARtR6BfPZUkr7lXHC89cr7MN/ngGdZ2bflLS7lPLNBtvYpta7s39z3g4sGa9BDHTKLkn/PN0HkR3hxAVXSvnr6T6GbsDLWiCpnnpDCOgl7staM0t7Wf1wK/P/m85XBCtXrnTrmzdvrq1t3bq1tiZJO3fudOtnzpxx62NjY2599erVtbXbb7/dHbt37163/tBDD7n10dFRt96rSikT/jBz5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpNz/Q6iTfc7p7FNed911bv3OO+9063fccYdbf/99f9bS4OBgbW3u3Lnu2IULF7r1TtqzZ49b/+ADf4bZihX+AhIHDhyorT3//PPu2Icfftit7969261PJ/qcQJchnEBShBNIinACSRFOICnCCSRFOIGkpq3P2dRFF/krVD7xxBO1tWuuucYdO2OG/zvr2LFjbv306dNu3ZtTGfVI+/r63Pr8+fPd+okTJ9y616vs9BzZOXPm1Nai/m9/f79bf/HFF9363Xff7dY7iT4n0GUIJ5AU4QSSIpxAUoQTSIpwAkl1bSvlhRdecOtXXDH+7nE/NjIy4o6Npj7NmuUvlH/27Fm3Hk2X80RtnmhpzJkzp35v2WjfndR0iuGiRf6tTW+99Va3/tprr7n1JmilAF2GcAJJEU4gKcIJJEU4gaQIJ5AU4QSSSntn6xtuuMGte31MSTp06FBtLepTRr1Ab2qTJC1ZssStDwwM1NaiXmJ0C7/ouUVT0rx+YjRdLervRlPt3nrrrSlvOxI97w0bNrj1+++/v9H+p4IrJ5AU4QSSIpxAUoQTSIpwAkkRTiApwgkklXY+Z9RX2rhxo1v3+pzRfM2ozxn1zB555BG3vn///tqa1+uTpMWLF7v1d955x603mQ86e/Zsd+y8efPc+vXXX+/W77vvvtqa9/2U4v5utJRqNH7ZsmVuvQnmcwJdhnACSRFOICnCCSRFOIGkCCeQFOEEkkrb53z55Zfd+qWXXurWvbmD0dquUb/uyJEjbn3NmjVu/ZZbbqmtRXNBH330Ubd+zz33uPXdu3e7de9We1H/98CBA259165dbv3111+vrUVzQaM5ttF80JUrV7r11atX19b27Nnjjo3Q5wS6DOEEkiKcQFKEE0iKcAJJEU4gqbRLY1577bVufXh42K17U6OiqU+RaPpR5LnnnqutnThxwh27atUqtx5NtXv66afd+m233VZbi6ZV7dixw61Hy5167Y7BwUF3bDSNL5omuG/fPrd+00031daatlLqcOUEkiKcQFKEE0iKcAJJEU4gKcIJJEU4gaSmrc/pTcGRpIMHD7r1aAqQN73Ju82d5E+bkqSRkRG3HvGe+3vvveeOXbRokVvftGmTW4+eu3eLwWis1wucDG/J0GgqXdM+56lTp9z62rVra2uPP/64O3aquHICSRFOICnCCSRFOIGkCCeQFOEEkiKcQFLT1ud84IEH3HrUazx+/Lhb9/pe0bZPnz7t1qMe64033ujWFy5cWFu7+OKL3bF9fX1u/bLLLnPrXh9T8p97f3+/O3ZoaMitr1+/3q0vWLCgthb1IefPn+/Wo/HRc4u+p53AlRNIinACSRFOICnCCSRFOIGkCCeQFOEEkpq2PudLL73k1i+//HK3fuWVV7p1b23ZaA1U71Z0Ujx3MLp9oTe3MJp3GO07uk1ftPasN2cz2re3VrAU38bPW/91YGDAHRs97+jYvLmkkvTMM8+49U7gygkkRTiBpAgnkBThBJIinEBShBNIinACSVkppb5oVl+cZt7cP0m66qqramv33nuvO/bmm29269G9QaO5haOjo7W1aL5m1M/rpGjd2qiXGM2T9c7bK6+84o6966673HpmpZQJTyxXTiApwgkkRTiBpAgnkBThBJIinEBS0zZlrKnDhw+79e3bt9fWotvsrVu3zq177ScpXmbRm7IWtUqiKWWRqB3i1aN9z549262fOXPGrc+ZM6e2Fk0x7EVcOYGkCCeQFOEEkiKcQFKEE0iKcAJJEU4gqbR9zqgfF02t8npqUZ/y6NGjbj3qRUZLSEb790Tnpcm2O63JdDdvmt352HfUw52O88qVE0iKcAJJEU4gKcIJJEU4gaQIJ5AU4QSSStvnjPpKY2NjU9723r173XrU54xuoxfNW/REz7vTfc5o+57oeUe9aU/0PYlEy3ZGvenpwJUTSIpwAkkRTiApwgkkRTiBpAgnkBThBJJK2+eMNOlbnTp1yh0b9eui9VnPnj3r1r0+adM+ZpN1aSX/vEb7jtYDHhgYcOvesUXntBdx5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpLq2z9lk3mK0RmnTdWejetSj9UTH3mRtWMnvNUbHHT3v6Nib9FgjmdfzrcOVE0iKcAJJEU4gKcIJJEU4gaQIJ5BU17ZSOmnJkiVu/fDhw249amd4b+tH7YomS1d2WnTs0XKm3nNr2iLqRlw5gaQIJ5AU4QSSIpxAUoQTSIpwAkkRTiCpru1zdnIKUNNlGPv7+926NyWt6dKWnVxaM5ryFd3iL1o60zu2JrcPjLadFVdOICnCCSRFOIGkCCeQFOEEkiKcQFKEE0iqa/ucnRT146K5hVGf1Bsf9RKjfl10bNHtDb3te7cujMZK0smTJ926Z2hoaMpjuxVXTiApwgkkRTiBpAgnkBThBJIinEBShBNIij7nBKJeY1PenMmm8w47ue5tk7mgkxnv9Yfnzp3rjo0wnxPAeUM4gaQIJ5AU4QSSIpxAUoQTSIpWygSidkRTnXxbfzpbKdG+m7RSBgYG3LG9iCsnkBThBJIinEBShBNIinACSRFOICnCCSTVtX3O6ZwCFC0/2UTTaVmRJsfe6els3q0RO3nOs+LKCSRFOIGkCCeQFOEEkiKcQFKEE0iKcAJJdW2fs+kyjJ7oNnmdnFsYLcvZ9PaDnTxvTXWyz8nSmADOG8IJJEU4gaQIJ5AU4QSSIpxAUoQTSKpr+5zTqcm8RMnvNUbbblqP+qjTua6th/mcANIgnEBShBNIinACSRFOICnCCSRFOIGkurbP2cn5efv373frV199tVuP5lR6vcaoD9nX1zflbU+m7p3XqH87a1azHydv38znBJAG4QSSIpxAUoQTSIpwAkkRTiCprm2ldNLQ0JBbHxwcdOtRS+GSSy6prTWdEha1WpqIWilRu2N4eNite0uOLl++3B0baTqVbjpw5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpLq2z9nJW9nt3LnTrb/66qtufXR01K036UVG/brjx4+79ei8eOe1yVQ4Kb614oIFC2pr27dvd8dGMvYxI1w5gaQIJ5AU4QSSIpxAUoQTSIpwAkkRTiAp68YlA4GfBFw5gaQIJ5AU4QSSIpxAUoQTSIpwAkn9H5vkccLt/ncCAAAAAElFTkSuQmCC\n", 241 | "text/plain": [ 242 | "
" 243 | ] 244 | }, 245 | "metadata": { 246 | "needs_background": "light" 247 | }, 248 | "output_type": "display_data" 249 | } 250 | ], 251 | "source": [ 252 | "view_sample_img(train_data, index=1, label_map=label_map)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "id": "13642d17-c534-40fb-86e7-e5234a6127d0", 258 | "metadata": {}, 259 | "source": [ 260 | "## Create the dataloader" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": 11, 266 | "id": "4b7a8c70-14d0-4a5f-9700-0fd317929074", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "BATCH_SIZE = 64\n", 271 | "\n", 272 | "train_data_loader = DataLoader(\n", 273 | " dataset = train_data,\n", 274 | " batch_size = BATCH_SIZE,\n", 275 | " shuffle = True\n", 276 | " )\n", 277 | "\n", 278 | "test_data_loader = DataLoader(\n", 279 | " dataset = test_data,\n", 280 | " batch_size = BATCH_SIZE,\n", 281 | " shuffle = True\n", 282 | " )" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": 12, 288 | "id": "c216a67f-9aab-4013-a270-5dbc8cbd8ff5", 289 | "metadata": {}, 290 | "outputs": [ 291 | { 292 | "name": "stdout", 293 | "output_type": "stream", 294 | "text": [ 295 | "torch.Size([64, 1, 28, 28])\n", 296 | "torch.Size([64])\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "for data, label in test_data_loader:\n", 302 | " print(data.shape) \n", 303 | " print(label.shape)\n", 304 | " break" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": null, 310 | "id": "7215cb7d-302d-428b-bf8a-46811bba8181", 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "Python 3 (ipykernel)", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.7.11" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 5 337 | } 338 | -------------------------------------------------------------------------------- /codebase/06.02_CNN_architecture.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "09726bcc-9410-4397-8306-6e4eff4b9b5e", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "D:\\oneNeuron\\Pytorch\\Pytorch-basics\\env\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 14 | " from .autonotebook import tqdm as notebook_tqdm\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import os\n", 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "import pandas as pd\n", 23 | "import seaborn as sns\n", 24 | "import torch\n", 25 | "import torch.nn as nn\n", 26 | "from torch.utils.data import DataLoader\n", 27 | "from torchvision import transforms, datasets\n", 28 | "import torch.nn.functional as F\n", 29 | "from sklearn.metrics import confusion_matrix\n", 30 | "from tqdm import tqdm" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "445fafab-6fa0-4279-b9de-1dcf0483be2a", 36 | "metadata": {}, 37 | "source": [ 38 | "## Download dataset" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "id": "60f60c74-9ca3-4514-a42f-b50389b70e27", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "ROOT_DATA_DIR = \"FashionMNISTDir\"\n", 49 | "\n", 50 | "train_data = datasets.FashionMNIST(\n", 51 | " root = ROOT_DATA_DIR,\n", 52 | " train = True,\n", 53 | " download = True,\n", 54 | " transform = transforms.ToTensor()\n", 55 | " )\n", 56 | "\n", 57 | "\n", 58 | "test_data = datasets.FashionMNIST(\n", 59 | " root = ROOT_DATA_DIR,\n", 60 | " train = False, ## <<< Test data\n", 61 | " download = True,\n", 62 | " transform = transforms.ToTensor()\n", 63 | " )" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "id": "c2f04852-f194-4812-a6fe-8557fb145ee1", 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "data": { 74 | "text/plain": [ 75 | "torch.Size([60000, 28, 28])" 76 | ] 77 | }, 78 | "execution_count": 3, 79 | "metadata": {}, 80 | "output_type": "execute_result" 81 | } 82 | ], 83 | "source": [ 84 | "train_data.data.shape" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "id": "adf58954-e425-47a6-97de-3674a2e209da", 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [ 96 | "torch.Size([10000, 28, 28])" 97 | ] 98 | }, 99 | "execution_count": 4, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "test_data.data.shape" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "id": "50d627a6-55c8-445f-86ce-2f2202bcdf3e", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "label_map = {\n", 116 | " 0: 'T-shirt/top',\n", 117 | " 1: 'Trouser',\n", 118 | " 2: 'Pullover',\n", 119 | " 3:' Dress',\n", 120 | " 4: 'Coat',\n", 121 | " 5: 'Sandal',\n", 122 | " 6: 'Shirt',\n", 123 | " 7: 'Sneaker',\n", 124 | " 8: 'Bag',\n", 125 | " 9: 'Ankle boot',\n", 126 | " }" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "a5ee85fb-9e21-4b3c-a5c6-2a088e80f2e7", 132 | "metadata": {}, 133 | "source": [ 134 | "## Visualize one sample" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 6, 140 | "id": "1cc1b28d-a533-4566-81ae-a3f37471bcfd", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "def view_sample_img(data, index, label_map):\n", 145 | " plt.imshow(data.data[index], cmap=\"gray\")\n", 146 | " plt.title(f\"data label: {label_map[data.targets[index].item()]}\")\n", 147 | " plt.axis(\"off\")" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 7, 153 | "id": "76adb871-bff6-4292-b94c-94e47e8f07b1", 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "data": { 158 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAD3CAYAAADmIkO7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPRUlEQVR4nO3dfYxc5XXH8d+xvWt718FrTAHbrbFqwJZrAQXUmigukZuCkga1hDSmJUj9wxLhD6xWoUWt1BZVcZMKKiWVFZWqEm/uS2RSqBoRUJFqlYogp7Ud4VBiagm8YHDtrdfvxmt4+sdcK8N27znLXo/3zOT7kVbs+sxz7527+9s7O4fnuVZKEYB8Zkz3AQCYGOEEkiKcQFKEE0iKcAJJEU4gqZ4Pp5k9ZmZfybYvM9tmZhumuJ8pj53i/h40sy1O/Ydm9skLdTw/KXo+nB/Fhf6h7zQz+yszO159nDGzsbavv3u+9lNK+blSyjbnOCYMt5n1m9khM5vXa+f+fCCcPayU8qVSyrxSyjxJfybpW+e+LqV8+kIcg5nNcsq/JGlXKeX4hTiWbtNz4TSznzezHWZ2zMy+JWlOW22BmX3HzA6a2eHq85+uapskrZW0ubqybK7+/RtmNmxmR83sP81s7SSPo3ZfbZab2fZq2/9kZhe3jV9jZi+Z2aiZ/eBCvGw0swfM7O3q3P3IzH65rdxvZk9UtR+a2Y1t494ws09Vnz9oZk+Z2RYzOyrpS5L+UNL66rz+oG2bn5H0rHPuP25m3zezI9V/P962z21m9tW689cTSik98yGpX9Kbkn5XUp+kz0sak/SVqr5Q0h2SBiR9TNJWSc+0jd8macO4bX6xGjdL0pclvStpTs3+H/uI+3pb0mpJg5K+LWlLVVsiaUStH94Zkn6l+vqnxh+npKWSRiUtDc7Ng+e2X1NfIWlY0uLq62WSlreNPV0dz0xJX5X0ctvYNyR9qu2xY5J+vTr2uXX7lvSapBUTnXtJF0s6LOnu6tz/ZvX1wuj89cpHr10516gVyq+XUsZKKU9J+v65YillpJTy7VLKyVLKMUmbJN3sbbCUsqUad7aU8heSZqv1g+ya5L6eLKXsLqWckPRHkr5gZjPV+oXwbCnl2VLKB6WUf5H0H2qFY/x+9pVShkop+6JjCrxfPbdVZtZXSnmjlLK3rf7v1fG8L+lJSdc62/peKeWZ6thPTfQAM1suaVYp5Uc12/hVSa+XUp6szv3fqxXm29oeU3f+ekKvhXOxpLdL9au18ua5T8xswMweMbM3q5dc/yZpyPuGmtn9ZvZf1UurUUnzJV0SHcgk9zU87jj7qm1fIek3qpe0o9V+PyFpUbTfyTKz77a9OXRXKeW/Jf2OWle5/zGzfzCzxW1D3m37/KSkOc7fk8M1/97uM5K8N6UWq+17V3lTrVcVE+2n/fz1hF4L5zuSlpiZtf3b0rbPv6zWVe8XSykXqfWGhCSde/yHpuhUf1/+vqQvSFpQShmSdKTt8Z5oX5L0M+OOc0zSIbV+6J6srojnPgZLKV+bxH4npZTy6fLjN4f+tvq3vyulfEKtXw5F0p9PdfPB11L196bzmP3VcbRbqtZL2XPqzl9P6LVwfk/SWUkbzazPzD4n6Rfa6h+TdErSaPXmwZ+MG39A0s+Oe/xZSQclzTKzP5Z00SSPJdqXJH3RzFaZ2YCkP5X0VPWycYuk28zsVjObaWZzzOyTE7yhdN6Y2QozW2dms9X6+/KUpA/O0+YPSFpmZjOqfQ2o9X3513GPaT/3z0q62sx+y8xmmdl6SaskfaftMXXnryf0VDhLKWckfU7Sb0v6X0nrJf1j20O+rtYbFIckvSzpuXGb+Iakz1fvrv6lpOerx+xR62XTaU3uJdtk9iW1/nZ7TNWbTJI2Vs9jWNKvqfUu58Fqn7+nCb5fZra0emm6dHztI5ot6WvV8b4r6VJJf9Bwm+dsrf47YmY7JK1T6+/S022P+dC5L6WMSPqsWq9ARtR6BfPZUkr7lXHC89cr7MN/ngGdZ2bflLS7lPLNBtvYpta7s39z3g4sGa9BDHTKLkn/PN0HkR3hxAVXSvnr6T6GbsDLWiCpnnpDCOgl7staM0t7Wf1wK/P/m85XBCtXrnTrmzdvrq1t3bq1tiZJO3fudOtnzpxx62NjY2599erVtbXbb7/dHbt37163/tBDD7n10dFRt96rSikT/jBz5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpNz/Q6iTfc7p7FNed911bv3OO+9063fccYdbf/99f9bS4OBgbW3u3Lnu2IULF7r1TtqzZ49b/+ADf4bZihX+AhIHDhyorT3//PPu2Icfftit7969261PJ/qcQJchnEBShBNIinACSRFOICnCCSRFOIGkpq3P2dRFF/krVD7xxBO1tWuuucYdO2OG/zvr2LFjbv306dNu3ZtTGfVI+/r63Pr8+fPd+okTJ9y616vs9BzZOXPm1Nai/m9/f79bf/HFF9363Xff7dY7iT4n0GUIJ5AU4QSSIpxAUoQTSIpwAkl1bSvlhRdecOtXXDH+7nE/NjIy4o6Npj7NmuUvlH/27Fm3Hk2X80RtnmhpzJkzp35v2WjfndR0iuGiRf6tTW+99Va3/tprr7n1JmilAF2GcAJJEU4gKcIJJEU4gaQIJ5AU4QSSSntn6xtuuMGte31MSTp06FBtLepTRr1Ab2qTJC1ZssStDwwM1NaiXmJ0C7/ouUVT0rx+YjRdLervRlPt3nrrrSlvOxI97w0bNrj1+++/v9H+p4IrJ5AU4QSSIpxAUoQTSIpwAkkRTiApwgkklXY+Z9RX2rhxo1v3+pzRfM2ozxn1zB555BG3vn///tqa1+uTpMWLF7v1d955x603mQ86e/Zsd+y8efPc+vXXX+/W77vvvtqa9/2U4v5utJRqNH7ZsmVuvQnmcwJdhnACSRFOICnCCSRFOIGkCCeQFOEEkkrb53z55Zfd+qWXXurWvbmD0dquUb/uyJEjbn3NmjVu/ZZbbqmtRXNBH330Ubd+zz33uPXdu3e7de9We1H/98CBA259165dbv3111+vrUVzQaM5ttF80JUrV7r11atX19b27Nnjjo3Q5wS6DOEEkiKcQFKEE0iKcAJJEU4gqbRLY1577bVufXh42K17U6OiqU+RaPpR5LnnnqutnThxwh27atUqtx5NtXv66afd+m233VZbi6ZV7dixw61Hy5167Y7BwUF3bDSNL5omuG/fPrd+00031daatlLqcOUEkiKcQFKEE0iKcAJJEU4gKcIJJEU4gaSmrc/pTcGRpIMHD7r1aAqQN73Ju82d5E+bkqSRkRG3HvGe+3vvveeOXbRokVvftGmTW4+eu3eLwWis1wucDG/J0GgqXdM+56lTp9z62rVra2uPP/64O3aquHICSRFOICnCCSRFOIGkCCeQFOEEkiKcQFLT1ud84IEH3HrUazx+/Lhb9/pe0bZPnz7t1qMe64033ujWFy5cWFu7+OKL3bF9fX1u/bLLLnPrXh9T8p97f3+/O3ZoaMitr1+/3q0vWLCgthb1IefPn+/Wo/HRc4u+p53AlRNIinACSRFOICnCCSRFOIGkCCeQFOEEkpq2PudLL73k1i+//HK3fuWVV7p1b23ZaA1U71Z0Ujx3MLp9oTe3MJp3GO07uk1ftPasN2cz2re3VrAU38bPW/91YGDAHRs97+jYvLmkkvTMM8+49U7gygkkRTiBpAgnkBThBJIinEBShBNIinACSVkppb5oVl+cZt7cP0m66qqramv33nuvO/bmm29269G9QaO5haOjo7W1aL5m1M/rpGjd2qiXGM2T9c7bK6+84o6966673HpmpZQJTyxXTiApwgkkRTiBpAgnkBThBJIinEBS0zZlrKnDhw+79e3bt9fWotvsrVu3zq177ScpXmbRm7IWtUqiKWWRqB3i1aN9z549262fOXPGrc+ZM6e2Fk0x7EVcOYGkCCeQFOEEkiKcQFKEE0iKcAJJEU4gqbR9zqgfF02t8npqUZ/y6NGjbj3qRUZLSEb790Tnpcm2O63JdDdvmt352HfUw52O88qVE0iKcAJJEU4gKcIJJEU4gaQIJ5AU4QSSStvnjPpKY2NjU9723r173XrU54xuoxfNW/REz7vTfc5o+57oeUe9aU/0PYlEy3ZGvenpwJUTSIpwAkkRTiApwgkkRTiBpAgnkBThBJJK2+eMNOlbnTp1yh0b9eui9VnPnj3r1r0+adM+ZpN1aSX/vEb7jtYDHhgYcOvesUXntBdx5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpLq2z9lk3mK0RmnTdWejetSj9UTH3mRtWMnvNUbHHT3v6Nib9FgjmdfzrcOVE0iKcAJJEU4gKcIJJEU4gaQIJ5BU17ZSOmnJkiVu/fDhw249amd4b+tH7YomS1d2WnTs0XKm3nNr2iLqRlw5gaQIJ5AU4QSSIpxAUoQTSIpwAkkRTiCpru1zdnIKUNNlGPv7+926NyWt6dKWnVxaM5ryFd3iL1o60zu2JrcPjLadFVdOICnCCSRFOIGkCCeQFOEEkiKcQFKEE0iqa/ucnRT146K5hVGf1Bsf9RKjfl10bNHtDb3te7cujMZK0smTJ926Z2hoaMpjuxVXTiApwgkkRTiBpAgnkBThBJIinEBShBNIij7nBKJeY1PenMmm8w47ue5tk7mgkxnv9Yfnzp3rjo0wnxPAeUM4gaQIJ5AU4QSSIpxAUoQTSIpWygSidkRTnXxbfzpbKdG+m7RSBgYG3LG9iCsnkBThBJIinEBShBNIinACSRFOICnCCSTVtX3O6ZwCFC0/2UTTaVmRJsfe6els3q0RO3nOs+LKCSRFOIGkCCeQFOEEkiKcQFKEE0iKcAJJdW2fs+kyjJ7oNnmdnFsYLcvZ9PaDnTxvTXWyz8nSmADOG8IJJEU4gaQIJ5AU4QSSIpxAUoQTSKpr+5zTqcm8RMnvNUbbblqP+qjTua6th/mcANIgnEBShBNIinACSRFOICnCCSRFOIGkurbP2cn5efv373frV199tVuP5lR6vcaoD9nX1zflbU+m7p3XqH87a1azHydv38znBJAG4QSSIpxAUoQTSIpwAkkRTiCprm2ldNLQ0JBbHxwcdOtRS+GSSy6prTWdEha1WpqIWilRu2N4eNite0uOLl++3B0baTqVbjpw5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpLq2z9nJW9nt3LnTrb/66qtufXR01K036UVG/brjx4+79ei8eOe1yVQ4Kb614oIFC2pr27dvd8dGMvYxI1w5gaQIJ5AU4QSSIpxAUoQTSIpwAkkRTiAp68YlA4GfBFw5gaQIJ5AU4QSSIpxAUoQTSIpwAkn9H5vkccLt/ncCAAAAAElFTkSuQmCC\n", 159 | "text/plain": [ 160 | "
" 161 | ] 162 | }, 163 | "metadata": { 164 | "needs_background": "light" 165 | }, 166 | "output_type": "display_data" 167 | } 168 | ], 169 | "source": [ 170 | "view_sample_img(train_data, index=1, label_map=label_map)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "id": "13642d17-c534-40fb-86e7-e5234a6127d0", 176 | "metadata": {}, 177 | "source": [ 178 | "## Create the dataloader" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 8, 184 | "id": "4b7a8c70-14d0-4a5f-9700-0fd317929074", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "BATCH_SIZE = 64\n", 189 | "\n", 190 | "train_data_loader = DataLoader(\n", 191 | " dataset = train_data,\n", 192 | " batch_size = BATCH_SIZE,\n", 193 | " shuffle = True\n", 194 | " )\n", 195 | "\n", 196 | "test_data_loader = DataLoader(\n", 197 | " dataset = test_data,\n", 198 | " batch_size = BATCH_SIZE,\n", 199 | " shuffle = True\n", 200 | " )" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 9, 206 | "id": "c216a67f-9aab-4013-a270-5dbc8cbd8ff5", 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "torch.Size([64, 1, 28, 28])\n", 214 | "torch.Size([64])\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "for data, label in test_data_loader:\n", 220 | " print(data.shape) \n", 221 | " print(label.shape)\n", 222 | " break" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "id": "697498b5-fbd2-47e1-8823-c983cc9c85c2", 228 | "metadata": {}, 229 | "source": [ 230 | "## CNN architecture\n", 231 | "\n", 232 | "pytorch doc - [reference](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html?highlight=conv2d#torch.nn.Conv2d)" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 12, 238 | "id": "a837c0ec-c0ef-4bc9-96e4-e1cbc52a44d2", 239 | "metadata": {}, 240 | "outputs": [ 241 | { 242 | "data": { 243 | "text/plain": [ 244 | "'cuda'" 245 | ] 246 | }, 247 | "execution_count": 12, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 254 | "device" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 13, 260 | "id": "94510485-d718-4912-95c9-1f06f22257ad", 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "class CNN(nn.Module):\n", 265 | " def __init__(self, in_, out_):\n", 266 | " super(CNN, self).__init__()\n", 267 | " \n", 268 | " self.conv_pool_01 = nn.Sequential(\n", 269 | " nn.Conv2d(in_channels=in_, out_channels=8, kernel_size=5, stride=1, padding=0),\n", 270 | " nn.ReLU(),\n", 271 | " nn.MaxPool2d(kernel_size=2, stride=2)\n", 272 | " )\n", 273 | " \n", 274 | " self.conv_pool_02 = nn.Sequential(\n", 275 | " nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=0),\n", 276 | " nn.ReLU(),\n", 277 | " nn.MaxPool2d(kernel_size=2, stride=2)\n", 278 | " )\n", 279 | " \n", 280 | " self.Flatten = nn.Flatten()\n", 281 | " self.FC_01 = nn.Linear(in_features=16*4*4, out_features=128)\n", 282 | " self.FC_02 = nn.Linear(in_features=128, out_features=64)\n", 283 | " self.FC_03 = nn.Linear(in_features=64, out_features=out_)\n", 284 | " \n", 285 | " \n", 286 | " def forward(self, x):\n", 287 | " x = self.conv_pool_01(x)\n", 288 | " x = self.conv_pool_02(x)\n", 289 | " x = self.Flatten(x)\n", 290 | " x = self.FC_01(x)\n", 291 | " x = self.FC_02(x) \n", 292 | " x = self.FC_03(x)\n", 293 | " return x" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 14, 299 | "id": "62ee8c3b-01e2-4c06-b3a0-25ababc457e6", 300 | "metadata": {}, 301 | "outputs": [ 302 | { 303 | "data": { 304 | "text/plain": [ 305 | "CNN(\n", 306 | " (conv_pool_01): Sequential(\n", 307 | " (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))\n", 308 | " (1): ReLU()\n", 309 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 310 | " )\n", 311 | " (conv_pool_02): Sequential(\n", 312 | " (0): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))\n", 313 | " (1): ReLU()\n", 314 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 315 | " )\n", 316 | " (Flatten): Flatten(start_dim=1, end_dim=-1)\n", 317 | " (FC_01): Linear(in_features=256, out_features=128, bias=True)\n", 318 | " (FC_02): Linear(in_features=128, out_features=64, bias=True)\n", 319 | " (FC_03): Linear(in_features=64, out_features=10, bias=True)\n", 320 | ")" 321 | ] 322 | }, 323 | "execution_count": 14, 324 | "metadata": {}, 325 | "output_type": "execute_result" 326 | } 327 | ], 328 | "source": [ 329 | "model = CNN(1, 10)\n", 330 | "model.to(device)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "id": "1ce4336d-8d13-4a8d-a586-393229fe248d", 336 | "metadata": {}, 337 | "source": [ 338 | "## Count no. of trainable params" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 16, 344 | "id": "81b2554a-c504-442f-9363-7a6b51ed1364", 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "data": { 349 | "text/html": [ 350 | "\n", 352 | "\n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | "
Total trainable parameters: 45226
 ModulesParameters
0conv_pool_01.0.weight200
1conv_pool_01.0.bias8
2conv_pool_02.0.weight3200
3conv_pool_02.0.bias16
4FC_01.weight32768
5FC_01.bias128
6FC_02.weight8192
7FC_02.bias64
8FC_03.weight640
9FC_03.bias10
\n" 414 | ], 415 | "text/plain": [ 416 | "" 417 | ] 418 | }, 419 | "execution_count": 16, 420 | "metadata": {}, 421 | "output_type": "execute_result" 422 | } 423 | ], 424 | "source": [ 425 | "def count_params(model):\n", 426 | " model_params = {\"Modules\": list(), \"Parameters\": list()}\n", 427 | " total = 0\n", 428 | " for name, parameters in model.named_parameters():\n", 429 | " if not parameters.requires_grad:\n", 430 | " continue\n", 431 | " param = parameters.numel()\n", 432 | " model_params[\"Modules\"].append(name)\n", 433 | " model_params[\"Parameters\"].append(param)\n", 434 | " total += param\n", 435 | " df = pd.DataFrame(model_params)\n", 436 | " df = df.style.set_caption(f\"Total trainable parameters: {total}\")\n", 437 | " return df\n", 438 | "\n", 439 | "count_params(model)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "id": "f9c79c55-2bc0-416b-adfd-c5f6d05ad02d", 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [] 449 | } 450 | ], 451 | "metadata": { 452 | "kernelspec": { 453 | "display_name": "Python 3 (ipykernel)", 454 | "language": "python", 455 | "name": "python3" 456 | }, 457 | "language_info": { 458 | "codemirror_mode": { 459 | "name": "ipython", 460 | "version": 3 461 | }, 462 | "file_extension": ".py", 463 | "mimetype": "text/x-python", 464 | "name": "python", 465 | "nbconvert_exporter": "python", 466 | "pygments_lexer": "ipython3", 467 | "version": "3.7.11" 468 | } 469 | }, 470 | "nbformat": 4, 471 | "nbformat_minor": 5 472 | } 473 | -------------------------------------------------------------------------------- /codebase/06.03_train_CNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "09726bcc-9410-4397-8306-6e4eff4b9b5e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import pandas as pd\n", 14 | "import seaborn as sns\n", 15 | "import torch\n", 16 | "import torch.nn as nn\n", 17 | "from torch.utils.data import DataLoader\n", 18 | "from torchvision import transforms, datasets\n", 19 | "import torch.nn.functional as F\n", 20 | "from sklearn.metrics import confusion_matrix\n", 21 | "from tqdm import tqdm" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "445fafab-6fa0-4279-b9de-1dcf0483be2a", 27 | "metadata": {}, 28 | "source": [ 29 | "## Download dataset" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "60f60c74-9ca3-4514-a42f-b50389b70e27", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ROOT_DATA_DIR = \"FashionMNISTDir\"\n", 40 | "\n", 41 | "train_data = datasets.FashionMNIST(\n", 42 | " root = ROOT_DATA_DIR,\n", 43 | " train = True,\n", 44 | " download = True,\n", 45 | " transform = transforms.ToTensor()\n", 46 | " )\n", 47 | "\n", 48 | "\n", 49 | "test_data = datasets.FashionMNIST(\n", 50 | " root = ROOT_DATA_DIR,\n", 51 | " train = False, ## <<< Test data\n", 52 | " download = True,\n", 53 | " transform = transforms.ToTensor()\n", 54 | " )" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "id": "c2f04852-f194-4812-a6fe-8557fb145ee1", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "torch.Size([60000, 28, 28])" 67 | ] 68 | }, 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "train_data.data.shape" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "id": "adf58954-e425-47a6-97de-3674a2e209da", 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "torch.Size([10000, 28, 28])" 88 | ] 89 | }, 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "output_type": "execute_result" 93 | } 94 | ], 95 | "source": [ 96 | "test_data.data.shape" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "id": "50d627a6-55c8-445f-86ce-2f2202bcdf3e", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "label_map = {\n", 107 | " 0: 'T-shirt/top',\n", 108 | " 1: 'Trouser',\n", 109 | " 2: 'Pullover',\n", 110 | " 3:' Dress',\n", 111 | " 4: 'Coat',\n", 112 | " 5: 'Sandal',\n", 113 | " 6: 'Shirt',\n", 114 | " 7: 'Sneaker',\n", 115 | " 8: 'Bag',\n", 116 | " 9: 'Ankle boot',\n", 117 | " }" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "a5ee85fb-9e21-4b3c-a5c6-2a088e80f2e7", 123 | "metadata": {}, 124 | "source": [ 125 | "## Visualize one sample" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 6, 131 | "id": "1cc1b28d-a533-4566-81ae-a3f37471bcfd", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "def view_sample_img(data, index, label_map):\n", 136 | " plt.imshow(data.data[index], cmap=\"gray\")\n", 137 | " plt.title(f\"data label: {label_map[data.targets[index].item()]}\")\n", 138 | " plt.axis(\"off\")" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 7, 144 | "id": "76adb871-bff6-4292-b94c-94e47e8f07b1", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAD3CAYAAADmIkO7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPRUlEQVR4nO3dfYxc5XXH8d+xvWt718FrTAHbrbFqwJZrAQXUmigukZuCkga1hDSmJUj9wxLhD6xWoUWt1BZVcZMKKiWVFZWqEm/uS2RSqBoRUJFqlYogp7Ud4VBiagm8YHDtrdfvxmt4+sdcK8N27znLXo/3zOT7kVbs+sxz7527+9s7O4fnuVZKEYB8Zkz3AQCYGOEEkiKcQFKEE0iKcAJJEU4gqZ4Pp5k9ZmZfybYvM9tmZhumuJ8pj53i/h40sy1O/Ydm9skLdTw/KXo+nB/Fhf6h7zQz+yszO159nDGzsbavv3u+9lNK+blSyjbnOCYMt5n1m9khM5vXa+f+fCCcPayU8qVSyrxSyjxJfybpW+e+LqV8+kIcg5nNcsq/JGlXKeX4hTiWbtNz4TSznzezHWZ2zMy+JWlOW22BmX3HzA6a2eHq85+uapskrZW0ubqybK7+/RtmNmxmR83sP81s7SSPo3ZfbZab2fZq2/9kZhe3jV9jZi+Z2aiZ/eBCvGw0swfM7O3q3P3IzH65rdxvZk9UtR+a2Y1t494ws09Vnz9oZk+Z2RYzOyrpS5L+UNL66rz+oG2bn5H0rHPuP25m3zezI9V/P962z21m9tW689cTSik98yGpX9Kbkn5XUp+kz0sak/SVqr5Q0h2SBiR9TNJWSc+0jd8macO4bX6xGjdL0pclvStpTs3+H/uI+3pb0mpJg5K+LWlLVVsiaUStH94Zkn6l+vqnxh+npKWSRiUtDc7Ng+e2X1NfIWlY0uLq62WSlreNPV0dz0xJX5X0ctvYNyR9qu2xY5J+vTr2uXX7lvSapBUTnXtJF0s6LOnu6tz/ZvX1wuj89cpHr10516gVyq+XUsZKKU9J+v65YillpJTy7VLKyVLKMUmbJN3sbbCUsqUad7aU8heSZqv1g+ya5L6eLKXsLqWckPRHkr5gZjPV+oXwbCnl2VLKB6WUf5H0H2qFY/x+9pVShkop+6JjCrxfPbdVZtZXSnmjlLK3rf7v1fG8L+lJSdc62/peKeWZ6thPTfQAM1suaVYp5Uc12/hVSa+XUp6szv3fqxXm29oeU3f+ekKvhXOxpLdL9au18ua5T8xswMweMbM3q5dc/yZpyPuGmtn9ZvZf1UurUUnzJV0SHcgk9zU87jj7qm1fIek3qpe0o9V+PyFpUbTfyTKz77a9OXRXKeW/Jf2OWle5/zGzfzCzxW1D3m37/KSkOc7fk8M1/97uM5K8N6UWq+17V3lTrVcVE+2n/fz1hF4L5zuSlpiZtf3b0rbPv6zWVe8XSykXqfWGhCSde/yHpuhUf1/+vqQvSFpQShmSdKTt8Z5oX5L0M+OOc0zSIbV+6J6srojnPgZLKV+bxH4npZTy6fLjN4f+tvq3vyulfEKtXw5F0p9PdfPB11L196bzmP3VcbRbqtZL2XPqzl9P6LVwfk/SWUkbzazPzD4n6Rfa6h+TdErSaPXmwZ+MG39A0s+Oe/xZSQclzTKzP5Z00SSPJdqXJH3RzFaZ2YCkP5X0VPWycYuk28zsVjObaWZzzOyTE7yhdN6Y2QozW2dms9X6+/KUpA/O0+YPSFpmZjOqfQ2o9X3513GPaT/3z0q62sx+y8xmmdl6SaskfaftMXXnryf0VDhLKWckfU7Sb0v6X0nrJf1j20O+rtYbFIckvSzpuXGb+Iakz1fvrv6lpOerx+xR62XTaU3uJdtk9iW1/nZ7TNWbTJI2Vs9jWNKvqfUu58Fqn7+nCb5fZra0emm6dHztI5ot6WvV8b4r6VJJf9Bwm+dsrf47YmY7JK1T6+/S022P+dC5L6WMSPqsWq9ARtR6BfPZUkr7lXHC89cr7MN/ngGdZ2bflLS7lPLNBtvYpta7s39z3g4sGa9BDHTKLkn/PN0HkR3hxAVXSvnr6T6GbsDLWiCpnnpDCOgl7staM0t7Wf1wK/P/m85XBCtXrnTrmzdvrq1t3bq1tiZJO3fudOtnzpxx62NjY2599erVtbXbb7/dHbt37163/tBDD7n10dFRt96rSikT/jBz5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpNz/Q6iTfc7p7FNed911bv3OO+9063fccYdbf/99f9bS4OBgbW3u3Lnu2IULF7r1TtqzZ49b/+ADf4bZihX+AhIHDhyorT3//PPu2Icfftit7969261PJ/qcQJchnEBShBNIinACSRFOICnCCSRFOIGkpq3P2dRFF/krVD7xxBO1tWuuucYdO2OG/zvr2LFjbv306dNu3ZtTGfVI+/r63Pr8+fPd+okTJ9y616vs9BzZOXPm1Nai/m9/f79bf/HFF9363Xff7dY7iT4n0GUIJ5AU4QSSIpxAUoQTSIpwAkl1bSvlhRdecOtXXDH+7nE/NjIy4o6Npj7NmuUvlH/27Fm3Hk2X80RtnmhpzJkzp35v2WjfndR0iuGiRf6tTW+99Va3/tprr7n1JmilAF2GcAJJEU4gKcIJJEU4gaQIJ5AU4QSSSntn6xtuuMGte31MSTp06FBtLepTRr1Ab2qTJC1ZssStDwwM1NaiXmJ0C7/ouUVT0rx+YjRdLervRlPt3nrrrSlvOxI97w0bNrj1+++/v9H+p4IrJ5AU4QSSIpxAUoQTSIpwAkkRTiApwgkklXY+Z9RX2rhxo1v3+pzRfM2ozxn1zB555BG3vn///tqa1+uTpMWLF7v1d955x603mQ86e/Zsd+y8efPc+vXXX+/W77vvvtqa9/2U4v5utJRqNH7ZsmVuvQnmcwJdhnACSRFOICnCCSRFOIGkCCeQFOEEkkrb53z55Zfd+qWXXurWvbmD0dquUb/uyJEjbn3NmjVu/ZZbbqmtRXNBH330Ubd+zz33uPXdu3e7de9We1H/98CBA259165dbv3111+vrUVzQaM5ttF80JUrV7r11atX19b27Nnjjo3Q5wS6DOEEkiKcQFKEE0iKcAJJEU4gqbRLY1577bVufXh42K17U6OiqU+RaPpR5LnnnqutnThxwh27atUqtx5NtXv66afd+m233VZbi6ZV7dixw61Hy5167Y7BwUF3bDSNL5omuG/fPrd+00031daatlLqcOUEkiKcQFKEE0iKcAJJEU4gKcIJJEU4gaSmrc/pTcGRpIMHD7r1aAqQN73Ju82d5E+bkqSRkRG3HvGe+3vvveeOXbRokVvftGmTW4+eu3eLwWis1wucDG/J0GgqXdM+56lTp9z62rVra2uPP/64O3aquHICSRFOICnCCSRFOIGkCCeQFOEEkiKcQFLT1ud84IEH3HrUazx+/Lhb9/pe0bZPnz7t1qMe64033ujWFy5cWFu7+OKL3bF9fX1u/bLLLnPrXh9T8p97f3+/O3ZoaMitr1+/3q0vWLCgthb1IefPn+/Wo/HRc4u+p53AlRNIinACSRFOICnCCSRFOIGkCCeQFOEEkpq2PudLL73k1i+//HK3fuWVV7p1b23ZaA1U71Z0Ujx3MLp9oTe3MJp3GO07uk1ftPasN2cz2re3VrAU38bPW/91YGDAHRs97+jYvLmkkvTMM8+49U7gygkkRTiBpAgnkBThBJIinEBShBNIinACSVkppb5oVl+cZt7cP0m66qqramv33nuvO/bmm29269G9QaO5haOjo7W1aL5m1M/rpGjd2qiXGM2T9c7bK6+84o6966673HpmpZQJTyxXTiApwgkkRTiBpAgnkBThBJIinEBS0zZlrKnDhw+79e3bt9fWotvsrVu3zq177ScpXmbRm7IWtUqiKWWRqB3i1aN9z549262fOXPGrc+ZM6e2Fk0x7EVcOYGkCCeQFOEEkiKcQFKEE0iKcAJJEU4gqbR9zqgfF02t8npqUZ/y6NGjbj3qRUZLSEb790Tnpcm2O63JdDdvmt352HfUw52O88qVE0iKcAJJEU4gKcIJJEU4gaQIJ5AU4QSSStvnjPpKY2NjU9723r173XrU54xuoxfNW/REz7vTfc5o+57oeUe9aU/0PYlEy3ZGvenpwJUTSIpwAkkRTiApwgkkRTiBpAgnkBThBJJK2+eMNOlbnTp1yh0b9eui9VnPnj3r1r0+adM+ZpN1aSX/vEb7jtYDHhgYcOvesUXntBdx5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpLq2z9lk3mK0RmnTdWejetSj9UTH3mRtWMnvNUbHHT3v6Nib9FgjmdfzrcOVE0iKcAJJEU4gKcIJJEU4gaQIJ5BU17ZSOmnJkiVu/fDhw249amd4b+tH7YomS1d2WnTs0XKm3nNr2iLqRlw5gaQIJ5AU4QSSIpxAUoQTSIpwAkkRTiCpru1zdnIKUNNlGPv7+926NyWt6dKWnVxaM5ryFd3iL1o60zu2JrcPjLadFVdOICnCCSRFOIGkCCeQFOEEkiKcQFKEE0iqa/ucnRT146K5hVGf1Bsf9RKjfl10bNHtDb3te7cujMZK0smTJ926Z2hoaMpjuxVXTiApwgkkRTiBpAgnkBThBJIinEBShBNIij7nBKJeY1PenMmm8w47ue5tk7mgkxnv9Yfnzp3rjo0wnxPAeUM4gaQIJ5AU4QSSIpxAUoQTSIpWygSidkRTnXxbfzpbKdG+m7RSBgYG3LG9iCsnkBThBJIinEBShBNIinACSRFOICnCCSTVtX3O6ZwCFC0/2UTTaVmRJsfe6els3q0RO3nOs+LKCSRFOIGkCCeQFOEEkiKcQFKEE0iKcAJJdW2fs+kyjJ7oNnmdnFsYLcvZ9PaDnTxvTXWyz8nSmADOG8IJJEU4gaQIJ5AU4QSSIpxAUoQTSKpr+5zTqcm8RMnvNUbbblqP+qjTua6th/mcANIgnEBShBNIinACSRFOICnCCSRFOIGkurbP2cn5efv373frV199tVuP5lR6vcaoD9nX1zflbU+m7p3XqH87a1azHydv38znBJAG4QSSIpxAUoQTSIpwAkkRTiCprm2ldNLQ0JBbHxwcdOtRS+GSSy6prTWdEha1WpqIWilRu2N4eNite0uOLl++3B0baTqVbjpw5QSSIpxAUoQTSIpwAkkRTiApwgkkRTiBpLq2z9nJW9nt3LnTrb/66qtufXR01K036UVG/brjx4+79ei8eOe1yVQ4Kb614oIFC2pr27dvd8dGMvYxI1w5gaQIJ5AU4QSSIpxAUoQTSIpwAkkRTiAp68YlA4GfBFw5gaQIJ5AU4QSSIpxAUoQTSIpwAkn9H5vkccLt/ncCAAAAAElFTkSuQmCC", 150 | "text/plain": [ 151 | "
" 152 | ] 153 | }, 154 | "metadata": { 155 | "needs_background": "light" 156 | }, 157 | "output_type": "display_data" 158 | } 159 | ], 160 | "source": [ 161 | "view_sample_img(train_data, index=1, label_map=label_map)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "id": "13642d17-c534-40fb-86e7-e5234a6127d0", 167 | "metadata": {}, 168 | "source": [ 169 | "## Create the dataloader" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 8, 175 | "id": "4b7a8c70-14d0-4a5f-9700-0fd317929074", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "BATCH_SIZE = 64\n", 180 | "\n", 181 | "train_data_loader = DataLoader(\n", 182 | " dataset = train_data,\n", 183 | " batch_size = BATCH_SIZE,\n", 184 | " shuffle = True\n", 185 | " )\n", 186 | "\n", 187 | "test_data_loader = DataLoader(\n", 188 | " dataset = test_data,\n", 189 | " batch_size = BATCH_SIZE,\n", 190 | " shuffle = True\n", 191 | " )" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 9, 197 | "id": "c216a67f-9aab-4013-a270-5dbc8cbd8ff5", 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "torch.Size([64, 1, 28, 28])\n", 205 | "torch.Size([64])\n" 206 | ] 207 | } 208 | ], 209 | "source": [ 210 | "for data, label in test_data_loader:\n", 211 | " print(data.shape) \n", 212 | " print(label.shape)\n", 213 | " break" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "id": "697498b5-fbd2-47e1-8823-c983cc9c85c2", 219 | "metadata": {}, 220 | "source": [ 221 | "## CNN architecture\n", 222 | "\n", 223 | "pytorch doc - [reference](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html?highlight=conv2d#torch.nn.Conv2d)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 10, 229 | "id": "a837c0ec-c0ef-4bc9-96e4-e1cbc52a44d2", 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "data": { 234 | "text/plain": [ 235 | "'cuda'" 236 | ] 237 | }, 238 | "execution_count": 10, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 245 | "device" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 11, 251 | "id": "94510485-d718-4912-95c9-1f06f22257ad", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "class CNN(nn.Module):\n", 256 | " def __init__(self, in_, out_):\n", 257 | " super(CNN, self).__init__()\n", 258 | " \n", 259 | " self.conv_pool_01 = nn.Sequential(\n", 260 | " nn.Conv2d(in_channels=in_, out_channels=8, kernel_size=5, stride=1, padding=0),\n", 261 | " nn.ReLU(),\n", 262 | " nn.MaxPool2d(kernel_size=2, stride=2)\n", 263 | " )\n", 264 | " \n", 265 | " self.conv_pool_02 = nn.Sequential(\n", 266 | " nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=0),\n", 267 | " nn.ReLU(),\n", 268 | " nn.MaxPool2d(kernel_size=2, stride=2)\n", 269 | " )\n", 270 | " \n", 271 | " self.Flatten = nn.Flatten()\n", 272 | " self.FC_01 = nn.Linear(in_features=16*4*4, out_features=128)\n", 273 | " self.FC_02 = nn.Linear(in_features=128, out_features=64)\n", 274 | " self.FC_03 = nn.Linear(in_features=64, out_features=out_)\n", 275 | " \n", 276 | " \n", 277 | " def forward(self, x):\n", 278 | " x = self.conv_pool_01(x)\n", 279 | " x = self.conv_pool_02(x)\n", 280 | " x = self.Flatten(x)\n", 281 | " x = self.FC_01(x)\n", 282 | " x = self.FC_02(x) \n", 283 | " x = self.FC_03(x)\n", 284 | " return x" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 12, 290 | "id": "62ee8c3b-01e2-4c06-b3a0-25ababc457e6", 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "data": { 295 | "text/plain": [ 296 | "CNN(\n", 297 | " (conv_pool_01): Sequential(\n", 298 | " (0): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))\n", 299 | " (1): ReLU()\n", 300 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 301 | " )\n", 302 | " (conv_pool_02): Sequential(\n", 303 | " (0): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1))\n", 304 | " (1): ReLU()\n", 305 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 306 | " )\n", 307 | " (Flatten): Flatten(start_dim=1, end_dim=-1)\n", 308 | " (FC_01): Linear(in_features=256, out_features=128, bias=True)\n", 309 | " (FC_02): Linear(in_features=128, out_features=64, bias=True)\n", 310 | " (FC_03): Linear(in_features=64, out_features=10, bias=True)\n", 311 | ")" 312 | ] 313 | }, 314 | "execution_count": 12, 315 | "metadata": {}, 316 | "output_type": "execute_result" 317 | } 318 | ], 319 | "source": [ 320 | "model = CNN(1, 10)\n", 321 | "model.to(device)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "id": "1ce4336d-8d13-4a8d-a586-393229fe248d", 327 | "metadata": {}, 328 | "source": [ 329 | "## Count no. of trainable params" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 13, 335 | "id": "81b2554a-c504-442f-9363-7a6b51ed1364", 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "data": { 340 | "text/html": [ 341 | "\n", 343 | "\n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | "
Total trainable parameters: 45226
 ModulesParameters
0conv_pool_01.0.weight200
1conv_pool_01.0.bias8
2conv_pool_02.0.weight3200
3conv_pool_02.0.bias16
4FC_01.weight32768
5FC_01.bias128
6FC_02.weight8192
7FC_02.bias64
8FC_03.weight640
9FC_03.bias10
\n" 405 | ], 406 | "text/plain": [ 407 | "" 408 | ] 409 | }, 410 | "execution_count": 13, 411 | "metadata": {}, 412 | "output_type": "execute_result" 413 | } 414 | ], 415 | "source": [ 416 | "def count_params(model):\n", 417 | " model_params = {\"Modules\": list(), \"Parameters\": list()}\n", 418 | " total = 0\n", 419 | " for name, parameters in model.named_parameters():\n", 420 | " if not parameters.requires_grad:\n", 421 | " continue\n", 422 | " param = parameters.numel()\n", 423 | " model_params[\"Modules\"].append(name)\n", 424 | " model_params[\"Parameters\"].append(param)\n", 425 | " total += param\n", 426 | " df = pd.DataFrame(model_params)\n", 427 | " df = df.style.set_caption(f\"Total trainable parameters: {total}\")\n", 428 | " return df\n", 429 | "\n", 430 | "count_params(model)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 14, 436 | "id": "f9c79c55-2bc0-416b-adfd-c5f6d05ad02d", 437 | "metadata": {}, 438 | "outputs": [ 439 | { 440 | "data": { 441 | "text/plain": [ 442 | "True" 443 | ] 444 | }, 445 | "execution_count": 14, 446 | "metadata": {}, 447 | "output_type": "execute_result" 448 | } 449 | ], 450 | "source": [ 451 | "next(model.parameters()).is_cuda" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "id": "27898820-48dc-4d82-a5b3-33d431baf79e", 457 | "metadata": {}, 458 | "source": [ 459 | "## Traning loop" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 15, 465 | "id": "4a5d7fad-e433-49f1-a06a-fe48fb7684fd", 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "learning_rate = 0.001\n", 470 | "num_epochs = 20" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": 16, 476 | "id": "9124fc31-1b9d-4a16-995d-9d2dd178238f", 477 | "metadata": {}, 478 | "outputs": [], 479 | "source": [ 480 | "criterion = nn.CrossEntropyLoss()\n", 481 | "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)" 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": 17, 487 | "id": "ee45f38a-da2b-4835-9507-3f368198eb41", 488 | "metadata": {}, 489 | "outputs": [ 490 | { 491 | "data": { 492 | "text/plain": [ 493 | "938" 494 | ] 495 | }, 496 | "execution_count": 17, 497 | "metadata": {}, 498 | "output_type": "execute_result" 499 | } 500 | ], 501 | "source": [ 502 | "n_total_steps = len(train_data_loader)\n", 503 | "n_total_steps" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 18, 509 | "id": "9b3fe689-eaba-4c86-b221-48cc89428248", 510 | "metadata": {}, 511 | "outputs": [ 512 | { 513 | "data": { 514 | "text/plain": [ 515 | "937.5" 516 | ] 517 | }, 518 | "execution_count": 18, 519 | "metadata": {}, 520 | "output_type": "execute_result" 521 | } 522 | ], 523 | "source": [ 524 | "60000/BATCH_SIZE" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": 19, 530 | "id": "5860758d-95b1-42c2-a285-544254fb9936", 531 | "metadata": {}, 532 | "outputs": [ 533 | { 534 | "name": "stderr", 535 | "output_type": "stream", 536 | "text": [ 537 | "Epoch 1/20: 100%|██████████| 938/938 [00:20<00:00, 46.28it/s, loss=0.405]\n", 538 | "Epoch 2/20: 100%|██████████| 938/938 [00:14<00:00, 62.66it/s, loss=0.461]\n", 539 | "Epoch 3/20: 100%|██████████| 938/938 [00:15<00:00, 58.97it/s, loss=0.279]\n", 540 | "Epoch 4/20: 100%|██████████| 938/938 [00:15<00:00, 59.80it/s, loss=0.144] \n", 541 | "Epoch 5/20: 100%|██████████| 938/938 [00:15<00:00, 61.09it/s, loss=0.336]\n", 542 | "Epoch 6/20: 100%|██████████| 938/938 [00:15<00:00, 60.48it/s, loss=0.443] \n", 543 | "Epoch 7/20: 100%|██████████| 938/938 [00:15<00:00, 60.35it/s, loss=0.334] \n", 544 | "Epoch 8/20: 100%|██████████| 938/938 [00:15<00:00, 60.15it/s, loss=0.413] \n", 545 | "Epoch 9/20: 100%|██████████| 938/938 [00:15<00:00, 59.85it/s, loss=0.209] \n", 546 | "Epoch 10/20: 100%|██████████| 938/938 [00:15<00:00, 60.09it/s, loss=0.228] \n", 547 | "Epoch 11/20: 100%|██████████| 938/938 [00:15<00:00, 58.82it/s, loss=0.212] \n", 548 | "Epoch 12/20: 100%|██████████| 938/938 [00:15<00:00, 59.74it/s, loss=0.0203]\n", 549 | "Epoch 13/20: 100%|██████████| 938/938 [00:15<00:00, 59.96it/s, loss=0.437] \n", 550 | "Epoch 14/20: 100%|██████████| 938/938 [00:15<00:00, 60.11it/s, loss=0.38] \n", 551 | "Epoch 15/20: 100%|██████████| 938/938 [00:15<00:00, 60.20it/s, loss=0.292] \n", 552 | "Epoch 16/20: 100%|██████████| 938/938 [00:15<00:00, 59.15it/s, loss=0.177] \n", 553 | "Epoch 17/20: 100%|██████████| 938/938 [00:16<00:00, 58.12it/s, loss=0.381] \n", 554 | "Epoch 18/20: 100%|██████████| 938/938 [00:16<00:00, 57.91it/s, loss=0.173] \n", 555 | "Epoch 19/20: 100%|██████████| 938/938 [00:16<00:00, 57.57it/s, loss=0.276] \n", 556 | "Epoch 20/20: 100%|██████████| 938/938 [00:16<00:00, 57.84it/s, loss=0.233] \n" 557 | ] 558 | } 559 | ], 560 | "source": [ 561 | "for epoch in range(num_epochs):\n", 562 | " with tqdm(train_data_loader) as tqdm_epoch:\n", 563 | " for images, labels in tqdm_epoch:\n", 564 | " tqdm_epoch.set_description(f\"Epoch {epoch + 1}/{num_epochs}\")\n", 565 | " \n", 566 | " images = images.to(device)\n", 567 | " labels = labels.to(device) \n", 568 | " \n", 569 | " # forward pass\n", 570 | " outputs = model(images)\n", 571 | " loss = criterion(outputs, labels)\n", 572 | " \n", 573 | " # backward prop\n", 574 | " optimizer.zero_grad()\n", 575 | " loss.backward()\n", 576 | " optimizer.step()\n", 577 | " tqdm_epoch.set_postfix(loss=loss.item())\n", 578 | "\n", 579 | " " 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": 20, 585 | "id": "d79162ba", 586 | "metadata": {}, 587 | "outputs": [ 588 | { 589 | "data": { 590 | "text/plain": [ 591 | "'d:\\\\oneNeuron\\\\Pytorch\\\\Pytorch-basics\\\\codebase'" 592 | ] 593 | }, 594 | "execution_count": 20, 595 | "metadata": {}, 596 | "output_type": "execute_result" 597 | } 598 | ], 599 | "source": [ 600 | "os.getcwd()" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 21, 606 | "id": "226b3c19", 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "## save trained model -\n", 611 | "os.makedirs(\"06_03_session_dir\", exist_ok=True)\n", 612 | "modle_file = os.path.join(\"06_03_session_dir\", 'CNN_model.pth')\n", 613 | "torch.save(model, modle_file)" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": null, 619 | "id": "1a640fbb", 620 | "metadata": {}, 621 | "outputs": [], 622 | "source": [] 623 | } 624 | ], 625 | "metadata": { 626 | "kernelspec": { 627 | "display_name": "Python 3 (ipykernel)", 628 | "language": "python", 629 | "name": "python3" 630 | }, 631 | "language_info": { 632 | "codemirror_mode": { 633 | "name": "ipython", 634 | "version": 3 635 | }, 636 | "file_extension": ".py", 637 | "mimetype": "text/x-python", 638 | "name": "python", 639 | "nbconvert_exporter": "python", 640 | "pygments_lexer": "ipython3", 641 | "version": "3.7.11" 642 | } 643 | }, 644 | "nbformat": 4, 645 | "nbformat_minor": 5 646 | } 647 | -------------------------------------------------------------------------------- /codebase/06_03_session_dir/CNN_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/06_03_session_dir/CNN_model.pth -------------------------------------------------------------------------------- /codebase/07.01_Transfer_learning_Download_data_create_dataloader.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "id": "9550bd5a-34b5-415d-b6aa-bf2ca24ba3bd", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import torch \n", 12 | "import torch.nn as nn\n", 13 | "import torch.nn.functional as F\n", 14 | "from torchvision import models, datasets, transforms\n", 15 | "from torch.utils.data import DataLoader\n", 16 | "\n", 17 | "import pandas as pd\n", 18 | "import numpy as np \n", 19 | "import matplotlib.pyplot as plt \n", 20 | "import seaborn as sns\n", 21 | "from sklearn.metrics import confusion_matrix\n", 22 | "from tqdm import tqdm\n", 23 | "import urllib.request as req\n", 24 | "plt.style.use('fivethirtyeight')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "cd966087-46e9-4aaf-8b06-ad6e7d6d470f", 30 | "metadata": {}, 31 | "source": [ 32 | "## Download data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 6, 38 | "id": "97bf46b5-f3fd-4916-bd8a-2b5d880ec625", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "data_URL = \"https://download.pytorch.org/tutorial/hymenoptera_data.zip\"" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 7, 48 | "id": "83c10a11-2baf-4241-9794-c87208645e5f", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "hymenoptera_data directory created\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "# create a directory\n", 61 | "def create_dirs(dir_path):\n", 62 | " os.makedirs(dir_path, exist_ok=True)\n", 63 | " print(f\"{dir_path} directory created\")\n", 64 | " \n", 65 | "ROOT_DATA_DIR = \"hymenoptera_data\"\n", 66 | "create_dirs(ROOT_DATA_DIR)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 9, 72 | "id": "c01b0dea-6ff1-4094-a0b8-058ded4c29f1", 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "downloading data...\n", 80 | "filename: hymenoptera_data\\data.zip created with info \n", 81 | "Content-Type: application/zip\n", 82 | "Content-Length: 47286322\n", 83 | "Connection: close\n", 84 | "Date: Fri, 11 Mar 2022 11:26:26 GMT\n", 85 | "Last-Modified: Wed, 15 Mar 2017 18:46:00 GMT\n", 86 | "ETag: \"5f8c32a6554f6acb4d649776e7735e48\"\n", 87 | "x-amz-version-id: null\n", 88 | "Accept-Ranges: bytes\n", 89 | "Server: AmazonS3\n", 90 | "X-Cache: Miss from cloudfront\n", 91 | "Via: 1.1 ecfda1b7359bd66eb2625616364a7174.cloudfront.net (CloudFront)\n", 92 | "X-Amz-Cf-Pop: BLR50-C1\n", 93 | "X-Amz-Cf-Id: lKA-dKlvyudX3FR72PKWIKQpcq0yjmmbjpebl7WM_IThzeU7NRM-Hg==\n", 94 | "\n", 95 | "\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "data_zip_file = \"data.zip\"\n", 101 | "data_zip_path = os.path.join(ROOT_DATA_DIR, data_zip_file)\n", 102 | "\n", 103 | "if not os.path.isfile(data_zip_file):\n", 104 | " print(\"downloading data...\")\n", 105 | " filename, headers = req.urlretrieve(data_URL, data_zip_path)\n", 106 | " print(f\"filename: {filename} created with info \\n{headers}\")\n", 107 | "else:\n", 108 | " print(f\"file is already present\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "86e20203-3dc9-4009-8ba0-e0a30fbaf28e", 114 | "metadata": {}, 115 | "source": [ 116 | "## Unzip data" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 15, 122 | "id": "47ad877f-1263-4ed4-9729-dba5b8484ccc", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "from zipfile import ZipFile\n", 127 | "\n", 128 | "unzip_data_dirname = \"unzip_data_dir\"\n", 129 | "unzip_data_dir = os.path.join(ROOT_DATA_DIR, unzip_data_dirname)\n", 130 | "\n", 131 | "if not os.path.exists(unzip_data_dir):\n", 132 | " os.makedirs(unzip_data_dir, exist_ok=True)\n", 133 | " with ZipFile(data_zip_path) as f:\n", 134 | " f.extractall(unzip_data_dir)\n", 135 | "else:\n", 136 | " print(f\"data already extacted\")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "8667fdf7-3c6a-4676-af54-e51d94e23b58", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "e887bdd9-9e0d-4f60-be98-cf47c0b60185", 150 | "metadata": {}, 151 | "source": [ 152 | "## Create data loaders" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 24, 158 | "id": "6400e8d4-3926-4a7c-9670-8b35a52aca74", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "from pathlib import Path" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 25, 168 | "id": "a7ec05a5-8c31-46ba-9a46-95bbd50dc666", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "train_path = Path(\"hymenoptera_data/unzip_data_dir/hymenoptera_data/train\")\n", 173 | "test_path = Path(\"hymenoptera_data/unzip_data_dir/hymenoptera_data/val\")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 26, 179 | "id": "b68eff04-125e-43bf-94b3-616cb374aa9d", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "img_size = (224, 224)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 28, 189 | "id": "177183a7-df46-4d1e-a2d7-43a90a1fbd9d", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "mean = torch.tensor([0.5, 0.5, 0.5])\n", 194 | "std = torch.tensor([0.5, 0.5, 0.5])" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "a76fcc8e-d435-4106-89b1-04bd5530a7fe", 200 | "metadata": {}, 201 | "source": [ 202 | "### Transformations" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 30, 208 | "id": "1b05a074-50b4-4cde-a9c9-d44706bf8718", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "train_transforms = transforms.Compose([\n", 213 | " transforms.Resize(img_size),\n", 214 | " transforms.RandomRotation(degrees=20),\n", 215 | " transforms.ToTensor(),\n", 216 | " transforms.Normalize(mean, std)\n", 217 | "])\n", 218 | "\n", 219 | "test_transforms = transforms.Compose([\n", 220 | " transforms.Resize(img_size),\n", 221 | " transforms.ToTensor(),\n", 222 | " transforms.Normalize(mean, std)\n", 223 | "])" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 32, 229 | "id": "2d04a935-e8b7-44bd-845c-3880e3a4edbd", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "train_data = datasets.ImageFolder(root=train_path, transform=train_transforms)\n", 234 | "test_data = datasets.ImageFolder(root=test_path, transform=test_transforms)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 33, 240 | "id": "0883fce9-2ded-4d1c-b81d-1bb6094c6f44", 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "text/plain": [ 246 | "{'ants': 0, 'bees': 1}" 247 | ] 248 | }, 249 | "execution_count": 33, 250 | "metadata": {}, 251 | "output_type": "execute_result" 252 | } 253 | ], 254 | "source": [ 255 | "train_data.class_to_idx" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 35, 261 | "id": "fa1c95eb-9b09-4fba-b48d-bf3b00ea95a4", 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "{'ants': 0, 'bees': 1}" 268 | ] 269 | }, 270 | "execution_count": 35, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "label_map = train_data.class_to_idx\n", 277 | "label_map" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 36, 283 | "id": "7f2b14d8-8ee1-4ae6-b2f1-63122daffd3c", 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "data": { 288 | "text/plain": [ 289 | "Dataset ImageFolder\n", 290 | " Number of datapoints: 244\n", 291 | " Root location: hymenoptera_data\\unzip_data_dir\\hymenoptera_data\\train\n", 292 | " StandardTransform\n", 293 | "Transform: Compose(\n", 294 | " Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)\n", 295 | " RandomRotation(degrees=[-20.0, 20.0], interpolation=nearest, expand=False, fill=0)\n", 296 | " ToTensor()\n", 297 | " Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))\n", 298 | " )" 299 | ] 300 | }, 301 | "execution_count": 36, 302 | "metadata": {}, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "train_data" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 37, 313 | "id": "98f06f1e-72a0-47c7-b0a5-0b8688cab929", 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "batch_size = 64\n", 318 | "\n", 319 | "train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)\n", 320 | "test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 38, 326 | "id": "e383d78c-1f28-4d5d-8a15-cce068d49258", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "data = next(iter(train_loader))" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 39, 336 | "id": "aed223e0-715e-4dfe-8faf-ddca6b95b187", 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "data": { 341 | "text/plain": [ 342 | "2" 343 | ] 344 | }, 345 | "execution_count": 39, 346 | "metadata": {}, 347 | "output_type": "execute_result" 348 | } 349 | ], 350 | "source": [ 351 | "len(data)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 40, 357 | "id": "4a842d1a-e6b8-454e-a176-803831b169d7", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "images, labels = data" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 41, 367 | "id": "67d0d7da-bacc-4732-a704-2b819a5d01ef", 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "text/plain": [ 373 | "torch.Size([64, 3, 224, 224])" 374 | ] 375 | }, 376 | "execution_count": 41, 377 | "metadata": {}, 378 | "output_type": "execute_result" 379 | } 380 | ], 381 | "source": [ 382 | "images.shape" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 42, 388 | "id": "2fcef7f8-4b52-470a-91dd-16c0c7722829", 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "text/plain": [ 394 | "torch.Size([64])" 395 | ] 396 | }, 397 | "execution_count": 42, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "labels.shape" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "id": "2b97c6ac-4d9b-44e3-a768-f794a5762bd7", 410 | "metadata": {}, 411 | "outputs": [], 412 | "source": [] 413 | } 414 | ], 415 | "metadata": { 416 | "kernelspec": { 417 | "display_name": "Python 3 (ipykernel)", 418 | "language": "python", 419 | "name": "python3" 420 | }, 421 | "language_info": { 422 | "codemirror_mode": { 423 | "name": "ipython", 424 | "version": 3 425 | }, 426 | "file_extension": ".py", 427 | "mimetype": "text/x-python", 428 | "name": "python", 429 | "nbconvert_exporter": "python", 430 | "pygments_lexer": "ipython3", 431 | "version": "3.7.11" 432 | } 433 | }, 434 | "nbformat": 4, 435 | "nbformat_minor": 5 436 | } 437 | -------------------------------------------------------------------------------- /codebase/07.02_Transfer_learning_using_pretrained_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "id": "9550bd5a-34b5-415d-b6aa-bf2ca24ba3bd", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import torch \n", 12 | "import torch.nn as nn\n", 13 | "import torch.nn.functional as F\n", 14 | "from torchvision import models, datasets, transforms\n", 15 | "from torch.utils.data import DataLoader\n", 16 | "\n", 17 | "import pandas as pd\n", 18 | "import numpy as np \n", 19 | "import matplotlib.pyplot as plt \n", 20 | "import seaborn as sns\n", 21 | "from sklearn.metrics import confusion_matrix\n", 22 | "from tqdm import tqdm\n", 23 | "import urllib.request as req\n", 24 | "plt.style.use('fivethirtyeight')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "cd966087-46e9-4aaf-8b06-ad6e7d6d470f", 30 | "metadata": {}, 31 | "source": [ 32 | "## Download data" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 6, 38 | "id": "97bf46b5-f3fd-4916-bd8a-2b5d880ec625", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "data_URL = \"https://download.pytorch.org/tutorial/hymenoptera_data.zip\"" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 7, 48 | "id": "83c10a11-2baf-4241-9794-c87208645e5f", 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "hymenoptera_data directory created\n" 56 | ] 57 | } 58 | ], 59 | "source": [ 60 | "# create a directory\n", 61 | "def create_dirs(dir_path):\n", 62 | " os.makedirs(dir_path, exist_ok=True)\n", 63 | " print(f\"{dir_path} directory created\")\n", 64 | " \n", 65 | "ROOT_DATA_DIR = \"hymenoptera_data\"\n", 66 | "create_dirs(ROOT_DATA_DIR)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 9, 72 | "id": "c01b0dea-6ff1-4094-a0b8-058ded4c29f1", 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "downloading data...\n", 80 | "filename: hymenoptera_data\\data.zip created with info \n", 81 | "Content-Type: application/zip\n", 82 | "Content-Length: 47286322\n", 83 | "Connection: close\n", 84 | "Date: Fri, 11 Mar 2022 11:26:26 GMT\n", 85 | "Last-Modified: Wed, 15 Mar 2017 18:46:00 GMT\n", 86 | "ETag: \"5f8c32a6554f6acb4d649776e7735e48\"\n", 87 | "x-amz-version-id: null\n", 88 | "Accept-Ranges: bytes\n", 89 | "Server: AmazonS3\n", 90 | "X-Cache: Miss from cloudfront\n", 91 | "Via: 1.1 ecfda1b7359bd66eb2625616364a7174.cloudfront.net (CloudFront)\n", 92 | "X-Amz-Cf-Pop: BLR50-C1\n", 93 | "X-Amz-Cf-Id: lKA-dKlvyudX3FR72PKWIKQpcq0yjmmbjpebl7WM_IThzeU7NRM-Hg==\n", 94 | "\n", 95 | "\n" 96 | ] 97 | } 98 | ], 99 | "source": [ 100 | "data_zip_file = \"data.zip\"\n", 101 | "data_zip_path = os.path.join(ROOT_DATA_DIR, data_zip_file)\n", 102 | "\n", 103 | "if not os.path.isfile(data_zip_file):\n", 104 | " print(\"downloading data...\")\n", 105 | " filename, headers = req.urlretrieve(data_URL, data_zip_path)\n", 106 | " print(f\"filename: {filename} created with info \\n{headers}\")\n", 107 | "else:\n", 108 | " print(f\"file is already present\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "86e20203-3dc9-4009-8ba0-e0a30fbaf28e", 114 | "metadata": {}, 115 | "source": [ 116 | "## Unzip data" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 15, 122 | "id": "47ad877f-1263-4ed4-9729-dba5b8484ccc", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "from zipfile import ZipFile\n", 127 | "\n", 128 | "unzip_data_dirname = \"unzip_data_dir\"\n", 129 | "unzip_data_dir = os.path.join(ROOT_DATA_DIR, unzip_data_dirname)\n", 130 | "\n", 131 | "if not os.path.exists(unzip_data_dir):\n", 132 | " os.makedirs(unzip_data_dir, exist_ok=True)\n", 133 | " with ZipFile(data_zip_path) as f:\n", 134 | " f.extractall(unzip_data_dir)\n", 135 | "else:\n", 136 | " print(f\"data already extacted\")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "8667fdf7-3c6a-4676-af54-e51d94e23b58", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "e887bdd9-9e0d-4f60-be98-cf47c0b60185", 150 | "metadata": {}, 151 | "source": [ 152 | "## Create data loaders" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 24, 158 | "id": "6400e8d4-3926-4a7c-9670-8b35a52aca74", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "from pathlib import Path" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 25, 168 | "id": "a7ec05a5-8c31-46ba-9a46-95bbd50dc666", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "train_path = Path(\"hymenoptera_data/unzip_data_dir/hymenoptera_data/train\")\n", 173 | "test_path = Path(\"hymenoptera_data/unzip_data_dir/hymenoptera_data/val\")" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 26, 179 | "id": "b68eff04-125e-43bf-94b3-616cb374aa9d", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "img_size = (224, 224)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 28, 189 | "id": "177183a7-df46-4d1e-a2d7-43a90a1fbd9d", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "mean = torch.tensor([0.5, 0.5, 0.5])\n", 194 | "std = torch.tensor([0.5, 0.5, 0.5])" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "a76fcc8e-d435-4106-89b1-04bd5530a7fe", 200 | "metadata": {}, 201 | "source": [ 202 | "### Transformations" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 30, 208 | "id": "1b05a074-50b4-4cde-a9c9-d44706bf8718", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "train_transforms = transforms.Compose([\n", 213 | " transforms.Resize(img_size),\n", 214 | " transforms.RandomRotation(degrees=20),\n", 215 | " transforms.ToTensor(),\n", 216 | " transforms.Normalize(mean, std)\n", 217 | "])\n", 218 | "\n", 219 | "test_transforms = transforms.Compose([\n", 220 | " transforms.Resize(img_size),\n", 221 | " transforms.ToTensor(),\n", 222 | " transforms.Normalize(mean, std)\n", 223 | "])" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 32, 229 | "id": "2d04a935-e8b7-44bd-845c-3880e3a4edbd", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "train_data = datasets.ImageFolder(root=train_path, transform=train_transforms)\n", 234 | "test_data = datasets.ImageFolder(root=test_path, transform=test_transforms)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 33, 240 | "id": "0883fce9-2ded-4d1c-b81d-1bb6094c6f44", 241 | "metadata": {}, 242 | "outputs": [ 243 | { 244 | "data": { 245 | "text/plain": [ 246 | "{'ants': 0, 'bees': 1}" 247 | ] 248 | }, 249 | "execution_count": 33, 250 | "metadata": {}, 251 | "output_type": "execute_result" 252 | } 253 | ], 254 | "source": [ 255 | "train_data.class_to_idx" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 35, 261 | "id": "fa1c95eb-9b09-4fba-b48d-bf3b00ea95a4", 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "{'ants': 0, 'bees': 1}" 268 | ] 269 | }, 270 | "execution_count": 35, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "label_map = train_data.class_to_idx\n", 277 | "label_map" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 36, 283 | "id": "7f2b14d8-8ee1-4ae6-b2f1-63122daffd3c", 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "data": { 288 | "text/plain": [ 289 | "Dataset ImageFolder\n", 290 | " Number of datapoints: 244\n", 291 | " Root location: hymenoptera_data\\unzip_data_dir\\hymenoptera_data\\train\n", 292 | " StandardTransform\n", 293 | "Transform: Compose(\n", 294 | " Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)\n", 295 | " RandomRotation(degrees=[-20.0, 20.0], interpolation=nearest, expand=False, fill=0)\n", 296 | " ToTensor()\n", 297 | " Normalize(mean=tensor([0.5000, 0.5000, 0.5000]), std=tensor([0.5000, 0.5000, 0.5000]))\n", 298 | " )" 299 | ] 300 | }, 301 | "execution_count": 36, 302 | "metadata": {}, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "train_data" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 37, 313 | "id": "98f06f1e-72a0-47c7-b0a5-0b8688cab929", 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "batch_size = 64\n", 318 | "\n", 319 | "train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)\n", 320 | "test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 38, 326 | "id": "e383d78c-1f28-4d5d-8a15-cce068d49258", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "data = next(iter(train_loader))" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 39, 336 | "id": "aed223e0-715e-4dfe-8faf-ddca6b95b187", 337 | "metadata": {}, 338 | "outputs": [ 339 | { 340 | "data": { 341 | "text/plain": [ 342 | "2" 343 | ] 344 | }, 345 | "execution_count": 39, 346 | "metadata": {}, 347 | "output_type": "execute_result" 348 | } 349 | ], 350 | "source": [ 351 | "len(data)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 40, 357 | "id": "4a842d1a-e6b8-454e-a176-803831b169d7", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "images, labels = data" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 41, 367 | "id": "67d0d7da-bacc-4732-a704-2b819a5d01ef", 368 | "metadata": {}, 369 | "outputs": [ 370 | { 371 | "data": { 372 | "text/plain": [ 373 | "torch.Size([64, 3, 224, 224])" 374 | ] 375 | }, 376 | "execution_count": 41, 377 | "metadata": {}, 378 | "output_type": "execute_result" 379 | } 380 | ], 381 | "source": [ 382 | "images.shape" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 42, 388 | "id": "2fcef7f8-4b52-470a-91dd-16c0c7722829", 389 | "metadata": {}, 390 | "outputs": [ 391 | { 392 | "data": { 393 | "text/plain": [ 394 | "torch.Size([64])" 395 | ] 396 | }, 397 | "execution_count": 42, 398 | "metadata": {}, 399 | "output_type": "execute_result" 400 | } 401 | ], 402 | "source": [ 403 | "labels.shape" 404 | ] 405 | }, 406 | { 407 | "cell_type": "markdown", 408 | "id": "676d295f-df27-4f01-b6e4-99a7a642cbf0", 409 | "metadata": {}, 410 | "source": [ 411 | "## Download and use pre-trained model for transfer learning" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 43, 417 | "id": "2b97c6ac-4d9b-44e3-a768-f794a5762bd7", 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [ 421 | "model = models.alexnet(pretrained=True)" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 44, 427 | "id": "2d6413aa-2da7-468e-b423-62c4a7dc1017", 428 | "metadata": {}, 429 | "outputs": [ 430 | { 431 | "name": "stdout", 432 | "output_type": "stream", 433 | "text": [ 434 | "AlexNet(\n", 435 | " (features): Sequential(\n", 436 | " (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n", 437 | " (1): ReLU(inplace=True)\n", 438 | " (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 439 | " (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", 440 | " (4): ReLU(inplace=True)\n", 441 | " (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 442 | " (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 443 | " (7): ReLU(inplace=True)\n", 444 | " (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 445 | " (9): ReLU(inplace=True)\n", 446 | " (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 447 | " (11): ReLU(inplace=True)\n", 448 | " (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 449 | " )\n", 450 | " (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n", 451 | " (classifier): Sequential(\n", 452 | " (0): Dropout(p=0.5, inplace=False)\n", 453 | " (1): Linear(in_features=9216, out_features=4096, bias=True)\n", 454 | " (2): ReLU(inplace=True)\n", 455 | " (3): Dropout(p=0.5, inplace=False)\n", 456 | " (4): Linear(in_features=4096, out_features=4096, bias=True)\n", 457 | " (5): ReLU(inplace=True)\n", 458 | " (6): Linear(in_features=4096, out_features=1000, bias=True)\n", 459 | " )\n", 460 | ")\n" 461 | ] 462 | } 463 | ], 464 | "source": [ 465 | "print(model)" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 47, 471 | "id": "a41f47b5-e6ed-4369-8d71-b05c215700f3", 472 | "metadata": {}, 473 | "outputs": [ 474 | { 475 | "data": { 476 | "text/html": [ 477 | "\n", 479 | "\n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | "
Total trainable: 61100840, non trainable: 0
 ModulesParameters
0features.0.weight23232
1features.0.bias64
2features.3.weight307200
3features.3.bias192
4features.6.weight663552
5features.6.bias384
6features.8.weight884736
7features.8.bias256
8features.10.weight589824
9features.10.bias256
10classifier.1.weight37748736
11classifier.1.bias4096
12classifier.4.weight16777216
13classifier.4.bias4096
14classifier.6.weight4096000
15classifier.6.bias1000
\n" 571 | ], 572 | "text/plain": [ 573 | "" 574 | ] 575 | }, 576 | "execution_count": 47, 577 | "metadata": {}, 578 | "output_type": "execute_result" 579 | } 580 | ], 581 | "source": [ 582 | "def count_both_params(model):\n", 583 | " model_params = {\"Modules\": list(), \"Parameters\": list()}\n", 584 | " total = {\"trainable\": 0, \"non_trainable\": 0}\n", 585 | " \n", 586 | " for name, parameters in model.named_parameters():\n", 587 | " param = parameters.numel()\n", 588 | " if not parameters.requires_grad:\n", 589 | " total[\"non_trainable\"] += param\n", 590 | " continue\n", 591 | " model_params[\"Modules\"].append(name)\n", 592 | " model_params[\"Parameters\"].append(param)\n", 593 | " total[\"trainable\"] += param\n", 594 | " df = pd.DataFrame(model_params)\n", 595 | " df = df.style.set_caption(f\"\"\"Total trainable: {total[\"trainable\"]}, non trainable: {total[\"non_trainable\"]} \"\"\")\n", 596 | " return df\n", 597 | "\n", 598 | "count_both_params(model)" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": 48, 604 | "id": "f3e57740-50c3-414e-b96f-1e2d913c06e7", 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "# Freezing all the model layers\n", 609 | "for parameters in model.parameters():\n", 610 | " parameters.requires_grad = False" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 49, 616 | "id": "2334e231-6e2d-45f8-bd62-3036ebbead34", 617 | "metadata": {}, 618 | "outputs": [ 619 | { 620 | "data": { 621 | "text/html": [ 622 | "\n", 624 | "\n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | "
Total trainable: 0, non trainable: 61100840
 ModulesParameters
\n" 636 | ], 637 | "text/plain": [ 638 | "" 639 | ] 640 | }, 641 | "execution_count": 49, 642 | "metadata": {}, 643 | "output_type": "execute_result" 644 | } 645 | ], 646 | "source": [ 647 | "count_both_params(model)" 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 50, 653 | "id": "7e2b8357-15ba-4c30-962c-4b41b4bc2c94", 654 | "metadata": {}, 655 | "outputs": [ 656 | { 657 | "data": { 658 | "text/plain": [ 659 | "Sequential(\n", 660 | " (0): Dropout(p=0.5, inplace=False)\n", 661 | " (1): Linear(in_features=9216, out_features=4096, bias=True)\n", 662 | " (2): ReLU(inplace=True)\n", 663 | " (3): Dropout(p=0.5, inplace=False)\n", 664 | " (4): Linear(in_features=4096, out_features=4096, bias=True)\n", 665 | " (5): ReLU(inplace=True)\n", 666 | " (6): Linear(in_features=4096, out_features=1000, bias=True)\n", 667 | ")" 668 | ] 669 | }, 670 | "execution_count": 50, 671 | "metadata": {}, 672 | "output_type": "execute_result" 673 | } 674 | ], 675 | "source": [ 676 | "model.classifier" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 51, 682 | "id": "eb732e18-bf6b-4824-b2f6-a7401c8ac632", 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "model.classifier = nn.Sequential(\n", 687 | " nn.Linear(in_features=9216, out_features=100, bias=True),\n", 688 | " nn.ReLU(inplace=True),\n", 689 | " nn.Dropout(p=0.5, inplace=False),\n", 690 | " nn.Linear(in_features=100, out_features=2, bias=True)\n", 691 | ")" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": 52, 697 | "id": "8dba5858-9c0f-4859-8241-0d0d055b6b0b", 698 | "metadata": {}, 699 | "outputs": [ 700 | { 701 | "name": "stdout", 702 | "output_type": "stream", 703 | "text": [ 704 | "AlexNet(\n", 705 | " (features): Sequential(\n", 706 | " (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))\n", 707 | " (1): ReLU(inplace=True)\n", 708 | " (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 709 | " (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n", 710 | " (4): ReLU(inplace=True)\n", 711 | " (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 712 | " (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 713 | " (7): ReLU(inplace=True)\n", 714 | " (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 715 | " (9): ReLU(inplace=True)\n", 716 | " (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", 717 | " (11): ReLU(inplace=True)\n", 718 | " (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 719 | " )\n", 720 | " (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))\n", 721 | " (classifier): Sequential(\n", 722 | " (0): Linear(in_features=9216, out_features=100, bias=True)\n", 723 | " (1): ReLU(inplace=True)\n", 724 | " (2): Dropout(p=0.5, inplace=False)\n", 725 | " (3): Linear(in_features=100, out_features=2, bias=True)\n", 726 | " )\n", 727 | ")\n" 728 | ] 729 | } 730 | ], 731 | "source": [ 732 | "print(model)" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": 53, 738 | "id": "462abbd3-c6e0-4173-9c99-6d1a67d6e5a8", 739 | "metadata": {}, 740 | "outputs": [ 741 | { 742 | "data": { 743 | "text/html": [ 744 | "\n", 746 | "\n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | "
Total trainable: 921902, non trainable: 2469696
 ModulesParameters
0classifier.0.weight921600
1classifier.0.bias100
2classifier.3.weight200
3classifier.3.bias2
\n" 778 | ], 779 | "text/plain": [ 780 | "" 781 | ] 782 | }, 783 | "execution_count": 53, 784 | "metadata": {}, 785 | "output_type": "execute_result" 786 | } 787 | ], 788 | "source": [ 789 | "count_both_params(model)" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": 54, 795 | "id": "adbe0f09-e319-4591-9fa2-b90b08b0bd03", 796 | "metadata": {}, 797 | "outputs": [ 798 | { 799 | "data": { 800 | "text/plain": [ 801 | "'cuda'" 802 | ] 803 | }, 804 | "execution_count": 54, 805 | "metadata": {}, 806 | "output_type": "execute_result" 807 | } 808 | ], 809 | "source": [ 810 | "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 811 | "device" 812 | ] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "execution_count": null, 817 | "id": "accbd115-9d36-4fca-9a08-f28d607df0bb", 818 | "metadata": {}, 819 | "outputs": [], 820 | "source": [] 821 | } 822 | ], 823 | "metadata": { 824 | "kernelspec": { 825 | "display_name": "Python 3 (ipykernel)", 826 | "language": "python", 827 | "name": "python3" 828 | }, 829 | "language_info": { 830 | "codemirror_mode": { 831 | "name": "ipython", 832 | "version": 3 833 | }, 834 | "file_extension": ".py", 835 | "mimetype": "text/x-python", 836 | "name": "python", 837 | "nbconvert_exporter": "python", 838 | "pygments_lexer": "ipython3", 839 | "version": "3.7.11" 840 | } 841 | }, 842 | "nbformat": 4, 843 | "nbformat_minor": 5 844 | } 845 | -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/0.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/1.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/10.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/11.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/12.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/13.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/14.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/15.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/16.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/17.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/18.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/19.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/2.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/20.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/21.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/22.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/24.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/25.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/26.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/27.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/28.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/28.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/3.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/4.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/5.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/6.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/7.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/8.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Cat/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Cat/9.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/0.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/1.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/10.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/11.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/12.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/13.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/14.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/15.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/16.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/17.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/18.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/19.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/2.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/20.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/21.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/22.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/23.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/24.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/25.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/26.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/27.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/3.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/4.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/5.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/6.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/7.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/8.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/train/Dog/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/train/Dog/9.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/0.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/1.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/10.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/11.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/12.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/13.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/14.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/15.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/16.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/17.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/18.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/19.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/2.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/20.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/21.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/22.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/24.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/25.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/26.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/27.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/28.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/28.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/3.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/4.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/5.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/6.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/7.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/8.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Cat/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Cat/9.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/0.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/1.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/10.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/11.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/12.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/13.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/14.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/15.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/16.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/17.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/18.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/19.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/2.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/20.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/21.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/21.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/22.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/22.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/23.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/23.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/24.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/24.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/25.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/25.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/26.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/26.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/27.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/3.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/4.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/5.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/6.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/7.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/8.jpg -------------------------------------------------------------------------------- /codebase/Data/img_data/validation/Dog/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/codebase/Data/img_data/validation/Dog/9.jpg -------------------------------------------------------------------------------- /codebase/Data/iris.csv: -------------------------------------------------------------------------------- 1 | sepal.length,sepal.width,petal.length,petal.width,species 2 | 5.1,3.5,1.4,0.2,Setosa 3 | 4.9,3,1.4,0.2,Setosa 4 | 4.7,3.2,1.3,0.2,Setosa 5 | 4.6,3.1,1.5,0.2,Setosa 6 | 5,3.6,1.4,0.2,Setosa 7 | 5.4,3.9,1.7,0.4,Setosa 8 | 4.6,3.4,1.4,0.3,Setosa 9 | 5,3.4,1.5,0.2,Setosa 10 | 4.4,2.9,1.4,0.2,Setosa 11 | 4.9,3.1,1.5,0.1,Setosa 12 | 5.4,3.7,1.5,0.2,Setosa 13 | 4.8,3.4,1.6,0.2,Setosa 14 | 4.8,3,1.4,0.1,Setosa 15 | 4.3,3,1.1,0.1,Setosa 16 | 5.8,4,1.2,0.2,Setosa 17 | 5.7,4.4,1.5,0.4,Setosa 18 | 5.4,3.9,1.3,0.4,Setosa 19 | 5.1,3.5,1.4,0.3,Setosa 20 | 5.7,3.8,1.7,0.3,Setosa 21 | 5.1,3.8,1.5,0.3,Setosa 22 | 5.4,3.4,1.7,0.2,Setosa 23 | 5.1,3.7,1.5,0.4,Setosa 24 | 4.6,3.6,1,0.2,Setosa 25 | 5.1,3.3,1.7,0.5,Setosa 26 | 4.8,3.4,1.9,0.2,Setosa 27 | 5,3,1.6,0.2,Setosa 28 | 5,3.4,1.6,0.4,Setosa 29 | 5.2,3.5,1.5,0.2,Setosa 30 | 5.2,3.4,1.4,0.2,Setosa 31 | 4.7,3.2,1.6,0.2,Setosa 32 | 4.8,3.1,1.6,0.2,Setosa 33 | 5.4,3.4,1.5,0.4,Setosa 34 | 5.2,4.1,1.5,0.1,Setosa 35 | 5.5,4.2,1.4,0.2,Setosa 36 | 4.9,3.1,1.5,0.2,Setosa 37 | 5,3.2,1.2,0.2,Setosa 38 | 5.5,3.5,1.3,0.2,Setosa 39 | 4.9,3.6,1.4,0.1,Setosa 40 | 4.4,3,1.3,0.2,Setosa 41 | 5.1,3.4,1.5,0.2,Setosa 42 | 5,3.5,1.3,0.3,Setosa 43 | 4.5,2.3,1.3,0.3,Setosa 44 | 4.4,3.2,1.3,0.2,Setosa 45 | 5,3.5,1.6,0.6,Setosa 46 | 5.1,3.8,1.9,0.4,Setosa 47 | 4.8,3,1.4,0.3,Setosa 48 | 5.1,3.8,1.6,0.2,Setosa 49 | 4.6,3.2,1.4,0.2,Setosa 50 | 5.3,3.7,1.5,0.2,Setosa 51 | 5,3.3,1.4,0.2,Setosa 52 | 7,3.2,4.7,1.4,Versicolor 53 | 6.4,3.2,4.5,1.5,Versicolor 54 | 6.9,3.1,4.9,1.5,Versicolor 55 | 5.5,2.3,4,1.3,Versicolor 56 | 6.5,2.8,4.6,1.5,Versicolor 57 | 5.7,2.8,4.5,1.3,Versicolor 58 | 6.3,3.3,4.7,1.6,Versicolor 59 | 4.9,2.4,3.3,1,Versicolor 60 | 6.6,2.9,4.6,1.3,Versicolor 61 | 5.2,2.7,3.9,1.4,Versicolor 62 | 5,2,3.5,1,Versicolor 63 | 5.9,3,4.2,1.5,Versicolor 64 | 6,2.2,4,1,Versicolor 65 | 6.1,2.9,4.7,1.4,Versicolor 66 | 5.6,2.9,3.6,1.3,Versicolor 67 | 6.7,3.1,4.4,1.4,Versicolor 68 | 5.6,3,4.5,1.5,Versicolor 69 | 5.8,2.7,4.1,1,Versicolor 70 | 6.2,2.2,4.5,1.5,Versicolor 71 | 5.6,2.5,3.9,1.1,Versicolor 72 | 5.9,3.2,4.8,1.8,Versicolor 73 | 6.1,2.8,4,1.3,Versicolor 74 | 6.3,2.5,4.9,1.5,Versicolor 75 | 6.1,2.8,4.7,1.2,Versicolor 76 | 6.4,2.9,4.3,1.3,Versicolor 77 | 6.6,3,4.4,1.4,Versicolor 78 | 6.8,2.8,4.8,1.4,Versicolor 79 | 6.7,3,5,1.7,Versicolor 80 | 6,2.9,4.5,1.5,Versicolor 81 | 5.7,2.6,3.5,1,Versicolor 82 | 5.5,2.4,3.8,1.1,Versicolor 83 | 5.5,2.4,3.7,1,Versicolor 84 | 5.8,2.7,3.9,1.2,Versicolor 85 | 6,2.7,5.1,1.6,Versicolor 86 | 5.4,3,4.5,1.5,Versicolor 87 | 6,3.4,4.5,1.6,Versicolor 88 | 6.7,3.1,4.7,1.5,Versicolor 89 | 6.3,2.3,4.4,1.3,Versicolor 90 | 5.6,3,4.1,1.3,Versicolor 91 | 5.5,2.5,4,1.3,Versicolor 92 | 5.5,2.6,4.4,1.2,Versicolor 93 | 6.1,3,4.6,1.4,Versicolor 94 | 5.8,2.6,4,1.2,Versicolor 95 | 5,2.3,3.3,1,Versicolor 96 | 5.6,2.7,4.2,1.3,Versicolor 97 | 5.7,3,4.2,1.2,Versicolor 98 | 5.7,2.9,4.2,1.3,Versicolor 99 | 6.2,2.9,4.3,1.3,Versicolor 100 | 5.1,2.5,3,1.1,Versicolor 101 | 5.7,2.8,4.1,1.3,Versicolor 102 | 6.3,3.3,6,2.5,Virginica 103 | 5.8,2.7,5.1,1.9,Virginica 104 | 7.1,3,5.9,2.1,Virginica 105 | 6.3,2.9,5.6,1.8,Virginica 106 | 6.5,3,5.8,2.2,Virginica 107 | 7.6,3,6.6,2.1,Virginica 108 | 4.9,2.5,4.5,1.7,Virginica 109 | 7.3,2.9,6.3,1.8,Virginica 110 | 6.7,2.5,5.8,1.8,Virginica 111 | 7.2,3.6,6.1,2.5,Virginica 112 | 6.5,3.2,5.1,2,Virginica 113 | 6.4,2.7,5.3,1.9,Virginica 114 | 6.8,3,5.5,2.1,Virginica 115 | 5.7,2.5,5,2,Virginica 116 | 5.8,2.8,5.1,2.4,Virginica 117 | 6.4,3.2,5.3,2.3,Virginica 118 | 6.5,3,5.5,1.8,Virginica 119 | 7.7,3.8,6.7,2.2,Virginica 120 | 7.7,2.6,6.9,2.3,Virginica 121 | 6,2.2,5,1.5,Virginica 122 | 6.9,3.2,5.7,2.3,Virginica 123 | 5.6,2.8,4.9,2,Virginica 124 | 7.7,2.8,6.7,2,Virginica 125 | 6.3,2.7,4.9,1.8,Virginica 126 | 6.7,3.3,5.7,2.1,Virginica 127 | 7.2,3.2,6,1.8,Virginica 128 | 6.2,2.8,4.8,1.8,Virginica 129 | 6.1,3,4.9,1.8,Virginica 130 | 6.4,2.8,5.6,2.1,Virginica 131 | 7.2,3,5.8,1.6,Virginica 132 | 7.4,2.8,6.1,1.9,Virginica 133 | 7.9,3.8,6.4,2,Virginica 134 | 6.4,2.8,5.6,2.2,Virginica 135 | 6.3,2.8,5.1,1.5,Virginica 136 | 6.1,2.6,5.6,1.4,Virginica 137 | 7.7,3,6.1,2.3,Virginica 138 | 6.3,3.4,5.6,2.4,Virginica 139 | 6.4,3.1,5.5,1.8,Virginica 140 | 6,3,4.8,1.8,Virginica 141 | 6.9,3.1,5.4,2.1,Virginica 142 | 6.7,3.1,5.6,2.4,Virginica 143 | 6.9,3.1,5.1,2.3,Virginica 144 | 5.8,2.7,5.1,1.9,Virginica 145 | 6.8,3.2,5.9,2.3,Virginica 146 | 6.7,3.3,5.7,2.5,Virginica 147 | 6.7,3,5.2,2.3,Virginica 148 | 6.3,2.5,5,1.9,Virginica 149 | 6.5,3,5.2,2,Virginica 150 | 6.2,3.4,5.4,2.3,Virginica 151 | 5.9,3,5.1,1.8,Virginica 152 | -------------------------------------------------------------------------------- /docs/Section_001_PyTorch_Introduction.md: -------------------------------------------------------------------------------- 1 | # Section: 1 PyTorch Introduction 2 | 3 | ## Introduction to PyTorch 4 | 5 | * PyTorch official docs - [pytorch.org](https://pytorch.org) 6 | 7 | ## PyTorch installation and setup 8 | 9 | * PyTorch installation reference - [Click here](https://pytorch.org/get-started/locally/) 10 | * Notebook insallation command - 11 | ```bash 12 | pip install notebook 13 | ``` 14 | 15 | ## Demo Notebooks - 16 | 17 | * PyTorch installation and setup first demo - [nbviewer](https://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/01.02_PyTorch%20installation%20and%20setup.ipynb) 18 | 19 | -------------------------------------------------------------------------------- /docs/Section_002_PyTorch_Tensors_and_Operations.md: -------------------------------------------------------------------------------- 1 | # Section: 2 PyTorch Tensors and Operations 2 | 3 | ## What is tensor? 4 | 5 | - A kind of data structure => multidimensional arrays or matrices 6 | - With tensors you enocode all your parameters. 7 | 8 | ## Type Conversions 9 | 10 | - Conversions from one datatype to another. 11 | - Conversions from torch tensors to numpy arrays and vice versa. 12 | 13 | 14 | --- 15 | 16 | ## Demo Notebooks - 17 | 18 | * What is tensor? & Type Conversions- [nbviewer](https://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/02.01%20What%20is%20tensor%20and%20Type%20Conversions.ipynb) 19 | 20 | * Mathematical Operations - [nbviewer](https://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/02.02%20Mathematical%20Operations.ipynb) 21 | 22 | * Indexing, Slicing, Concatenation, Reshaping Ops - [nbviewer](https://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/03.01%20Derivatives%2C%20Partial%20derivative%2C%20and%20Successive%20Differentiation.ipynb) 23 | 24 | -------------------------------------------------------------------------------- /docs/Section_003_AutoGrad.md: -------------------------------------------------------------------------------- 1 | # Section: 3 AutoGrad 2 | 3 | PyTorch has a capability of automatic gradient calculation ! 4 | 5 | ## Why we require AutoGrad ? 6 | 7 | When we do backpropragation we need to calculate gradient of loss function w.r.t weight 8 | 9 | If we do gradient calculation with hands it will take time and it won't be dynamic as then we would have to write each derivative manually. 10 | 11 | To resolve this issue PyTorch has a capability to calculate derivative of function automatically which is also known as AutoGrad. 12 | 13 | !!! Info 14 | 15 | A simplified model of a PyTorch tensor is as an object containing the following properties: 16 | 17 | 1. **data** — a self-reference. 18 | 2. **required_grad** — whether or not this tensor is/should be connected to the computational graph. 19 | 3. **grad** — if required_grad is true, this prop will be a sub-tensor that collects the gradients against this tensor accumulated during backwardpropagation. 20 | 4. **grad_fn** — This is a reference to the most recent operation which generated this tensor. PyTorch performs automatic differentiation by looking through the grad_fn list. 21 | 5. **is_leaf** — Whether or not this is a leaf node. 22 | 23 | 24 | ## Demo Notebooks - 25 | 26 | * Derivatives, Partial derivative, & Successive Differentiation - [nbviewer](https://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/03.01%20Derivatives%2C%20Partial%20derivative%2C%20and%20Successive%20Differentiation.ipynb) -------------------------------------------------------------------------------- /docs/Section_004_PyTorch_First_NN.md: -------------------------------------------------------------------------------- 1 | # Section 4: First Neural Network 2 | 3 | ## Demo Notebooks - 4 | 5 | * Simple ANN Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/04.01%20Simple%20ANN%20implementation.ipynb) 6 | 7 | ??? info "Alternative link" 8 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/04.01%20Simple%20ANN%20implementation.ipynb) -------------------------------------------------------------------------------- /docs/Section_005_Custom_data_loading.md: -------------------------------------------------------------------------------- 1 | # Section 5: Custom Data Loading 2 | 3 | ## Structured dataset - 4 | 5 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/05.01_Custom%20data%20loading%20for%20structured.ipynb?flush_cache=true) 6 | 7 | * Data - [iris](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/Data/iris.csv) 8 | 9 | ??? info "Alternative link" 10 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/05.01_Custom%20data%20loading%20for%20structured.ipynb) 11 | 12 | ## Un-Structured dataset - 13 | 14 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/05.02_Custom%20data%20loading%20for%20Ustructured.ipynb?flush_cache=true) 15 | 16 | * Data - [Image_data](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/Data/img_data) 17 | 18 | ??? info "Alternative link" 19 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/05.02_Custom%20data%20loading%20for%20Ustructured.ipynb) -------------------------------------------------------------------------------- /docs/Section_006_CNN.md: -------------------------------------------------------------------------------- 1 | # Section 6: Convolutional Neural Network 2 | 3 | ## Create data loader- 4 | 5 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/06.01_CNN_create_data_loader.ipynb?flush_cache=true) 6 | 7 | ??? info "Alternative link" 8 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/06.01_CNN_create_data_loader.ipynb) 9 | 10 | 11 | * Data - [Image_data](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/Data/img_data) 12 | 13 | 14 | ## Define CNN model architecture - 15 | 16 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/06.02_CNN_architecture.ipynb?flush_cache=true) 17 | 18 | ??? info "Alternative link" 19 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/06.02_CNN_architecture.ipynb) 20 | 21 | 22 | !!! note "Update: add relu in forward method" 23 | 24 | ```python hl_lines="6-9" 25 | def forward(self, x): 26 | x = self.conv_pool_01(x) 27 | x = self.conv_pool_02(x) 28 | x = self.Flatten(x) 29 | x = self.FC_01(x) 30 | x = F.relu(x) 31 | x = self.FC_02(x) 32 | x = F.relu(x) 33 | x = self.FC_03(x) 34 | return x 35 | ``` 36 | 37 | 38 | * Data - [Image_data](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/Data/img_data) 39 | 40 | 41 | 42 | 43 | 44 | ## Train CNN model - 45 | 46 | 47 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/06.03_train_CNN.ipynb?flush_cache=true) 48 | 49 | ??? info "Alternative link" 50 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/06.03_train_CNN.ipynb) 51 | 52 | * Data - [Image_data](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/Data/img_data) 53 | 54 | 55 | 56 | 57 | ## Evaluate CNN model - 58 | 59 | 60 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/06.04_evaluate_CNN.ipynb?flush_cache=true) 61 | 62 | ??? info "Alternative link" 63 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/06.04_evaluate_CNN.ipynb) 64 | 65 | * Data - [Image_data](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/Data/img_data) 66 | 67 | 68 | 69 | 70 | ## Predict using CNN model - 71 | 72 | 73 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/06.05_predict_using_CNN.ipynb?flush_cache=true) 74 | 75 | ??? info "Alternative link" 76 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/06.05_predict_using_CNN.ipynb) 77 | 78 | * Data - [Image_data](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/Data/img_data) 79 | 80 | 81 | -------------------------------------------------------------------------------- /docs/Section_007_Transfer_learning.md: -------------------------------------------------------------------------------- 1 | # Section 7: Transfer learning 2 | 3 | ## Download data, unzip and create data loader- 4 | 5 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/07.01_Transfer_learning_Download_data_create_dataloader.ipynb?flush_cache=true) 6 | 7 | ??? info "Alternative link" 8 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/07.01_Transfer_learning_Download_data_create_dataloader.ipynb) 9 | 10 | 11 | ## Download and use pretrained model- 12 | 13 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/07.02_Transfer_learning_using_pretrained_model.ipynb?flush_cache=true) 14 | 15 | ??? info "Alternative link" 16 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/07.02_Transfer_learning_using_pretrained_model.ipynb) 17 | 18 | 19 | ## Train our model- 20 | 21 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/07.03_Transfer_learning_Training_model.ipynb?flush_cache=true) 22 | 23 | ??? info "Alternative link" 24 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/07.03_Transfer_learning_Training_model.ipynb) 25 | 26 | 27 | ## Evaluate our model- 28 | 29 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/07.04_Transfer_learning_evaluate.ipynb?flush_cache=true) 30 | 31 | ??? info "Alternative link" 32 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/07.04_Transfer_learning_evaluate.ipynb) 33 | 34 | 35 | ## Prediction and visualizing prediction outcome- 36 | 37 | * Implementation - [nbviewer](http://nbviewer.org/github/c17hawke/Pytorch-basics/blob/main/codebase/07.05_Transfer_learning_prediction.ipynb?flush_cache=true) 38 | 39 | ??? info "Alternative link" 40 | Alternative link - [source repository](https://github.com/c17hawke/Pytorch-basics/blob/main/codebase/07.05_Transfer_learning_prediction.ipynb) 41 | 42 | -------------------------------------------------------------------------------- /docs/img/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/c17hawke/Pytorch-basics/b6c6035edcd9c33f384b4e1fa2ed299bff2db2ac/docs/img/.gitkeep -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | This contains PyTorch basic demo notebooks and scripts 3 | 4 | ## Paper 5 | [PyTorch: An Imperative Style, High-Performance Deep Learning Library](https://arxiv.org/abs/1912.01703) 6 | 7 | ## Content 8 | 9 | |Topic|SubTopic| 10 | |-|-| 11 | |[Section: 1 PyTorch Introduction](./Section_001_PyTorch_Introduction/) | Introduction to PyTorch | 12 | || PyTorch installation and setup | 13 | |[Section: 2 PyTorch Tensors and Operations](./Section_002_PyTorch_Tensors_and_Operations/) | What is tensor? & Type Conversions| 14 | || Mathematical Operations | 15 | || Indexing, Slicing, Concatenation, Reshaping Ops | 16 | |[Section: 3 AutoGrad](./Section_003_AutoGrad/) | Derivatives, Partial derivative, & Successive Differentiation | 17 | |[Section: 4 First Neural Network](./Section_004_PyTorch_First_NN/) | Simple ANN Implementation | 18 | |[Section: 5 Custom data loading](./Section_005_Custom_data_loading) | Structured data | 19 | ||Unstructured data| 20 | |[Section 6 Convolutional Neural Network](./Section_006_CNN) | Create data loader | 21 | ||Define CNN model architecture| 22 | ||Train CNN model| 23 | ||Evaluate CNN model| 24 | ||Predict using CNN model| 25 | |[Section 7 Transfer learning](./Section_007_Transfer_learning) | Download data and create data loader | 26 | ||Download and use pretrained model| 27 | ||Train our model| 28 | ||Evaluate our model| 29 | ||Prediction and visualizing prediction outcome| 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | # Project information 2 | site_name: PyTorch basics 3 | # defeine TOC or order of website- 4 | nav: 5 | - Home: index.md 6 | - Section 1 PyTorch Introduction: Section_001_PyTorch_Introduction.md 7 | - Section 2 PyTorch Tensors and Operations: Section_002_PyTorch_Tensors_and_Operations.md 8 | - Section 3 AutoGrad: Section_003_AutoGrad.md 9 | - Section 4 First Neural Network: Section_004_PyTorch_First_NN.md 10 | - Section 5 Custom data loading: Section_005_Custom_data_loading.md 11 | - Section 6 Convolutional Neural Network: Section_006_CNN.md 12 | - Section 7 Transfer learning: Section_007_Transfer_learning.md 13 | - References: references.md 14 | 15 | site_author: Sunny Bhaveen Chandra 16 | site_description: >- 17 | This page is meant to host the code used in the PyTorch tutorial of oneNeuron. 18 | 19 | # Repository 20 | repo_name: c17hawke/Pytorch-basics 21 | repo_url: https://github.com/c17hawke/Pytorch-basics 22 | 23 | # Configuration 24 | theme: 25 | name: material 26 | features: 27 | # - navigation.tabs 28 | - navigation.sections 29 | - toc.integrate 30 | - navigation.top 31 | language: en 32 | palette: 33 | - scheme: default 34 | toggle: 35 | icon: material/toggle-switch-off-outline 36 | name: Switch to dark mode 37 | primary: teal 38 | accent: purple 39 | - scheme: slate 40 | toggle: 41 | icon: material/toggle-switch 42 | name: Switch to light mode 43 | primary: teal 44 | accent: lime 45 | # logo: PyTorch_logo_icon.svg 46 | 47 | extra: 48 | social: 49 | - icon: fontawesome/brands/github-alt 50 | link: https://github.com/c17hawke 51 | - icon: fontawesome/brands/twitter 52 | link: https://twitter.com/c17hawke 53 | - icon: fontawesome/brands/linkedin 54 | link: https://linkedin.com/in/c17hawke 55 | - icon: fontawesome/brands/instagram 56 | link: https://www.instagram.com/c17hawke/ 57 | 58 | extra_javascript: 59 | - https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML 60 | markdown_extensions: 61 | - pymdownx.highlight: 62 | anchor_linenums: true 63 | - pymdownx.inlinehilite 64 | - pymdownx.snippets 65 | - pymdownx.superfences 66 | - admonition 67 | - pymdownx.arithmatex 68 | - footnotes 69 | - pymdownx.details 70 | - pymdownx.superfences 71 | - pymdownx.mark 72 | - pymdownx.highlight 73 | 74 | 75 | 76 | plugins: 77 | - search 78 | 79 | copyright: | 80 | © 2022 Sunny Bhaveen Chandra 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argon2-cffi==21.3.0 2 | argon2-cffi-bindings==21.2.0 3 | attrs==21.4.0 4 | backcall==0.2.0 5 | bleach==4.1.0 6 | certifi==2021.10.8 7 | cffi==1.15.0 8 | click==8.0.3 9 | colorama==0.4.4 10 | cycler==0.11.0 11 | debugpy==1.5.1 12 | decorator==5.1.1 13 | defusedxml==0.7.1 14 | entrypoints==0.3 15 | fonttools==4.29.0 16 | ghp-import==2.0.2 17 | importlib-metadata==4.10.1 18 | importlib-resources==5.4.0 19 | ipykernel==6.7.0 20 | ipython==7.31.1 21 | ipython-genutils==0.2.0 22 | jedi==0.18.1 23 | Jinja2==3.0.3 24 | jsonschema==4.4.0 25 | jupyter-client==7.1.2 26 | jupyter-core==4.9.1 27 | jupyterlab-pygments==0.1.2 28 | kiwisolver==1.3.2 29 | Markdown==3.3.6 30 | MarkupSafe==2.0.1 31 | matplotlib==3.5.1 32 | matplotlib-inline==0.1.3 33 | mergedeep==1.3.4 34 | mistune==0.8.4 35 | mkdocs==1.2.3 36 | mkdocs-material==8.1.8 37 | mkdocs-material-extensions==1.0.3 38 | nbclient==0.5.10 39 | nbconvert==6.4.1 40 | nbformat==5.1.3 41 | nest-asyncio==1.5.4 42 | notebook==6.4.8 43 | numpy==1.21.5 44 | packaging==21.3 45 | pandas==1.3.5 46 | pandocfilters==1.5.0 47 | parso==0.8.3 48 | pickleshare==0.7.5 49 | Pillow==9.0.0 50 | prometheus-client==0.13.0 51 | prompt-toolkit==3.0.26 52 | pycparser==2.21 53 | Pygments==2.11.2 54 | pymdown-extensions==9.1 55 | pyparsing==3.0.7 56 | pyrsistent==0.18.1 57 | python-dateutil==2.8.2 58 | pytz==2021.3 59 | pywin32==303 60 | pywinpty==2.0.1 61 | PyYAML==6.0 62 | pyyaml_env_tag==0.1 63 | pyzmq==22.3.0 64 | scipy==1.7.3 65 | seaborn==0.11.2 66 | Send2Trash==1.8.0 67 | six==1.16.0 68 | terminado==0.13.1 69 | testpath==0.5.0 70 | torch==1.10.2+cu113 71 | torchaudio==0.10.2+cu113 72 | torchvision==0.11.3+cu113 73 | tornado==6.1 74 | traitlets==5.1.1 75 | typing_extensions==4.0.1 76 | watchdog==2.1.6 77 | wcwidth==0.2.5 78 | webencodings==0.5.1 79 | wincertstore==0.2 80 | zipp==3.7.0 81 | prettytable==3.2.0 82 | --------------------------------------------------------------------------------