├── .gitignore
├── LICENSE
├── README.md
├── images
├── fish_000004249599_07973.png
├── mask_000004249599_07973.png
├── results.png
└── sample_fish.png
├── models
└── unet.pt
├── notebooks
├── inference.ipynb
└── train.ipynb
└── src
├── FishDataset.py
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch U-Net
2 |
3 |
4 | This repository contains a simple PyTorch implementation of an U-Net for semantic segmentation of fish images, using [this dataset](http://groups.inf.ed.ac.uk/f4k/GROUNDTRUTH/RECOG/) by B. J. Boom, P. X. Huang and J. He, R. B. Fisher [1].
5 |
6 | Here is a sample fish image and its ground truth mask:
7 |
8 |
9 |
10 | The model is very simple and not super accurate, but the results are kinda cute:
11 |
12 |
13 |
14 | The code for the U-Net is partially based on this [Kaggle kernel](https://www.kaggle.com/mlagunas/naive-unet-with-pytorch-tensorboard-logging).
15 |
16 | [[1] B. J. Boom, P. X. Huang, J. He, R. B. Fisher, "Supporting Ground-Truth annotation of image datasets using clustering", 21st Int. Conf. on Pattern Recognition (ICPR), 2012](https://ieeexplore.ieee.org/document/6460437/)
17 |
--------------------------------------------------------------------------------
/images/fish_000004249599_07973.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arturml/pytorch-unet/85d94c5f691fe5f35f1c0ad541d565e63c870947/images/fish_000004249599_07973.png
--------------------------------------------------------------------------------
/images/mask_000004249599_07973.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arturml/pytorch-unet/85d94c5f691fe5f35f1c0ad541d565e63c870947/images/mask_000004249599_07973.png
--------------------------------------------------------------------------------
/images/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arturml/pytorch-unet/85d94c5f691fe5f35f1c0ad541d565e63c870947/images/results.png
--------------------------------------------------------------------------------
/images/sample_fish.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arturml/pytorch-unet/85d94c5f691fe5f35f1c0ad541d565e63c870947/images/sample_fish.png
--------------------------------------------------------------------------------
/models/unet.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arturml/pytorch-unet/85d94c5f691fe5f35f1c0ad541d565e63c870947/models/unet.pt
--------------------------------------------------------------------------------
/notebooks/inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "sys.path.append('../src')"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 9,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "%matplotlib inline\n",
20 | "import numpy as np\n",
21 | "import torch \n",
22 | "import torch.nn as nn\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "import numpy as np\n",
25 | "from torch.autograd import Variable\n",
26 | "from torch.utils.data.dataset import Dataset\n",
27 | "from torchvision import transforms\n",
28 | "from FishDataset import FishDataset\n",
29 | "from PIL import Image\n",
30 | "from sklearn.model_selection import train_test_split"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 3,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "fish_dataset = FishDataset('../data')"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 4,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "unet = torch.load('../models/unet.pt')"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 5,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "def extract_fish(image, model):\n",
58 | " original_shape = image.size\n",
59 | " image = image.resize((128, 128))\n",
60 | " inputs = Variable(transforms.ToTensor()(image).unsqueeze(0)).cuda()\n",
61 | " outputs = model(inputs).round().squeeze(0).cpu().data\n",
62 | " mask = transforms.ToPILImage()(outputs)\n",
63 | " background = Image.new('RGB', (128, 128), color='white')\n",
64 | " \n",
65 | " return Image.composite(image, background, mask).resize(original_shape)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 6,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "# use the same random_sate to get the same validation set from traning\n",
75 | "_, test_indices = train_test_split(np.arange(len(fish_dataset)), test_size=0.2, random_state=42)"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 7,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "images = [fish_dataset[i][0] for i in test_indices[:10]]"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 8,
90 | "metadata": {},
91 | "outputs": [
92 | {
93 | "data": {
94 | "image/png": "\n",
95 | "text/plain": [
96 | ""
97 | ]
98 | },
99 | "metadata": {},
100 | "output_type": "display_data"
101 | }
102 | ],
103 | "source": [
104 | "fig, axis = plt.subplots(10, 2, figsize=(15,15))\n",
105 | "for image, (ax1, ax2) in zip(images, axis):\n",
106 | " fish = extract_fish(image, unet)\n",
107 | " ax1.imshow(image)\n",
108 | " ax2.imshow(fish)\n",
109 | " ax1.set_xticks([])\n",
110 | " ax1.set_yticks([])\n",
111 | " ax2.set_xticks([])\n",
112 | " ax2.set_yticks([])\n",
113 | "plt.show()"
114 | ]
115 | }
116 | ],
117 | "metadata": {
118 | "kernelspec": {
119 | "display_name": "Environment (conda_pytorch_p36)",
120 | "language": "python",
121 | "name": "conda_pytorch_p36"
122 | },
123 | "language_info": {
124 | "codemirror_mode": {
125 | "name": "ipython",
126 | "version": 3
127 | },
128 | "file_extension": ".py",
129 | "mimetype": "text/x-python",
130 | "name": "python",
131 | "nbconvert_exporter": "python",
132 | "pygments_lexer": "ipython3",
133 | "version": "3.6.4"
134 | }
135 | },
136 | "nbformat": 4,
137 | "nbformat_minor": 2
138 | }
139 |
--------------------------------------------------------------------------------
/notebooks/train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import os\n",
12 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\""
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {
19 | "collapsed": true
20 | },
21 | "outputs": [],
22 | "source": [
23 | "import sys\n",
24 | "sys.path.append('../src')"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 3,
30 | "metadata": {
31 | "collapsed": true
32 | },
33 | "outputs": [],
34 | "source": [
35 | "import numpy as np\n",
36 | "import torch \n",
37 | "import torch.nn as nn\n",
38 | "import os\n",
39 | "from torch.autograd import Variable\n",
40 | "from torch.utils.data.dataset import Dataset\n",
41 | "from torch.utils.data import DataLoader\n",
42 | "from torch.utils.data.sampler import SubsetRandomSampler\n",
43 | "from torchvision import transforms\n",
44 | "from sklearn.model_selection import train_test_split\n",
45 | "from FishDataset import FishDataset"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 4,
51 | "metadata": {
52 | "collapsed": true
53 | },
54 | "outputs": [],
55 | "source": [
56 | "%load_ext autoreload\n",
57 | "%autoreload 2\n",
58 | "from model import UNet"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 5,
64 | "metadata": {
65 | "collapsed": true
66 | },
67 | "outputs": [],
68 | "source": [
69 | "train_transform = transforms.Compose([\n",
70 | " transforms.Resize(size=(128, 128)),\n",
71 | " transforms.RandomHorizontalFlip(),\n",
72 | " transforms.ToTensor()\n",
73 | "])\n",
74 | "\n",
75 | "test_transform = transforms.Compose([\n",
76 | " transforms.Resize(size=(128, 128)),\n",
77 | " transforms.ToTensor()\n",
78 | "])"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 6,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "train_dataset = FishDataset('../data', download=True, transform=train_transform, target_transform=train_transform)"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 7,
93 | "metadata": {
94 | "collapsed": true
95 | },
96 | "outputs": [],
97 | "source": [
98 | "train_indices, test_indices = train_test_split(np.arange(len(train_dataset)), test_size=0.2, random_state=42)"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 8,
104 | "metadata": {
105 | "collapsed": true
106 | },
107 | "outputs": [],
108 | "source": [
109 | "train_loader = DataLoader(\n",
110 | " train_dataset,\n",
111 | " batch_size=32,\n",
112 | " sampler=SubsetRandomSampler(train_indices),\n",
113 | " num_workers=4\n",
114 | ")\n",
115 | "\n",
116 | "val_loader = DataLoader(\n",
117 | " FishDataset('../data', transform=test_transform, target_transform=test_transform),\n",
118 | " batch_size=32,\n",
119 | " sampler=SubsetRandomSampler(train_indices),\n",
120 | " num_workers=4\n",
121 | ")"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 9,
127 | "metadata": {
128 | "collapsed": true
129 | },
130 | "outputs": [],
131 | "source": [
132 | "def jaccard(outputs, targets):\n",
133 | " outputs = outputs.view(outputs.size(0), -1)\n",
134 | " targets = targets.view(targets.size(0), -1)\n",
135 | " intersection = (outputs * targets).sum(1)\n",
136 | " union = (outputs + targets).sum(1) - intersection\n",
137 | " jac = (intersection + 0.001) / (union + 0.001)\n",
138 | " return jac.mean()"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 10,
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "data": {
148 | "text/plain": [
149 | "UNet(\n",
150 | " (down1): Sequential(\n",
151 | " (0): conv_block(\n",
152 | " (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
153 | " (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)\n",
154 | " (leaky_relu): LeakyReLU(0.01)\n",
155 | " )\n",
156 | " (1): conv_block(\n",
157 | " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
158 | " (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)\n",
159 | " (leaky_relu): LeakyReLU(0.01)\n",
160 | " )\n",
161 | " )\n",
162 | " (down2): Sequential(\n",
163 | " (0): conv_block(\n",
164 | " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
165 | " (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
166 | " (leaky_relu): LeakyReLU(0.01)\n",
167 | " )\n",
168 | " (1): conv_block(\n",
169 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
170 | " (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
171 | " (leaky_relu): LeakyReLU(0.01)\n",
172 | " )\n",
173 | " )\n",
174 | " (down3): Sequential(\n",
175 | " (0): conv_block(\n",
176 | " (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
177 | " (batch_norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
178 | " (leaky_relu): LeakyReLU(0.01)\n",
179 | " )\n",
180 | " (1): conv_block(\n",
181 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
182 | " (batch_norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
183 | " (leaky_relu): LeakyReLU(0.01)\n",
184 | " )\n",
185 | " )\n",
186 | " (middle): conv_block(\n",
187 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
188 | " (batch_norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
189 | " (leaky_relu): LeakyReLU(0.01)\n",
190 | " )\n",
191 | " (up3): Sequential(\n",
192 | " (0): conv_block(\n",
193 | " (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
194 | " (batch_norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
195 | " (leaky_relu): LeakyReLU(0.01)\n",
196 | " )\n",
197 | " (1): conv_block(\n",
198 | " (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
199 | " (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
200 | " (leaky_relu): LeakyReLU(0.01)\n",
201 | " )\n",
202 | " )\n",
203 | " (up2): Sequential(\n",
204 | " (0): conv_block(\n",
205 | " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
206 | " (batch_norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
207 | " (leaky_relu): LeakyReLU(0.01)\n",
208 | " )\n",
209 | " (1): conv_block(\n",
210 | " (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
211 | " (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)\n",
212 | " (leaky_relu): LeakyReLU(0.01)\n",
213 | " )\n",
214 | " )\n",
215 | " (up1): Sequential(\n",
216 | " (0): conv_block(\n",
217 | " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
218 | " (batch_norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)\n",
219 | " (leaky_relu): LeakyReLU(0.01)\n",
220 | " )\n",
221 | " (1): conv_block(\n",
222 | " (conv): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
223 | " (batch_norm): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)\n",
224 | " (leaky_relu): LeakyReLU(0.01)\n",
225 | " )\n",
226 | " )\n",
227 | ")"
228 | ]
229 | },
230 | "execution_count": 10,
231 | "metadata": {},
232 | "output_type": "execute_result"
233 | }
234 | ],
235 | "source": [
236 | "model = UNet()\n",
237 | "model.cuda()"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": 11,
243 | "metadata": {
244 | "collapsed": true
245 | },
246 | "outputs": [],
247 | "source": [
248 | "criterion = nn.BCELoss()\n",
249 | "optimizer = torch.optim.Adam(model.parameters())"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": 12,
255 | "metadata": {
256 | "collapsed": true
257 | },
258 | "outputs": [],
259 | "source": [
260 | "model_folder = os.path.abspath('../models')\n",
261 | "if not os.path.exists(model_folder):\n",
262 | " os.mkdir(model_folder)\n",
263 | "model_path = os.path.join(model_folder, 'unet.pt')"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 13,
269 | "metadata": {
270 | "scrolled": false
271 | },
272 | "outputs": [
273 | {
274 | "name": "stdout",
275 | "output_type": "stream",
276 | "text": [
277 | "Starting epoch 1/5\n",
278 | " batch 1/685 loss: 0.7597, jaccard 0.1310\n",
279 | " batch 51/685 loss: 0.6546, jaccard 0.6962\n",
280 | " batch 101/685 loss: 0.6477, jaccard 0.7860\n",
281 | " batch 151/685 loss: 0.6474, jaccard 0.7801\n",
282 | " batch 201/685 loss: 0.6405, jaccard 0.8275\n",
283 | " batch 251/685 loss: 0.6364, jaccard 0.8467\n",
284 | " batch 301/685 loss: 0.6396, jaccard 0.8081\n",
285 | " batch 351/685 loss: 0.6308, jaccard 0.8697\n",
286 | " batch 401/685 loss: 0.6329, jaccard 0.8403\n",
287 | " batch 451/685 loss: 0.6293, jaccard 0.8602\n",
288 | " batch 501/685 loss: 0.6279, jaccard 0.8722\n",
289 | " batch 551/685 loss: 0.6297, jaccard 0.8639\n",
290 | " batch 601/685 loss: 0.6294, jaccard 0.8621\n",
291 | " batch 651/685 loss: 0.6205, jaccard 0.8920\n",
292 | "Finished epoch 1, starting evaluation\n",
293 | " loss: 0.6381 jaccard: 0.8154 val_loss: 0.6274 val_jaccard: 0.8766\n",
294 | "\n",
295 | "Starting epoch 2/5\n",
296 | " batch 1/685 loss: 0.6228, jaccard 0.8836\n",
297 | " batch 51/685 loss: 0.6248, jaccard 0.8651\n",
298 | " batch 101/685 loss: 0.6259, jaccard 0.8638\n",
299 | " batch 151/685 loss: 0.6236, jaccard 0.8764\n",
300 | " batch 201/685 loss: 0.6231, jaccard 0.8814\n",
301 | " batch 251/685 loss: 0.6231, jaccard 0.8788\n",
302 | " batch 301/685 loss: 0.6223, jaccard 0.8793\n",
303 | " batch 351/685 loss: 0.6178, jaccard 0.9005\n",
304 | " batch 401/685 loss: 0.6190, jaccard 0.8915\n",
305 | " batch 451/685 loss: 0.6218, jaccard 0.8824\n",
306 | " batch 501/685 loss: 0.6208, jaccard 0.8881\n",
307 | " batch 551/685 loss: 0.6185, jaccard 0.8938\n",
308 | " batch 601/685 loss: 0.6250, jaccard 0.8801\n",
309 | " batch 651/685 loss: 0.6249, jaccard 0.8721\n",
310 | "Finished epoch 2, starting evaluation\n",
311 | " loss: 0.6233 jaccard: 0.8786 val_loss: 0.6207 val_jaccard: 0.8883\n",
312 | "\n",
313 | "Starting epoch 3/5\n",
314 | " batch 1/685 loss: 0.6170, jaccard 0.8874\n",
315 | " batch 51/685 loss: 0.6205, jaccard 0.8924\n",
316 | " batch 101/685 loss: 0.6252, jaccard 0.8714\n",
317 | " batch 151/685 loss: 0.6192, jaccard 0.8835\n",
318 | " batch 201/685 loss: 0.6179, jaccard 0.8774\n",
319 | " batch 251/685 loss: 0.6166, jaccard 0.9028\n",
320 | " batch 301/685 loss: 0.6184, jaccard 0.8931\n",
321 | " batch 351/685 loss: 0.6176, jaccard 0.8855\n",
322 | " batch 401/685 loss: 0.6184, jaccard 0.8910\n",
323 | " batch 451/685 loss: 0.6180, jaccard 0.8913\n",
324 | " batch 501/685 loss: 0.6147, jaccard 0.8963\n",
325 | " batch 551/685 loss: 0.6183, jaccard 0.9003\n",
326 | " batch 601/685 loss: 0.6175, jaccard 0.8936\n",
327 | " batch 651/685 loss: 0.6220, jaccard 0.8853\n",
328 | "Finished epoch 3, starting evaluation\n",
329 | " loss: 0.6194 jaccard: 0.8882 val_loss: 0.6186 val_jaccard: 0.8856\n",
330 | "\n",
331 | "Starting epoch 4/5\n",
332 | " batch 1/685 loss: 0.6154, jaccard 0.9014\n",
333 | " batch 51/685 loss: 0.6153, jaccard 0.9080\n",
334 | " batch 101/685 loss: 0.6146, jaccard 0.9078\n",
335 | " batch 151/685 loss: 0.6188, jaccard 0.8717\n",
336 | " batch 201/685 loss: 0.6168, jaccard 0.8856\n",
337 | " batch 251/685 loss: 0.6151, jaccard 0.8938\n",
338 | " batch 301/685 loss: 0.6171, jaccard 0.8906\n",
339 | " batch 351/685 loss: 0.6155, jaccard 0.8950\n",
340 | " batch 401/685 loss: 0.6106, jaccard 0.9106\n",
341 | " batch 451/685 loss: 0.6208, jaccard 0.8844\n",
342 | " batch 501/685 loss: 0.6158, jaccard 0.8982\n",
343 | " batch 551/685 loss: 0.6194, jaccard 0.8981\n",
344 | " batch 601/685 loss: 0.6161, jaccard 0.8912\n",
345 | " batch 651/685 loss: 0.6158, jaccard 0.8964\n",
346 | "Finished epoch 4, starting evaluation\n",
347 | " loss: 0.6164 jaccard: 0.8924 val_loss: 0.6146 val_jaccard: 0.8952\n",
348 | "\n",
349 | "Starting epoch 5/5\n",
350 | " batch 1/685 loss: 0.6130, jaccard 0.8998\n",
351 | " batch 51/685 loss: 0.6131, jaccard 0.8983\n",
352 | " batch 101/685 loss: 0.6143, jaccard 0.8942\n",
353 | " batch 151/685 loss: 0.6146, jaccard 0.9045\n",
354 | " batch 201/685 loss: 0.6188, jaccard 0.8719\n",
355 | " batch 251/685 loss: 0.6146, jaccard 0.8931\n",
356 | " batch 301/685 loss: 0.6120, jaccard 0.8870\n",
357 | " batch 351/685 loss: 0.6087, jaccard 0.9090\n",
358 | " batch 401/685 loss: 0.6129, jaccard 0.8918\n",
359 | " batch 451/685 loss: 0.6132, jaccard 0.8809\n",
360 | " batch 501/685 loss: 0.6136, jaccard 0.8864\n",
361 | " batch 551/685 loss: 0.6093, jaccard 0.8916\n",
362 | " batch 601/685 loss: 0.6102, jaccard 0.9043\n",
363 | " batch 651/685 loss: 0.6140, jaccard 0.8954\n",
364 | "Finished epoch 5, starting evaluation\n",
365 | " loss: 0.6133 jaccard: 0.8948 val_loss: 0.6111 val_jaccard: 0.8984\n",
366 | "\n"
367 | ]
368 | }
369 | ],
370 | "source": [
371 | "hist = {'loss': [], 'jaccard': [], 'val_loss': [], 'val_jaccard': []}\n",
372 | "num_epochs = 5\n",
373 | "display_steps = 50\n",
374 | "best_jaccard = 0\n",
375 | "for epoch in range(num_epochs):\n",
376 | " print('Starting epoch {}/{}'.format(epoch+1, num_epochs))\n",
377 | " # train\n",
378 | " model.train()\n",
379 | " running_loss = 0.0\n",
380 | " running_jaccard = 0.0\n",
381 | " for batch_idx, (images, masks, _) in enumerate(train_loader):\n",
382 | " images = Variable(images.cuda())\n",
383 | " masks = Variable(masks.cuda())\n",
384 | " \n",
385 | " optimizer.zero_grad()\n",
386 | " outputs = model(images)\n",
387 | " predicted = outputs.round()\n",
388 | " loss = criterion(outputs, masks)\n",
389 | " loss.backward()\n",
390 | " optimizer.step()\n",
391 | " \n",
392 | " jac = jaccard(outputs.round(), masks)\n",
393 | " running_jaccard += jac.data[0]\n",
394 | " running_loss += loss.data[0]\n",
395 | " \n",
396 | " if batch_idx % display_steps == 0:\n",
397 | " print(' ', end='')\n",
398 | " print('batch {:>3}/{:>3} loss: {:.4f}, jaccard {:.4f}\\r'\\\n",
399 | " .format(batch_idx+1, len(train_loader),\n",
400 | " loss.data[0], jac.data[0]))\n",
401 | "\n",
402 | " \n",
403 | " # evalute\n",
404 | " print('Finished epoch {}, starting evaluation'.format(epoch+1))\n",
405 | " model.eval()\n",
406 | " val_running_loss = 0.0\n",
407 | " val_running_jaccard = 0.0\n",
408 | " for images, masks, _ in val_loader:\n",
409 | " images = Variable(images.cuda())\n",
410 | " masks = Variable(masks.cuda())\n",
411 | " \n",
412 | " outputs = model(images)\n",
413 | " loss = criterion(outputs, masks)\n",
414 | " \n",
415 | " val_running_loss += loss.data[0]\n",
416 | " jac = jaccard(outputs.round(), masks)\n",
417 | " val_running_jaccard += jac.data[0]\n",
418 | "\n",
419 | " train_loss = running_loss / len(train_loader)\n",
420 | " train_jaccard = running_jaccard / len(train_loader)\n",
421 | " val_loss = val_running_loss / len(val_loader)\n",
422 | " val_jaccard = val_running_jaccard / len(val_loader)\n",
423 | " \n",
424 | " hist['loss'].append(train_loss)\n",
425 | " hist['jaccard'].append(train_jaccard)\n",
426 | " hist['val_loss'].append(val_loss)\n",
427 | " hist['val_jaccard'].append(val_jaccard)\n",
428 | " \n",
429 | " if val_jaccard > best_jaccard:\n",
430 | " torch.save(model, model_path)\n",
431 | " print(' ', end='')\n",
432 | " print('loss: {:.4f} jaccard: {:.4f} \\\n",
433 | " val_loss: {:.4f} val_jaccard: {:4.4f}\\n'\\\n",
434 | " .format(train_loss, train_jaccard, val_loss, val_jaccard))"
435 | ]
436 | }
437 | ],
438 | "metadata": {
439 | "kernelspec": {
440 | "display_name": "Python 3",
441 | "language": "python",
442 | "name": "python3"
443 | },
444 | "language_info": {
445 | "codemirror_mode": {
446 | "name": "ipython",
447 | "version": 3
448 | },
449 | "file_extension": ".py",
450 | "mimetype": "text/x-python",
451 | "name": "python",
452 | "nbconvert_exporter": "python",
453 | "pygments_lexer": "ipython3",
454 | "version": "3.6.1"
455 | }
456 | },
457 | "nbformat": 4,
458 | "nbformat_minor": 2
459 | }
460 |
--------------------------------------------------------------------------------
/src/FishDataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import re
4 | import numpy as np
5 | import random
6 | import tarfile
7 | import urllib
8 | from PIL import Image
9 | from glob import glob
10 | import matplotlib.pyplot as plt
11 | from torch.utils.data.dataset import Dataset
12 |
13 | class FishDataset(Dataset):
14 | """Fishes dataset."""
15 |
16 | def __init__(self, root_dir, transform=None, target_transform=None, download=False):
17 | """
18 | Args:
19 | root_dir (string): Data directory containing the fish_image and mask_image folders.
20 | transform (callable, optional): Optional transform to be applied on an image.
21 | """
22 | self.root_dir = os.path.abspath(root_dir)
23 | self.transform = transform
24 | self.target_transform = target_transform
25 |
26 | if download:
27 | self.download()
28 |
29 | if not self._check_exists():
30 | raise RuntimeError('Dataset not found. You can use download=True to download it.')
31 |
32 |
33 | self.images = glob(os.path.join(root_dir, 'fish_image/*/*.png'))
34 | self.masks = [re.sub('fish', 'mask', image) for image in self.images]
35 | self.labels = [int(re.search('.*fish_image/fish_(\d+)', image).group(1)) for image in self.images]
36 |
37 | def __len__(self):
38 | return len(self.labels)
39 |
40 | def __getitem__(self, index):
41 | label = self.labels[index]
42 | image = Image.open(self.images[index])
43 | mask = Image.open(self.masks[index])
44 |
45 | if mask.mode == '1':
46 | mask = mask.convert('L')
47 |
48 | # https://github.com/pytorch/vision/issues/9
49 | seed = np.random.randint(2147483647)
50 | random.seed(seed)
51 | if self.transform is not None:
52 | image = self.transform(image)
53 |
54 | random.seed(seed)
55 | if self.target_transform is not None:
56 | mask = self.target_transform(mask)
57 | mask = mask.round()
58 |
59 | return (image, mask, label)
60 |
61 |
62 | def download(self):
63 | if self._check_exists():
64 | return
65 |
66 | try:
67 | os.makedirs(self.root_dir)
68 | except FileExistsError:
69 | pass
70 |
71 | url = 'http://groups.inf.ed.ac.uk/f4k/GROUNDTRUTH/RECOG/Archive/fishRecognition_GT.tar'
72 | file_path = os.path.join(self.root_dir, 'fishRecognition_GT.tar')
73 | print('Downloading...', end=' ')
74 | urllib.request.urlretrieve(url, file_path)
75 | print('Done!')
76 | print('Extracting files...', end=' ')
77 | with tarfile.open(file_path) as tar:
78 | tar.extractall(self.root_dir)
79 | os.remove(file_path)
80 | print('Done!')
81 |
82 | def _check_exists(self):
83 | return os.path.exists(os.path.join(self.root_dir, 'fish_image')) and \
84 | os.path.exists(os.path.join(self.root_dir, 'mask_image'))
85 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.init as init
5 | import numpy as np
6 |
7 | class conv_block(nn.Module):
8 | def __init__(self, in_channels, out_channels):
9 | super().__init__()
10 | self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
11 | init.xavier_uniform(self.conv.weight, gain=np.sqrt(2))
12 | self.batch_norm = nn.BatchNorm2d(out_channels)
13 | self.leaky_relu = nn.LeakyReLU(0.01)
14 |
15 | def forward(self, x):
16 | x = self.conv(x)
17 | x = self.batch_norm(x)
18 | x = self.leaky_relu(x)
19 | return x
20 |
21 | class UNet(nn.Module):
22 | def __init__(self):
23 | super().__init__()
24 | self.down1 = nn.Sequential(
25 | conv_block(3, 32),
26 | conv_block(32, 32)
27 | )
28 | self.down2 = nn.Sequential(
29 | conv_block(32, 64),
30 | conv_block(64, 64)
31 | )
32 | self.down3 = nn.Sequential(
33 | conv_block(64, 128),
34 | conv_block(128, 128)
35 | )
36 |
37 | self.middle = conv_block(128, 128)
38 |
39 | self.up3 = nn.Sequential(
40 | conv_block(256, 256),
41 | conv_block(256, 64)
42 | )
43 |
44 | self.up2 = nn.Sequential(
45 | conv_block(128, 128),
46 | conv_block(128, 32)
47 | )
48 |
49 | self.up1 = nn.Sequential(
50 | conv_block(64, 64),
51 | conv_block(64, 1)
52 | )
53 |
54 | def forward(self, x):
55 | down1 = self.down1(x)
56 | out = F.max_pool2d(down1, 2)
57 |
58 | down2 = self.down2(out)
59 | out = F.max_pool2d(down2, 2)
60 |
61 | down3 = self.down3(out)
62 | out = F.max_pool2d(down3, 2)
63 |
64 | out = self.middle(out)
65 |
66 | out = F.upsample(out, scale_factor=2)
67 | out = torch.cat([down3, out], 1)
68 | out = self.up3(out)
69 |
70 | out = F.upsample(out, scale_factor=2)
71 | out = torch.cat([down2, out], 1)
72 | out = self.up2(out)
73 |
74 | out = F.upsample(out, scale_factor=2)
75 | out = torch.cat([down1, out], 1)
76 | out = self.up1(out)
77 |
78 | out = F.sigmoid(out)
79 |
80 | return out
81 |
--------------------------------------------------------------------------------