├── .gitignore ├── README.md ├── dataset ├── README.md ├── data_loader.py ├── dataset_README.md ├── dataset_exploration.ipynb ├── dataset_generation.ipynb ├── environment.yml ├── masks.py └── visualization.py ├── images ├── masks.png └── scalars.png └── unet_segmentation ├── data ├── augmentation.py └── dataset.py ├── environment.yml ├── metrics.py ├── predict.py ├── prediction ├── display.py └── post_processing.py ├── trainer.py ├── training ├── stats.py └── training.py └── unet.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | deep_fashion 3 | __pycache__ 4 | .ipynb_checkpoints 5 | **.pt 6 | 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unet segmentation 2 | 3 | This repository contains the resources needed for semantic segmentation 4 | using Unet for Pytorch. 5 | 6 | ## Setting environment up 7 | 8 | Make sure conda >=4.8.3 is installed in your system. Then, type: 9 | 10 | ```shell script 11 | conda env create -f unet_segmentation/environment.yml 12 | conda env activate unet 13 | ``` 14 | 15 | # Training 16 | 17 | First of all, we need to ensure we have data to train on. We can generate 18 | fashion data through the Deep Fashion 2 dataset by following the instructions 19 | in the [dataset directory](dataset). 20 | 21 | Once your dataset is generated and your environment is active, you can type: 22 | 23 | ```shell script 24 | python unet_segmentation/trainer.py 25 | ``` 26 | 27 | This script assumes your dataset is located in `segmentation_dataset` and will 28 | start logging Tensorboard data into `logs`. 29 | 30 | You can visualize this data by typing: 31 | 32 | ```shell script 33 | tensorboard --logdir=logs 34 | ``` 35 | 36 | You will be able to visualize the following: 37 | 38 | Metrics in Tensorboard | Predicted and groundtruth masks in Tensorboard 39 | :-------------------------:|:-------------------------: 40 | ![](images/scalars.png) | ![](images/masks.png) 41 | 42 | ## Pre-trained model 43 | 44 | A pre-trained model (Mean IOU 0.63) is available [here](https://drive.google.com/file/d/1sC_puW3pc6P75KTi2hJxgjgUiZo3zl1Q/view?usp=sharing). 45 | 46 | Given model weights are in `unet.pt`, you can predict mask on example `example.png` by: 47 | 48 | ```python 49 | import torch 50 | from unet_segmentation.prediction.display import display_prediction 51 | 52 | model = torch.load('unet.pt') 53 | display_prediction(model, 'example.png') 54 | ``` 55 | 56 | # Future work 57 | 58 | - [ ] Add test evaluation for generated models. 59 | - [ ] Argument parsing in trainer script (as parameters are hard-coded in the trainer script). 60 | - [ ] Port code to Pytorch Ignite. 61 | - [ ] Generate masks in a GAN-fashion. 62 | - [ ] Explore using trained model for clothing retrieval. 63 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # Deep Fashion segmentation dataset 2 | 3 | - [dataset_exploration.ipynb](dataset_exploration.ipynb): Notebook to have a 4 | general view of the Deep Fashion 2 dataset. 5 | 6 | - [dataset_generation.ipynb](dataset_generation.ipynb): Notebook to generate 7 | a segmentation dataset ready for training, 8 | 9 | ## Setting environment up 10 | 11 | Make sure conda >=4.8.3 is installed in your system. Then, type: 12 | 13 | ```shell script 14 | conda env create -f environment.yml 15 | conda env activate segmentation 16 | ``` 17 | 18 | Moreover, make sure you have downloaded the Deep Fashion 2 dataset from their [Github page](https://github.com/switchablenorms/DeepFashion2). 19 | 20 | ## Exploring dataset 21 | 22 | You can execute the dataset exploration notebook by typing: 23 | 24 | ```shell script 25 | papermill dataset_exploration.ipynb -p dataset_path 26 | ``` 27 | 28 | Where: 29 | - ``: path where you want to store the resulting 30 | notebook, where the corresponding output cells have been filled. 31 | - ``: root directory of the Deep Fashion 2 dataset. 32 | 33 | ## Generating the dataset 34 | 35 | The notebook generates two directories: 36 | 37 | - Deepfashion2 dataset for clothing segmentation. 38 | - Deepfashion2 toy dataset (i.e. 10 instances) for clothing segmentation, 39 | The purpose of this dataset is to check the sanity of our code during 40 | development. If we can easily overfit a tiny subset of rows, it means 41 | the model is learning. 42 | 43 | To generate those, you can type: 44 | 45 | ```shell script 46 | papermill dataset_generation.ipynb data_gen.out.ipynb \ 47 | -p dataset_path deep_fashion \ 48 | -p output_dir deep_fashion_seg_dataset \ 49 | -p toy_output_dir deep_fashion_toy_seg_dataset 50 | ``` 51 | 52 | Where: 53 | - ``: path where you want to store the resulting notebook, 54 | where the corresponding output cells have been filled. 55 | - ``: root directory of the Deep Fashion 2 dataset. 56 | - ``: output directory where to store the segmentation dataset. 57 | - ``: output directory where to store the toy segmentation 58 | dataset. 59 | -------------------------------------------------------------------------------- /dataset/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import itertools 4 | 5 | import pandas as pd 6 | 7 | from typing import List 8 | 9 | 10 | def _load_images_df(images_path: str) -> pd.DataFrame: 11 | """ Returns a DataFrame with the path and id of images in the dataset """ 12 | image_files = [ 13 | os.path.join(images_path, path) 14 | for path in os.listdir(images_path) 15 | ] 16 | image_df = pd.DataFrame(image_files, columns=['image_path']) 17 | # Set image filename as id 18 | image_df['id'] = image_df['image_path'].apply( 19 | lambda x: os.path.splitext(os.path.basename(x))[0] 20 | ).astype(int) 21 | return image_df.set_index('id') 22 | 23 | 24 | def _read_annotations(annotation_path: str) -> List[dict]: 25 | """ Returns the list of annotations in the annotation file """ 26 | annotation_id = os.path.splitext(os.path.basename(annotation_path))[0] 27 | 28 | with open(annotation_path, 'r') as f: 29 | metadata = json.load(f) 30 | 31 | # Read each item in the image 32 | items = [metadata[k] for k in metadata.keys() if 'item' in k] 33 | 34 | # Keep other metadata, just in case 35 | image_metadata = {k: metadata[k] 36 | for k in metadata.keys() if 'item' not in k} 37 | 38 | # Add index in each of the items 39 | for item in items: 40 | item.update({'id': int(annotation_id)}) 41 | item.update(image_metadata) 42 | 43 | return items 44 | 45 | 46 | def _load_annotations_df(annotations_path: str) -> pd.DataFrame: 47 | """ Loads a DataFrame containing the """ 48 | # Build absolute paths 49 | annotation_files = [ 50 | os.path.join(annotations_path, path) 51 | for path in os.listdir(annotations_path) 52 | ] 53 | 54 | # Map each file into a list of annotations 55 | annotations = [_read_annotations(annotation_path) 56 | for annotation_path in annotation_files] 57 | # Unpack lists 58 | annotations = list(itertools.chain.from_iterable(annotations)) 59 | 60 | return pd.DataFrame(annotations).set_index('id') 61 | 62 | 63 | def load_training_df(dataset_path: str) -> pd.DataFrame: 64 | """ Load data from training set in Deep Fashion 2 into DataFrame """ 65 | training_dir = os.path.join(dataset_path, 'train') 66 | images_dir = os.path.join(training_dir, 'image') 67 | images_df = _load_images_df(images_dir) 68 | annotations_df = _load_annotations_df(os.path.join(training_dir, 'annos')) 69 | # Merge DataFrames 70 | df = annotations_df.join(images_df, how='left') 71 | # Make sure no image or annotation is missing 72 | assert len(df) == len(annotations_df) 73 | return df 74 | -------------------------------------------------------------------------------- /dataset/dataset_README.md: -------------------------------------------------------------------------------- 1 | # Deep Fashion segmentation dataset 2 | 3 | This dataset contains a clean version of the [Deep Fashion 2 dataset](https://github.com/switchablenorms/DeepFashion2) 4 | for clothing segmentation. 5 | 6 | Dataset contains 191961 images (which can be user generated), each of those 7 | mapped to its corresponding segmentation mask. Masks are uint8 (i.e. interval 8 | [0, 255]) images of the same size as the original image that indicate the 9 | clothes from certain categories that appear in the image. Given a pixel of the 10 | mask, it contains no relevant clothing item if it is 0. Otherwise, pixel is 11 | tagged with the corresponding identifier. 12 | 13 | 6 categories have been defined: 14 | 15 | | Category name | # of items | 16 | |---------------|------------| 17 | | Top | 125789 | 18 | | Shorts | 36616 | 19 | | Dress | 49559 | 20 | | Skirt | 30835 | 21 | | Trousers | 55387 | 22 | | Outwear | 14000 | 23 | 24 | 25 | Rows have been randomly split into 3 sets: training (184821 rows), 26 | test(4800 rows) and validation (2340 rows). 27 | 28 | # Folder structure: 29 | 30 | - `data.json`: Contains the path (relative to this folder) to the original image 31 | (i.e. `image_path`) and the path to the mask image (`mask_path`). It also 32 | contains, per each image, the list of labels it has. 33 | - `images`: Directory where original images are. 34 | - `masks`: Directory where the mask for the images are stored. 35 | - `labels.json`: Contains the mapping between category ids and names. 36 | -------------------------------------------------------------------------------- /dataset/dataset_exploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Deep Fashion 2 visualization\n", 8 | "\n", 9 | "The main purpose of this notebook is to understand the contents of the Deep Fashion dataset and have answers to questions like:\n", 10 | "\n", 11 | "- How many images per category are there?\n", 12 | "- How many clothing objects per image are there?\n", 13 | "- Is the quality of the image good enough?\n", 14 | "- What does the attributes tell us about the images?" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from typing import List\n", 24 | "\n", 25 | "import skimage.io\n", 26 | "\n", 27 | "import pandas as pd\n", 28 | "import numpy as np\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n", 31 | "from masks import get_mask\n", 32 | "from data_loader import load_training_df\n", 33 | "from visualization import display_instances" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": { 40 | "tags": [ 41 | "parameters" 42 | ] 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "dataset_path = 'dataset'" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Read dataset\n", 54 | "\n", 55 | "Read images and annotations in training." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "df = load_training_df(dataset_path)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "Map ordinal categories (e.g. scale, occlusion and viewpoint) into categorical columns for better understanding (values according to documentation)." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "df['scale_categorical'] = df['scale'].map(\n", 81 | " {1: 'small_scale', 2: 'modest_scale', 3: 'large_scale'}\n", 82 | ")\n", 83 | "df['zoom_in_categorical'] = df['zoom_in'].map(\n", 84 | " {1: 'no_zoom_in', 2: 'medium_zoom_in', 3: 'large_zoom_in'}\n", 85 | ")\n", 86 | "df['viewpoint_categorical'] = df['viewpoint'].map(\n", 87 | " {1: 'no_wear', 2: 'frontal_viewpoint', 3: 'side_or_back_viewpoint'}\n", 88 | ")\n", 89 | "df['occlusion_categorical'] = df['occlusion'].map(\n", 90 | " {1: 'slight_occlusion', 2: 'medium_occlusion', 3: 'heavy_occlusion'}\n", 91 | ")" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "Let's visualize some examples." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "df.sample(3)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "## Data visualization" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "Let's observe random examples for the different features." 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "### Display image by category id" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "def _samples_per_category(df: pd.DataFrame,\n", 138 | " column: str,\n", 139 | " n_samples: int = 7) -> pd.DataFrame:\n", 140 | " return df.groupby(column)\\\n", 141 | " .apply(lambda x: x.sample(n_samples))\\\n", 142 | " .reset_index(level=0, drop=True)\\\n", 143 | " .reset_index()\n", 144 | " \n", 145 | "\n", 146 | "column = 'category_name'\n", 147 | "display_instances(_samples_per_category(df, column, n_samples=2),\n", 148 | " title_column=column,\n", 149 | " n_cols=9)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "df[column].value_counts().plot.bar(figsize=(20, 4), rot=45)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "We see some categories are highly close to one another (e.g. sling dress and short sleeve dress)." 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "### Display images by source information" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "column = 'source'\n", 182 | "display_instances(_samples_per_category(df, column, n_samples=8),\n", 183 | " title_column=column,\n", 184 | " n_cols=8)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "df[column].value_counts().plot.bar(figsize=(10, 4), rot=45)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "We see the `shop` images have much higher quality than `user` images. We see that there are ~3 times more shop images than user images." 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "### Display image by viewpoint information" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "column = 'viewpoint_categorical'\n", 217 | "display_instances(_samples_per_category(df, column, n_samples=4),\n", 218 | " title_column=column,\n", 219 | " n_cols=6)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "df[column].value_counts().plot.bar(figsize=(10, 4), rot=45)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "We see that the viewpoint information can be ambiguous as side viewpoint (which most of the time are mostly frontal vies) and back viewpoint are tagged in the same category. Moreover, we observe that most of the images fall in frontal category." 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "metadata": {}, 241 | "source": [ 242 | "### Display image by scale information" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "column = 'scale_categorical'\n", 252 | "display_instances(_samples_per_category(df, column, n_samples=4),\n", 253 | " title_column=column,\n", 254 | " n_cols=6)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "df[column].value_counts().plot.bar(figsize=(20, 4), rot=45)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "Again, we see that scale is not very informative, as `small` and `modest` scale refer to very similar kind of images. However, they seem to properly tag those which are zoomed in pictures." 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "### Display image by zoom-in information" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "column = 'zoom_in_categorical'\n", 287 | "display_instances(_samples_per_category(df, column, n_samples=4),\n", 288 | " title_column=column,\n", 289 | " n_cols=6)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "df[column].value_counts().plot.bar(figsize=(20, 4), rot=45)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "We observe that this feature gives little information about the content, as we see that similar images appear in different categories such as `no_zoom_in` and `large_zoom_in`." 306 | ] 307 | }, 308 | { 309 | "cell_type": "markdown", 310 | "metadata": {}, 311 | "source": [ 312 | "### Display image by occlusion information" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [ 321 | "column = 'occlusion_categorical'\n", 322 | "display_instances(_samples_per_category(df, column, n_samples=4),\n", 323 | " title_column=column,\n", 324 | " n_cols=6)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "df[column].value_counts().plot.bar(figsize=(20, 4), rot=45)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": {}, 339 | "source": [ 340 | "Again, it is not clear what is the criteria used to tag image occlusion." 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": {}, 346 | "source": [ 347 | "### Display image pairs" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "pair_ids = list(df.sample(1)['pair_id'].values)\n", 357 | "sample_pairs_df = df[df['pair_id'].isin(pair_ids)].drop_duplicates(['image_path'])\n", 358 | "display_instances(sample_pairs_df, title_column='source', n_cols=9)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "metadata": {}, 364 | "source": [ 365 | "Shop and user image does not need to be from the same size or color. That can be read in the `style` field (see [documentation](https://github.com/switchablenorms/DeepFashion2))." 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": {}, 371 | "source": [ 372 | "## Clothing elements per image: stats" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": {}, 379 | "outputs": [], 380 | "source": [ 381 | "clothes_per_image = df.groupby('image_path')['category_id'].count()\n", 382 | "mean, std = clothes_per_image.mean(), clothes_per_image.std()\n", 383 | "print(f'Clothes per image: {mean:.2f} +- {std:.2f}')" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "## Mask generation" 391 | ] 392 | }, 393 | { 394 | "cell_type": "markdown", 395 | "metadata": {}, 396 | "source": [ 397 | "Compute mask image from examples." 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "metadata": {}, 404 | "outputs": [], 405 | "source": [ 406 | "def _display_masks(image: np.ndarray, masks: List, ax=None):\n", 407 | " # Display image\n", 408 | " if ax is not None:\n", 409 | " axis = ax\n", 410 | " else:\n", 411 | " plt.figure(figsize=(8, 15))\n", 412 | " axis = plt\n", 413 | " axis.imshow(image)\n", 414 | " \n", 415 | " # Display all masks\n", 416 | " for mask in masks:\n", 417 | " axis.imshow(mask, alpha=0.25, vmin=-1.0, vmax=1.0)\n", 418 | "\n", 419 | " \n", 420 | "def instance_to_mask(row: pd.Series) -> np.ndarray:\n", 421 | " image = skimage.io.imread(row['image_path'])\n", 422 | " image_height, image_width = image.shape[:2]\n", 423 | " return get_mask(image_height,\n", 424 | " image_width,\n", 425 | " polygons=row['segmentation'],\n", 426 | " category_id=int(row['category_id']))\n", 427 | "\n", 428 | "\n", 429 | "def display_instance_mask(row: pd.Series, ax) -> None:\n", 430 | " masks = [instance_to_mask(row)]\n", 431 | " image = skimage.io.imread(row['image_path'])\n", 432 | " _display_masks(image, masks, ax=ax)\n", 433 | " ax.set_title(row[\"category_name\"])\n", 434 | " ax.axis('off')\n", 435 | " \n", 436 | "samples = df.sample(12)\n", 437 | "display_instances(samples, display_fn=display_instance_mask, n_cols=6)" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": {}, 443 | "source": [ 444 | "We see that, in many cases, polygons defining the clothing area are quite sharp and do not properly wrap the clothes margin." 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "Let's now visualize examples of images with all masks in it." 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "def display_all_instance_masks(row: pd.Series, ax) -> None:\n", 461 | " items = df[df.index == row.name]\n", 462 | " masks = items.apply(instance_to_mask, axis=1).values.tolist()\n", 463 | " image = skimage.io.imread(row['image_path'])\n", 464 | " _display_masks(image, masks, ax=ax)\n", 465 | " # Displau call categories\n", 466 | " categories = items[\"category_name\"].values.tolist()\n", 467 | " ax.set_title(f'{categories}', fontsize=8)\n", 468 | " ax.axis('off')\n", 469 | "\n", 470 | "samples = df.sample(12)\n", 471 | "display_instances(samples, display_fn=display_all_instance_masks, n_cols=6)" 472 | ] 473 | } 474 | ], 475 | "metadata": { 476 | "celltoolbar": "Tags", 477 | "kernelspec": { 478 | "display_name": "Python 3", 479 | "language": "python", 480 | "name": "python3" 481 | }, 482 | "language_info": { 483 | "codemirror_mode": { 484 | "name": "ipython", 485 | "version": 3 486 | }, 487 | "file_extension": ".py", 488 | "mimetype": "text/x-python", 489 | "name": "python", 490 | "nbconvert_exporter": "python", 491 | "pygments_lexer": "ipython3", 492 | "version": "3.6.10" 493 | } 494 | }, 495 | "nbformat": 4, 496 | "nbformat_minor": 4 497 | } 498 | -------------------------------------------------------------------------------- /dataset/dataset_generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Deep Fashion 2 dataset generation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os\n", 17 | "import json\n", 18 | "import itertools\n", 19 | "\n", 20 | "import skimage.io\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "import pandas as pd\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "\n", 26 | "from shutil import copyfile\n", 27 | "\n", 28 | "from sklearn.model_selection import train_test_split\n", 29 | "\n", 30 | "from data_loader import load_training_df\n", 31 | "from masks import get_mask" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "tags": [ 39 | "parameters" 40 | ] 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "dataset_path = 'dataset'\n", 45 | "output_dir = 'deep_fashion_segmentation_dataset'\n", 46 | "toy_output_dir = 'deep_fashion_segmentation_dataset_toy'" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "# Generate dataset" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## Read data\n", 61 | "\n", 62 | "Read images and annotations in training." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "df = load_training_df(dataset_path)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "Show samples." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "df.sample(3)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "Drop non-relevant columns. Note that we decide to use **both shop and user images**." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# Drop landmarks as we already have segmentation masks\n", 104 | "df = df.drop(columns=['scale', 'viewpoint', 'zoom_in', 'landmarks', 'bounding_box', 'occlusion', 'style', 'pair_id', 'source'])" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "## Redefine categories\n", 112 | "\n", 113 | "We saw that some categories are too much fine grained and it would be much useful for us (and easier for a model) to deal with a coarser categorization. Let's map each class to a super-category:\n", 114 | "\n", 115 | "- Top: Contains `short sleeve top`, `long sleeve top`, `sling` and `vest`.\n", 116 | "- Outwear: Contains `short sleeve outwear` and `long sleeve outwear`.\n", 117 | "- Skirt (`skirt`).\n", 118 | "- Trousers (`trousers`).\n", 119 | "- Shorts (`shorts`).\n", 120 | "- Dress. Contains `vest dress`, `short sleeve dress` and`sling dress`." 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "Map categories to supercategory." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "supercategory_map = {\n", 137 | " 'short sleeve top': 'top',\n", 138 | " 'long sleeve dress': 'dress',\n", 139 | " 'trousers': 'trousers',\n", 140 | " 'long sleeve top': 'top',\n", 141 | " 'skirt': 'skirt',\n", 142 | " 'shorts': 'shorts',\n", 143 | " 'long sleeve outwear': 'outwear',\n", 144 | " 'vest dress': 'dress',\n", 145 | " 'short sleeve dress': 'dress',\n", 146 | " 'vest': 'top',\n", 147 | " 'sling dress': 'dress',\n", 148 | " 'short sleeve outwear': 'outwear',\n", 149 | " 'sling': 'top'\n", 150 | "}\n", 151 | "\n", 152 | "# Map category to supercategory\n", 153 | "df.loc[:, 'supercategory_name'] = df['category_name'].map(supercategory_map)\n", 154 | "\n", 155 | "# Create id for supercategories\n", 156 | "# Make sure ids start at 1\n", 157 | "supercategory_id_map = {\n", 158 | " name: i + 1\n", 159 | " for i, name in enumerate(set(supercategory_map.values()))\n", 160 | "}\n", 161 | "df['supercategory_id'] = df['supercategory_name'].map(supercategory_id_map)\n", 162 | "\n", 163 | "# Ensure no extra categories exist\n", 164 | "assert set(supercategory_map.values()) == set(df['supercategory_name'].unique())" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": {}, 170 | "source": [ 171 | "Visualize now the distribution of clothing supercategories." 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "df['supercategory_name'].value_counts().plot.bar(figsize=(12, 4))" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "## Group clothes per image" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "def _join_all_rows(df: pd.DataFrame) -> pd.Series:\n", 197 | " result = {col: df[col].values for col in df.columns.values}\n", 198 | " # Make sure we keep the id\n", 199 | " result.update({'id': df.name})\n", 200 | " return result\n", 201 | "\n", 202 | "images_series = df.groupby('id').apply(_join_all_rows)\n", 203 | "images_df = pd.DataFrame.from_records(images_series.values)\n", 204 | "print(f'Using {len(images_df)} images')" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "## Store dataset" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "os.makedirs(output_dir, exist_ok=True)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "Prepare output directories." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "image_dir = 'images'\n", 237 | "mask_dir = 'masks'\n", 238 | "\n", 239 | "os.makedirs(os.path.join(output_dir, 'images'), exist_ok=True)\n", 240 | "os.makedirs(os.path.join(output_dir, 'masks'), exist_ok=True)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "Generate path to images and masks in the Dataframe." 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "def build_mask_png_path(files: str) -> str:\n", 257 | " basename = os.path.basename(files[0])\n", 258 | " filename = os.path.splitext(basename)[0]\n", 259 | " return os.path.join(mask_dir, filename + '.png')\n", 260 | "\n", 261 | "\n", 262 | "images_df['source_path'] = images_df['image_path'].apply(\n", 263 | " lambda files: os.path.join(image_dir, os.path.basename(files[0]))\n", 264 | ")\n", 265 | "\n", 266 | "images_df.loc[:, 'mask_path'] = images_df['image_path'].apply(build_mask_png_path)" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "Store picture and mask images for each row." 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "n_labels = len(supercategory_id_map)\n", 283 | "\n", 284 | "def get_mask_image(row: pd.Series) -> np.ndarray:\n", 285 | " \n", 286 | " n_objects = len(row['segmentation'])\n", 287 | " image = skimage.io.imread(row['image_path'][0])\n", 288 | " image_height, image_width = image.shape[:2]\n", 289 | " \n", 290 | " # Build mask matrix\n", 291 | " mask = np.zeros((image_height, image_width), dtype=np.uint8)\n", 292 | " \n", 293 | " # IMPORTANT: order of clothes matter in the resulting mask as\n", 294 | " # masks do overlap\n", 295 | " for i in range(n_objects):\n", 296 | " submask = get_mask(image_height,\n", 297 | " image_width,\n", 298 | " row['segmentation'][i],\n", 299 | " category_id=int(row['supercategory_id'][i]))\n", 300 | " mask = np.where(submask != 0, submask, mask)\n", 301 | "\n", 302 | " return mask\n", 303 | "\n", 304 | "\n", 305 | "def store_example(row: pd.Series, folder: str) -> bool:\n", 306 | " \n", 307 | " try:\n", 308 | " # Copy image\n", 309 | " image_dst_path = os.path.join(folder, row['source_path'])\n", 310 | " copyfile(row['image_path'][0], image_dst_path)\n", 311 | " # Generate and store mask\n", 312 | " mask_dst_path = os.path.join(folder, row['mask_path'])\n", 313 | " skimage.io.imsave(mask_dst_path, get_mask_image(row))\n", 314 | " return True\n", 315 | " except Exception as e:\n", 316 | " print(f'Could not store example: {str(e)}')\n", 317 | " return False\n", 318 | "\n", 319 | "# # Ignore warnings generated when storing mask images\n", 320 | "import warnings\n", 321 | "warnings.filterwarnings('ignore')\n", 322 | "images_df['created_ok'] = images_df.apply(\n", 323 | " lambda x: store_example(x, output_dir), axis=1\n", 324 | ")" 325 | ] 326 | }, 327 | { 328 | "cell_type": "markdown", 329 | "metadata": {}, 330 | "source": [ 331 | "## Generate clean CSV\n", 332 | "\n", 333 | "Now let's only keep the path to the image and the mask as well as the mapping between supercategories and their ids." 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": {}, 339 | "source": [ 340 | "Let's remove the rows that raised errors." 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "dataset_df = images_df[images_df['created_ok']==True]\n", 350 | "print(f'{len(images_df) - len(dataset_df)} images could not be stored')" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "Let's keep only columns of interest." 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "dataset_df = dataset_df[['mask_path', 'source_path', 'supercategory_id']]" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": {}, 372 | "source": [ 373 | "Rename columns accordingly." 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [ 382 | "dataset_df = dataset_df.rename(columns={'source_path': 'image_path',\n", 383 | " 'supercategory_id': 'labels'})" 384 | ] 385 | }, 386 | { 387 | "cell_type": "markdown", 388 | "metadata": {}, 389 | "source": [ 390 | "Perform training-validation-test split." 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "def _row_to_set(row: pd.Series) -> str:\n", 400 | " if row.name in train_idx:\n", 401 | " return 'train'\n", 402 | " elif row.name in val_idx:\n", 403 | " return 'validation'\n", 404 | " elif row.name in test_idx:\n", 405 | " return 'test'\n", 406 | " else:\n", 407 | " raise ValueError(f'Unknown row index: {row.name}')\n", 408 | "\n", 409 | "train_idx, test_idx = train_test_split(dataset_df.index, test_size=0.025, random_state=500)\n", 410 | "train_idx, val_idx = train_test_split(train_idx, test_size=0.0125, random_state=300)\n", 411 | "\n", 412 | "dataset_df['set'] = dataset_df.apply(_row_to_set, axis=1)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "metadata": {}, 418 | "source": [ 419 | "Visualize instances per set." 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "dataset_df['set'].value_counts()" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "metadata": {}, 434 | "source": [ 435 | "Store DataFrame as json." 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "dataset_df.to_json(os.path.join(output_dir, 'data.json'), orient='records')" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "## Store supercategory id mapping" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": null, 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "inverse_supercategory_id_map = {\n", 461 | " category_id: category_name\n", 462 | " for category_name, category_id in supercategory_id_map.items()\n", 463 | "}\n", 464 | "\n", 465 | "labels_path = os.path.join(output_dir, 'labels.json')\n", 466 | "with open(labels_path, 'w') as file:\n", 467 | " json.dump(inverse_supercategory_id_map, file)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "metadata": {}, 473 | "source": [ 474 | "## Copy README into dataset" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "copyfile('dataset_README.md', os.path.join(output_dir, 'README.md'));" 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "## See an example." 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [ 499 | "n_rows = 5\n", 500 | "fig, axs = plt.subplots(n_rows, 2, figsize=(12, n_rows*4.0))\n", 501 | "\n", 502 | "for i in range(n_rows):\n", 503 | "\n", 504 | " sample = dataset_df.sample(1).iloc[0]\n", 505 | "\n", 506 | " # Parse image and mask\n", 507 | " image = skimage.io.imread(os.path.join(output_dir, sample['image_path']))\n", 508 | " mask = skimage.io.imread(os.path.join(output_dir, sample['mask_path']))\n", 509 | " labels = sample['labels']\n", 510 | " # Parse labels\n", 511 | " str_labels = [inverse_supercategory_id_map[category_id] for category_id in labels]\n", 512 | "\n", 513 | " axs[i][0].imshow(image)\n", 514 | " axs[i][1].set_title(f'Categories: {str_labels}')\n", 515 | " axs[i][1].imshow(mask)" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "## Extract some stats" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": null, 528 | "metadata": {}, 529 | "outputs": [], 530 | "source": [ 531 | "def _category_count(df: pd.DataFrame) -> dict:\n", 532 | " label_ids = df['labels'].apply(lambda x: x.tolist()).values\n", 533 | " label_ids = list(itertools.chain.from_iterable(label_ids))\n", 534 | " label_ids_count = np.unique(label_ids, return_counts=True)\n", 535 | " return {\n", 536 | " inverse_supercategory_id_map[label_id]: label_count\n", 537 | " for label_id, label_count in zip(*label_ids_count)\n", 538 | " }\n", 539 | "\n", 540 | "print(f'All instances category count: {_category_count(dataset_df)}')\n", 541 | "print(f'Training set category count: {_category_count(dataset_df[dataset_df[\"set\"] == \"train\"])}')\n", 542 | "print(f'Validation set category count: {_category_count(dataset_df[dataset_df[\"set\"] == \"validation\"])}')\n", 543 | "print(f'Test set category count: {_category_count(dataset_df[dataset_df[\"set\"] == \"test\"])}')" 544 | ] 545 | }, 546 | { 547 | "cell_type": "markdown", 548 | "metadata": {}, 549 | "source": [ 550 | "# Generate toy dataset\n", 551 | "\n", 552 | "Build toy dataset to check we can easily overfit." 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [ 561 | "os.makedirs(toy_output_dir, exist_ok=True)\n", 562 | "os.makedirs(os.path.join(toy_output_dir, 'images'))\n", 563 | "os.makedirs(os.path.join(toy_output_dir, 'masks'))" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "metadata": {}, 569 | "source": [ 570 | "Use only a subset of the images for the toy dataset." 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "metadata": {}, 577 | "outputs": [], 578 | "source": [ 579 | "toy_images_df = dataset_df[dataset_df['set'] == 'train'].sample(10)\n", 580 | "toy_images_df.apply(\n", 581 | " lambda x: store_example(x, toy_output_dir), axis=1\n", 582 | ");" 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "metadata": {}, 588 | "source": [ 589 | "Store toy DataFrame." 590 | ] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "toy_images_df.to_json(os.path.join(toy_output_dir, 'data.json'), orient='records')" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "Copy labels as well." 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": null, 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [ 614 | "copyfile(labels_path, os.path.join(toy_output_dir, 'labels.json'));" 615 | ] 616 | } 617 | ], 618 | "metadata": { 619 | "celltoolbar": "Tags", 620 | "kernelspec": { 621 | "display_name": "Python 3", 622 | "language": "python", 623 | "name": "python3" 624 | }, 625 | "language_info": { 626 | "codemirror_mode": { 627 | "name": "ipython", 628 | "version": 3 629 | }, 630 | "file_extension": ".py", 631 | "mimetype": "text/x-python", 632 | "name": "python", 633 | "nbconvert_exporter": "python", 634 | "pygments_lexer": "ipython3", 635 | "version": "3.6.10" 636 | } 637 | }, 638 | "nbformat": 4, 639 | "nbformat_minor": 4 640 | } 641 | -------------------------------------------------------------------------------- /dataset/environment.yml: -------------------------------------------------------------------------------- 1 | name: segmentation 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _pytorch_select=0.2=gpu_0 8 | - ansiwrap=0.8.4=py_0 9 | - appdirs=1.4.3=py_1 10 | - async_generator=1.10=py_0 11 | - attrs=19.3.0=py_0 12 | - backcall=0.1.0=py_0 13 | - black=19.10b0=py36_0 14 | - blas=1.0=mkl 15 | - bleach=3.1.4=pyh9f0ad1d_0 16 | - brotlipy=0.7.0=py36h8c4c3a4_1000 17 | - ca-certificates=2020.4.5.1=hecc5488_0 18 | - certifi=2020.4.5.1=py36h9f0ad1d_0 19 | - cffi=1.14.0=py36hd463f26_0 20 | - chardet=3.0.4=py36h9f0ad1d_1006 21 | - click=7.1.1=pyh8c360ce_0 22 | - cloudpickle=1.3.0=py_0 23 | - cryptography=2.8=py36h45558ae_2 24 | - cudatoolkit=10.0.130=0 25 | - cudnn=7.6.5=cuda10.0_0 26 | - cycler=0.10.0=py_2 27 | - cytoolz=0.10.1=py36h516909a_0 28 | - dask-core=2.14.0=py_0 29 | - dataclasses=0.7=py36_0 30 | - decorator=4.4.2=py_0 31 | - defusedxml=0.6.0=py_0 32 | - entrypoints=0.3=py36h9f0ad1d_1001 33 | - freetype=2.10.1=he06d7ca_0 34 | - icu=58.2=hf484d3e_1000 35 | - idna=2.9=py_1 36 | - imageio=2.8.0=py_0 37 | - importlib-metadata=1.6.0=py36h9f0ad1d_0 38 | - importlib_metadata=1.6.0=0 39 | - intel-openmp=2020.0=166 40 | - ipykernel=5.2.0=py36h95af2a2_1 41 | - ipython=7.13.0=py36h9f0ad1d_2 42 | - ipython_genutils=0.2.0=py_1 43 | - jedi=0.17.0=py36h9f0ad1d_0 44 | - jinja2=2.11.2=pyh9f0ad1d_0 45 | - joblib=0.14.1=py_0 46 | - jpeg=9c=h14c3975_1001 47 | - json5=0.9.0=py_0 48 | - jsonschema=3.2.0=py36h9f0ad1d_1 49 | - jupyter_client=6.1.3=py_0 50 | - jupyter_core=4.6.3=py36h9f0ad1d_1 51 | - jupyterlab=2.1.0=py_0 52 | - jupyterlab_server=1.1.1=py_0 53 | - kiwisolver=1.2.0=py36hdb11119_0 54 | - ld_impl_linux-64=2.33.1=h53a641e_7 55 | - libblas=3.8.0=15_mkl 56 | - libcblas=3.8.0=15_mkl 57 | - libedit=3.1.20181209=hc058e9b_0 58 | - libffi=3.2.1=hd88cf55_4 59 | - libgcc-ng=9.1.0=hdf63c60_0 60 | - libgfortran-ng=7.3.0=hdf63c60_5 61 | - libpng=1.6.37=hed695b0_1 62 | - libsodium=1.0.17=h516909a_0 63 | - libstdcxx-ng=9.1.0=hdf63c60_0 64 | - libtiff=4.1.0=hc7e4089_6 65 | - libwebp-base=1.1.0=h516909a_3 66 | - lz4-c=1.9.2=he1b5a44_0 67 | - markupsafe=1.1.1=py36h8c4c3a4_1 68 | - matplotlib-base=3.1.3=py36hef1b27d_0 69 | - mistune=0.8.4=py36h8c4c3a4_1001 70 | - mkl=2020.0=166 71 | - mkl-service=2.3.0=py36he904b0f_0 72 | - mkl_fft=1.0.15=py36ha843d7b_0 73 | - mkl_random=1.1.1=py36h830a2c2_0 74 | - mypy_extensions=0.4.3=py36h9f0ad1d_1 75 | - nbclient=0.2.0=py_0 76 | - nbconvert=5.6.1=py36h9f0ad1d_1 77 | - nbformat=5.0.6=py_0 78 | - ncurses=6.2=he6710b0_0 79 | - nest-asyncio=1.3.2=py_0 80 | - networkx=2.4=py_1 81 | - ninja=1.10.0=hc9558a2_0 82 | - notebook=6.0.3=py36_0 83 | - numpy=1.18.1=py36h4f9e942_0 84 | - numpy-base=1.18.1=py36hde5b4d6_1 85 | - olefile=0.46=py_0 86 | - openssl=1.1.1g=h516909a_0 87 | - pandas=1.0.3=py36h830a2c2_1 88 | - pandoc=2.9.2.1=0 89 | - pandocfilters=1.4.2=py_1 90 | - papermill=2.1.0=py36h9f0ad1d_0 91 | - parso=0.7.0=pyh9f0ad1d_0 92 | - pathspec=0.8.0=pyh9f0ad1d_0 93 | - pexpect=4.8.0=py36h9f0ad1d_1 94 | - pickleshare=0.7.5=py36h9f0ad1d_1001 95 | - pillow=7.0.0=py36hb39fc2d_0 96 | - pip=20.0.2=py36_1 97 | - prometheus_client=0.7.1=py_0 98 | - prompt-toolkit=3.0.5=py_0 99 | - ptyprocess=0.6.0=py_1001 100 | - pycparser=2.20=py_0 101 | - pygments=2.6.1=py_0 102 | - pyopenssl=19.1.0=py_1 103 | - pyparsing=2.4.7=pyh9f0ad1d_0 104 | - pyrsistent=0.16.0=py36h8c4c3a4_0 105 | - pysocks=1.7.1=py36h9f0ad1d_1 106 | - python=3.6.10=hcf32534_1 107 | - python-dateutil=2.8.1=py_0 108 | - python_abi=3.6=1_cp36m 109 | - pytorch=1.3.1=cuda100py36h53c1284_0 110 | - pytz=2019.3=py_0 111 | - pywavelets=1.1.1=py36h785e9b2_1 112 | - pyyaml=5.3.1=py36h8c4c3a4_0 113 | - pyzmq=19.0.0=py36h9947dbf_1 114 | - readline=8.0=h7b6447c_0 115 | - regex=2020.4.4=py36h8c4c3a4_0 116 | - requests=2.23.0=pyh8c360ce_2 117 | - scikit-image=0.16.2=py36hb3f55d8_0 118 | - scikit-learn=0.22.2.post1=py36hcdab131_0 119 | - scipy=1.4.1=py36h0b6359f_0 120 | - send2trash=1.5.0=py_0 121 | - setuptools=46.1.3=py36_0 122 | - six=1.14.0=py_1 123 | - sqlite=3.31.1=h7b6447c_0 124 | - tenacity=6.1.0=py36h9f0ad1d_1 125 | - terminado=0.8.3=py36h9f0ad1d_1 126 | - testpath=0.4.4=py_0 127 | - textwrap3=0.9.2=py_0 128 | - tk=8.6.8=hbc83047_0 129 | - toml=0.10.0=py_0 130 | - toolz=0.10.0=py_0 131 | - torchvision=0.4.2=cuda100py36hecfc37a_0 132 | - tornado=6.0.4=py36h8c4c3a4_1 133 | - tqdm=4.45.0=pyh9f0ad1d_0 134 | - traitlets=4.3.3=py36h9f0ad1d_1 135 | - typed-ast=1.4.1=py36h516909a_0 136 | - typing_extensions=3.7.4.1=py36h9f0ad1d_3 137 | - urllib3=1.25.9=py_0 138 | - wcwidth=0.1.9=pyh9f0ad1d_0 139 | - webencodings=0.5.1=py_1 140 | - wheel=0.34.2=py36_0 141 | - xz=5.2.5=h7b6447c_0 142 | - yaml=0.2.4=h516909a_0 143 | - zeromq=4.3.2=he1b5a44_2 144 | - zipp=3.1.0=py_0 145 | - zlib=1.2.11=h7b6447c_3 146 | - zstd=1.4.4=h6597ccf_3 147 | 148 | -------------------------------------------------------------------------------- /dataset/masks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from typing import List 5 | 6 | import skimage.io 7 | 8 | from PIL import ( 9 | Image, 10 | ImageDraw 11 | ) 12 | 13 | 14 | def get_mask(height: int, 15 | width: int, 16 | polygons: List, 17 | category_id: int) -> np.ndarray: 18 | default_value = 0 19 | # See https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes 20 | mask = Image.new(mode='L', size=(width, height), color=default_value) 21 | # Fill each of the input polygons 22 | for polygon in polygons: 23 | ImageDraw.Draw(mask).polygon(polygon, 24 | outline=category_id, 25 | fill=category_id) 26 | return np.array(mask) 27 | 28 | 29 | -------------------------------------------------------------------------------- /dataset/visualization.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import skimage.io 5 | 6 | from typing import List 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def display_instance(row: pd.Series, 12 | ax, 13 | title_column: str = 'category_name', 14 | fontsize: int = 14) -> None: 15 | ax.imshow(skimage.io.imread(row['image_path']), aspect='auto') 16 | ax.set_title(row[title_column], fontsize=fontsize) 17 | ax.axis('off') 18 | 19 | 20 | def display_instances(df: pd.DataFrame, 21 | display_fn=display_instance, 22 | n_cols=7, 23 | **display_args): 24 | n = len(df) 25 | n_rows = int(np.ceil(n/n_cols)) 26 | # Estimate figsize 27 | figsize = (3*n_cols, 3.75*n_rows) 28 | fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize) 29 | for i in range(n): 30 | row = df.iloc[i] 31 | 32 | if n_rows == 1: 33 | ax = axes[i] 34 | else: 35 | i_idx, j_idx = np.unravel_index(i, (n_rows, n_cols)) 36 | ax = axes[i_idx][j_idx] 37 | 38 | display_fn(row, ax, **display_args) 39 | 40 | fig.tight_layout() 41 | -------------------------------------------------------------------------------- /images/masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielMoraDC/pytorch-unet-segmentation/6b7dda434f994c25abc084044eee4cc0827835ad/images/masks.png -------------------------------------------------------------------------------- /images/scalars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DanielMoraDC/pytorch-unet-segmentation/6b7dda434f994c25abc084044eee4cc0827835ad/images/scalars.png -------------------------------------------------------------------------------- /unet_segmentation/data/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | 4 | from torchvision import transforms 5 | import torchvision.transforms.functional as TF 6 | 7 | 8 | @dataclass 9 | class DataAugmentation(object): 10 | 11 | central_crop_size: int = 512 12 | h_flipping_chance: float = 0.50 13 | brightness_rate: float = 0.10 14 | contrast_rate: float = 0.10 15 | saturation_rate: float = 0.10 16 | hue_rate: float = 0.05 17 | 18 | def augment(self, image, mask): 19 | # Color and illumination changes 20 | image = transforms.ColorJitter(brightness=self.brightness_rate, 21 | contrast=self.contrast_rate, 22 | saturation=self.saturation_rate, 23 | hue=self.hue_rate)(image) 24 | 25 | # Random crop the image 26 | random_crop = PairRandomCrop(image_size=self.central_crop_size) 27 | image, mask = random_crop(image, mask) 28 | 29 | # Random horizontal flipping 30 | if random.random() > self.h_flipping_chance: 31 | image, mask = TF.hflip(image), TF.hflip(mask) 32 | 33 | return image, mask 34 | 35 | 36 | class PairRandomCrop(object): 37 | 38 | def __init__(self, image_size: int): 39 | self._image_size = image_size 40 | 41 | def __call__(self, image, mask): 42 | # Compute paddings for random crop given dimensions 43 | crop_params = transforms.RandomCrop.get_params( 44 | image, 45 | output_size=(self._image_size, self._image_size) 46 | ) 47 | start_y, start_x, new_height, new_width = crop_params 48 | # Apply crop given computed paddings 49 | image = TF.crop(image, start_y, start_x, new_height, new_width) 50 | mask = TF.crop(mask, start_y, start_x, new_height, new_width) 51 | return image, mask 52 | -------------------------------------------------------------------------------- /unet_segmentation/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import skimage.io 5 | import skimage.transform 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | import torchvision.transforms.functional as TF 13 | 14 | from unet_segmentation.data.augmentation import DataAugmentation 15 | 16 | 17 | class SegmentationDataset(Dataset): 18 | 19 | def __init__(self, 20 | path: str, 21 | image_resize: int, 22 | subset: str = 'train', 23 | data_augmentation: DataAugmentation = None): 24 | """ 25 | 26 | Args: 27 | path: Path to the root directory of the dataset. Directory is 28 | expected to contain: 29 | - data.json: dataset descriptor, containing at least 30 | `image_path' (i.e. path to the raw image), 'mask_path' 31 | (i.e. path to mask) and 'set' (i.e. train/validation/test). 32 | - any other required files or directories (i.e. image dirs). 33 | image_resize: Size images will be resized into. 34 | subset: Data to retrieve (i.e. train/validation/test). 35 | data_augmentation: Data augmentation parameters. 36 | """ 37 | 38 | self._image_resize = image_resize 39 | 40 | valid_subsets = ['train', 'validation', 'test'] 41 | if subset not in valid_subsets: 42 | raise ValueError( 43 | f'Subset must be one of {valid_subsets}. Is "{subset}"' 44 | ) 45 | 46 | self._data_augmentation = data_augmentation 47 | 48 | with open(os.path.join(path, 'data.json'), 'r') as file: 49 | data = json.load(file) 50 | 51 | images = [os.path.join(path, row['image_path']) 52 | for row in data if row['set'] == subset] 53 | masks = [os.path.join(path, row['mask_path']) 54 | for row in data if row['set'] == subset] 55 | 56 | # Read only existing image + mask pairs 57 | self._image_mask_pairs = np.array([ 58 | (img, mask) 59 | for (img, mask) in zip(images, masks) 60 | if os.path.isfile(img) and os.path.isfile(mask) 61 | ]) 62 | print(f'Read {len(images)} images, built ' + 63 | f'{len(self._image_mask_pairs)} image pairs') 64 | 65 | def _transform(self, image, mask): 66 | image = TF.to_pil_image(_resize_image(image, self._image_resize)) 67 | mask = TF.to_pil_image(_resize_mask(mask, self._image_resize)) 68 | 69 | if self._data_augmentation is not None: 70 | image, mask = self._data_augmentation.augment(image, mask) 71 | 72 | # Image normalization 73 | image = transforms.ToTensor()(image) 74 | image = transforms.Normalize(mean=(0,), std=(1,))(image) 75 | return image, np.array(mask) 76 | 77 | def __len__(self): 78 | return len(self._image_mask_pairs) 79 | 80 | def __getitem__(self, idx): 81 | 82 | if torch.is_tensor(idx): 83 | idx = idx.tolist() 84 | 85 | image_path, mask_path = self._image_mask_pairs[idx] 86 | image = skimage.io.imread(image_path) 87 | mask = skimage.io.imread(mask_path) 88 | return self._transform(image, mask) 89 | 90 | 91 | def _resize_image(img: np.ndarray, new_size: int) -> np.ndarray: 92 | return skimage.transform.resize(img, 93 | (new_size, new_size), 94 | preserve_range=True, 95 | order=0).astype('uint8') 96 | 97 | 98 | def _resize_mask(mask: np.ndarray, new_size: int) -> np.ndarray: 99 | return skimage.transform.resize(mask, 100 | (new_size, new_size), 101 | order=0, # mode=nearest 102 | anti_aliasing=False, 103 | preserve_range=True).astype('uint8') 104 | 105 | 106 | import random 107 | import matplotlib.pyplot as plt 108 | 109 | if __name__ == '__main__': 110 | 111 | # Display some data 112 | 113 | data = SegmentationDataset('segmentation_dataset', 114 | image_resize=580, 115 | data_augmentation=DataAugmentation()) 116 | 117 | n_rows = 5 118 | idxs = [random.randint(0, len(data)) for _ in range(n_rows)] 119 | fig, axs = plt.subplots(n_rows, 2, figsize=(7.0, n_rows*3.0)) 120 | 121 | for i, idx in enumerate(idxs): 122 | image, mask = data[idx] 123 | axs[i][0].imshow(np.transpose(image, (1, 2, 0))) 124 | axs[i][0].set_title('Sample image #{}'.format(i)) 125 | axs[i][1].imshow(mask) 126 | labels = np.unique(mask) 127 | axs[i][1].set_title('Sample mask #{}. Labels: {}'.format(i, labels)) 128 | 129 | plt.tight_layout() 130 | plt.show() 131 | -------------------------------------------------------------------------------- /unet_segmentation/environment.yml: -------------------------------------------------------------------------------- 1 | name: unet 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _pytorch_select=0.2=gpu_0 8 | - _tflow_select=2.3.0=mkl 9 | - absl-py=0.9.0=py37hc8dfbb8_1 10 | - astor=0.7.1=py_0 11 | - attrs=19.3.0=py_0 12 | - backcall=0.1.0=py_0 13 | - blas=1.0=mkl 14 | - bleach=3.1.4=pyh9f0ad1d_0 15 | - blinker=1.4=py_1 16 | - brotlipy=0.7.0=py37h8f50634_1000 17 | - c-ares=1.15.0=h516909a_1001 18 | - ca-certificates=2020.4.5.1=hecc5488_0 19 | - cachetools=3.1.1=py_0 20 | - certifi=2020.4.5.1=py37hc8dfbb8_0 21 | - cffi=1.14.0=py37hd463f26_0 22 | - chardet=3.0.4=py37hc8dfbb8_1006 23 | - click=7.1.1=pyh8c360ce_0 24 | - cloudpickle=1.3.0=py_0 25 | - cryptography=2.8=py37hb09aad4_2 26 | - cudatoolkit=10.0.130=0 27 | - cudnn=7.6.5=cuda10.0_0 28 | - cycler=0.10.0=py_2 29 | - cytoolz=0.10.1=py37h516909a_0 30 | - dask-core=2.15.0=py_0 31 | - decorator=4.4.2=py_0 32 | - defusedxml=0.6.0=py_0 33 | - entrypoints=0.3=py37hc8dfbb8_1001 34 | - freetype=2.10.1=he06d7ca_0 35 | - gast=0.2.2=py_0 36 | - google-auth=1.14.1=pyh9f0ad1d_0 37 | - google-auth-oauthlib=0.4.1=py_2 38 | - google-pasta=0.2.0=pyh8c360ce_0 39 | - grpcio=1.27.2=py37hf8bcb03_0 40 | - h5py=2.10.0=nompi_py37h513d04c_102 41 | - hdf5=1.10.5=nompi_h3c11f04_1104 42 | - icu=58.2=hf484d3e_1000 43 | - idna=2.9=py_1 44 | - imageio=2.8.0=py_0 45 | - importlib-metadata=1.6.0=py37hc8dfbb8_0 46 | - importlib_metadata=1.6.0=0 47 | - intel-openmp=2020.0=166 48 | - ipykernel=5.2.1=py37h43977f1_0 49 | - ipython=7.13.0=py37hc8dfbb8_2 50 | - ipython_genutils=0.2.0=py_1 51 | - jedi=0.17.0=py37hc8dfbb8_0 52 | - jinja2=2.11.2=pyh9f0ad1d_0 53 | - jpeg=9c=h14c3975_1001 54 | - json5=0.9.0=py_0 55 | - jsonschema=3.2.0=py37hc8dfbb8_1 56 | - jupyter_client=6.1.3=py_0 57 | - jupyter_core=4.6.3=py37hc8dfbb8_1 58 | - jupyterlab=2.1.1=py_0 59 | - jupyterlab_server=1.1.1=py_0 60 | - keras-applications=1.0.8=py_1 61 | - keras-preprocessing=1.1.0=py_0 62 | - kiwisolver=1.2.0=py37h99015e2_0 63 | - ld_impl_linux-64=2.33.1=h53a641e_7 64 | - libedit=3.1.20181209=hc058e9b_0 65 | - libffi=3.2.1=hd88cf55_4 66 | - libgcc-ng=9.1.0=hdf63c60_0 67 | - libgfortran-ng=7.3.0=hdf63c60_5 68 | - libpng=1.6.37=hed695b0_1 69 | - libprotobuf=3.11.4=h8b12597_0 70 | - libsodium=1.0.17=h516909a_0 71 | - libstdcxx-ng=9.1.0=hdf63c60_0 72 | - libtiff=4.1.0=hc7e4089_6 73 | - libwebp-base=1.1.0=h516909a_3 74 | - lz4-c=1.9.2=he1b5a44_0 75 | - markdown=3.2.1=py_0 76 | - markupsafe=1.1.1=py37h8f50634_1 77 | - matplotlib-base=3.1.3=py37hef1b27d_0 78 | - mistune=0.8.4=py37h8f50634_1001 79 | - mkl=2020.0=166 80 | - mkl-service=2.3.0=py37he904b0f_0 81 | - mkl_fft=1.0.15=py37ha843d7b_0 82 | - mkl_random=1.1.1=py37h0da4684_0 83 | - nbconvert=5.6.1=py37hc8dfbb8_1 84 | - nbformat=5.0.6=py_0 85 | - ncurses=6.2=he6710b0_0 86 | - networkx=2.4=py_1 87 | - ninja=1.10.0=hc9558a2_0 88 | - notebook=6.0.3=py37_0 89 | - numpy=1.18.1=py37h4f9e942_0 90 | - numpy-base=1.18.1=py37hde5b4d6_1 91 | - oauthlib=3.0.1=py_0 92 | - olefile=0.46=py_0 93 | - openssl=1.1.1g=h516909a_0 94 | - opt_einsum=3.2.1=py_0 95 | - pandas=1.0.3=py37h0da4684_1 96 | - pandoc=2.9.2.1=0 97 | - pandocfilters=1.4.2=py_1 98 | - parso=0.7.0=pyh9f0ad1d_0 99 | - pexpect=4.8.0=py37hc8dfbb8_1 100 | - pickleshare=0.7.5=py37hc8dfbb8_1001 101 | - pillow=6.1.0=py37h34e0f95_0 102 | - pip=20.0.2=py37_1 103 | - prometheus_client=0.7.1=py_0 104 | - prompt-toolkit=3.0.5=py_0 105 | - protobuf=3.11.4=py37h3340039_1 106 | - ptyprocess=0.6.0=py_1001 107 | - pyasn1=0.4.8=py_0 108 | - pyasn1-modules=0.2.7=py_0 109 | - pycparser=2.20=py_0 110 | - pygments=2.6.1=py_0 111 | - pyjwt=1.7.1=py_0 112 | - pyopenssl=19.1.0=py_1 113 | - pyparsing=2.4.7=pyh9f0ad1d_0 114 | - pyrsistent=0.16.0=py37h8f50634_0 115 | - pysocks=1.7.1=py37hc8dfbb8_1 116 | - python=3.7.7=hcf32534_0_cpython 117 | - python-dateutil=2.8.1=py_0 118 | - python_abi=3.7=1_cp37m 119 | - pytorch=1.3.1=cuda100py37h53c1284_0 120 | - pytz=2019.3=py_0 121 | - pywavelets=1.1.1=py37h03ebfcd_1 122 | - pyzmq=19.0.0=py37hac76be4_1 123 | - readline=8.0=h7b6447c_0 124 | - requests=2.23.0=pyh8c360ce_2 125 | - requests-oauthlib=1.2.0=py_0 126 | - rsa=4.0=py_0 127 | - scikit-image=0.16.2=py37hb3f55d8_0 128 | - scipy=1.4.1=py37h0b6359f_0 129 | - send2trash=1.5.0=py_0 130 | - setuptools=46.1.3=py37_0 131 | - six=1.14.0=py_1 132 | - sqlite=3.31.1=h62c20be_1 133 | - tensorboard=2.1.1=py_1 134 | - tensorflow=2.1.0=mkl_py37h80a91df_0 135 | - tensorflow-base=2.1.0=mkl_py37h6d63fb7_0 136 | - tensorflow-estimator=2.1.0=pyhd54b08b_0 137 | - termcolor=1.1.0=py_2 138 | - terminado=0.8.3=py37hc8dfbb8_1 139 | - testpath=0.4.4=py_0 140 | - tk=8.6.8=hbc83047_0 141 | - toolz=0.10.0=py_0 142 | - torchvision=0.4.2=cuda100py37hecfc37a_0 143 | - tornado=6.0.4=py37h8f50634_1 144 | - tqdm=4.45.0=pyh9f0ad1d_0 145 | - traitlets=4.3.3=py37hc8dfbb8_1 146 | - urllib3=1.25.9=py_0 147 | - wcwidth=0.1.9=pyh9f0ad1d_0 148 | - webencodings=0.5.1=py_1 149 | - werkzeug=1.0.1=pyh9f0ad1d_0 150 | - wheel=0.34.2=py37_0 151 | - wrapt=1.12.1=py37h8f50634_1 152 | - xz=5.2.5=h7b6447c_0 153 | - zeromq=4.3.2=he1b5a44_2 154 | - zipp=3.1.0=py_0 155 | - zlib=1.2.11=h7b6447c_3 156 | - zstd=1.4.4=h6597ccf_3 157 | -------------------------------------------------------------------------------- /unet_segmentation/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def mean_iou(y_true: torch.Tensor, y_pred: torch.Tensor) -> float: 6 | """ Computes the mean of the Intersection over union of all classes """ 7 | return np.mean(iou(y_true, y_pred)) 8 | 9 | 10 | def iou(y_true: torch.Tensor, y_pred: torch.Tensor) -> float: 11 | """ Computes the Intersection over all classes of the inputs """ 12 | true_classes = y_true.unique().detach() 13 | predicted_image = torch.argmax(y_pred, dim=1) 14 | 15 | def _class_iou(c): 16 | predicted_class_mask = torch.eq(predicted_image, c).bool() 17 | true_class_mask = torch.eq(y_true, c).bool() 18 | intersection = predicted_class_mask & true_class_mask 19 | union = predicted_class_mask | true_class_mask 20 | return (intersection.float().sum() / union.float().sum()).mean() 21 | 22 | # Iterate over true classes so we never have a zero division 23 | return [_class_iou(c).item() for c in true_classes if c != 0] 24 | -------------------------------------------------------------------------------- /unet_segmentation/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predict model on image samples using CPU 3 | """ 4 | 5 | import torch 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from unet_segmentation.data.dataset import SegmentationDataset 11 | 12 | # --------------- 13 | # Load model 14 | # --------------- 15 | 16 | unet = torch.load('unet.pt').cpu() 17 | 18 | # --------------- 19 | # Dataset 20 | # --------------- 21 | 22 | dataset = SegmentationDataset('seg_dataset', image_resize=256) 23 | 24 | # --------------- 25 | # Show predictions 26 | # --------------- 27 | 28 | n_rows = 2 29 | fig, axs = plt.subplots(n_rows, 3, figsize=(7.0, n_rows * 3.0)) 30 | 31 | for i in range(n_rows): 32 | 33 | image, mask = dataset[i] 34 | img_size = image.shape[-1] 35 | predicted_mask = unet(torch.Tensor(image).unsqueeze_(0)) 36 | predicted_mask = predicted_mask.detach().numpy() 37 | predicted_image = np.argmax(predicted_mask, axis=1)[0].astype('uint8') 38 | 39 | axs[i][0].imshow(np.transpose(image, (1, 2, 0))) 40 | axs[i][0].set_title('Sample image #{}'.format(i)) 41 | axs[i][1].imshow(mask) 42 | axs[i][1].set_title('Sample mask #{}'.format(i)) 43 | axs[i][2].imshow(predicted_image) 44 | axs[i][2].set_title('Predicted mask #{}'.format(i)) 45 | 46 | plt.tight_layout() 47 | plt.show() 48 | -------------------------------------------------------------------------------- /unet_segmentation/prediction/display.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | import torch 4 | from torchvision import transforms 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | from unet_segmentation.unet import Unet 10 | from unet_segmentation.prediction.post_processing import ( 11 | prediction_to_classes, 12 | mask_from_prediction, 13 | remove_predictions 14 | ) 15 | 16 | DEEP_FASHION2_CUSTOM_CLASS_MAP = { 17 | 1: "trousers", 18 | 2: "skirt", 19 | 3: "top", 20 | 4: "dress", 21 | 5: "outwear", 22 | 6: "shorts" 23 | } 24 | 25 | 26 | def display_prediction( 27 | unet: Unet, 28 | image_path: str, 29 | image_resize: int = 512, 30 | label_map: dict = DEEP_FASHION2_CUSTOM_CLASS_MAP, 31 | device: str = 'cuda', 32 | min_area_rate: float = 0.05) -> None: 33 | 34 | # Load image tensor 35 | img = Image.open(image_path) 36 | img_tensor = _preprocess_image(img, image_size=image_resize).to(device) 37 | 38 | # Predict classes from model 39 | prediction_map = \ 40 | torch.argmax(unet(img_tensor), dim=1).squeeze(0).cpu().numpy() 41 | 42 | # Remove spurious classes 43 | classes = prediction_to_classes(prediction_map, label_map) 44 | predicted_classes = list( 45 | filter(lambda x: x.area_ratio >= min_area_rate, classes)) 46 | spurious_classes = list( 47 | filter(lambda x: x.area_ratio < min_area_rate, classes)) 48 | clean_prediction_map = remove_predictions(spurious_classes, prediction_map) 49 | 50 | # Get masks for each of the predictions 51 | masks = [ 52 | mask_from_prediction(predicted_class, clean_prediction_map) 53 | for predicted_class in predicted_classes 54 | ] 55 | 56 | # Display predictions on top of original image 57 | plt.imshow(np.array(img)) 58 | for mask in masks: 59 | plt.imshow(mask.resize(img).binary_mask, cmap='jet', alpha=0.65) 60 | plt.show() 61 | 62 | 63 | def _preprocess_image(image: Image, image_size: int) -> torch.Tensor: 64 | preprocess_pipeline = transforms.Compose([ 65 | transforms.Resize((image_size, image_size)), 66 | transforms.ToTensor(), 67 | transforms.Normalize(mean=(0,), std=(1,)) 68 | ]) 69 | return preprocess_pipeline(image).unsqueeze(0) 70 | -------------------------------------------------------------------------------- /unet_segmentation/prediction/post_processing.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | from PIL import Image 5 | 6 | import torch 7 | import numpy as np 8 | import skimage.transform 9 | import skimage.morphology 10 | from scipy.signal import medfilt 11 | from skimage.measure import regionprops 12 | 13 | 14 | @dataclass 15 | class PredictedClass(object): 16 | area_ratio: float 17 | class_name: str 18 | class_id: int 19 | 20 | 21 | @dataclass 22 | class PredictedMask(PredictedClass): 23 | 24 | binary_mask: np.ndarray 25 | 26 | @staticmethod 27 | def from_prediction(binary_mask: np.ndarray, 28 | prediction: PredictedClass): 29 | return PredictedMask( 30 | binary_mask=binary_mask, 31 | area_ratio=prediction.area_ratio, 32 | class_name=prediction.class_name, 33 | class_id=prediction.class_id 34 | ) 35 | 36 | def resize(self, img: Image): 37 | width, height = img.size 38 | new_binary_mask = skimage.transform.resize( 39 | self.binary_mask, 40 | (height, width), 41 | anti_aliasing=False, 42 | order=0, # mode=nearest 43 | preserve_range=True).astype('uint8') 44 | 45 | return PredictedMask( 46 | binary_mask=new_binary_mask, 47 | area_ratio=self.area_ratio, 48 | class_name=self.class_name, 49 | class_id=self.class_id 50 | ) 51 | 52 | 53 | def prediction_to_classes(prediction_img: np.ndarray, 54 | class_id_to_name: dict) -> List[PredictedClass]: 55 | predicted_classes = [ 56 | class_id 57 | for class_id in np.unique(prediction_img) 58 | if class_id != 0 59 | ] 60 | img_area = prediction_img.shape[0] * prediction_img.shape[1] 61 | 62 | return [ 63 | PredictedClass( 64 | area_ratio=(prediction_img == class_id).sum() / img_area, 65 | class_name=class_id_to_name[class_id], 66 | class_id=class_id 67 | ) 68 | for class_id in predicted_classes 69 | ] 70 | 71 | 72 | def remove_predictions(predictions: List[PredictedClass], 73 | img: np.ndarray) -> np.ndarray: 74 | result = img.copy() 75 | for prediction in predictions: 76 | result[result == prediction.class_id] = 0 77 | return result 78 | 79 | 80 | def mask_from_prediction(prediction: PredictedClass, 81 | img: np.ndarray, 82 | smoothing_kernel_size: int = 7) -> PredictedMask: 83 | # Create binary mask for detection 84 | binary_img = np.zeros((img.shape)) 85 | class_idxs = np.where(img == prediction.class_id) 86 | binary_img[class_idxs] = 1 87 | 88 | # Closing to remove spurious pixels 89 | closed_img = skimage.morphology.closing(binary_img) 90 | # Smooth edges with median filtering 91 | smoothed_img = medfilt(closed_img, kernel_size=smoothing_kernel_size) 92 | return PredictedMask.from_prediction( 93 | binary_mask=smoothed_img, 94 | prediction=prediction 95 | ) 96 | -------------------------------------------------------------------------------- /unet_segmentation/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | from unet_segmentation.data.dataset import SegmentationDataset 6 | from unet_segmentation.data.augmentation import DataAugmentation 7 | 8 | from unet_segmentation.training.training import ( 9 | TrainingParams, 10 | fit 11 | ) 12 | 13 | 14 | # -------------------- 15 | # Training dataset 16 | # -------------------- 17 | 18 | batch_size = 1 19 | data_load_workers = 4 20 | image_crop = 512 21 | image_resize = 575 22 | 23 | data_aug = DataAugmentation( 24 | central_crop_size=image_crop, 25 | h_flipping_chance=0.50, 26 | brightness_rate=0.10, 27 | contrast_rate=0.10, 28 | saturation_rate=0.10, 29 | hue_rate=0.05) 30 | 31 | train_dataset = SegmentationDataset('full_dataset', 32 | image_resize=image_resize, 33 | data_augmentation=data_aug) 34 | 35 | train_loader = DataLoader(train_dataset, 36 | batch_size=batch_size, 37 | shuffle=True, 38 | num_workers=data_load_workers, 39 | pin_memory=True) 40 | 41 | # -------------------- 42 | # Validation dataset 43 | # -------------------- 44 | 45 | val_dataset = SegmentationDataset('full_dataset', 46 | subset='validation', 47 | image_resize=image_crop) 48 | 49 | val_loader = DataLoader(val_dataset, 50 | batch_size=1, 51 | shuffle=True, 52 | num_workers=data_load_workers, 53 | pin_memory=True) 54 | 55 | # -------------------- 56 | # Training 57 | # -------------------- 58 | 59 | 60 | # Hard-code presence for each class 61 | class_rate = np.array( 62 | # High value for background 63 | [0.50] + 64 | # Top, Shorts, Dress, Skirt, Trousers, Outwear 65 | [0.4039, 0.1173, 0.1587, 0.0988, 0.1774, 0.04485] 66 | ) 67 | 68 | # The higher the presence, the lower the weight in the loss 69 | base_weight = 0.25 70 | class_loss_weight = torch.Tensor(class_rate.max() / class_rate * base_weight) 71 | print(f'Using class weights: {class_loss_weight}') 72 | cross_entropy_loss = torch.nn.CrossEntropyLoss(weight=class_loss_weight) 73 | 74 | device = torch.device("cuda:0") if torch.cuda.is_available() else 'cpu' 75 | 76 | params = TrainingParams( 77 | train_loader=train_loader, 78 | val_loader=val_loader, 79 | loss=cross_entropy_loss, 80 | lr=1e-3, 81 | momentum=0.99, 82 | stats_interval=25000, 83 | save_model_interval=25000, 84 | save_model_dir='.', 85 | n_epochs=100, 86 | n_classes=6 + 1, # 6 + background 87 | device=device, 88 | checkpoint=None 89 | # checkpoint='unet_iter_750000.pt' 90 | ) 91 | 92 | fit(params) 93 | -------------------------------------------------------------------------------- /unet_segmentation/training/stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import List 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | @dataclass 11 | class MovingStats(object): 12 | 13 | losses: List 14 | ious: List 15 | 16 | def __init__(self): 17 | self.restart() 18 | 19 | def update(self, loss: float, iou: float) -> None: 20 | self.losses.append(loss) 21 | self.ious.append(iou) 22 | 23 | def moving_loss(self): 24 | return torch.mean(torch.Tensor(self.losses)) 25 | 26 | def moving_iou(self): 27 | return np.mean(self.ious) 28 | 29 | def restart(self): 30 | self.losses = [] 31 | self.ious = [] 32 | 33 | 34 | @dataclass 35 | class Stats(object): 36 | 37 | images: torch.Tensor 38 | masks: torch.Tensor 39 | predictions: torch.Tensor 40 | 41 | loss: torch.Tensor 42 | iou_value: float 43 | 44 | 45 | @dataclass 46 | class SummaryStats(object): 47 | 48 | iteration: int 49 | 50 | images: np.ndarray 51 | masks: np.ndarray 52 | predictions: np.ndarray 53 | 54 | loss_value: float 55 | iou_value: float 56 | 57 | def track(self, writer: SummaryWriter) -> None: 58 | writer.add_images('inputs', self.images, self.iteration) 59 | writer.add_images('groundtruth', self.masks, self.iteration) 60 | writer.add_images('predictions', self.predictions, self.iteration) 61 | writer.add_scalar('loss', self.loss_value, self.iteration) 62 | writer.add_scalar('iou', self.iou_value, self.iteration) 63 | 64 | @staticmethod 65 | def from_stats(stats: Stats, iteration: int): 66 | return SummaryStats( 67 | iteration=iteration, 68 | images=stats.images, 69 | masks=stats.masks, 70 | predictions=torch.unsqueeze( 71 | torch.argmax(stats.predictions, dim=1), dim=1), 72 | loss_value=stats.loss.item(), 73 | iou_value=stats.iou_value, 74 | ) 75 | -------------------------------------------------------------------------------- /unet_segmentation/training/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from unet_segmentation.unet import ( 12 | Unet, 13 | initialize_weights 14 | ) 15 | 16 | from unet_segmentation.metrics import mean_iou 17 | 18 | from unet_segmentation.training.stats import ( 19 | Stats, 20 | SummaryStats, 21 | MovingStats 22 | ) 23 | 24 | 25 | @dataclass 26 | class TrainingParams(object): 27 | 28 | train_loader: DataLoader 29 | val_loader: DataLoader 30 | loss: torch.nn.Module 31 | lr: float 32 | momentum: float 33 | stats_interval: int 34 | save_model_interval: int 35 | save_model_dir: str 36 | n_epochs: int 37 | n_classes: int 38 | device: torch.device 39 | n_channels: int = 3 40 | tensorboard_logs_dir: str = 'logs' 41 | checkpoint: str = None 42 | 43 | 44 | def _validation_eval(model: Unet, 45 | params: TrainingParams, 46 | n_samples: int = 5) -> Stats: 47 | """ Extract stats from a forward pass for whole validation dataset 48 | Args: 49 | model: Model used to extract masks from data 50 | params: Training parameters 51 | n_samples: Number of image samples to extract in order to include them 52 | in the generated report. 53 | Returns: 54 | stats: Validation stats report. 55 | """ 56 | losses, ious = [], [] 57 | image_samples, mask_samples, prediction_samples = [], [], [] 58 | sample_idxs = np.random.randint(0, len(params.val_loader), size=n_samples) 59 | 60 | with torch.no_grad(): 61 | 62 | for i, (imgs, masks) in enumerate(params.val_loader): 63 | 64 | batch_imgs = imgs.to(params.device).float() 65 | batch_masks = masks.to(params.device).long() 66 | 67 | predicted_masks = model(batch_imgs) 68 | loss = params.loss(predicted_masks, batch_masks) 69 | 70 | losses.append(loss) 71 | ious.append(mean_iou(batch_masks, predicted_masks)) 72 | 73 | if i in sample_idxs: 74 | image_samples.append(imgs) 75 | mask_samples.append(masks) 76 | prediction_samples.append(predicted_masks) 77 | 78 | return Stats( 79 | images=torch.cat(image_samples, dim=0), 80 | masks=torch.unsqueeze(torch.cat(mask_samples, dim=0), dim=1), 81 | predictions=torch.cat(prediction_samples, dim=0), 82 | loss=torch.mean(torch.Tensor(losses)), 83 | iou_value=np.mean(ious) 84 | ) 85 | 86 | 87 | def _train_step(image_batch: torch.Tensor, 88 | mask_batch: torch.Tensor, 89 | model: Unet, 90 | optimizer: torch.optim.Optimizer, 91 | params: TrainingParams) -> Stats: 92 | 93 | # Compute loss and propagate backwards 94 | predicted_masks = model(image_batch) 95 | loss = params.loss(predicted_masks, mask_batch) 96 | 97 | # Set cumulative gradients in optimizer to 0 98 | optimizer.zero_grad() 99 | loss.backward() 100 | optimizer.step() 101 | 102 | return Stats( 103 | images=image_batch, 104 | masks=torch.unsqueeze(mask_batch, dim=1), 105 | predictions=predicted_masks, 106 | loss=loss, 107 | iou_value=mean_iou(mask_batch, predicted_masks) 108 | ) 109 | 110 | 111 | def _print_current_step(epoch: int, 112 | batch_idx: int, 113 | train_stats: Stats, 114 | params: TrainingParams) -> None: 115 | n_batches = len(params.train_loader) 116 | print( 117 | f'[Epoch {epoch}/{params.n_epochs}] [Batch {batch_idx}/{n_batches}] ' + 118 | f'Loss: {train_stats.loss.item():.4f}, Iou: {train_stats.iou_value:.2f}' # noqa 119 | ) 120 | 121 | 122 | def _print_stats(epoch: int, 123 | stats: Stats, 124 | params: TrainingParams, 125 | tag: str = 'training') -> None: 126 | print( 127 | f'[Epoch {epoch}/{params.n_epochs}] ' + 128 | f'{tag} loss: {stats.loss.item():.4f}, ' + 129 | f'{tag} Iou: {stats.iou_value:.2f}' 130 | ) 131 | 132 | 133 | def _load_model(params: TrainingParams) -> Tuple[Unet, int]: 134 | checkpoint = params.checkpoint 135 | if checkpoint is not None: 136 | filename = os.path.basename(checkpoint) 137 | iteration_offset = int(os.path.splitext(filename)[0].split('_')[-1]) 138 | print( 139 | f'Loading weights from checkpoint: {checkpoint}. ' 140 | f'Iteration: {iteration_offset}' 141 | ) 142 | unet = torch.load(checkpoint) 143 | else: 144 | print('Initializing model from scratch') 145 | unet = Unet(n_channels=params.n_channels, n_classes=params.n_classes) 146 | unet.apply(initialize_weights) 147 | iteration_offset = 0 148 | 149 | return unet, iteration_offset 150 | 151 | 152 | def _initialize_writers(params: TrainingParams) -> Tuple[SummaryWriter, SummaryWriter]: # noqa 153 | def _init_writer(tag: str) -> SummaryWriter: 154 | return SummaryWriter( 155 | os.path.join(params.tensorboard_logs_dir, tag), 156 | filename_suffix=f'_{tag}' 157 | ) 158 | 159 | return _init_writer('train'), _init_writer('validation') 160 | 161 | 162 | def fit(params: TrainingParams) -> None: 163 | 164 | unet, iteration_offset = _load_model(params) 165 | optimizer = torch.optim.SGD( 166 | unet.parameters(), lr=params.lr, momentum=params.momentum) 167 | 168 | # Move model and loss to devices 169 | params.loss.to(params.device) 170 | unet.to(params.device) 171 | 172 | n_batches = len(params.train_loader) 173 | train_writer, val_writer = _initialize_writers(params) 174 | stats = MovingStats() 175 | 176 | for epoch in range(params.n_epochs): 177 | 178 | for batch_idx, (imgs, masks) in enumerate(params.train_loader): 179 | 180 | # Configure batch data 181 | iteration = epoch * n_batches + batch_idx + iteration_offset 182 | batch_imgs = imgs.to(params.device).float() 183 | batch_masks = masks.to(params.device).long() 184 | 185 | training_stats = _train_step(image_batch=batch_imgs, 186 | mask_batch=batch_masks, 187 | model=unet, 188 | optimizer=optimizer, 189 | params=params) 190 | stats.update(training_stats.loss, training_stats.iou_value) 191 | 192 | if iteration % params.stats_interval == 0 \ 193 | and iteration != iteration_offset: 194 | # Replace step metrics with moving metrics 195 | training_stats.loss = stats.moving_loss() 196 | training_stats.iou_value = stats.moving_iou() 197 | stats.restart() 198 | 199 | # Track training summary 200 | training_summary = SummaryStats.from_stats(training_stats, 201 | iteration) 202 | training_summary.track(train_writer) 203 | 204 | # Track validation summary 205 | validation_stats = _validation_eval(model=unet, params=params) 206 | validation_summary = SummaryStats.from_stats(validation_stats, 207 | iteration) 208 | validation_summary.track(val_writer) 209 | 210 | _print_stats(epoch, training_stats, params) 211 | _print_stats(epoch, validation_stats, params, tag='validation') 212 | 213 | if iteration % params.save_model_interval == 0 and \ 214 | iteration != iteration_offset: 215 | model_path = os.path.join(params.save_model_dir, 216 | f'unet_iter_{iteration}.pt') 217 | torch.save(unet, model_path) 218 | 219 | torch.save(unet, os.path.join(params.save_model_dir, f'unet_final.pt')) 220 | -------------------------------------------------------------------------------- /unet_segmentation/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | 8 | # Implementation of U-Net According to https://arxiv.org/pdf/1505.04597.pdf 9 | 10 | 11 | class Unet(nn.Module): 12 | 13 | def __init__(self, 14 | n_classes: int, 15 | n_channels: int = 3, 16 | base_maps: int = 64): 17 | super(Unet, self).__init__() 18 | 19 | # Compute contracting blocks 20 | # Input size of ith block: (I/2**i) - ((2**i - 1)/2**(i-2)) 21 | self._c_out_1 = DoubleConv2d(n_channels, base_maps) 22 | self._c_out_2 = DoubleConv2d(base_maps, base_maps * 2) 23 | self._c_out_3 = DoubleConv2d(base_maps * 2, base_maps * 4) 24 | self._c_out_4 = DoubleConv2d(base_maps * 4, base_maps * 8) 25 | self._maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2) 26 | 27 | # Junction between contracting and expanding parts 28 | self._bottleneck = DoubleConv2d(base_maps * 8, base_maps * 16) 29 | 30 | # Compute expanding blocks 31 | self._e_input_1 = Upsampling(base_maps * 16, base_maps * 8) 32 | self._e_input_2 = nn.Sequential( 33 | DoubleConv2d(base_maps * 16, base_maps * 8), 34 | Upsampling(base_maps * 8, base_maps * 4) 35 | ) 36 | self._e_input_3 = nn.Sequential( 37 | DoubleConv2d(base_maps * 8, base_maps * 4), 38 | Upsampling(base_maps * 4, base_maps * 2) 39 | ) 40 | self._e_input_4 = nn.Sequential( 41 | DoubleConv2d(base_maps * 4, base_maps * 2), 42 | Upsampling(base_maps * 2, base_maps) 43 | ) 44 | 45 | # Define expanding output layer 46 | self._output = nn.Sequential( 47 | DoubleConv2d(base_maps * 2, base_maps), 48 | nn.Dropout(0.25), 49 | # 1 x 1 convolution to generate output class number of channels 50 | nn.Conv2d(in_channels=base_maps, 51 | out_channels=n_classes, 52 | kernel_size=1, 53 | stride=1, 54 | padding=0), 55 | ) 56 | 57 | def forward(self, images): 58 | # Contracting path forward propagation 59 | c_out_1 = self._c_out_1(images) 60 | c_out_2 = self._c_out_2(self._maxpool(c_out_1)) 61 | c_out_3 = self._c_out_3(self._maxpool(c_out_2)) 62 | c_out_4 = self._c_out_4(self._maxpool(c_out_3)) 63 | 64 | # Bottleneck connects contracting end and expanding start 65 | bottleneck = self._bottleneck(self._maxpool(c_out_4)) 66 | # Build expanding path input by upsampling 67 | e_input_1 = self._e_input_1(bottleneck) 68 | # return e_input_1, c_out_4 69 | e_input_2 = self._e_input_2(torch.cat((c_out_4, e_input_1), dim=1)) 70 | # return e_input_2, c_out_3 71 | e_input_3 = self._e_input_3(torch.cat((c_out_3, e_input_2), dim=1)) 72 | # return e_input_3, c_out_2 73 | e_input_4 = self._e_input_4(torch.cat((c_out_2, e_input_3), dim=1)) 74 | 75 | # # End of expanding path 76 | return self._output(torch.cat((c_out_1, e_input_4), axis=1)) 77 | 78 | 79 | class DoubleConv2d(nn.Module): 80 | 81 | # Contracting block = Conv2 + Relu + Conv2 + Relu 82 | 83 | def __init__(self, in_channels: int, out_channels: int): 84 | super(DoubleConv2d, self).__init__() 85 | self._block = nn.Sequential( 86 | nn.Conv2d(in_channels=in_channels, 87 | out_channels=out_channels, 88 | kernel_size=3, 89 | stride=1, 90 | padding=1), 91 | nn.BatchNorm2d(out_channels), 92 | nn.ReLU(inplace=True), 93 | nn.Conv2d(in_channels=out_channels, 94 | out_channels=out_channels, 95 | kernel_size=3, 96 | stride=1, 97 | padding=1), 98 | nn.BatchNorm2d(out_channels), 99 | nn.ReLU(inplace=True) 100 | ) 101 | 102 | def forward(self, x): 103 | return self._block(x) 104 | 105 | 106 | class Upsampling(nn.Module): 107 | 108 | def __init__(self, 109 | in_channels: int, 110 | out_channels: int): 111 | super(Upsampling, self).__init__() 112 | self._block = nn.Sequential( 113 | nn.ConvTranspose2d(in_channels=in_channels, 114 | out_channels=out_channels, 115 | kernel_size=2, 116 | stride=2, 117 | # dilation=5, 118 | padding=0) 119 | ) 120 | 121 | def forward(self, x): 122 | return self._block(x) 123 | 124 | 125 | def initialize_weights(param): 126 | class_name = param.__class__.__name__ 127 | if class_name.startswith('Conv'): 128 | # Initialization according to original Unet paper 129 | # See https://arxiv.org/pdf/1505.04597.pdf 130 | print(f'Initializing weights for layer {class_name}') 131 | _, in_maps, k, _ = param.weight.shape 132 | n = k*k*in_maps 133 | std = np.sqrt(2/n) 134 | nn.init.normal_(param.weight.data, mean=0.0, std=std) 135 | else: 136 | print(f'No need to initialize weights for {class_name}') 137 | 138 | 139 | if __name__ == '__main__': 140 | device = torch.device("cuda:0") if torch.cuda.is_available() else 'cpu' 141 | model = Unet(n_classes=5).to(device) 142 | model.apply(initialize_weights) 143 | input_batch = torch.randn((1, 3, 512, 512)).to(device) 144 | print(model(input_batch).detach().shape) 145 | --------------------------------------------------------------------------------