├── .gitignore ├── README.md ├── index.ipynb ├── 01a_api.ipynb ├── 07a_nbdev.ipynb ├── 07_fastcore.ipynb ├── 03b_cross_validation.ipynb ├── 01_pets.ipynb ├── 01b_kaggle.ipynb ├── 05_deployment_with_fastai.ipynb ├── 03a_unknown.ipynb ├── 04a_custom_weights.ipynb ├── 02_low_level.ipynb ├── LICENSE ├── 02a_pytorch.ipynb ├── 03_multilabel.ipynb ├── 04_semantic_segmentation.ipynb └── 05a_deployment_no_fastai.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/* 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Source notebook code for the course, stripped of all information and under a Apache license. 2 | 3 | To purchase access please go [here](https://thezachmueller.gumroad.com/l/walkwithfastai) 4 | -------------------------------------------------------------------------------- /index.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | } 10 | ], 11 | "metadata": {}, 12 | "nbformat": 4, 13 | "nbformat_minor": 5 14 | } 15 | -------------------------------------------------------------------------------- /01a_api.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "2b9a4bc9-ed9c-4f39-97dd-b657d8d0cf46", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [] 10 | } 11 | ], 12 | "metadata": {}, 13 | "nbformat": 4, 14 | "nbformat_minor": 5 15 | } 16 | -------------------------------------------------------------------------------- /07a_nbdev.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "cdb5a118-eb64-48c0-93e5-62a997733b38", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#| default_exp core" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "2d72ce68", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "#| export\n", 21 | "def addition(a,b):\n", 22 | " \"Add two numbers together\"\n", 23 | " return a+b" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "id": "422331ae-4ca1-4a29-a6db-89a256390b76", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "#| exports\n", 34 | "def subtraction(a,b):\n", 35 | " \"Subtracts two numbers\"\n", 36 | " return a-b" 37 | ] 38 | } 39 | ], 40 | "metadata": {}, 41 | "nbformat": 4, 42 | "nbformat_minor": 5 43 | } 44 | -------------------------------------------------------------------------------- /07_fastcore.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "09da2d67-cfa2-4c28-9cd6-7b81acb2c59a", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from fastcore.basics import store_attr\n", 11 | "class A:\n", 12 | " def __init__(self, b, c):\n", 13 | " store_attr()\n", 14 | "\n", 15 | "# Check they exist\n", 16 | "a = A(1,2)\n", 17 | "assert a.b == 1\n", 18 | "assert a.c == 2" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 5, 24 | "id": "31d9e12f-9920-4f68-8dd7-dfd5941edd59", 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from fastcore.foundation import L\n", 29 | "\n", 30 | "my_list = L(1,2,3,3)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 6, 36 | "id": "bdb0df52-9921-4c34-ab78-b8fd842c1de7", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "my_list.map(lambda x: x+1)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 7, 46 | "id": "b4d28532-7101-4f08-9318-1bc4cf5f498b", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "my_list.count(3)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 10, 56 | "id": "c40b1b91-a676-45cd-a18d-8280589420f5", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "my_list.sum()" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 24, 66 | "id": "24504d00-f231-47ab-8ee2-a91cd8f7cf15", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "from fastcore.basics import AttrDict\n", 71 | "from fastcore.xtras import dict2obj, obj2dict\n", 72 | "\n", 73 | "o = {\"a\":0, \"b\":1, \"c\":2}\n", 74 | "o_obj = dict2obj(o)\n", 75 | "o_dict = obj2dict(o_obj)\n", 76 | "\n", 77 | "o[\"a\"], o_obj.a, o_obj[\"a\"]" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 25, 83 | "id": "573372f1-53aa-4a38-8fe2-82ba9aad08bc", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "type(o), type(o_obj), type(o_dict)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 12, 93 | "id": "9cd1c947-cb78-4e5e-88e0-9dd1976bf6d8", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "import fastcore.test as fasttest\n", 98 | "\n", 99 | "fasttest.equals(0,0)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 15, 105 | "id": "581377a6-ff79-4435-b22e-dd288ea71a81", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "from fastcore.dispatch import typedispatch\n", 110 | "\n", 111 | "@typedispatch\n", 112 | "def my_transform(value:int):\n", 113 | " print(\"Passed an `int`!\")\n", 114 | "\n", 115 | "@typedispatch\n", 116 | "def my_transform(value:str):\n", 117 | " print(\"Passed a `str`!\")\n", 118 | "\n", 119 | "@typedispatch\n", 120 | "def my_transform(value:bool):\n", 121 | " print(\"Passed a `bool`!\")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 16, 127 | "id": "8517965c-daaa-4e1a-9f57-8fda501007df", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "my_transform(1)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 17, 137 | "id": "b7e68412-40c7-49b5-8f65-978864a33fba", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "my_transform(\"Hi!\")" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 18, 147 | "id": "ea2c7e56-5b97-408e-a127-36eb1396deea", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "my_transform(False)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 27, 157 | "id": "ee6a9720-9d40-4bd6-a3b9-054cb6a76723", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "from fastai.torch_core import apply\n", 162 | "\n", 163 | "o = [1,2,3]\n", 164 | "apply(lambda x: x+1, o)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 29, 170 | "id": "f893e16c-999b-4ace-b3dc-2595a913a0a2", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "from fastai.torch_core import concat\n", 175 | "\n", 176 | "concat([1,2,3], (4, 5), 6)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 31, 182 | "id": "719e3436-9d20-4d1e-af2e-f12d8537a284", 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "from fastai.torch_core import find_bs\n", 187 | "import torch\n", 188 | "\n", 189 | "t = torch.rand(64, 3, 224,224)\n", 190 | "find_bs(t)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 32, 196 | "id": "4c1c0a15-0d0c-48b3-b00f-a551cc5c9a3a", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "from fastai.torch_core import Module\n", 201 | "\n", 202 | "class MyLayer(Module):\n", 203 | " def __init__(self, arg1):\n", 204 | " self.arg1 = arg1" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "9ca1822f-6819-4548-adfc-89b5c40a5c2a", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [] 214 | } 215 | ], 216 | "metadata": {}, 217 | "nbformat": 4, 218 | "nbformat_minor": 5 219 | } 220 | -------------------------------------------------------------------------------- /03b_cross_validation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "0d219360-e836-48ce-88cc-c2d83a99810c", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from fastai.vision.all import *\n", 11 | "\n", 12 | "from sklearn.model_selection import StratifiedKFold" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "98bf894c-da6c-4090-a70a-9ba73d732c01", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "path = untar_data(URLs.PETS)\n", 23 | "fnames = get_image_files(path/'images')\n", 24 | "pat = r'(.+)_\\d+.jpg$'\n", 25 | "item_tfms = [RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.)), ToTensor()]\n", 26 | "batch_tfms = [IntToFloatTensor(), *aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]\n", 27 | "batch_size = 64" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "b397433d-0129-4ceb-b94f-d60487dba71d", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "random.shuffle(fnames)\n", 38 | "\n", 39 | "train_fnames = [filename for filename in fnames[:int(len(fnames) * .9)]]\n", 40 | "test_fnames = [filename for filename in fnames[int(len(fnames) * .9):]]" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "423582fc-441f-4972-818a-133a31ea55a8", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "vocab = list(map(RegexLabeller(pat=r'/([^/]+)_\\d+.*'), train_fnames))" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "25598c83-ce59-41ac-a35c-8388f6b60aa9", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "pipe = Pipeline([\n", 61 | " RegexLabeller(pat=r'/([^/]+)_\\d+.*'), Categorize(vocab=vocab)\n", 62 | "])" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "74fd0220-87d7-4351-afca-355d851df264", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "labels = list(map(pipe, train_fnames))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "fc6203ca-0030-4ed9-aeb3-45c737dcb83e", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "splits = []\n", 83 | "skf = StratifiedKFold(n_splits=10, shuffle=True)\n", 84 | "for _, valid_indexes in skf.split(\n", 85 | " np.zeros(len(labels)), labels\n", 86 | "):\n", 87 | " split = IndexSplitter(valid_indexes)\n", 88 | " splits.append(split)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "8e05e6f3-5217-4906-9aa0-f47d4c62db10", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "valid_pcts = []\n", 99 | "test_preds = []" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "bfdb8abb-6369-4a8f-8c9b-c1023ee0c153", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "def train(splitter:IndexSplitter):\n", 110 | " \"Trains a single model over a set of splits based on `splitter`\"\n", 111 | " dset = Datasets(\n", 112 | " train_fnames,\n", 113 | " tfms = [\n", 114 | " [PILImage.create], \n", 115 | " [RegexLabeller(pat=r'/([^/]+)_\\d+.*'), Categorize]\n", 116 | " ],\n", 117 | " splits = splitter(train_fnames)\n", 118 | " )\n", 119 | " dls = dset.dataloaders(\n", 120 | " bs=batch_size,\n", 121 | " after_item=item_tfms,\n", 122 | " after_batch=batch_tfms\n", 123 | " )\n", 124 | " learn = vision_learner(dls, resnet34, metrics=accuracy)\n", 125 | " learn.fit_one_cycle(1)\n", 126 | " valid_pcts.append(learn.validate()[1])\n", 127 | " dl = learn.dls.test_dl(test_fnames)\n", 128 | " preds, _ = learn.get_preds(dl=dl)\n", 129 | " test_preds.append(preds)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "62b48308-0626-4d6c-af17-55a6240b25f8", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "for splitter in splits:\n", 140 | " train(splitter)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "1e3ee0c5-ef28-4f03-a218-b32f249f1b8b", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "test_labels = torch.stack([pipe(fname) for fname in test_fnames])\n", 151 | "accuracy(test_preds[0], test_labels)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "84fabad1-844d-40a9-8666-ef3a0f2eea57", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "for preds in test_preds:\n", 162 | " print(accuracy(preds, test_labels))" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "1509fd52-9d60-40e3-972b-ca1ee6bc4dfd", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "votes = torch.stack(test_preds, dim=-1).sum(-1) / 5" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "582d50db-9e15-474c-b1ab-3e6553df6449", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "accuracy(votes, test_labels)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "id": "4ca00ee8-5379-4ca2-90e4-df341fd0d234", 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [] 192 | } 193 | ], 194 | "metadata": {}, 195 | "nbformat": 4, 196 | "nbformat_minor": 5 197 | } 198 | -------------------------------------------------------------------------------- /01_pets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "0ddeef59-1445-4d52-875f-543c628432df", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from fastai.vision.all import *" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "eb717c67-3114-43ee-82fe-239b12673dc8", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "help(untar_data)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "592d756d-fbfb-4038-85a5-03e2bd882dd7", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from nbdev.showdoc import show_doc\n", 31 | "show_doc(untar_data)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "a4ef5035-6183-4135-aa69-22fe0ee7044d", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "path = untar_data(URLs.PETS)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "06443102-05a9-4bcd-a3af-8074d2ef725c", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "set_seed(42)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "54fe5b92-d8bb-4e50-ba48-ee436187d235", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "path.ls()[:3]" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "id": "dd0ce209-0d20-440f-aed9-11ba68a4f933", 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "(path/'images').ls()[:3]" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "id": "35083bfc-5681-412b-840f-33760169e28f", 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "path = untar_data(URLs.PETS)\n", 82 | "fnames = get_image_files(path/'images')\n", 83 | "pat = r'(.+)_\\d+.jpg$'" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "3816443d-65ff-4cd6-8293-f69c8de4045f", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))\n", 94 | "batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]\n", 95 | "bs=64" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "01ea79c0-b968-45d7-a472-2225d98f859b", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "dls = ImageDataLoaders.from_name_re(\n", 106 | " path,\n", 107 | " fnames,\n", 108 | " pat,\n", 109 | " item_tfms=item_tfms,\n", 110 | " batch_tfms=batch_tfms,\n", 111 | " bs=bs\n", 112 | ")" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "4579aab8-303e-4173-ad79-5e24b0ab1a90", 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "pets = DataBlock(blocks=(ImageBlock, CategoryBlock),\n", 123 | " get_items=get_image_files,\n", 124 | " splitter=RandomSplitter(),\n", 125 | " get_y=RegexLabeller(pat = r'/([^/]+)_\\d+.*'),\n", 126 | " item_tfms=item_tfms,\n", 127 | " batch_tfms=batch_tfms)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "id": "0dfed758-8c36-49e8-9598-cd1040c929b8", 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "path_im = path/'images'" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "6aa8e7ba-b45f-44e0-bfbc-304af3d17fac", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "dls = pets.dataloaders(path_im, bs=bs)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "d1a1f794-8ec9-4df9-99d1-8e9964ebcf5b", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "dls.show_batch(max_n=9, figsize=(6,7))" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "ad13ee4a-8c18-4499-9999-b410d222259a", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "dls.vocab" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "8cc90173-44db-4ad4-be58-5693be5a5462", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "dls.vocab.o2i" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "id": "37e9408e-c32d-4de3-8cb0-6bbeb3511a23", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "learn = vision_learner(dls, resnet34, metrics=error_rate)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "id": "68bb7b80-9193-4880-952b-03747a40c755", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "learn.fit_one_cycle(4)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "2f2877e3-c4bc-485c-97a5-47d31ebb9abe", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "learn.save('stage_1')" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "01f833bc-e1b3-4cc9-8239-25a7caa5425b", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "interp = ClassificationInterpretation.from_learner(learn)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "id": "929e1699-1008-49e6-bb7d-be9a67efebf4", 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "interp.plot_top_losses(9, figsize=(15,10))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "id": "61d949f4-8c71-43e5-9587-27558d228539", 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "interp.plot_confusion_matrix(figsize=(12,12), dpi=60)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "923fae43-578e-4aeb-98d3-fd3fb81397ee", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "interp.most_confused(min_val=3)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "id": "6a89a4ce-8c64-4899-a0e7-37af05907ecc", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "learn.load('stage_1');" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "id": "e6bb0cfa-106a-4ea5-9094-2d0ebafe51a2", 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [ 267 | "learn.unfreeze()" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "id": "62efbf4a-5ffb-417e-833a-26e18e97f5e4", 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "learn.lr_find()" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "id": "069933fa-406a-47a1-8f3d-68c0eacdac5b", 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "learn.fit_one_cycle(4, lr_max=slice(1e-6, 1e-4))" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "id": "176664f0-d83c-418b-b073-f2e5410b5e38", 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "learn.save('stage_2')" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "id": "16043dad-5276-49f5-8aa5-6453d4312958", 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [] 307 | } 308 | ], 309 | "metadata": {}, 310 | "nbformat": 4, 311 | "nbformat_minor": 5 312 | } 313 | -------------------------------------------------------------------------------- /01b_kaggle.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "b603cf1c-a9dd-43d6-84d4-9fdbdf11d0a8", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install kaggle >> /dev/null" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "b1084a98-1338-45d4-b279-fdadb3a41b7f", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "!kaggle datasets download -d agrigorev/clothing-dataset-full" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "20214cf6-f1a7-4912-96e9-b70222081962", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "from fastcore.xtras import Path\n", 31 | "zip_path = Path(\"../clothing-dataset-full.zip\")\n", 32 | "zip_path.exists()" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "85609d5c-901d-49e9-9899-1a4425af0089", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import zipfile\n", 43 | "with zipfile.ZipFile(zip_path, \"r\") as zip_ref:\n", 44 | " zip_ref.extractall(\"../data\")" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "c7cd1059-cb2f-433a-9baf-25b0aea95367", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "data_path = Path(\"../data\")\n", 55 | "data_path.ls()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "8907b1fb-e8aa-4f7c-8189-f0307684ec37", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "from fastai.vision.all import *" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "b1474850-e519-489e-8c90-559a9ab098c0", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "(data_path/\"images_compressed\").ls()[:3]" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "dd1267c8-4618-46b3-8906-47ec427490a4", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "im_path = (data_path/\"images_compressed\").ls()[0]\n", 86 | "im = Image.open(im_path)\n", 87 | "im" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "67f8abc8-0b00-4634-827d-3633a51fde75", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "im.shape" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "cdc2f78d-2604-4324-8b21-f20d49a410a7", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "im_path = (data_path/\"images_original\").ls()[0]\n", 108 | "im = Image.open(im_path);" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "d9a7d743-af42-4f74-9bff-630b309812ec", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "im.shape" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "c3dd385a-d2e7-4308-89e0-ebacff0e97b5", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "bad_imgs = []\n", 129 | "for im in (data_path/\"images_compressed\").ls():\n", 130 | " try:\n", 131 | " _ = Image.open(im)\n", 132 | " except:\n", 133 | " bad_imgs.append(im)\n", 134 | " im.unlink()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "62091093-e650-44d5-b149-93a47dd2e5cc", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "len(bad_imgs)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "3749e090-52b0-4421-a788-e6c5d96a1ca6", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "df = pd.read_csv(data_path/'images.csv')\n", 155 | "df.head()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "c4e29b02-cfea-4c3e-99be-d4ed030ba12d", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "len(df[df[\"label\"] == \"Not sure\"]), len(df)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "188e5872-35ff-494c-8b6d-936cab09b723", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "len(df[~(df[\"label\"] == \"Not sure\")])" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "id": "1f3b8604-ac79-47b3-a408-d8014b4ffb8d", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "clean_df = df[~(df[\"label\"] == \"Not sure\")]\n", 186 | "clean_df.head()" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "cff67308-aaf7-49b6-8586-7af7859e9ec1", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "for img in bad_imgs:\n", 197 | " clean_df = clean_df[clean_df[\"image\"] != img.stem]" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "896ef810-20f0-447a-921e-8b4f7e836658", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "len(clean_df)" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "479a61b4-c973-4c10-be2a-07e84848e739", 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "clean_df[\"label\"].unique(), len(clean_df[\"label\"].unique())" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "id": "e810b0b9-4567-4b7e-a321-03f3326f76ba", 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "blocks = (ImageBlock, CategoryBlock)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "id": "2a934e1a-9dc8-4ea2-b9ba-d5381d85c30d", 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "get_x = ColReader(\"image\", pref=(data_path/\"images_compressed\"), suff=\".jpg\")\n", 238 | "get_y = ColReader(\"label\")" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "id": "05a2abb7-8646-46e8-b6b2-4f36924c6301", 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "item_tfms = [Resize(224)]\n", 249 | "batch_tfms = [*aug_transforms(), Normalize.from_stats(*imagenet_stats)]" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "513970ac-e6f4-48c8-9808-6f6ac0d1dca2", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "dblock = DataBlock(\n", 260 | " blocks=blocks,\n", 261 | " get_x=get_x,\n", 262 | " get_y=get_y,\n", 263 | " item_tfms=item_tfms,\n", 264 | " batch_tfms=batch_tfms\n", 265 | ")" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "96635bc1-d3f6-4cfa-bf37-7a297a64181a", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "dls = dblock.dataloaders(clean_df)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "id": "e3f01df4-b5a4-4052-9c8c-ad298b902f3d", 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "dls.show_batch()" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "id": "d75f2346-d0ee-4f2b-a6c9-7749e720dfc7", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "learn = vision_learner(dls, resnet34, metrics=accuracy)\n", 296 | "learn.fine_tune(5)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "fe9e2d0f-0485-4a75-a47d-5b33bbdfb8cf", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [] 306 | } 307 | ], 308 | "metadata": {}, 309 | "nbformat": 4, 310 | "nbformat_minor": 5 311 | } 312 | -------------------------------------------------------------------------------- /05_deployment_with_fastai.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "ab6c91c8", 7 | "metadata": { 8 | "id": "ab6c91c8" 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "from fastai.vision.all import *" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "a19efa03-9f6c-4ed4-b8e2-a866b5ddcf76", 19 | "metadata": { 20 | "id": "a19efa03-9f6c-4ed4-b8e2-a866b5ddcf76" 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "path = untar_data(URLs.PETS)/'images'\n", 25 | "fnames = get_image_files(path)\n", 26 | "pat = r'/([^/]+)_\\d+.*'\n", 27 | "batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]\n", 28 | "item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))\n", 29 | "bs=64\n", 30 | "\n", 31 | "pets = DataBlock(\n", 32 | " blocks=(ImageBlock, CategoryBlock),\n", 33 | " get_items=get_image_files,\n", 34 | " splitter=RandomSplitter(),\n", 35 | " get_y=RegexLabeller(pat = r'/([^/]+)_\\d+.*'),\n", 36 | " item_tfms=item_tfms,\n", 37 | " batch_tfms=batch_tfms\n", 38 | ")\n", 39 | "dls = pets.dataloaders(path, bs=bs)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "id": "e3bbb74d-ce46-409b-8ec8-066ef08ab66d", 46 | "metadata": { 47 | "id": "e3bbb74d-ce46-409b-8ec8-066ef08ab66d" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "learn = vision_learner(dls, \"vit_tiny_patch16_224\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "id": "501ca71a-71e0-447f-8409-1da7b8281bcb", 58 | "metadata": { 59 | "colab": { 60 | "base_uri": "https://localhost:8080/", 61 | "height": 144 62 | }, 63 | "id": "501ca71a-71e0-447f-8409-1da7b8281bcb", 64 | "outputId": "a399b80f-19d6-4a99-8d1d-7a4353837b4e" 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "learn.fine_tune(1)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 5, 74 | "id": "af6b19cf-1db1-4cf8-b69e-6d3c752950b3", 75 | "metadata": { 76 | "colab": { 77 | "base_uri": "https://localhost:8080/" 78 | }, 79 | "id": "af6b19cf-1db1-4cf8-b69e-6d3c752950b3", 80 | "outputId": "46f269fa-8661-40ca-95e9-df8b087021ac" 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "learn.export(\"exported_fastai\")\n", 85 | "learn.save(\"exported_model\", with_opt=False)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 6, 91 | "id": "30b4baf4-726a-4f7b-954b-f82d12bb8428", 92 | "metadata": { 93 | "id": "30b4baf4-726a-4f7b-954b-f82d12bb8428" 94 | }, 95 | "outputs": [], 96 | "source": [ 97 | "learn = load_learner(\"exported_fastai\")" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 7, 103 | "id": "5d842d02-863a-42ce-abc2-8c24cf7dcd72", 104 | "metadata": { 105 | "colab": { 106 | "base_uri": "https://localhost:8080/" 107 | }, 108 | "id": "5d842d02-863a-42ce-abc2-8c24cf7dcd72", 109 | "outputId": "8cacb094-148e-495c-eed8-68fea3690638" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "fname = fnames[0]; fname" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 8, 119 | "id": "a2b55797-d442-4f3a-8206-9f30d556ba4a", 120 | "metadata": { 121 | "colab": { 122 | "base_uri": "https://localhost:8080/", 123 | "height": 176 124 | }, 125 | "id": "a2b55797-d442-4f3a-8206-9f30d556ba4a", 126 | "outputId": "654430d8-0a19-497e-ddef-6a232626c994" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "learn.predict(fname)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 9, 136 | "id": "2e8f0866-86d9-42b8-90c7-8e70c7ebcb4a", 137 | "metadata": { 138 | "id": "2e8f0866-86d9-42b8-90c7-8e70c7ebcb4a" 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "dl = learn.dls.test_dl([fname])" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 16, 148 | "id": "91e31268-f365-46a3-8b20-569fa8ecb83b", 149 | "metadata": { 150 | "colab": { 151 | "base_uri": "https://localhost:8080/", 152 | "height": 17 153 | }, 154 | "id": "91e31268-f365-46a3-8b20-569fa8ecb83b", 155 | "outputId": "54dd6cef-5fcb-47df-e143-b12196623ca5" 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "preds = learn.get_preds(dl=dl)[0]" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 17, 165 | "id": "090f3427-9f51-492c-a43b-f6b9a1e5c92e", 166 | "metadata": { 167 | "colab": { 168 | "base_uri": "https://localhost:8080/" 169 | }, 170 | "id": "090f3427-9f51-492c-a43b-f6b9a1e5c92e", 171 | "outputId": "2ccd54e1-a674-4898-b9cf-96d8dfaf02db" 172 | }, 173 | "outputs": [], 174 | "source": [ 175 | "softmax = preds.softmax(dim=1)\n", 176 | "argmax = preds.argmax(dim=1)\n", 177 | "labels = [learn.dls.vocab[pred] for pred in argmax]\n", 178 | "softmax, argmax, labels" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 20, 184 | "id": "1c858747-2240-433e-b97d-02974f7c4e69", 185 | "metadata": { 186 | "colab": { 187 | "base_uri": "https://localhost:8080/", 188 | "height": 35 189 | }, 190 | "id": "1c858747-2240-433e-b97d-02974f7c4e69", 191 | "outputId": "ff01093b-c7d5-46b7-c0d7-50ef813bf82b" 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "%%timeit\n", 196 | "dl = learn.dls.test_dl([fname])\n", 197 | "_ = learn.get_preds(dl=dl)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 21, 203 | "id": "25d8977a-c1a4-44a4-91bc-066b91340ddd", 204 | "metadata": { 205 | "colab": { 206 | "base_uri": "https://localhost:8080/", 207 | "height": 35 208 | }, 209 | "id": "25d8977a-c1a4-44a4-91bc-066b91340ddd", 210 | "outputId": "e13201ee-818f-45a8-a483-94b0e0770914" 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "%%timeit\n", 215 | "_ = learn.predict(fname)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 36, 221 | "id": "lPhjbDVIH2nK", 222 | "metadata": { 223 | "colab": { 224 | "base_uri": "https://localhost:8080/" 225 | }, 226 | "id": "lPhjbDVIH2nK", 227 | "outputId": "412221b0-bf2c-48a4-b43b-86dc496e241e" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "learn.dls.cuda()" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 40, 237 | "id": "bs5fGGmXH6Dz", 238 | "metadata": { 239 | "colab": { 240 | "base_uri": "https://localhost:8080/", 241 | "height": 35 242 | }, 243 | "id": "bs5fGGmXH6Dz", 244 | "outputId": "ee952a38-fbe6-4403-bd54-5f517430f2df" 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "%%timeit\n", 249 | "dl = learn.dls.test_dl([fname])\n", 250 | "_ = learn.get_preds(dl=dl)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 41, 256 | "id": "y05DiOtOH7fJ", 257 | "metadata": { 258 | "colab": { 259 | "base_uri": "https://localhost:8080/", 260 | "height": 35 261 | }, 262 | "id": "y05DiOtOH7fJ", 263 | "outputId": "7b3c149d-c14e-4af6-abdc-5b93b57c9bde" 264 | }, 265 | "outputs": [], 266 | "source": [ 267 | "%%timeit\n", 268 | "dl = learn.dls.test_dl([fname], num_workers=0)\n", 269 | "_ = learn.get_preds(dl=dl)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "id": "z5-ifLCnI64w", 276 | "metadata": { 277 | "id": "z5-ifLCnI64w" 278 | }, 279 | "outputs": [], 280 | "source": [] 281 | } 282 | ], 283 | "metadata": { 284 | "kernelspec": { 285 | "display_name": "Python 3", 286 | "language": "python", 287 | "name": "python3" 288 | }, 289 | "language_info": { 290 | "codemirror_mode": { 291 | "name": "ipython", 292 | "version": 3 293 | }, 294 | "file_extension": ".py", 295 | "mimetype": "text/x-python", 296 | "name": "python", 297 | "nbconvert_exporter": "python", 298 | "pygments_lexer": "ipython3", 299 | "version": "3.9.16" 300 | } 301 | }, 302 | "nbformat": 4, 303 | "nbformat_minor": 5 304 | } 305 | -------------------------------------------------------------------------------- /03a_unknown.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1f1eeb17-bec8-4cf0-80be-5a3034669fb6", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from fastai.vision.all import *" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "0ee21d7c-fbf3-4fc5-8b9d-bc363dd4a666", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import torch\n", 21 | "from torch import tensor\n", 22 | "from torchvision.models.resnet import resnet34\n", 23 | "import requests\n", 24 | "\n", 25 | "import pandas as pd\n", 26 | "from fastcore.transform import Pipeline\n", 27 | "from fastcore.xtras import Path\n", 28 | "\n", 29 | "from fastai.data.core import Datasets\n", 30 | "from fastai.data.block import DataBlock, MultiCategoryBlock\n", 31 | "from fastai.vision.data import ImageBlock\n", 32 | "from fastai.data.external import URLs, untar_data\n", 33 | "from fastai.data.transforms import (\n", 34 | " ColReader,\n", 35 | " IntToFloatTensor, \n", 36 | " MultiCategorize,\n", 37 | " Normalize, \n", 38 | " OneHotEncode, \n", 39 | " RandomSplitter,\n", 40 | " RegexLabeller,\n", 41 | " get_image_files\n", 42 | ")\n", 43 | "\n", 44 | "from fastai.metrics import accuracy_multi\n", 45 | "\n", 46 | "from fastai.vision.augment import aug_transforms, RandomResizedCrop\n", 47 | "from fastai.vision.core import PILImage\n", 48 | "from fastai.vision.learner import vision_learner\n", 49 | "from fastai.learner import Learner\n", 50 | "from fastai.callback.schedule import Learner" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "5ea69753-05bf-4575-8cea-4cd52c0008c7", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "path = untar_data(URLs.PETS)/'images'" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "c8f2e5f5-a4e7-4195-93a6-18fbbcf9e350", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "fnames = get_image_files(path/'images')\n", 71 | "pat = r'(.+)_\\d+.jpg$'\n", 72 | "item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))\n", 73 | "batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]\n", 74 | "bs=64" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "7c843f6a-cf67-4c55-820d-34aac9c6fd37", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "pets = DataBlock(blocks=(ImageBlock, CategoryBlock),\n", 85 | " get_items=get_image_files,\n", 86 | " splitter=RandomSplitter(),\n", 87 | " get_y=RegexLabeller(pat = r'/([^/]+)_\\d+.*'),\n", 88 | " item_tfms=item_tfms,\n", 89 | " batch_tfms=batch_tfms)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "8f272263-af5d-4b9c-b4d3-8551c1ca6a94", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "def label_to_list(o): return [o]" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "47655d34-0b97-4c9b-a192-52f73887fa94", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "multi_pets = DataBlock(\n", 110 | " blocks=(ImageBlock, MultiCategoryBlock),\n", 111 | " get_items=get_image_files,\n", 112 | " splitter=RandomSplitter(),\n", 113 | " get_y=Pipeline(\n", 114 | " [RegexLabeller(pat = r'/([^/]+)_\\d+.*'), label_to_list]\n", 115 | " ),\n", 116 | " item_tfms=item_tfms,\n", 117 | " batch_tfms=batch_tfms\n", 118 | ")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "3d579c51-9c07-4f44-9b0c-2f3e72783a4a", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "dls = multi_pets.dataloaders(path, bs=32)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "9d94369c-f6b4-4662-8eb9-359c9c777865", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "dls.show_batch()" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "614a6e4d-d34b-4026-846a-da7ea1af32ab", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "train_idxs, valid_idxs = RandomSplitter()(get_image_files(path))" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "32e190b7-730c-4b99-a48e-ccac26c51608", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "tfms = [\n", 159 | " [PILImage.create],\n", 160 | " [\n", 161 | " RegexLabeller(pat = r'/([^/]+)_\\d+.*'),\n", 162 | " label_to_list,\n", 163 | " MultiCategorize(vocab=list(dls.vocab)),\n", 164 | " OneHotEncode(len(dls.vocab))\n", 165 | " ]\n", 166 | "]" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "a20745f7-12f4-4109-85b0-3686ea61e516", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "dsets = Datasets(get_image_files(path), tfms=tfms, splits=[train_idxs, valid_idxs])" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "id": "7d570e2e-237e-4a33-8468-263a2acd655e", 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "dsets[0]" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "8c83aad4-263c-487c-94ea-da0124d759f9", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "dls = dsets.dataloaders(\n", 197 | " after_item=[ToTensor(), RandomResizedCrop(460, min_scale=.75)],\n", 198 | " after_batch=[IntToFloatTensor(), *aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)],\n", 199 | " bs=32\n", 200 | ")" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "id": "def5ab39-f6a4-41c1-a6ac-26b0ee392dec", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "dls.show_batch()" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "id": "80f023a6-aa46-442e-860a-c8d57727287e", 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "learn = vision_learner(dls, resnet34, metrics=[partial(accuracy_multi, thresh=0.95)])" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "4783a8e1-b7ca-4d0b-b9ad-6cce87970ad1", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "learn.fine_tune(4, 2e-3)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "9a230829-0ac9-4f4e-8153-4192d70d212d", 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "learn.loss_func.thresh = 0.95" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "id": "a8cf8846-55ce-4d91-8299-62dd9e522bec", 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "PERSIAN_CAT_URL = \"https://azure.wgp-cdn.co.uk/app-yourcat/posts/iStock-174776419-1.jpg\"" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "id": "0684f0d0-6738-4e01-b647-9c4e6a0d82c0", 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "response = requests.get(PERSIAN_CAT_URL)\n", 261 | "im = PILImage.create(response.content)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "52e01576-6511-4ac2-b91e-bc4997ce1515", 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "im.show();" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "id": "89abfa26-86cd-4aa6-8010-39a0bf44df3d", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "learn.predict(im)[0]" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "id": "5cbff0f1-55f7-4eb5-b255-6161c29c4e59", 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "DONKEY_URL = \"https://cdn.britannica.com/68/143568-050-5246474F/Donkey.jpg\"\n", 292 | "response = requests.get(DONKEY_URL)\n", 293 | "learn.predict(response.content)[0]" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "9f9b286e-1584-4cf3-8fae-0670f06b8b00", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [] 303 | } 304 | ], 305 | "metadata": {}, 306 | "nbformat": 4, 307 | "nbformat_minor": 5 308 | } 309 | -------------------------------------------------------------------------------- /04a_custom_weights.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "7a5c0c7b-5a9e-4957-a943-fee098749aef", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from fastai.vision.all import *" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "f53849ec-071d-4721-8012-84225b52c586", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "path = untar_data(URLs.PETS)/'images'\n", 21 | "fnames = get_image_files(path)\n", 22 | "pat = r'/([^/]+)_\\d+.*'\n", 23 | "batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]\n", 24 | "item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))\n", 25 | "bs=64\n", 26 | "\n", 27 | "pets = DataBlock(\n", 28 | " blocks=(ImageBlock, CategoryBlock),\n", 29 | " get_items=get_image_files,\n", 30 | " splitter=RandomSplitter(),\n", 31 | " get_y=RegexLabeller(pat = r'/([^/]+)_\\d+.*'),\n", 32 | " item_tfms=item_tfms,\n", 33 | " batch_tfms=batch_tfms\n", 34 | ")\n", 35 | "dls = pets.dataloaders(path, bs=bs)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "6a1f5509-418f-4b86-8d26-a90b99806aed", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "!pip install timm >> /dev/null" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "54b463db-543b-4b72-a4f7-3b046f302a53", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "from timm import create_model\n", 56 | "net = create_model(\"vit_tiny_patch16_224\", pretrained=True)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "57af3310-ec5f-4917-829c-d81f073f701f", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "learn = vision_learner(dls, models.resnet18)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "9485cdc4-7a59-4c0e-9132-3e58846a0e01", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "learn.model[-1]" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "37c0a746-befa-4b3e-a01b-efbed1769935", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "net[-1]" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "2bb8d847-408b-47e2-bba5-f2a6bf23ab86", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "len(learn.model)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "0f762dcf-4d71-4fcb-a0c2-5e6fd1c9ca9c", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "len(net)" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "6061994e-1a94-49b5-b9d9-4a007c28e09f", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "class MyModel(nn.Module):\n", 117 | " def __init__(self):\n", 118 | " self.l1 = nn.Linear(1,1)\n", 119 | " self.l2 = nn.linear(1,1)\n", 120 | " def forward(self, x):\n", 121 | " return self.l2(self.l1(x))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "905cb189-5a4d-4fef-b0cd-db85e1109ac2", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "class MyModel(nn.Sequential):\n", 132 | " def __init__(self):\n", 133 | " layers = [\n", 134 | " nn.Linear(1,1),\n", 135 | " nn.Linear(1,1),\n", 136 | " ]\n", 137 | " super().__init__(*layers)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "8aec2400-819d-4225-8745-7c865217f857", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "net = MyModel()\n", 148 | "net[0], net[1]" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "736a04d6-e999-4a03-8f81-5bf3e9f83009", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "def custom_cut_model(model:nn.Module, cut:typing.Union[int, typing.Callable]):\n", 159 | " \"\"\"\n", 160 | " Cuts `model` into an `nn.Sequential` based on `cut`. \n", 161 | " \"\"\"\n", 162 | " if isinstance(cut, int):\n", 163 | " return nn.Sequential(*list(model.children())[:cut])\n", 164 | " elif callable(cut):\n", 165 | " return cut(model)\n", 166 | " else:\n", 167 | " raise NameError(\"`cut` must either be an integer or a function\")" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "59a6edbe-0579-456b-a0d8-9d418957f796", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "class CustomTimmBody(nn.Module):\n", 178 | " \"\"\"\n", 179 | " A small submodule to work with `timm` models more easily\n", 180 | " \"\"\"\n", 181 | " def __init__(\n", 182 | " self, \n", 183 | " model, \n", 184 | " pretrained:bool=True, \n", 185 | " cut=None, \n", 186 | " n_in:int=3\n", 187 | " ):\n", 188 | " super().__init__()\n", 189 | " self.needs_pooling = model.default_cfg.get('pool_size', None)\n", 190 | " if cut is None:\n", 191 | " self.model = model\n", 192 | " else:\n", 193 | " self.model = custom_cut_model(model, cut)\n", 194 | " \n", 195 | " def forward(self, x): \n", 196 | " if self.needs_pooling:\n", 197 | " return self.model.forward_features(x)\n", 198 | " else:\n", 199 | " return self.model(x)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "ec72cb40-b963-4558-8fde-b3b7e87d601b", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "body = CustomTimmBody(\n", 210 | " create_model(\"vit_tiny_patch16_224\", pretrained=True, num_classes=0, in_chans=3)\n", 211 | ").train()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "id": "8e3c549f-0abc-4f7a-9b98-52d6d36feec1", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "head = create_head(body.model.num_features, dls.c, pool=None)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "7341eedb-6249-43be-b11c-f51aa7454166", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "head" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "c2cfe3dc-5db0-442f-858c-4fbc167391b8", 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "x = torch.randn(2,3,224,224)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "id": "e569fb07-fc24-4d9a-84a2-35aaaef227f5", 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "out = head(body(x))\n", 252 | "out, out.shape" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "id": "cfa5b80b-9a96-43fd-8c26-0f4c53105453", 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "apply_init?" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "id": "fce78840-d9d1-4129-ad42-5047c4f6d95c", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "apply_init(head)" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "id": "69a98b76-91ff-4c2a-a3dd-84f3c129f6f4", 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "head(body(x))" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "id": "44ff7136-c3f3-4464-951b-1bb1f189bed5", 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "def my_split_func(model:nn.Module):\n", 293 | " \"A function that splits layers by their parameters\"\n", 294 | " return L(model[0], model[1:]).map(params)" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "id": "3e7a18a8-8cbc-4f9f-8680-18e321e54c82", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "def splitter(model):\n", 305 | " \"Splits a model by head and body\"\n", 306 | " return L(model[0], model[1]).map(params)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "id": "384fde99-5b09-435a-9b08-b4b9e2534e56", 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "learn = Learner(\n", 317 | " dls,\n", 318 | " nn.Sequential(body, head),\n", 319 | " splitter=splitter\n", 320 | ")" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "id": "0300c6e8-4f5a-4862-8fac-1a37070f8892", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "print(learn.summary()[-250:])" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "id": "cb25c1f5-142d-46ef-8070-2227ec83de91", 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "learn.freeze()" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "04d17ea6-b1a2-4bb9-9917-87e49aa26fc1", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [ 350 | "print(learn.summary()[-295:])" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "id": "21379796-a099-4c01-9ada-97e5a56271e0", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [] 360 | } 361 | ], 362 | "metadata": { 363 | "kernelspec": { 364 | "display_name": "Python 3", 365 | "language": "python", 366 | "name": "python3" 367 | } 368 | }, 369 | "nbformat": 4, 370 | "nbformat_minor": 5 371 | } 372 | -------------------------------------------------------------------------------- /02_low_level.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "24491d4a-7940-44db-8900-c1f72b1b6d31", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from fastai.vision.all import *" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "0e60d5dd-44ab-4df7-bc3d-7b6534d25cbf", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "path = untar_data(URLs.MNIST); path, path.ls()" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "ac0ba66f-32b3-4172-9816-8efd366229a0", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "items = get_image_files(path)\n", 31 | "items[:10]" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "d6d14076-7ee0-457b-945a-a3a42125b6e9", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "im = PILImageBW.create(items[0])" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "9f0d2d75-aafc-4f24-8851-5f323e77f3d7", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "im.show()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "eb7f02e6-0d1d-41b8-88e0-117b4ee7863f", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "splitter = GrandparentSplitter(\n", 62 | " train_name=\"training\",\n", 63 | " valid_name=\"testing\",\n", 64 | ")" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "4be60568-9f57-4128-ab98-7d36b52bc8ba", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "splits = splitter(items)\n", 75 | "splits[0][:5], splits[1][:5]" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "74ae470c-1cd8-4dc8-8fa9-1ba1366ce063", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "len(splits[0]), len(splits[1])" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "261eda5d-505a-41bc-9fa3-37e23c78dbcf", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "dsrc = Datasets(\n", 96 | " items,\n", 97 | " tfms=[[PILImageBW.create], [parent_label, Categorize]],\n", 98 | " splits=splits\n", 99 | ")" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "c9e5b730-5c21-4080-a536-eeeca646cf90", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "show_at(dsrc.train, 3);" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "81b11e3b-5495-4949-84b3-199170fda8df", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "item_tfms = [CropPad(34), RandomCrop(size=28), ToTensor()]" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "5eeafd6b-6f80-4e23-b299-6c5f38f776e0", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "batch_tfms = [IntToFloatTensor(), Normalize()]" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "b7563354-5188-47cc-b186-8ce36e1abe4f", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "dls = dsrc.dataloaders(\n", 140 | " bs=128,\n", 141 | " after_item=item_tfms,\n", 142 | " after_batch=batch_tfms\n", 143 | ")" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "b3fa8f1b-3447-4729-b458-44e71901d6ed", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "dls.show_batch()" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "e026e54e-723a-4d44-b6a3-2666c58012ee", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "xb, yb = dls.one_batch()" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "id": "e2e35cbf-9d55-4e1f-a428-54b93c1dcb9f", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "xb.shape, yb.shape" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "adf7f4af-0383-4780-96ec-33ec83019817", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "dls.c" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "id": "d5d567bc-82fd-437a-bf88-3276d4b5e75f", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "model = resnet18( num_classes=dls.c).cuda(); model.fc" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "id": "88049901-76b3-48cd-bea7-aea872a5632c", 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "model(xb)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "2419b2bb-c6e4-4164-aea5-0d1ca00ecb92", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "model.conv1" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "id": "1328f9f9-d70f-42b8-bac9-4d1de41d0588", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "model.conv1 = nn.Conv2d(\n", 224 | " in_channels=1, \n", 225 | " out_channels=64, \n", 226 | " kernel_size=(7,7), \n", 227 | " stride=(2,2), \n", 228 | " padding=(3,3), \n", 229 | " bias=False\n", 230 | ")" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "1a3ecaf1-7e85-4fde-8e9a-39ff0a0ab854", 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "model.cuda();" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "id": "50dcff75-e794-400f-a12a-0ce142146e05", 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "model(xb)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "id": "34a31f94-0c5c-438c-a757-aa44eb72a487", 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "learn = Learner(dls, model, metrics=[accuracy])" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "id": "f484da66-64ed-495b-bb65-d60a9ec1b53f", 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "learn.fit(1)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "id": "4ff7b06e-3703-43d5-812e-9bc898613fa1", 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "items[0]" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "id": "e8ba7546-3b32-4711-8806-23bd4baba706", 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [ 290 | "preds = learn.predict(items[0]); preds" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "id": "7ae9adb4-e26e-4d06-831d-fbf773146d94", 297 | "metadata": {}, 298 | "outputs": [], 299 | "source": [ 300 | "dl = learn.dls.test_dl(items[:1])\n", 301 | "inps, preds, _, decoded_preds = learn.get_preds(dl=dl, with_decoded=True, with_input=True)\n", 302 | "image, class_prediction = learn.dls.decode_batch((inps,) + tuplify(decoded_preds))[0]" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "id": "3e4b30f4-5a75-43d6-ae3d-d0a94187bd33", 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "class_prediction, decoded_preds" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "id": "57488801-3b3c-443f-9f90-abc1e9a02bde", 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "learn.dls.after_item, learn.dls.after_batch" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "id": "bd8eb60d-f443-49b6-a20a-20ca6ac51b97", 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "learn.dls.after_batch[1].mean, learn.dls.after_batch[1].std" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "id": "11fdb9d9-9e7a-4233-98fd-1c24665dbe84", 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "type_tfms = Pipeline([PILImageBW.create])\n", 343 | "item_tfms = Pipeline([CropPad((34,34)), CropPad((28,28)), ToTensor()])\n", 344 | "batch_tfms = Pipeline([\n", 345 | " IntToFloatTensor(), \n", 346 | " Normalize.from_stats([[[[0.1302]]]], [[[[0.3081]]]])\n", 347 | "])" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "id": "b8720a65-9884-4b97-bad1-42e1fd3879c3", 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "items[0]" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "id": "4e408728-e4a0-41ad-85a8-0d324b3dc828", 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "im = type_tfms(items[0]); im.shape" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "id": "24d51d68-ddb0-48f9-8fc7-84b0f316722f", 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "item_tfms(im).shape" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "id": "20a4a68a-7b96-4519-b698-d8e86773c7ce", 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "batch_tfms(item_tfms(im).cuda()).shape" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "id": "ba799a20-8982-4963-a5e2-4f3960318652", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "net = learn.model\n", 398 | "net.eval()\n", 399 | "t_im = batch_tfms(item_tfms(im).cuda())" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "id": "bea44f3f-5023-4154-becb-bc244cfc808d", 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "with torch.no_grad():\n", 410 | " out = net(t_im)" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "id": "73bd7abd-5fa1-4d77-8563-8d3989b5658e", 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "out.argmax(dim=-1)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "id": "14a8dd39-8150-4d80-aef0-a72bf847abba", 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "out.softmax(-1)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "id": "ca8a02a7-09f1-4b5f-9a7a-3ebb2077e13d", 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [] 440 | } 441 | ], 442 | "metadata": {}, 443 | "nbformat": 4, 444 | "nbformat_minor": 5 445 | } 446 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /02a_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "99a9b5b8-b507-411d-a0ab-bb00eb2c475b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from fastai.data.external import untar_data, URLs\n", 11 | "from fastai.vision.data import imagenet_stats\n", 12 | "from fastcore.xtras import Path\n", 13 | "\n", 14 | "dataset_path = untar_data(URLs.PETS)\n", 15 | "dataset_path.ls()" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "id": "83c9b3b2-afb8-4265-b81d-bfd55ef00cd0", 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "imagenet_stats" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "id": "c8f4f689-b0f0-4674-a328-1331d96a7f52", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "from torch import nn\n", 36 | "from torchvision.transforms import CenterCrop, RandomResizedCrop, ToTensor, Normalize\n", 37 | "\n", 38 | "train_transforms = nn.Sequential(\n", 39 | " RandomResizedCrop((224,224)),\n", 40 | " Normalize(*imagenet_stats)\n", 41 | ")\n", 42 | "\n", 43 | "valid_transforms = nn.Sequential(\n", 44 | " CenterCrop((224,224)),\n", 45 | " Normalize(*imagenet_stats)\n", 46 | ")" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "82db4e5b-ba00-4aed-9f8a-0716c7459a6e", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "import re\n", 57 | "from PIL import Image\n", 58 | "from torch.utils.data import Dataset\n", 59 | "\n", 60 | "# This example is highly based on the work of Sylvain Gugger\n", 61 | "# for the Accelerate notebook example which can be found here: \n", 62 | "# https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_cv_example.ipynb\n", 63 | "class PetsDataset(Dataset):\n", 64 | " \"A basic dataset that will return a tuple of (image, label)\"\n", 65 | " def __init__(self, filenames:list, transforms:nn.Sequential, label_to_int:dict):\n", 66 | " self.filenames = filenames\n", 67 | " self.transforms = transforms\n", 68 | " self.label_to_int = label_to_int\n", 69 | " self.to_tensor = ToTensor()\n", 70 | " \n", 71 | " def __len__(self):\n", 72 | " return len(self.filenames)\n", 73 | " \n", 74 | " def apply_x_transforms(self, filename):\n", 75 | " image = Image.open(filename).convert(\"RGB\")\n", 76 | " tensor_image = self.to_tensor(image)\n", 77 | " return self.transforms(tensor_image)\n", 78 | " \n", 79 | " def apply_y_transforms(self, filename):\n", 80 | " label = re.findall(r\"^(.*)_\\d+\\.jpg$\", filename.name)[0].lower()\n", 81 | " return self.label_to_int[label]\n", 82 | " \n", 83 | " def __getitem__(self, index):\n", 84 | " filename = self.filenames[index]\n", 85 | " x = self.apply_x_transforms(filename)\n", 86 | " y = self.apply_y_transforms(filename)\n", 87 | " return (x,y)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "4ebc33fd-3fb2-4af7-9f7b-850998a744f6", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "label_pat = r\"^(.*)_\\d+\\.jpg$\"\n", 98 | "filenames = (dataset_path/'images').ls(file_exts=\".jpg\")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "10be2949-af5f-4bdd-977b-1235eedfb764", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "labels = filenames.map(\n", 109 | " lambda x: re.findall(label_pat, x.name)[0].lower()\n", 110 | ").unique()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "4f9d3be0-cb05-493d-8b39-c73620d78465", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "labels" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "7d7003a6-05d3-4254-8f44-14fe57810c88", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "label_to_int = {index:key for key, index in enumerate(labels)}\n", 131 | "label_to_int.keys(), label_to_int[\"siamese\"]" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "0b0de16b-e7dc-4f11-856f-2195a8f7673b", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "import numpy as np\n", 142 | "shuffled_indexes = np.random.permutation(len(filenames))\n", 143 | "split = int(0.8 * len(filenames))\n", 144 | "train_indexes, valid_indexes = (\n", 145 | " shuffled_indexes[:split], shuffled_indexes[split:]\n", 146 | ")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "8cc8df81-cee4-434c-acf1-af9f6842d701", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "train_fnames = filenames[train_indexes]\n", 157 | "valid_fnames = filenames[valid_indexes]" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "6292ed9f-445b-43d0-98fc-6d5bbb936374", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "train_dataset = PetsDataset(\n", 168 | " train_fnames,\n", 169 | " train_transforms,\n", 170 | " label_to_int\n", 171 | ")\n", 172 | "\n", 173 | "valid_dataset = PetsDataset(\n", 174 | " valid_fnames,\n", 175 | " valid_transforms,\n", 176 | " label_to_int\n", 177 | ")" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "id": "aba8412a-db63-4fd1-bf65-9bd9648eaf39", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "x,y = train_dataset[0]\n", 188 | "x.shape, y" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "236646dd-6fd4-469e-a8f0-2a165bc924d7", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "from torch.utils.data import DataLoader" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "e691d373-5817-4a48-9b1c-77933234f83b", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "train_dataloader = DataLoader(\n", 209 | " train_dataset,\n", 210 | " shuffle=True,\n", 211 | " drop_last=True,\n", 212 | " batch_size=64\n", 213 | ")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "id": "40639473-c2a5-4c1d-8276-53f64e3fc442", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "valid_dataloader = DataLoader(\n", 224 | " valid_dataset,\n", 225 | " batch_size=128\n", 226 | ")" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "16689441-3e66-4434-b7dc-256d71fecd16", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "from fastai.data.core import DataLoaders" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "id": "574c31aa-c882-416e-933b-c06273ad0553", 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "dls = DataLoaders(train_dataloader, valid_dataloader)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "id": "d4007848-bc01-4fd1-bd2d-5bde58243664", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "from torchvision.models import resnet34\n", 257 | "\n", 258 | "model = resnet34(pretrained=True)" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "id": "23f9228b-a6d2-4f40-a117-385593b174aa", 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "model.fc = nn.Linear(512, 37, bias=True)" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "id": "c1c43a7b-7a46-4c62-81bb-8ca73f657a95", 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "model.fc" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "id": "3071120c-dfc1-49d2-9ee1-8eeb999048ae", 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "list(model.children())[-1]" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "id": "585c1494-6a2f-4bd3-a810-23588e0621bd", 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "for layer in list(model.children())[:-1]:\n", 299 | " if hasattr(layer, \"requires_grad_\"):\n", 300 | " layer.requires_grad_(False)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "id": "cd7448fc-fdb6-4c9e-b3ea-85e944e3ef72", 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "from torch.optim import AdamW" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "id": "d366e09d-c9e2-4b28-8f30-f088e14e44e7", 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "from functools import partial\n", 321 | "from fastai.optimizer import OptimWrapper" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "id": "01628357-f76d-4467-b8a9-f1380fed315b", 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "opt_func = partial(OptimWrapper, opt=AdamW)" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "id": "c9c81149-a93b-44d4-af8c-f06b204c5a50", 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "from fastai.losses import CrossEntropyLossFlat\n", 342 | "from fastai.metrics import accuracy\n", 343 | "from fastai.learner import Learner\n", 344 | "from fastai.callback.schedule import Learner" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "id": "f81ad8ff-80ba-4a1d-a9d7-5cfec97c6508", 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [ 354 | "model.cuda();" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "id": "c6515ef3-baeb-4a95-a09a-c1a543f7f4ca", 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "learn = Learner(\n", 365 | " dls, \n", 366 | " model, \n", 367 | " opt_func=opt_func, \n", 368 | " loss_func=CrossEntropyLossFlat(), \n", 369 | " metrics=accuracy\n", 370 | ")" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "id": "913da4a0-2f49-449c-a585-e040d7693111", 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "learn.lr_find()" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "id": "e7134a0a-3a0f-41cb-abbe-c5c325d8f47e", 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "learn.fit_one_cycle(5, 1e-3)" 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "id": "2c8ccb91-2029-4f94-be4e-7b8b4466896e", 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "im = Image.open(filenames[0])\n", 401 | "im" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "id": "9e1ba32d-1442-4026-8ae7-16f5c4b4d3e4", 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "net = learn.model" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "id": "323e789a-e245-468b-bbbe-711615c03aae", 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [ 421 | "tfm_x = valid_transforms(ToTensor()(im))\n", 422 | "tfm_x = tfm_x.unsqueeze(0); tfm_x.shape" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "id": "2594e87b-2455-4aa5-a7bd-e0878357932b", 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "import torch\n", 433 | "net.eval()\n", 434 | "with torch.no_grad():\n", 435 | " preds = net(tfm_x.cuda())\n", 436 | "pred = preds.argmax(dim=-1)[0]\n", 437 | "label = list(label_to_int.keys())[pred]\n", 438 | "pred, label" 439 | ] 440 | }, 441 | { 442 | "cell_type": "code", 443 | "execution_count": null, 444 | "id": "128835a8-b396-49de-8434-f4ac5fbdac51", 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [] 448 | } 449 | ], 450 | "metadata": {}, 451 | "nbformat": 4, 452 | "nbformat_minor": 5 453 | } 454 | -------------------------------------------------------------------------------- /03_multilabel.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "14702a64-48c1-4ab8-b202-e5d3784910ed", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "from fastai.vision.all import *" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "bcb93279-3597-4dac-a63f-ac1c1bc85c98", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import torch\n", 23 | "from torch import tensor\n", 24 | "from torchvision.models.resnet import resnet34\n", 25 | "from PIL import Image\n", 26 | "from itertools import compress\n", 27 | "\n", 28 | "import pandas as pd\n", 29 | "from pathlib import Path\n", 30 | "from fastcore.xtras import Path\n", 31 | "\n", 32 | "from fastai.data.core import show_at, Datasets\n", 33 | "from fastai.data.external import URLs, untar_data\n", 34 | "from fastai.data.transforms import (\n", 35 | " ColReader,\n", 36 | " IntToFloatTensor, \n", 37 | " MultiCategorize, \n", 38 | " Normalize,\n", 39 | " OneHotEncode, \n", 40 | " RandomSplitter,\n", 41 | ")\n", 42 | "\n", 43 | "from fastai.metrics import accuracy_multi\n", 44 | "\n", 45 | "from fastai.vision.augment import aug_transforms\n", 46 | "from fastai.vision.core import PILImage\n", 47 | "from fastai.vision.learner import vision_learner\n", 48 | "from fastai.learner import Learner\n", 49 | "from fastai.callback.schedule import Learner" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "id": "68b5019e-86f7-4750-adb9-745d18ac1adf", 56 | "metadata": { 57 | "tags": [] 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "src = untar_data(URLs.PLANET_SAMPLE)\n", 62 | "df = pd.read_csv(src/'labels.csv')" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "id": "ba23ad72-f155-4627-8775-04203a3c2060", 69 | "metadata": { 70 | "tags": [] 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "df.head()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 4, 80 | "id": "d115ffaf-9fc7-46f6-9f10-7af399acae74", 81 | "metadata": { 82 | "tags": [] 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "all_tags = df[\"tags\"].values\n", 87 | "all_labels = []\n", 88 | "for row in all_tags:\n", 89 | " all_labels += row.split(\" \")\n", 90 | "len(all_labels)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "id": "382633fb-3b76-4547-b1b6-4441a7403ebb", 97 | "metadata": { 98 | "tags": [] 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "different_labels = set(all_labels)\n", 103 | "len(different_labels)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 6, 109 | "id": "e749245e-db15-4c62-9363-0cad2a59ed86", 110 | "metadata": { 111 | "tags": [] 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "counts = {\n", 116 | " label: all_labels.count(label) \n", 117 | " for label in different_labels\n", 118 | "}\n", 119 | "\n", 120 | "counts = {\n", 121 | " key: value \n", 122 | " for key, value in \n", 123 | " sorted(\n", 124 | " counts.items(), \n", 125 | " key = lambda item: -item[1]\n", 126 | " )\n", 127 | "}" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 7, 133 | "id": "10048a95-6db6-4dfb-8db4-72d1c55769b7", 134 | "metadata": { 135 | "tags": [] 136 | }, 137 | "outputs": [], 138 | "source": [ 139 | "counts" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": 8, 145 | "id": "3283dfcd-f750-4afc-9982-946a1f01e4da", 146 | "metadata": { 147 | "tags": [] 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "len(df)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 9, 157 | "id": "de663a78-3d39-42f5-a8b4-248193ed5a62", 158 | "metadata": { 159 | "tags": [] 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "for key, count in counts.items():\n", 164 | " if count < 10:\n", 165 | " df = df[df[\"tags\"].str.contains(key) == False]" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 10, 171 | "id": "c2831e62-4a79-4091-b71b-bfe77e8c7538", 172 | "metadata": { 173 | "tags": [] 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "len(df)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 11, 183 | "id": "1c431bda-f8be-4dff-962b-d3fb38c9742b", 184 | "metadata": { 185 | "tags": [] 186 | }, 187 | "outputs": [], 188 | "source": [ 189 | "df[\"image_name\"].head(), src.ls()" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 12, 195 | "id": "6bf44fc9-b12d-48e5-8efd-b9dd3acf7024", 196 | "metadata": { 197 | "tags": [] 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "(src/'train').ls()[:3]" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 13, 207 | "id": "5ad5f7f6-a660-44b0-94f2-47f9f1775fae", 208 | "metadata": { 209 | "tags": [] 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "PILImage.create((src/'train'/'train_2407.jpg'))" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 14, 219 | "id": "4d517878-7b30-4223-9351-aee5e9ac5dee", 220 | "metadata": { 221 | "tags": [] 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "def get_x(row:pd.Series) -> Path:\n", 226 | " return (src/'train'/row.image_name).with_suffix(\".jpg\")" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 15, 232 | "id": "5748c3b8-9c13-4b73-a7a0-e09ecb9bdf5e", 233 | "metadata": { 234 | "tags": [] 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "def get_y(row:pd.Series) -> List[str]:\n", 239 | " return row.tags.split(\" \")" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 16, 245 | "id": "d923e47b-695b-4111-8c3c-7fb31220cb39", 246 | "metadata": { 247 | "tags": [] 248 | }, 249 | "outputs": [], 250 | "source": [ 251 | "row = df.iloc[0]\n", 252 | "get_x(row), get_y(row)" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": 17, 258 | "id": "7fe37edd-2632-4f78-b304-a87e7c521912", 259 | "metadata": { 260 | "tags": [] 261 | }, 262 | "outputs": [], 263 | "source": [ 264 | "get_x = ColReader(0, pref=f'{src}/train/', suff=\".jpg\")\n", 265 | "get_y = ColReader(1, label_delim=\" \")" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 18, 271 | "id": "a7ac3d9b-55e3-409d-95ec-2fa28bf9c331", 272 | "metadata": { 273 | "tags": [] 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "tfms = [\n", 278 | " [get_x, PILImage.create], \n", 279 | " [\n", 280 | " get_y,\n", 281 | " MultiCategorize(vocab=different_labels), \n", 282 | " OneHotEncode(len(different_labels))\n", 283 | " ]\n", 284 | "]" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 19, 290 | "id": "e15b8f4f-783f-4c2e-a604-c0147bc6e69a", 291 | "metadata": { 292 | "tags": [] 293 | }, 294 | "outputs": [], 295 | "source": [ 296 | "train_idxs, valid_idxs = (\n", 297 | " RandomSplitter(valid_pct=0.2, seed=42)(df)\n", 298 | ")" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 20, 304 | "id": "9a325977-62fa-4ce4-a12c-d4cfa4decf1f", 305 | "metadata": { 306 | "tags": [] 307 | }, 308 | "outputs": [], 309 | "source": [ 310 | "train_idxs, valid_idxs" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 21, 316 | "id": "35cbf350-f37a-4447-ace7-5e931e2084c5", 317 | "metadata": { 318 | "tags": [] 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "dsets = Datasets(df, tfms=tfms, splits=[train_idxs, valid_idxs])" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 22, 328 | "id": "cabc4923-cad7-42d5-b65a-be04548f2a72", 329 | "metadata": { 330 | "tags": [] 331 | }, 332 | "outputs": [], 333 | "source": [ 334 | "dsets.train[0]" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "80fbe5db-6d13-4e68-8472-ad9a953cd2bf", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "show_at(dsets.train, 0);" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "id": "b641b198-94fe-4a32-a4ff-a4754f8ded02", 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [ 354 | "batch_tfms = [\n", 355 | " IntToFloatTensor(), \n", 356 | " *aug_transforms(\n", 357 | " flip_vert=True, \n", 358 | " max_lighting=0.1, \n", 359 | " max_zoom=1.05, \n", 360 | " max_warp=0.\n", 361 | " ), \n", 362 | " Normalize.from_stats(*imagenet_stats)\n", 363 | "]" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "id": "6f81d974-abb3-4c1d-865a-81544ff0122f", 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "dls = dsets.dataloaders(\n", 374 | " after_item=[ToTensor], \n", 375 | " after_batch=batch_tfms\n", 376 | ")" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "id": "55fe631e-a0af-4632-bd9c-f1720dc04777", 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "dls.device" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "id": "11e44b40-c147-4c03-b7eb-4d145b66ec34", 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "dls.show_batch()" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "id": "86c44dc1-cc30-4c38-a909-fa673cb532e2", 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "learn = vision_learner(dls, resnet34, metrics=[accuracy_multi])" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "id": "58d5d79b-4a9f-4a92-8808-0ac91c413a6c", 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "learn.model[1]" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "id": "fed6d6bc-6a3b-4bee-ae74-7301eb03d4b5", 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [ 426 | "learn.loss_func" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "id": "5c2c4080-3e7e-4b78-a1a6-80666cdd9cda", 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "t = tensor([[0.1, 0.5, 0.3, 0.7, 0.2]])\n", 437 | "torch.sigmoid(t)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "id": "7d614f54-e25e-4782-856b-13ff0279143e", 444 | "metadata": {}, 445 | "outputs": [], 446 | "source": [ 447 | "learn.loss_func.thresh" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "id": "34badbad-84c3-4e2b-a060-1b2885f24af5", 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "learn.lr_find()" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "id": "1193da19-8e6b-4831-b143-485986ad65dd", 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "learn.fit_one_cycle(1, slice(2e-3))" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "id": "3163a0d7-2ac5-4c92-811e-c5ca15bed20a", 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "learn.unfreeze()\n", 478 | "learn.fit_one_cycle(5, slice(2e-3/2.6**4, 2e-3))" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "id": "a69dd9bc-f14a-487c-8d6a-8f67b38d5503", 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "learn.show_results(figsize=(15,15))" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 24, 494 | "id": "3c1c332e-89b8-415d-bd72-0995c39d6506", 495 | "metadata": { 496 | "tags": [] 497 | }, 498 | "outputs": [], 499 | "source": [ 500 | "model = learn.model\n", 501 | "fname = get_x(df.iloc[0])" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 2, 507 | "id": "a5b45d36-a544-4201-a27b-8d54236cf21d", 508 | "metadata": { 509 | "tags": [] 510 | }, 511 | "outputs": [], 512 | "source": [ 513 | "fname = '/home/zach/.fastai/data/planet_sample/train/train_21983.jpg'" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": 5, 519 | "id": "2e062c32-e93f-4611-a9c3-a7edada91546", 520 | "metadata": { 521 | "tags": [] 522 | }, 523 | "outputs": [], 524 | "source": [ 525 | "from torchvision.transforms import PILToTensor" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 7, 531 | "id": "d3f0e8e2-1f3e-42be-8b33-58c4c515d5a1", 532 | "metadata": { 533 | "tags": [] 534 | }, 535 | "outputs": [], 536 | "source": [ 537 | "im = Image.open(fname)\n", 538 | "im = im.convert(\"RGB\")\n", 539 | "t_im = PILToTensor()(im)" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 9, 545 | "id": "ecdb8c02-ebc2-443a-9a36-4e7aee88e746", 546 | "metadata": { 547 | "tags": [] 548 | }, 549 | "outputs": [], 550 | "source": [ 551 | "t_im = t_im.unsqueeze(0)\n", 552 | "t_im = t_im.float().div_(255.)" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "id": "57e408c8-1cbf-418a-b9b2-7de4055c6f7f", 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "mean, std = (\n", 563 | " [0.485, 0.456, 0.406], \n", 564 | " [0.229, 0.224, 0.225]\n", 565 | ")\n", 566 | "vector = [1]*4\n", 567 | "vector[1] = -1\n", 568 | "mean = tensor(mean).view(*vector)\n", 569 | "std = tensor(std).view(*vector)" 570 | ] 571 | }, 572 | { 573 | "cell_type": "code", 574 | "execution_count": null, 575 | "id": "c9138bce-ad72-46b5-a239-5dc0cd9c45f1", 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "mean.shape, std.shape" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": null, 585 | "id": "90ea047f-4e40-4a52-9c13-d491998fcabe", 586 | "metadata": {}, 587 | "outputs": [], 588 | "source": [ 589 | "t_im = (t_im - mean) / std" 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "id": "a0beea08-3327-483e-b627-eade2c8a4b06", 596 | "metadata": {}, 597 | "outputs": [], 598 | "source": [ 599 | "t_im.shape" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "id": "384bfde5-2e5c-444c-918b-029017957cfb", 606 | "metadata": {}, 607 | "outputs": [], 608 | "source": [ 609 | "with torch.inference_mode():\n", 610 | " model.eval()\n", 611 | " preds = model(t_im.cuda())" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "id": "2c761bf2-89c5-4ff4-bbd1-005f38e81dbd", 618 | "metadata": {}, 619 | "outputs": [], 620 | "source": [ 621 | "preds.shape" 622 | ] 623 | }, 624 | { 625 | "cell_type": "code", 626 | "execution_count": null, 627 | "id": "ee3419db-955e-4f85-931f-d5d4febadb8d", 628 | "metadata": {}, 629 | "outputs": [], 630 | "source": [ 631 | "decoded_preds = torch.sigmoid(preds) > 0.5" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": null, 637 | "id": "f58dc6bd-a4f9-475a-a530-5f813db75aa9", 638 | "metadata": {}, 639 | "outputs": [], 640 | "source": [ 641 | "decoded_preds" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": null, 647 | "id": "d449629b-fbe3-4b41-bb50-8365f17b02b9", 648 | "metadata": {}, 649 | "outputs": [], 650 | "source": [ 651 | "from itertools import compress" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": null, 657 | "id": "cd9179e8-62b2-4577-af81-390215e3ebc2", 658 | "metadata": {}, 659 | "outputs": [], 660 | "source": [ 661 | "present_labels = list(compress(\n", 662 | " data=list(different_labels), selectors=decoded_preds[0]\n", 663 | " ))" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": null, 669 | "id": "a1949f60-2866-4e7e-aae7-23b29cd366fd", 670 | "metadata": {}, 671 | "outputs": [], 672 | "source": [ 673 | "present_labels" 674 | ] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "execution_count": null, 679 | "id": "030ce8aa-511e-4337-9fcf-e42c276f4f89", 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "learn.predict(fname)[0]" 684 | ] 685 | }, 686 | { 687 | "cell_type": "code", 688 | "execution_count": null, 689 | "id": "9463d1a2-6ab0-42ed-89a7-708a3e35b71c", 690 | "metadata": {}, 691 | "outputs": [], 692 | "source": [ 693 | "im = Image.open(fname)\n", 694 | "im = im.convert(\"RGB\")\n", 695 | "t_im = PILToTensor()(im)\n", 696 | "\n", 697 | "mean, std = (\n", 698 | " [0.485, 0.456, 0.406], \n", 699 | " [0.229, 0.224, 0.225]\n", 700 | ")\n", 701 | "vector = [1]*4\n", 702 | "vector[1] = -1\n", 703 | "mean = tensor(mean).view(*vector)\n", 704 | "std = tensor(std).view(*vector)\n", 705 | "t_im = (t_im - mean) / std\n", 706 | "with torch.inference_mode():\n", 707 | " model.eval()\n", 708 | " preds = model(t_im.cuda())\n", 709 | " \n", 710 | "decoded_preds = torch.sigmoid(preds) > 0.5\n", 711 | "\n", 712 | "present_labels = list(compress(\n", 713 | " data=list(different_labels), selectors=decoded_preds[0]\n", 714 | " ))" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "execution_count": null, 720 | "id": "0099cf64-e9b9-411e-b34f-75e9b804755b", 721 | "metadata": {}, 722 | "outputs": [], 723 | "source": [] 724 | } 725 | ], 726 | "metadata": {}, 727 | "nbformat": 4, 728 | "nbformat_minor": 5 729 | } 730 | -------------------------------------------------------------------------------- /04_semantic_segmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "24d0f939-0f6c-4904-8824-7c87a9f160c1", 7 | "metadata": { 8 | "tags": [], 9 | "id": "24d0f939-0f6c-4904-8824-7c87a9f160c1" 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "from fastai.vision.all import *" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "ad1a14bc-8e95-4779-a9ee-3c34af86281b", 20 | "metadata": { 21 | "id": "ad1a14bc-8e95-4779-a9ee-3c34af86281b" 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "url = \"https://drive.google.com/uc?id=18xM3jU2dSp1DiDqEM6PVXattNMZvsX4z\"" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "id": "2bd90cf6-e1f3-408b-bd91-3b50544745c9", 32 | "metadata": { 33 | "id": "2bd90cf6-e1f3-408b-bd91-3b50544745c9" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "!gdown {url}" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 6, 43 | "id": "77f5eb23-f80d-402b-81bc-c67878700f63", 44 | "metadata": { 45 | "id": "77f5eb23-f80d-402b-81bc-c67878700f63" 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "from zipfile import ZipFile\n", 50 | "\n", 51 | "with ZipFile(\"Portrait.zip\", \"r\") as zip_ref:\n", 52 | " zip_ref.extractall(\"data\")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 7, 58 | "id": "718e08fe-7559-4752-b082-eb19d603cb00", 59 | "metadata": { 60 | "id": "718e08fe-7559-4752-b082-eb19d603cb00" 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "path = Path(\"data\")" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "id": "e0612b60-4cdc-423f-ae60-ff3e65db5a00", 71 | "metadata": { 72 | "id": "e0612b60-4cdc-423f-ae60-ff3e65db5a00" 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "for walk in path.ls():\n", 77 | " print(repr(walk), walk.is_file())" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "8c560e41-dbce-4bdd-8d83-c89f8164bb5b", 84 | "metadata": { 85 | "id": "8c560e41-dbce-4bdd-8d83-c89f8164bb5b" 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "(path/\"GT_png\").ls()[0]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 10, 95 | "id": "17f407fa-3eb7-4424-ba37-84787d8697e0", 96 | "metadata": { 97 | "id": "17f407fa-3eb7-4424-ba37-84787d8697e0" 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "mask = Image.open((path/\"GT_png\").ls()[0])" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "id": "a4d8b456-2abb-41c6-bf28-7dac707577b5", 108 | "metadata": { 109 | "id": "a4d8b456-2abb-41c6-bf28-7dac707577b5" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "mask" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "c99b4bbe-9d3b-4463-9d68-ea0d5a0c9634", 120 | "metadata": { 121 | "id": "c99b4bbe-9d3b-4463-9d68-ea0d5a0c9634" 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "mask = np.asarray(mask); mask" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 13, 131 | "id": "55cda0ab-37ae-43d4-848b-a48b622c1ab5", 132 | "metadata": { 133 | "id": "55cda0ab-37ae-43d4-848b-a48b622c1ab5" 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "def get_codes(fnames) -> Dict[int,int]: \n", 138 | " \"Returns a dictionary of `original_code:new_code` for pixel values in segmentation masks\"\n", 139 | " unique_codes = set()\n", 140 | " for fname in fnames:\n", 141 | " mask = Image.open(fname)\n", 142 | " mask = np.asarray(mask)\n", 143 | " for color in np.unique(mask):\n", 144 | " unique_codes.add(color)\n", 145 | " return {\n", 146 | " i : color\n", 147 | " for i, color in \n", 148 | " enumerate(unique_codes)\n", 149 | " }" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "0925d08f-4f86-4525-b806-01ffa2c83eba", 156 | "metadata": { 157 | "id": "0925d08f-4f86-4525-b806-01ffa2c83eba" 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "unique_codes = get_codes((path/\"GT_png\").ls()[:20])\n", 162 | "unique_codes" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "ff4f706f-8414-4e1a-9372-cc171d34e74e", 169 | "metadata": { 170 | "id": "ff4f706f-8414-4e1a-9372-cc171d34e74e" 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "mask = mask.copy()\n", 175 | "np.place(mask, mask==255, 1)\n", 176 | "np.unique(mask)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 16, 182 | "id": "a806dc9a-7014-4339-836a-ec785c2809b5", 183 | "metadata": { 184 | "id": "a806dc9a-7014-4339-836a-ec785c2809b5" 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "codes = [\"Background\", \"Face\"]\n", 189 | "blocks = (ImageBlock, MaskBlock(codes=codes))" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "id": "7be26a09-57b5-4e4f-be9c-6dbb5ecbb0dc", 196 | "metadata": { 197 | "id": "7be26a09-57b5-4e4f-be9c-6dbb5ecbb0dc" 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "unique_codes" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 18, 207 | "id": "323e5552-a836-4913-9e33-f372f19334ab", 208 | "metadata": { 209 | "id": "323e5552-a836-4913-9e33-f372f19334ab" 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "def get_y(filename:Path, unique_codes:dict):\n", 214 | " \"Grabs a mask from `filename` and adjusts the pixel values based on `unique_codes`\"\n", 215 | " filename = path/\"GT_png\"/f'{filename.stem}_mask.png'\n", 216 | " mask = np.asarray(Image.open(filename)).copy()\n", 217 | " for new_value, old_value in unique_codes.items():\n", 218 | " np.place(mask, mask==old_value, new_value)\n", 219 | " return PILMask.create(mask)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "b13cafa2-def9-4560-a88a-1a0573a366d7", 226 | "metadata": { 227 | "id": "b13cafa2-def9-4560-a88a-1a0573a366d7" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "new_mask = get_y((path/\"images_data_crop\").ls()[0], unique_codes)\n", 232 | "new_mask.show(cmap=\"Blues\");" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 20, 238 | "id": "d261e2ca-fbc4-484b-9703-81ef2e82b08c", 239 | "metadata": { 240 | "id": "d261e2ca-fbc4-484b-9703-81ef2e82b08c" 241 | }, 242 | "outputs": [], 243 | "source": [ 244 | "block = DataBlock(\n", 245 | " blocks=blocks,\n", 246 | " splitter=RandomSplitter(),\n", 247 | " get_y=partial(get_y, unique_codes=unique_codes),\n", 248 | " item_tfms=Resize(224),\n", 249 | " batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)]\n", 250 | ")" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": 21, 256 | "id": "8b61743e-a963-40e6-8bac-63b83a91c4f3", 257 | "metadata": { 258 | "id": "8b61743e-a963-40e6-8bac-63b83a91c4f3" 259 | }, 260 | "outputs": [], 261 | "source": [ 262 | "dls = block.dataloaders(\n", 263 | " get_image_files(path/'images_data_crop'), \n", 264 | " bs=8\n", 265 | ")" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "b8833731-bc07-4e17-a164-be8533b9c264", 272 | "metadata": { 273 | "id": "b8833731-bc07-4e17-a164-be8533b9c264" 274 | }, 275 | "outputs": [], 276 | "source": [ 277 | "dls.show_batch(cmap=\"Blues\", vmin=0, vmax=1)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 23, 283 | "id": "f41f7f4e-0e7e-487c-9a3a-30a1f7090901", 284 | "metadata": { 285 | "id": "f41f7f4e-0e7e-487c-9a3a-30a1f7090901" 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "splitter = RandomSplitter()\n", 290 | "dsets = Datasets(\n", 291 | " get_image_files(path/'images_data_crop'),\n", 292 | " tfms=[\n", 293 | " [PILImage.create], \n", 294 | " [partial(get_y, unique_codes=unique_codes)]\n", 295 | " ],\n", 296 | " splits = splitter(get_image_files(path/'images_data_crop'))\n", 297 | ")" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 24, 303 | "id": "8a1a81ac-ead0-4689-9fdc-f11bd3602f9d", 304 | "metadata": { 305 | "id": "8a1a81ac-ead0-4689-9fdc-f11bd3602f9d" 306 | }, 307 | "outputs": [], 308 | "source": [ 309 | "dls = dsets.dataloaders(\n", 310 | " after_item = [\n", 311 | " Resize(224), \n", 312 | " ToTensor(), \n", 313 | " AddMaskCodes(codes=codes)\n", 314 | " ],\n", 315 | " after_batch = [\n", 316 | " *aug_transforms(), \n", 317 | " IntToFloatTensor(), \n", 318 | " Normalize.from_stats(*imagenet_stats)\n", 319 | " ],\n", 320 | " bs=8\n", 321 | ")" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "id": "7437750b-3628-4f7f-bcee-5ee9ab767ae3", 328 | "metadata": { 329 | "id": "7437750b-3628-4f7f-bcee-5ee9ab767ae3" 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "dls.show_batch(cmap=\"Blues\", vmin=0, vmax=1)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 27, 339 | "id": "855b54be-61a2-47e0-b958-a76bc1946769", 340 | "metadata": { 341 | "id": "855b54be-61a2-47e0-b958-a76bc1946769" 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "learn = unet_learner(\n", 346 | " dls, \n", 347 | " resnet34, \n", 348 | " metrics=partial(accuracy, axis=1), \n", 349 | " self_attention=True, \n", 350 | " act_cls=Mish,\n", 351 | " loss_func = CrossEntropyLossFlat(axis=1)\n", 352 | ")" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "id": "b5f43d15-a3d7-434d-84c7-e5b6e9ae6feb", 359 | "metadata": { 360 | "id": "b5f43d15-a3d7-434d-84c7-e5b6e9ae6feb" 361 | }, 362 | "outputs": [], 363 | "source": [ 364 | "learn.summary()" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "id": "b7987b40-577f-4bb6-a390-e5bee32b66bf", 371 | "metadata": { 372 | "id": "b7987b40-577f-4bb6-a390-e5bee32b66bf" 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "learn.fit_one_cycle(10, 1e-3)" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "id": "b98a4bbe-33a4-42a9-9aba-ac33d69f0f41", 383 | "metadata": { 384 | "id": "b98a4bbe-33a4-42a9-9aba-ac33d69f0f41" 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "learn.save(\"stage_1\")\n", 389 | "#learn.load(\"stage_1\")" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "id": "cefea428-1bdb-49e5-8645-0d412560bbf2", 396 | "metadata": { 397 | "id": "cefea428-1bdb-49e5-8645-0d412560bbf2" 398 | }, 399 | "outputs": [], 400 | "source": [ 401 | "learn.show_results(max_n=4, figsize=(12,6))" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "id": "b1b58597-a83e-418e-8cad-2c26651368a4", 408 | "metadata": { 409 | "id": "b1b58597-a83e-418e-8cad-2c26651368a4" 410 | }, 411 | "outputs": [], 412 | "source": [ 413 | "learn.unfreeze()\n", 414 | "learn.fit_one_cycle(4, slice(1e-3/400, 1e-3/4))" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "id": "9240b5d5-079f-43ae-8728-d35285915744", 421 | "metadata": { 422 | "id": "9240b5d5-079f-43ae-8728-d35285915744" 423 | }, 424 | "outputs": [], 425 | "source": [ 426 | "learn.show_results(max_n=4, figsize=(12,6))" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "id": "1910faff-bf16-47a8-83e5-3220f494f4c3", 433 | "metadata": { 434 | "id": "1910faff-bf16-47a8-83e5-3220f494f4c3" 435 | }, 436 | "outputs": [], 437 | "source": [ 438 | "dl = learn.dls.test_dl(\n", 439 | " (path/'images_data_crop').ls()[:5]\n", 440 | ")\n", 441 | "dl.show_batch()" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "id": "f140d661-c31a-480e-82a8-d8dc2286fe44", 448 | "metadata": { 449 | "id": "f140d661-c31a-480e-82a8-d8dc2286fe44" 450 | }, 451 | "outputs": [], 452 | "source": [ 453 | "preds = learn.get_preds(dl=dl)" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "id": "e9f82cad-348b-4598-8371-d3f694088c92", 460 | "metadata": { 461 | "id": "e9f82cad-348b-4598-8371-d3f694088c92" 462 | }, 463 | "outputs": [], 464 | "source": [ 465 | "preds[0].shape" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 45, 471 | "id": "270c4c4c-0783-4af6-9aaf-211136819100", 472 | "metadata": { 473 | "id": "270c4c4c-0783-4af6-9aaf-211136819100" 474 | }, 475 | "outputs": [], 476 | "source": [ 477 | "pred = preds[0][0].argmax(dim=0)" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "id": "408153fa-7c0a-4913-b927-1b231816a92e", 484 | "metadata": { 485 | "id": "408153fa-7c0a-4913-b927-1b231816a92e" 486 | }, 487 | "outputs": [], 488 | "source": [ 489 | "pred.shape" 490 | ] 491 | }, 492 | { 493 | "cell_type": "code", 494 | "execution_count": null, 495 | "id": "f76a3250-4bb3-4b61-b624-d76ff5009882", 496 | "metadata": { 497 | "id": "f76a3250-4bb3-4b61-b624-d76ff5009882" 498 | }, 499 | "outputs": [], 500 | "source": [ 501 | "plt.imshow(pred);" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 48, 507 | "id": "24c06a87-ce66-49f7-a40f-c8a734a098e7", 508 | "metadata": { 509 | "id": "24c06a87-ce66-49f7-a40f-c8a734a098e7" 510 | }, 511 | "outputs": [], 512 | "source": [ 513 | "pred = pred.numpy()\n", 514 | "rescaled = (255.0 / pred.max() * (pred - pred.min())).astype(np.uint8)\n", 515 | "im = Image.fromarray(rescaled)\n", 516 | "im.save(\"mask.png\")" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": null, 522 | "id": "f0a491de-f8ae-42c4-96e2-6bd5bb1766b5", 523 | "metadata": { 524 | "id": "f0a491de-f8ae-42c4-96e2-6bd5bb1766b5" 525 | }, 526 | "outputs": [], 527 | "source": [ 528 | "im" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": 72, 534 | "id": "63f87c78-3a8b-4be2-bdcb-ab1503baca59", 535 | "metadata": { 536 | "id": "63f87c78-3a8b-4be2-bdcb-ab1503baca59" 537 | }, 538 | "outputs": [], 539 | "source": [ 540 | "fnames = (path/'images_data_crop').ls()[:5]\n", 541 | "\n", 542 | "item_tfms = Pipeline([\n", 543 | " PILImage.create, \n", 544 | " RandomResizedCrop(224), \n", 545 | " ToTensor()\n", 546 | "], split_idx=1)\n", 547 | "\n", 548 | "batch_tfms = Pipeline([\n", 549 | " IntToFloatTensor(), \n", 550 | " Normalize.from_stats(*imagenet_stats)\n", 551 | "])\n", 552 | "\n", 553 | "batch = []\n", 554 | "for fname in fnames:\n", 555 | " batch.append(item_tfms(fname))\n", 556 | "batch = torch.stack(batch, dim=0)\n", 557 | "batch = batch_tfms(batch.cuda())\n", 558 | "\n", 559 | "model = learn.model\n", 560 | "model.eval()\n", 561 | "\n", 562 | "with torch.no_grad():\n", 563 | " preds = model(batch)\n", 564 | "\n", 565 | "for i,pred in enumerate(preds):\n", 566 | " pred = pred.argmax(0)\n", 567 | " pred = pred.cpu().numpy()\n", 568 | " rescaled = (255.0 / pred.max() * (pred - pred.min())).astype(np.uint8)\n", 569 | " im = Image.fromarray(rescaled)\n", 570 | " im.save(f'pred_{i}.png')" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "id": "fc8bcdc1-1ab4-49e7-b67f-b3debddae1d4", 577 | "metadata": { 578 | "id": "fc8bcdc1-1ab4-49e7-b67f-b3debddae1d4" 579 | }, 580 | "outputs": [], 581 | "source": [] 582 | } 583 | ], 584 | "metadata": { 585 | "kernelspec": { 586 | "display_name": "Python 3 (ipykernel)", 587 | "language": "python", 588 | "name": "python3" 589 | }, 590 | "language_info": { 591 | "codemirror_mode": { 592 | "name": "ipython", 593 | "version": 3 594 | }, 595 | "file_extension": ".py", 596 | "mimetype": "text/x-python", 597 | "name": "python", 598 | "nbconvert_exporter": "python", 599 | "pygments_lexer": "ipython3", 600 | "version": "3.9.13" 601 | }, 602 | "colab": { 603 | "provenance": [], 604 | "machine_shape": "hm" 605 | }, 606 | "accelerator": "GPU", 607 | "gpuClass": "premium" 608 | }, 609 | "nbformat": 4, 610 | "nbformat_minor": 5 611 | } -------------------------------------------------------------------------------- /05a_deployment_no_fastai.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "396a4e1c-a5b6-4093-aac3-b42f305a613d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import math\n", 11 | "import torch\n", 12 | "import torch.nn.functional as F\n", 13 | "import torchvision.transforms.functional as tvf\n", 14 | "import torchvision.transforms as tvtfms\n", 15 | "import operator as op\n", 16 | "from PIL import Image\n", 17 | "from torch import nn\n", 18 | "from timm import create_model\n", 19 | "\n", 20 | "# For type hinting later on\n", 21 | "import collections\n", 22 | "import typing" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "bf18e089-4375-4aa3-8f62-fac603f458f9", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "net = create_model(\"vit_tiny_patch16_224\", pretrained=False, num_classes=0, in_chans=3)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "e55b7142-69e6-4182-85f6-893a51dee085", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "head = nn.Sequential(\n", 43 | " nn.BatchNorm1d(192),\n", 44 | " nn.Dropout(0.25),\n", 45 | " nn.Linear(192, 512, bias=False),\n", 46 | " nn.ReLU(inplace=True),\n", 47 | " nn.BatchNorm1d(512),\n", 48 | " nn.Dropout(0.5),\n", 49 | " nn.Linear(512, 37, bias=False)\n", 50 | ")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "d6c8c22b-eca7-40d4-97bd-fc28421dad8d", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "model = nn.Sequential(net, head)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "d6868944-1498-422e-994e-6fa15fd2136d", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "state = torch.load(\"models/MyModel.pth\")" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "a1dbd956-0d28-45ee-aa5d-33eff1a32b09", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "model.load_state_dict(state);" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "ba897ed8-1df7-4240-834c-a39d23ddb11d", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "list(model.state_dict().keys())[:5]" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "25b5f1e0-5605-49ff-8405-d929f102ce7d", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "list(state.keys())[:5]" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "ab5099e3-8379-4946-85a3-aeb672489e08", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "def copy_weight(name, parameter, state_dict):\n", 111 | " \"\"\"\n", 112 | " Takes in a layer `name`, model `parameter`, and `state_dict`\n", 113 | " and loads the weights from `state_dict` into `parameter`\n", 114 | " if it exists.\n", 115 | " \"\"\"\n", 116 | " \n", 117 | " if name[0] == \"0\":\n", 118 | " name = name[:2] + \"model.\" + name[2:]\n", 119 | " if name in state_dict.keys():\n", 120 | " input_parameter = state_dict[name]\n", 121 | " if input_parameter.shape == parameter.shape:\n", 122 | " parameter.copy_(input_parameter)\n", 123 | " else:\n", 124 | " print(f'Shape mismatch at layer: {name}, skipping')\n", 125 | " else:\n", 126 | " print(f'{name} is not in the state_dict, skipping.')" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "23ec54c5-4d27-43ac-aaf0-00eae7b18fa8", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "def apply_weights(input_model:nn.Module, input_weights:collections.OrderedDict, application_function:callable):\n", 137 | " \"\"\"\n", 138 | " Takes an input state_dict and applies those weights to the `input_model`, potentially \n", 139 | " with a modifier function.\n", 140 | " \n", 141 | " Args:\n", 142 | " input_model (`nn.Module`):\n", 143 | " The model that weights should be applied to\n", 144 | " input_weights (`collections.OrderedDict`):\n", 145 | " A dictionary of weights, the trained model's `state_dict()`\n", 146 | " application_function (`callable`):\n", 147 | " A function that takes in one parameter and layer name from `input_model`\n", 148 | " and the `input_weights`. Should apply the weights from the state dict into `input_model`.\n", 149 | " \"\"\"\n", 150 | " model_dict = input_model.state_dict()\n", 151 | " for name, parameter in model_dict.items():\n", 152 | " application_function(name, parameter, input_weights)\n", 153 | " input_model.load_state_dict(model_dict)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "5c884993-cb52-412e-9cad-8d1b9dfff70c", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "apply_weights(model, state, copy_weight)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "id": "e5620c43-fb68-42e1-becd-96c74046a031", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "from fastai.vision.data import PILImage\n", 174 | "from fastai.data.external import untar_data, URLs\n", 175 | "from fastai.data.transforms import get_image_files\n", 176 | "import fastai.vision.augment as fastai_aug\n", 177 | "\n", 178 | "import numpy as np" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "id": "ca0ea83f-8b9b-45c1-baf8-ee43adde1a1b", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "path = untar_data(URLs.PETS)/'images'\n", 189 | "fname = get_image_files(path)[0]\n", 190 | "fname" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "31184db2-06a4-46da-beba-864345c6dc3b", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "im_pil = Image.open(fname)\n", 201 | "im_fastai = PILImage.create(fname)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "603843e3-5448-47c9-a937-87e0a3bfe886", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "assert (np.array(im_pil) == np.array(im_fastai)).all()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "id": "ebe49581-c4e6-4d73-8b74-e7875962f793", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "crop_fastai = fastai_aug.RandomResizedCrop((460, 460))\n", 222 | "crop_torch = tvtfms.CenterCrop((460,460))" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "522ae350-afb6-41f0-bad3-1ddbf52bb369", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "assert (np.array(crop_fastai(im_fastai, split_idx=1)) == np.array(crop_torch(im_pil))).all()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "2e2ce70b-d412-47b3-8f8b-d9c1af4c02e1", 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:\n", 243 | " \"\"\"\n", 244 | " Takes a `PIL.Image` and crops it `size` unless one \n", 245 | " dimension is larger than the actual image. Padding \n", 246 | " must be performed afterwards if so.\n", 247 | " \n", 248 | " Args:\n", 249 | " image (`PIL.Image`):\n", 250 | " An image to perform cropping on\n", 251 | " size (`tuple` of integers):\n", 252 | " A size to crop to, should be in the form\n", 253 | " of (width, height)\n", 254 | " \n", 255 | " Returns:\n", 256 | " An augmented `PIL.Image`\n", 257 | " \"\"\"\n", 258 | " top = (image.shape[-1] - size[0]) // 2\n", 259 | " left = (image.shape[-2] - size[1]) // 2\n", 260 | " \n", 261 | " top = max(top, 0)\n", 262 | " left = max(left, 0)\n", 263 | " \n", 264 | " height = min(top + size[0], image.shape[-1])\n", 265 | " width = min(left + size[1], image.shape[-2])\n", 266 | " return image.crop((top, left, height, width))" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "id": "e55abb74-5e23-4930-a9f3-b260e6b99aa8", 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "def pad(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:\n", 277 | " \"\"\"\n", 278 | " Takes a `PIL.Image` and pads it to `size` with\n", 279 | " zeros.\n", 280 | " \n", 281 | " Args:\n", 282 | " image (`PIL.Image`):\n", 283 | " An image to perform padding on\n", 284 | " size (`tuple` of integers):\n", 285 | " A size to pad to, should be in the form\n", 286 | " of (width, height)\n", 287 | " \n", 288 | " Returns:\n", 289 | " An augmented `PIL.Image`\n", 290 | " \"\"\"\n", 291 | " top = (image.shape[-1] - size[0]) // 2\n", 292 | " left = (image.shape[-2] - size[1]) // 2\n", 293 | " \n", 294 | " pad_top = max(-top, 0)\n", 295 | " pad_left = max(-left, 0)\n", 296 | " \n", 297 | " height, width = (\n", 298 | " max(size[1] - image.shape[-1] + top, 0), \n", 299 | " max(size[0] - image.shape[-2] + left, 0)\n", 300 | " )\n", 301 | " return tvf.pad(\n", 302 | " image, \n", 303 | " [pad_top, pad_left, height, width], \n", 304 | " padding_mode=\"constant\"\n", 305 | " )" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "id": "b6ccbc19", 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "def resized_crop_pad(\n", 316 | " image: Union[Image.Image, torch.tensor],\n", 317 | " size: Tuple[int, int],\n", 318 | " extra_crop_ratio: float = 0.14,\n", 319 | ") -> Image:\n", 320 | " \"\"\"\n", 321 | " Takes a `PIL.Image`, resize it according to the\n", 322 | " `extra_crop_ratio`, and then crops and pads\n", 323 | " it to `size`.\n", 324 | "\n", 325 | " Args:\n", 326 | " image (`PIL.Image`):\n", 327 | " An image to perform padding on\n", 328 | " size (`tuple` of integers):\n", 329 | " A size to crop and pad to, should be in the form\n", 330 | " of (width, height)\n", 331 | " extra_crop_ratio (float):\n", 332 | " The ratio of size at the edge cropped out. Default 0.14\n", 333 | " \"\"\"\n", 334 | "\n", 335 | " maximum_space = max(size[0], size[1])\n", 336 | " extra_space = maximum_space * extra_crop_ratio\n", 337 | " extra_space = math.ceil(extra_space / 8) * 8\n", 338 | " extended_size = (size[0] + extra_space, size[1] + extra_space)\n", 339 | " resized_image = image.resize(extended_size, resample=Image.Resampling.BILINEAR)\n", 340 | "\n", 341 | " if extended_size != size:\n", 342 | " resized_image = pad(crop(resized_image, size), size)\n", 343 | "\n", 344 | " return resized_image" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "id": "35f08caa-723e-4167-8756-ec79b9a7698e", 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [ 354 | "size = (460,460)\n", 355 | "tfmd_img = resized_crop_pad(im_pil, size)" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "id": "0d384de6-bb93-4296-a664-a5d5a530dc22", 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "(np.array(tfmd_img) == crop_fastai(im_fastai, split_idx=1)).all()" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "id": "32b23700-4b94-4775-b5f9-cee632f39b6a", 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "def gpu_crop(\n", 376 | " batch:torch.tensor, \n", 377 | " size:typing.Tuple[int,int]\n", 378 | "):\n", 379 | " \"\"\"\n", 380 | " Crops each image in `batch` to a particular `size`.\n", 381 | " \n", 382 | " Args:\n", 383 | " batch (array of `torch.Tensor`):\n", 384 | " A batch of images, should be of shape `NxCxWxH`\n", 385 | " size (`tuple` of integers):\n", 386 | " A size to pad to, should be in the form\n", 387 | " of (width, height)\n", 388 | " \n", 389 | " Returns:\n", 390 | " A batch of cropped images\n", 391 | " \"\"\"\n", 392 | " \n", 393 | " affine_matrix = torch.eye(3, device=batch.device).float()\n", 394 | " affine_matrix = affine_matrix.unsqueeze(0)\n", 395 | " affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)\n", 396 | " affine_matrix = affine_matrix.contiguous()[:,:2]\n", 397 | " \n", 398 | " coords = F.affine_grid(\n", 399 | " affine_matrix, batch.shape[:2] + size, align_corners=True\n", 400 | " )\n", 401 | " \n", 402 | " top_range, bottom_range = coords.min(), coords.max()\n", 403 | " zoom = 1/(bottom_range - top_range).item()*2\n", 404 | " \n", 405 | " resizing_limit = min(\n", 406 | " batch.shape[-2]/coords.shape[-2],\n", 407 | " batch.shape[-1]/coords.shape[-1]\n", 408 | " )/2\n", 409 | " \n", 410 | " if resizing_limit > 1 and resizing_limit > zoom:\n", 411 | " batch = F.interpolate(\n", 412 | " batch, \n", 413 | " scale_factor=1/resizing_limit, \n", 414 | " mode='area', \n", 415 | " recompute_scale_factor=True\n", 416 | " )\n", 417 | " return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "id": "eb7ca54b-1103-4b90-8995-7dcc9ac56d7d", 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "# fastai augmentations\n", 428 | "tt_fastai = fastai_aug.ToTensor()\n", 429 | "i2f_fastai = fastai_aug.IntToFloatTensor()\n", 430 | "rrc_fastai = fastai_aug.RandomResizedCropGPU((224,224))\n", 431 | "\n", 432 | "# torchvision augmentations\n", 433 | "tt_torch = tvtfms.ToTensor()\n", 434 | "\n", 435 | "# apply fastai augmentations\n", 436 | "base_im_fastai = crop_fastai(im_fastai)\n", 437 | "result_im_fastai = rrc_fastai(\n", 438 | " i2f_fastai(\n", 439 | " tt_fastai(base_im_fastai).unsqueeze(0)\n", 440 | " ), split_idx=1\n", 441 | ")\n", 442 | "\n", 443 | "# apply torchvision augmentations\n", 444 | "result_im_tv = gpu_crop(tt_torch(tfmd_img).unsqueeze(0), (224,224))" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": null, 450 | "id": "4cb9ba76-3333-4320-adc9-68cb20c03287", 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [ 454 | "torch.allclose(result_im_fastai, result_im_tv)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "id": "a530430a-12fd-4a68-87b2-48518b943e58", 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "norm_torch = tvtfms.Normalize([0.485, 0.456, 0.406], [0.229,0.224,0.225])" 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": null, 470 | "id": "1888d3ca-56d5-4c39-882a-b35fb58e3411", 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "# fastai augmentations\n", 475 | "norm_fastai = fastai_aug.Normalize.from_stats(*fastai_aug.imagenet_stats, cuda=False)\n", 476 | "# apply fastai augmentations\n", 477 | "base_im_fastai = crop_fastai(im_fastai)\n", 478 | "result_im_fastai = norm_fastai(\n", 479 | " rrc_fastai(\n", 480 | " i2f_fastai(\n", 481 | " tt_fastai(base_im_fastai).unsqueeze(0)\n", 482 | " ), split_idx=1\n", 483 | " )\n", 484 | ")\n", 485 | "\n", 486 | "# apply torchvision augmentations\n", 487 | "result_im_tv = norm_torch(gpu_crop(tt_torch(tfmd_img).unsqueeze(0), (224,224)))" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": null, 493 | "id": "4deffdc8-c0c3-43d2-9732-2d6e8405b772", 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [ 497 | "torch.allclose(result_im_fastai, result_im_tv)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "id": "bd97c353-8e08-4ff8-b66d-e2783ec98b4d", 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "import typing\n", 508 | "from PIL import Image\n", 509 | "import torchvision.transforms.functional as tvf" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "id": "f7033592-efab-44f9-9e0d-5640f5d4a05a", 516 | "metadata": {}, 517 | "outputs": [], 518 | "source": [ 519 | "def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:\n", 520 | " \"\"\"\n", 521 | " Takes a `PIL.Image` and crops it `size` unless one \n", 522 | " dimension is larger than the actual image. Padding \n", 523 | " must be performed afterwards if so.\n", 524 | " \n", 525 | " Args:\n", 526 | " image (`PIL.Image`):\n", 527 | " An image to perform cropping on\n", 528 | " size (`tuple` of integers):\n", 529 | " A size to crop to, should be in the form\n", 530 | " of (width, height)\n", 531 | " \n", 532 | " Returns:\n", 533 | " An augmented `PIL.Image`\n", 534 | " \"\"\"\n", 535 | " top = (image.shape[-1] - size[0]) // 2\n", 536 | " left = (image.shape[-2] - size[1]) // 2\n", 537 | " \n", 538 | " top = max(top, 0)\n", 539 | " left = max(left, 0)\n", 540 | " \n", 541 | " height = min(top + size[0], image.shape[-1])\n", 542 | " width = min(left + size[1], image.shape[-2])\n", 543 | " return image.crop((top, left, height, width))" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "id": "ff65146c-6caf-4890-9d3c-9cced4620d0d", 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "def pad(image, size:typing.Tuple[int,int]) -> Image:\n", 554 | " \"\"\"\n", 555 | " Takes a `PIL.Image` and pads it to `size` with\n", 556 | " zeros.\n", 557 | " \n", 558 | " Args:\n", 559 | " image (`PIL.Image`):\n", 560 | " An image to perform padding on\n", 561 | " size (`tuple` of integers):\n", 562 | " A size to pad to, should be in the form\n", 563 | " of (width, height)\n", 564 | " \n", 565 | " Returns:\n", 566 | " An augmented `PIL.Image`\n", 567 | " \"\"\"\n", 568 | " top = (image.shape[-1] - size[0]) // 2\n", 569 | " left = (image.shape[-2] - size[1]) // 2\n", 570 | " \n", 571 | " pad_top = max(-top, 0)\n", 572 | " pad_left = max(-left, 0)\n", 573 | " \n", 574 | " height, width = (\n", 575 | " max(size[1] - image.shape[-1] + top, 0), \n", 576 | " max(size[0] - image.shape[-2] + left, 0)\n", 577 | " )\n", 578 | " return tvf.pad(\n", 579 | " image, \n", 580 | " [pad_top, pad_left, height, width], \n", 581 | " padding_mode=\"constant\"\n", 582 | " )" 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": null, 588 | "id": "d1406278-0a0e-43e1-b68a-62d8fe54ce1e", 589 | "metadata": {}, 590 | "outputs": [], 591 | "source": [ 592 | "def gpu_crop(\n", 593 | " batch:torch.tensor, \n", 594 | " size:typing.Tuple[int,int]\n", 595 | "):\n", 596 | " \"\"\"\n", 597 | " Crops each image in `batch` to a particular `size`.\n", 598 | " \n", 599 | " Args:\n", 600 | " batch (array of `torch.Tensor`):\n", 601 | " A batch of images, should be of shape `NxCxWxH`\n", 602 | " size (`tuple` of integers):\n", 603 | " A size to pad to, should be in the form\n", 604 | " of (width, height)\n", 605 | " \n", 606 | " Returns:\n", 607 | " A batch of cropped images\n", 608 | " \"\"\"\n", 609 | " \n", 610 | " affine_matrix = torch.eye(3, device=batch.device).float()\n", 611 | " affine_matrix = affine_matrix.unsqueeze(0)\n", 612 | " affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)\n", 613 | " affine_matrix = affine_matrix.contiguous()[:,:2]\n", 614 | " \n", 615 | " coords = F.affine_grid(\n", 616 | " affine_matrix, batch.shape[:2] + size, align_corners=True\n", 617 | " )\n", 618 | " \n", 619 | " top_range, bottom_range = coords.min(), coords.max()\n", 620 | " zoom = 1/(bottom_range - top_range).item()*2\n", 621 | " \n", 622 | " resizing_limit = min(\n", 623 | " batch.shape[-2]/coords.shape[-2],\n", 624 | " batch.shape[-1]/coords.shape[-1]\n", 625 | " )/2\n", 626 | " \n", 627 | " if resizing_limit > 1 and resizing_limit > zoom:\n", 628 | " batch = F.interpolate(\n", 629 | " batch, \n", 630 | " scale_factor=1/resizing_limit, \n", 631 | " mode='area', \n", 632 | " recompute_scale_factor=True\n", 633 | " )\n", 634 | " return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)" 635 | ] 636 | } 637 | ], 638 | "metadata": { 639 | "kernelspec": { 640 | "display_name": "fastai", 641 | "language": "python", 642 | "name": "python3" 643 | }, 644 | "language_info": { 645 | "codemirror_mode": { 646 | "name": "ipython", 647 | "version": 3 648 | }, 649 | "file_extension": ".py", 650 | "mimetype": "text/x-python", 651 | "name": "python", 652 | "nbconvert_exporter": "python", 653 | "pygments_lexer": "ipython3", 654 | "version": "3.10.9" 655 | }, 656 | "vscode": { 657 | "interpreter": { 658 | "hash": "dbeaabf96d056229716848a298cd9413f5c098c5e85ebec7037464305d96e83e" 659 | } 660 | } 661 | }, 662 | "nbformat": 4, 663 | "nbformat_minor": 5 664 | } 665 | --------------------------------------------------------------------------------