├── .gitignore ├── README.md ├── assets ├── aligned_exception.png ├── deeplabv3_plus_diagram.png ├── deeplabv3plus.svg └── human_parsing_results │ ├── test_result_1.png │ ├── test_result_2.png │ ├── test_result_3.png │ ├── test_result_4.png │ ├── train_result_1.png │ ├── train_result_2.png │ ├── train_result_3.png │ ├── train_result_4.png │ ├── training_results.png │ ├── val_result_1.png │ ├── val_result_2.png │ ├── val_result_3.png │ └── val_result_4.png ├── checkpoints └── .gitignore ├── config ├── __init__.py ├── camvid_resnet50.py └── human_parsing_resnet50.py ├── dataset ├── .gitignore ├── camvid.sh ├── cityscapes.sh └── download_human_parsing_dataset.sh ├── deeplabv3plus ├── datasets │ ├── __init__.py │ ├── augmentations.py │ └── dataloader.py ├── inference.py ├── model │ ├── __init__.py │ ├── backbones.py │ ├── blocks.py │ └── deeplabv3_plus.py ├── train.py └── utils.py ├── models.md ├── notebooks └── Demo-Human-Parsing.ipynb ├── requirements.txt └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | env/ 3 | test.py 4 | .vscode/ 5 | src/backbones/__pycache__/ 6 | dump.txt 7 | src/__pycache__/ 8 | .DS_Store 9 | assets/.DS_Store 10 | dataset/CamVid/ 11 | src/datasets/__pycache__/ 12 | dataset/camvid-secret.sh 13 | src/model/__pycache__/ 14 | secret.py 15 | __pycache__/ 16 | src/.ipynb_checkpoints/ 17 | .ipynb_checkpoints/ 18 | logs/ 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepLabV3-Plus (Ongoing) 2 | 3 | [![](https://camo.githubusercontent.com/7ce7d8e78ad8ddab3bea83bb9b98128528bae110/68747470733a2f2f616c65656e34322e6769746875622e696f2f6261646765732f7372632f74656e736f72666c6f772e737667)](https://tensorflow.org/) 4 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/deepwrex/DeepLabV3-Plus/augmentations) 5 | [![HitCount](http://hits.dwyl.com/lattice-ai/DeepLabV3-Plus.svg?style=flat-square)](http://hits.dwyl.com/lattice-ai/DeepLabV3-Plus) 6 | 7 | Tensorflow 2.2.0 implementation of DeepLabV3-Plus architecture as proposed by the paper [Encoder-Decoder with Atrous Separable 8 | Convolution for Semantic Image Segmentation](https://arxiv.org/pdf/1802.02611.pdf). 9 | 10 | ![](./assets/deeplabv3_plus_diagram.png) 11 | 12 | **Project Link:** [https://github.com/deepwrex/DeepLabV3-Plus/projects/](https://github.com/deepwrex/DeepLabV3-Plus/projects/). 13 | 14 | **Experiments:** [https://app.wandb.ai/19soumik-rakshit96/deeplabv3-plus](https://app.wandb.ai/19soumik-rakshit96/deeplabv3-plus). 15 | 16 | Model Architectures can be found [here](./models.md). 17 | 18 | ## Setup Datasets 19 | 20 | - **CamVid** 21 | 22 | ```shell script 23 | cd dataset 24 | bash camvid.sh 25 | ``` 26 | 27 | - **Multi-Person Human Parsing** 28 | 29 | Register on [https://www.kaggle.com/](https://www.kaggle.com/). 30 | 31 | Generate Kaggle API Token 32 | 33 | ```shell script 34 | bash download_human_parsing_dataset.sh 35 | ``` 36 | 37 | 38 | ## Code to test Model 39 | 40 | ```python 41 | from deeplabv3plus.model.deeplabv3_plus import DeeplabV3Plus 42 | 43 | model = DeepLabV3Plus(backbone='resnet50', num_classes=20) 44 | input_shape = (1, 512, 512, 3) 45 | input_tensor = tf.random.normal(input_shape) 46 | result = model(input_tensor) # build model by one forward pass 47 | model.summary() 48 | ``` 49 | 50 | ## Training 51 | 52 | Use the trainer.py script as documented with the help description below: 53 | ``` 54 | usage: trainer.py [-h] [--wandb_api_key WANDB_API_KEY] config_key 55 | 56 | Runs DeeplabV3+ trainer with the given config setting. 57 | 58 | Registered config_key values: 59 | camvid_resnet50 60 | human_parsing_resnet50 61 | 62 | positional arguments: 63 | config_key Key to use while looking up configuration from the CONFIG_MAP dictionary. 64 | 65 | optional arguments: 66 | -h, --help show this help message and exit 67 | --wandb_api_key WANDB_API_KEY 68 | Wandb API Key for logging run on Wandb. 69 | If provided, checkpoint_dir is set to wandb:// 70 | (Model checkpoints are saved to wandb.) 71 | ``` 72 | 73 | ### If you want to use your own custom training configuration, you can define it in the following way: 74 | 75 | - #### Define your configuration in a python dictionary as follows: 76 | `config/camvid_resnet50.py` 77 | 78 | ```python 79 | #!/usr/bin/env python 80 | 81 | """Module for training deeplabv3plus on camvid dataset.""" 82 | 83 | from glob import glob 84 | 85 | import tensorflow as tf 86 | 87 | 88 | # Sample Configuration 89 | CONFIG = { 90 | # We mandate specifying project_name and experiment_name in every config 91 | # file. They are used for wandb runs if wandb api key is specified. 92 | 'project_name': 'deeplabv3-plus', 93 | 'experiment_name': 'camvid-segmentation-resnet-50-backbone', 94 | 95 | 'train_dataset_config': { 96 | 'images': sorted(glob('./dataset/camvid/train/*')), 97 | 'labels': sorted(glob('./dataset/camvid/trainannot/*')), 98 | 'height': 512, 'width': 512, 'batch_size': 8 99 | }, 100 | 101 | 'val_dataset_config': { 102 | 'images': sorted(glob('./dataset/camvid/val/*')), 103 | 'labels': sorted(glob('./dataset/camvid/valannot/*')), 104 | 'height': 512, 'width': 512, 'batch_size': 8 105 | }, 106 | 107 | 'strategy': tf.distribute.OneDeviceStrategy(device="/gpu:0"), 108 | 'num_classes': 20, 'backbone': 'resnet50', 'learning_rate': 0.0001, 109 | 110 | 'checkpoint_dir': "./checkpoints/", 111 | 'checkpoint_file_prefix': "deeplabv3plus_with_resnet50_", 112 | 113 | 'epochs': 100 114 | } 115 | ``` 116 | 117 | - #### Save this file inside the configs directory. (As hinted in the file path above) 118 | - #### Register your config in the `__init.py__` module like below: 119 | `config/__init__.py` 120 | 121 | 122 | ```python 123 | #!/usr/bin/env python 124 | # -*- coding: utf-8 -*- 125 | 126 | """__init__ module for configs. Register your config file here by adding it's 127 | entry in the CONFIG_MAP as shown. 128 | """ 129 | 130 | import config.camvid_resnet50 131 | import config.human_parsing_resnet50 132 | 133 | 134 | CONFIG_MAP = { 135 | 'camvid_resnet50': config.camvid_resnet50.CONFIG, # the config file we defined above 136 | 'human_parsing_resnet50': config.human_parsing_resnet50.CONFIG # another config 137 | } 138 | 139 | ``` 140 | - #### Now you can run the trainer script like so (using the `camvid_resnet50` config key we registered above): 141 | ```bash 142 | ./trainer.py camvid_resnet50 --wandb_api_key 143 | ``` 144 | or, if you don't need wandb logging: 145 | ```bash 146 | ./trainer.py camvid_resnet50 147 | ``` 148 | 149 | ## Inference 150 | 151 | Sample Inference Code: 152 | 153 | ```python 154 | model_file = './dataset/deeplabv3-plus-human-parsing-resnet-50-backbone.h5' 155 | train_images = glob('./dataset/instance-level_human_parsing/Training/Images/*') 156 | val_images = glob('./dataset/instance-level_human_parsing/Validation/Images/*') 157 | test_images = glob('./dataset/instance-level_human_parsing/Testing/Images/*') 158 | 159 | 160 | def plot_predictions(images_list, size): 161 | for image_file in images_list: 162 | image_tensor = read_image(image_file, size) 163 | prediction = infer( 164 | image_tensor=image_tensor, 165 | model_file=model_file 166 | ) 167 | plot_samples_matplotlib( 168 | [image_tensor, prediction], figsize=(10, 6) 169 | ) 170 | 171 | plot_predictions(train_images[:4], (512, 512)) 172 | ``` 173 | 174 | ## Results 175 | 176 | ### Multi-Person Human Parsing 177 | 178 | ![](./assets/human_parsing_results/training_results.png) 179 | 180 | #### Training Set Results 181 | 182 | ![](./assets/human_parsing_results/train_result_1.png) 183 | 184 | ![](./assets/human_parsing_results/train_result_2.png) 185 | 186 | ![](./assets/human_parsing_results/train_result_3.png) 187 | 188 | ![](./assets/human_parsing_results/train_result_4.png) 189 | 190 | #### Validation Set Results 191 | 192 | ![](./assets/human_parsing_results/val_result_1.png) 193 | 194 | ![](./assets/human_parsing_results/val_result_2.png) 195 | 196 | ![](./assets/human_parsing_results/val_result_3.png) 197 | 198 | ![](./assets/human_parsing_results/val_result_4.png) 199 | 200 | #### Test Set Results 201 | 202 | ![](./assets/human_parsing_results/test_result_1.png) 203 | 204 | ![](./assets/human_parsing_results/test_result_2.png) 205 | 206 | ![](./assets/human_parsing_results/test_result_3.png) 207 | 208 | ![](./assets/human_parsing_results/test_result_4.png) 209 | 210 | ## Citation 211 | 212 | ``` 213 | @misc{1802.02611, 214 | Author = {Liang-Chieh Chen and Yukun Zhu and George Papandreou and Florian Schroff and Hartwig Adam}, 215 | Title = {Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation}, 216 | Year = {2018}, 217 | Eprint = {arXiv:1802.02611}, 218 | } 219 | ``` 220 | -------------------------------------------------------------------------------- /assets/aligned_exception.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/aligned_exception.png -------------------------------------------------------------------------------- /assets/deeplabv3_plus_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/deeplabv3_plus_diagram.png -------------------------------------------------------------------------------- /assets/human_parsing_results/test_result_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/test_result_1.png -------------------------------------------------------------------------------- /assets/human_parsing_results/test_result_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/test_result_2.png -------------------------------------------------------------------------------- /assets/human_parsing_results/test_result_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/test_result_3.png -------------------------------------------------------------------------------- /assets/human_parsing_results/test_result_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/test_result_4.png -------------------------------------------------------------------------------- /assets/human_parsing_results/train_result_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/train_result_1.png -------------------------------------------------------------------------------- /assets/human_parsing_results/train_result_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/train_result_2.png -------------------------------------------------------------------------------- /assets/human_parsing_results/train_result_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/train_result_3.png -------------------------------------------------------------------------------- /assets/human_parsing_results/train_result_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/train_result_4.png -------------------------------------------------------------------------------- /assets/human_parsing_results/training_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/training_results.png -------------------------------------------------------------------------------- /assets/human_parsing_results/val_result_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/val_result_1.png -------------------------------------------------------------------------------- /assets/human_parsing_results/val_result_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/val_result_2.png -------------------------------------------------------------------------------- /assets/human_parsing_results/val_result_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/val_result_3.png -------------------------------------------------------------------------------- /assets/human_parsing_results/val_result_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lattice-ai/DeepLabV3-Plus/f596eb70d385cf779f3fffc408358319b57cc655/assets/human_parsing_results/val_result_4.png -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """__init__ module for configs. Register your config file here by adding it's 5 | entry in the CONFIG_MAP as shown. 6 | """ 7 | 8 | import config.camvid_resnet50 9 | import config.human_parsing_resnet50 10 | 11 | 12 | CONFIG_MAP = { 13 | 'camvid_resnet50': config.camvid_resnet50.CONFIG, 14 | 'human_parsing_resnet50': config.human_parsing_resnet50.CONFIG 15 | } 16 | -------------------------------------------------------------------------------- /config/camvid_resnet50.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Module for training deeplabv3plus on camvid dataset.""" 4 | 5 | from glob import glob 6 | 7 | import tensorflow as tf 8 | 9 | 10 | # Sample Configuration 11 | CONFIG = { 12 | # We mandate specifying project_name and experiment_name in every config 13 | # file. They are used for wandb runs if wandb api key is specified. 14 | 'project_name': 'deeplabv3-plus', 15 | 'experiment_name': 'camvid-segmentation-resnet-50-backbone', 16 | 17 | 'train_dataset_config': { 18 | 'images': sorted(glob('./dataset/camvid/train/*')), 19 | 'labels': sorted(glob('./dataset/camvid/trainannot/*')), 20 | 'height': 512, 'width': 512, 'batch_size': 8 21 | }, 22 | 23 | 'val_dataset_config': { 24 | 'images': sorted(glob('./dataset/camvid/val/*')), 25 | 'labels': sorted(glob('./dataset/camvid/valannot/*')), 26 | 'height': 512, 'width': 512, 'batch_size': 8 27 | }, 28 | 29 | 'strategy': tf.distribute.OneDeviceStrategy(device="/gpu:0"), 30 | 'num_classes': 20, 'backbone': 'resnet50', 'learning_rate': 0.0001, 31 | 32 | 'checkpoint_dir': "./checkpoints/", 33 | 'checkpoint_file_prefix': "deeplabv3plus_with_resnet50_", 34 | 35 | 'epochs': 100 36 | } 37 | -------------------------------------------------------------------------------- /config/human_parsing_resnet50.py: -------------------------------------------------------------------------------- 1 | """Module providing configuration for training on human parsing with resnet50 2 | backbone""" 3 | 4 | from glob import glob 5 | 6 | import tensorflow as tf 7 | 8 | 9 | CONFIG = { 10 | 'project_name': 'deeplabv3-plus', 11 | 'experiment_name': 'human-parsing-resnet-50-backbone', 12 | 13 | 'train_dataset_config': { 14 | 'images': sorted( 15 | glob( 16 | './dataset/instance-level_human_parsing/' 17 | 'instance-level_human_parsing/Training/Images/*' 18 | ) 19 | ), 20 | 'labels': sorted( 21 | glob( 22 | './dataset/instance-level_human_parsing/' 23 | 'instance-level_human_parsing/Training/Category_ids/*' 24 | ) 25 | ), 26 | 'height': 512, 'width': 512, 'batch_size': 8 27 | }, 28 | 29 | 'val_dataset_config': { 30 | 'images': sorted( 31 | glob( 32 | './dataset/instance-level_human_parsing/' 33 | 'instance-level_human_parsing/Validation/Images/*' 34 | ) 35 | ), 36 | 'labels': sorted( 37 | glob( 38 | './dataset/instance-level_human_parsing/' 39 | 'instance-level_human_parsing/Validation/Category_ids/*' 40 | ) 41 | ), 42 | 'height': 512, 'width': 512, 'batch_size': 8 43 | }, 44 | 45 | 'strategy': tf.distribute.OneDeviceStrategy(device="/gpu:0"), 46 | 47 | 'num_classes': 20, 48 | 'backbone': 'resnet50', 49 | 'learning_rate': 0.0001, 50 | 51 | 'checkpoint_dir': "./checkpoints/", 52 | 'checkpoint_file_prefix': 53 | 'deeplabv3-plus-human-parsing-resnet-50-backbone_', 54 | 55 | 'epochs': 100 56 | } 57 | -------------------------------------------------------------------------------- /dataset/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except these files 4 | !.gitignore 5 | !camvid.sh 6 | !cityscapes.sh 7 | !download_human_parsing_dataset.sh 8 | -------------------------------------------------------------------------------- /dataset/camvid.sh: -------------------------------------------------------------------------------- 1 | mkdir camvid 2 | cd camvid 3 | wget https://www.dropbox.com/s/ej1gx48bxqbtwd2/CamVid.zip?dl=0 -O CamVid.zip 4 | unzip -qq CamVid.zip 5 | rm CamVid.zip 6 | cd .. -------------------------------------------------------------------------------- /dataset/cityscapes.sh: -------------------------------------------------------------------------------- 1 | mkdir train_imgs train_labels val_imgs val_labels 2 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=&password=&submit=Login' https://www.cityscapes-dataset.com/login/ 3 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 4 | rm index.html 5 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=&password=&submit=Login' https://www.cityscapes-dataset.com/login/ 6 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 7 | rm index.html 8 | unzip -q gtFine_trainvaltest.zip 9 | rm README 10 | rm license.txt 11 | unzip -q leftImg8bit_trainvaltest.zip 12 | mv ./leftImg8bit/train/**/*leftImg8bit.png ./train_imgs 13 | mv ./leftImg8bit/val/**/*leftImg8bit.png ./val_imgs 14 | mv ./gtFine/train/**/*labelIds.png ./train_labels 15 | mv ./gtFine/val/**/*labelIds.png ./val_labels -------------------------------------------------------------------------------- /dataset/download_human_parsing_dataset.sh: -------------------------------------------------------------------------------- 1 | export KAGGLE_USERNAME=$1 2 | export KAGGLE_KEY=$2 3 | kaggle datasets download -d soumikrakshit/human-segmentation 4 | unzip -q human-segmentation 5 | rm human-segmentation.zip -------------------------------------------------------------------------------- /deeplabv3plus/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import GenericDataLoader -------------------------------------------------------------------------------- /deeplabv3plus/datasets/augmentations.py: -------------------------------------------------------------------------------- 1 | import random 2 | import tensorflow as tf 3 | 4 | 5 | class Augmentation: 6 | 7 | def __init__(self, configs) -> None: 8 | self.configs = configs 9 | self.choice = None 10 | 11 | def apply_random_brightness(self, image, mask): 12 | condition = tf.cast( 13 | tf.random.uniform( 14 | [], maxval=2, dtype=tf.int32 15 | ), tf.bool 16 | ) 17 | image = tf.cond( 18 | condition, lambda: tf.image.random_brightness( 19 | image, self.configs['random_brightness_max_delta']), 20 | lambda: tf.identity(image) 21 | ) 22 | return image, mask 23 | 24 | def apply_random_contrast(self, image, mask): 25 | condition = tf.cast( 26 | tf.random.uniform( 27 | [], maxval=2, dtype=tf.int32 28 | ), tf.bool 29 | ) 30 | image = tf.cond( 31 | condition, lambda: tf.image.random_contrast( 32 | image, self.configs['random_contrast_lower_bound'], 33 | self.configs['random_contrast_upper_bound'] 34 | ), lambda: tf.identity(image) 35 | ) 36 | return image, mask 37 | 38 | def apply_random_saturation(self, image, mask): 39 | condition = tf.cast( 40 | tf.random.uniform( 41 | [], maxval=2, dtype=tf.int32 42 | ), tf.bool 43 | ) 44 | image = tf.cond( 45 | condition, lambda: tf.image.random_saturation( 46 | image, self.configs['random_contrast_lower_bound'], 47 | self.configs['random_contrast_upper_bound'] 48 | ), lambda: tf.identity(image) 49 | ) 50 | return image, mask 51 | 52 | def apply_horizontal_flip(self, image, mask): 53 | combined_tensor = tf.concat([image, mask], axis=2) 54 | combined_tensor = tf.image.random_flip_left_right( 55 | combined_tensor, seed=self.configs['seed'] 56 | ) 57 | image, mask = tf.split( 58 | combined_tensor, 59 | [self.configs['image_channels'], self.configs['label_channels']], axis=2 60 | ) 61 | return image, mask 62 | 63 | def apply_vertical_flip(self, image, mask): 64 | combined_tensor = tf.concat([image, mask], axis=2) 65 | combined_tensor = tf.image.random_flip_up_down( 66 | combined_tensor, seed=self.configs['seed'] 67 | ) 68 | image, mask = tf.split( 69 | combined_tensor, 70 | [self.configs['image_channels'], self.configs['label_channels']], axis=2 71 | ) 72 | return image, mask 73 | 74 | def apply_resize(self, image, mask): 75 | image = tf.image.resize(image, self.configs['image_size']) 76 | mask = tf.image.resize(mask, self.configs['image_size'], method="nearest") 77 | return image, mask 78 | 79 | def apply_random_crop(self, image, mask): 80 | condition = tf.cast( 81 | tf.random.uniform( 82 | [], maxval=2, dtype=tf.int32, 83 | seed=self.configs['seed'] 84 | ), tf.bool 85 | ) 86 | shape = tf.cast(tf.shape(image), tf.float32) 87 | h = tf.cast(shape[0] * self.configs['crop_percent'], tf.int32) 88 | w = tf.cast(shape[1] * self.configs['crop_percent'], tf.int32) 89 | combined_tensor = tf.concat([image, mask], axis=2) 90 | combined_tensor = tf.cond( 91 | condition, lambda: tf.image.random_crop( 92 | combined_tensor, 93 | [h, w, self.configs['image_channels'] + self.configs['label_channels']], 94 | seed=self.configs['seed'] 95 | ), lambda: tf.identity(combined_tensor) 96 | ) 97 | image, mask = tf.split( 98 | combined_tensor, 99 | [self.configs['image_channels'], self.configs['label_channels']], axis=2 100 | ) 101 | return image, mask 102 | 103 | def compose_sequential(self, image, mask): 104 | image, mask = self.apply_random_brightness(image, mask) 105 | image, mask = self.apply_random_contrast(image, mask) 106 | image, mask = self.apply_random_saturation(image, mask) 107 | image, mask = self.apply_random_crop(image, mask) 108 | image, mask = self.apply_horizontal_flip(image, mask) 109 | image, mask = self.apply_vertical_flip(image, mask) 110 | image, mask = self.apply_resize(image, mask) 111 | return image, mask 112 | 113 | @tf.function 114 | def compose_random(self, image, mask): 115 | 116 | def compose(_image, _mask): 117 | options = [ 118 | self.apply_random_brightness, 119 | self.apply_random_contrast, 120 | self.apply_random_saturation, 121 | self.apply_random_crop, 122 | self.apply_horizontal_flip, 123 | self.apply_vertical_flip 124 | ] 125 | augment_func = random.choice(options) 126 | _image, _mask = augment_func(_image, _mask) 127 | _image, _mask = self.apply_resize(_image, _mask) 128 | return _image, _mask 129 | 130 | return tf.py_function( 131 | compose, [image, mask], 132 | [tf.float32, tf.uint8] 133 | ) 134 | 135 | def set_choice(self, choice): 136 | self.choice = choice 137 | 138 | @tf.function 139 | def compose_by_choice(self, image, mask): 140 | 141 | def compose(_image, _mask): 142 | options = self.choice 143 | augment_func = random.choice(options) 144 | _image, _mask = augment_func(_image, _mask) 145 | _image, _mask = self.apply_resize(_image, _mask) 146 | return _image, _mask 147 | 148 | return tf.py_function( 149 | compose, [image, mask], 150 | [tf.float32, tf.uint8] 151 | ) 152 | -------------------------------------------------------------------------------- /deeplabv3plus/datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class GenericDataLoader: 5 | 6 | def __init__(self, configs): 7 | self.configs = configs 8 | self.assert_dataset() 9 | 10 | def assert_dataset(self): 11 | assert 'images' in self.configs and 'labels' in self.configs 12 | assert len(self.configs['images']) == len(self.configs['labels']) 13 | print('Train Images are good to go') 14 | 15 | def __len__(self): 16 | return len(self.configs['images']) 17 | 18 | def read_img(self, image_path, mask=False): 19 | image = tf.io.read_file(image_path) 20 | if mask: 21 | image = tf.image.decode_png(image, channels=1) 22 | image.set_shape([None, None, 1]) 23 | image = (tf.image.resize( 24 | images=image, size=[ 25 | self.configs['height'], 26 | self.configs['width'] 27 | ], method="nearest" 28 | )) 29 | image = tf.cast(image, tf.float32) 30 | else: 31 | image = tf.image.decode_png(image, channels=3) 32 | image.set_shape([None, None, 3]) 33 | image = (tf.image.resize( 34 | images=image, size=[ 35 | self.configs['height'], 36 | self.configs['width'] 37 | ] 38 | )) 39 | image = tf.cast(image, tf.float32) / 127.5 - 1 40 | return image 41 | 42 | def _map_function(self, image_list, mask_list): 43 | image = self.read_img(image_list) 44 | mask = self.read_img(mask_list, mask=True) 45 | return image, mask 46 | 47 | def get_dataset(self): 48 | dataset = tf.data.Dataset.from_tensor_slices( 49 | (self.configs['images'], self.configs['labels']) 50 | ) 51 | dataset = dataset.map(self._map_function, num_parallel_calls=tf.data.experimental.AUTOTUNE) 52 | dataset = dataset.batch(self.configs['batch_size'], drop_remainder=True) 53 | dataset = dataset.repeat() 54 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 55 | return dataset 56 | -------------------------------------------------------------------------------- /deeplabv3plus/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def read_image(image_file, image_size, is_mask=False): 6 | image = tf.io.read_file(image_file) 7 | if is_mask: 8 | image = tf.image.decode_png(image, channels=1) 9 | image.set_shape([None, None, 1]) 10 | else: 11 | image = tf.image.decode_png(image, channels=3) 12 | image.set_shape([None, None, 3]) 13 | image = (tf.image.resize(images=image, size=image_size)) 14 | image = tf.cast(image, tf.float32) / 127.5 - 1 15 | return image 16 | 17 | 18 | def infer(model_file, image_tensor): 19 | model = tf.keras.models.load_model(model_file) 20 | predictions = model.predict(np.expand_dims((image_tensor), axis=0)) 21 | predictions = np.squeeze(predictions) 22 | predictions = np.argmax(predictions, axis=2) 23 | return predictions 24 | -------------------------------------------------------------------------------- /deeplabv3plus/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplabv3_plus import DeeplabV3Plus -------------------------------------------------------------------------------- /deeplabv3plus/model/backbones.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | BACKBONES = { 5 | 'resnet50': { 6 | 'model': tf.keras.applications.ResNet50, 7 | 'feature_1': 'conv4_block6_2_relu', 8 | 'feature_2': 'conv2_block3_2_relu' 9 | }, 10 | 'mobilenetv2': { 11 | 'model': tf.keras.applications.MobileNetV2, 12 | 'feature_1': 'out_relu', 13 | 'feature_2': 'block_3_depthwise_relu' 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /deeplabv3plus/model/blocks.py: -------------------------------------------------------------------------------- 1 | """Module providing building blocks for the DeepLabV3+ netowork architecture. 2 | """ 3 | 4 | import tensorflow as tf 5 | 6 | 7 | class ConvBlock(tf.keras.layers.Layer): 8 | """Convolutional Block for DeepLabV3+ 9 | 10 | Convolutional block consisting of Conv2D -> BatchNorm -> ReLU 11 | 12 | Args: 13 | n_filters: 14 | number of output filters 15 | kernel_size: 16 | kernel_size for convolution 17 | padding: 18 | padding for convolution 19 | kernel_initializer: 20 | kernel initializer for convolution 21 | use_bias: 22 | boolean, whether of not to use bias in convolution 23 | dilation_rate: 24 | dilation rate for convolution 25 | activation: 26 | activation to be used for convolution 27 | """ 28 | # !pylint:disable=too-many-arguments 29 | def __init__(self, n_filters, kernel_size, padding, dilation_rate, 30 | kernel_initializer, use_bias, conv_activation=None): 31 | super(ConvBlock, self).__init__() 32 | 33 | self.conv = tf.keras.layers.Conv2D( 34 | n_filters, kernel_size=kernel_size, padding=padding, 35 | kernel_initializer=kernel_initializer, 36 | use_bias=use_bias, dilation_rate=dilation_rate, 37 | activation=conv_activation) 38 | 39 | self.batch_norm = tf.keras.layers.BatchNormalization() 40 | self.relu = tf.keras.layers.ReLU() 41 | 42 | def call(self, inputs, **kwargs): 43 | tensor = self.conv(inputs) 44 | tensor = self.batch_norm(tensor) 45 | tensor = self.relu(tensor) 46 | return tensor 47 | 48 | 49 | class AtrousSpatialPyramidPooling(tf.keras.layers.Layer): 50 | """Atrous Spatial Pyramid Pooling layer for DeepLabV3+ architecture.""" 51 | # !pylint:disable=too-many-instance-attributes 52 | def __init__(self): 53 | super(AtrousSpatialPyramidPooling, self).__init__() 54 | 55 | # layer architecture components 56 | self.avg_pool = None 57 | self.conv1, self.conv2 = None, None 58 | self.pool = None 59 | self.out1, self.out6, self.out12, self.out18 = None, None, None, None 60 | 61 | @staticmethod 62 | def _get_conv_block(kernel_size, dilation_rate, use_bias=False): 63 | return ConvBlock(256, 64 | kernel_size=kernel_size, 65 | dilation_rate=dilation_rate, 66 | padding='same', 67 | use_bias=use_bias, 68 | kernel_initializer=tf.keras.initializers.he_normal()) 69 | 70 | def build(self, input_shape): 71 | dummy_tensor = tf.random.normal(input_shape) # used for calculating 72 | # output shape of convolutional layers 73 | 74 | self.avg_pool = tf.keras.layers.AveragePooling2D( 75 | pool_size=(input_shape[-3], input_shape[-2])) 76 | 77 | self.conv1 = AtrousSpatialPyramidPooling._get_conv_block( 78 | kernel_size=1, dilation_rate=1, use_bias=True) 79 | 80 | self.conv2 = AtrousSpatialPyramidPooling._get_conv_block( 81 | kernel_size=1, dilation_rate=1) 82 | 83 | dummy_tensor = self.conv1(self.avg_pool(dummy_tensor)) 84 | 85 | self.pool = tf.keras.layers.UpSampling2D( 86 | size=( 87 | input_shape[-3] // dummy_tensor.shape[1], 88 | input_shape[-2] // dummy_tensor.shape[2] 89 | ), 90 | interpolation='bilinear' 91 | ) 92 | 93 | self.out1, self.out6, self.out12, self.out18 = map( 94 | lambda tup: AtrousSpatialPyramidPooling._get_conv_block( 95 | kernel_size=tup[0], dilation_rate=tup[1] 96 | ), 97 | [(1, 1), (3, 6), (3, 12), (3, 18)] 98 | ) 99 | 100 | def call(self, inputs, **kwargs): 101 | tensor = self.avg_pool(inputs) 102 | tensor = self.conv1(tensor) 103 | tensor = tf.keras.layers.Concatenate(axis=-1)([ 104 | self.pool(tensor), 105 | self.out1(inputs), 106 | self.out6(inputs), 107 | self.out12( 108 | inputs 109 | ), 110 | self.out18( 111 | inputs 112 | ) 113 | ]) 114 | tensor = self.conv2(tensor) 115 | return tensor 116 | -------------------------------------------------------------------------------- /deeplabv3plus/model/deeplabv3_plus.py: -------------------------------------------------------------------------------- 1 | """Module providing the DeeplabV3+ network architecture as a tf.keras.Model. 2 | """ 3 | 4 | import tensorflow as tf 5 | 6 | from .backbones import BACKBONES 7 | from .blocks import (AtrousSpatialPyramidPooling, 8 | ConvBlock) 9 | 10 | 11 | # !pylint:disable=too-many-ancestors, too-many-instance-attributes 12 | class DeeplabV3Plus(tf.keras.Model): 13 | """DeeplabV3+ network architecture provider tf.keras.Model implementation. 14 | Args: 15 | num_classes: 16 | number of segmentation classes, effectively - number of output 17 | filters 18 | height, width: 19 | expected height, width of image 20 | backbone: 21 | backbone to be used 22 | """ 23 | def __init__(self, num_classes, backbone='resnet50', **kwargs): 24 | super(DeeplabV3Plus, self).__init__() 25 | 26 | self.num_classes = num_classes 27 | self.backbone = backbone 28 | self.aspp = None 29 | self.backbone_feature_1, self.backbone_feature_2 = None, None 30 | self.input_a_upsampler_getter = None 31 | self.otensor_upsampler_getter = None 32 | self.input_b_conv, self.conv1, self.conv2, self.out_conv = (None, 33 | None, 34 | None, 35 | None) 36 | 37 | @staticmethod 38 | def _get_conv_block(filters, kernel_size, conv_activation=None): 39 | return ConvBlock(filters, kernel_size=kernel_size, padding='same', 40 | conv_activation=conv_activation, 41 | kernel_initializer=tf.keras.initializers.he_normal(), 42 | use_bias=False, dilation_rate=1) 43 | 44 | @staticmethod 45 | def _get_upsample_layer_fn(input_shape, factor: int): 46 | return lambda fan_in_shape: \ 47 | tf.keras.layers.UpSampling2D( 48 | size=( 49 | input_shape[1] 50 | // factor // fan_in_shape[1], 51 | input_shape[2] 52 | // factor // fan_in_shape[2] 53 | ), 54 | interpolation='bilinear' 55 | ) 56 | 57 | def _get_backbone_feature(self, feature: str, 58 | input_shape) -> tf.keras.Model: 59 | input_layer = tf.keras.Input(shape=input_shape[1:]) 60 | 61 | backbone_model = BACKBONES[self.backbone]['model']( 62 | input_tensor=input_layer, weights='imagenet', include_top=False) 63 | 64 | output_layer = backbone_model.get_layer( 65 | BACKBONES[self.backbone][feature]).output 66 | return tf.keras.Model(inputs=input_layer, outputs=output_layer) 67 | 68 | def build(self, input_shape): 69 | self.backbone_feature_1 = self._get_backbone_feature('feature_1', 70 | input_shape) 71 | self.backbone_feature_2 = self._get_backbone_feature('feature_2', 72 | input_shape) 73 | 74 | self.input_a_upsampler_getter = self._get_upsample_layer_fn( 75 | input_shape, factor=4) 76 | 77 | self.aspp = AtrousSpatialPyramidPooling() 78 | 79 | self.input_b_conv = DeeplabV3Plus._get_conv_block(48, 80 | kernel_size=(1, 1)) 81 | 82 | self.conv1 = DeeplabV3Plus._get_conv_block(256, kernel_size=3, 83 | conv_activation='relu') 84 | 85 | self.conv2 = DeeplabV3Plus._get_conv_block(256, kernel_size=3, 86 | conv_activation='relu') 87 | 88 | self.otensor_upsampler_getter = self._get_upsample_layer_fn( 89 | input_shape, factor=1) 90 | 91 | self.out_conv = tf.keras.layers.Conv2D(self.num_classes, 92 | kernel_size=(1, 1), 93 | padding='same') 94 | 95 | def call(self, inputs, training=None, mask=None): 96 | input_a = self.backbone_feature_1(inputs) 97 | 98 | input_a = self.aspp(input_a) 99 | input_a = self.input_a_upsampler_getter(input_a.shape)(input_a) 100 | 101 | input_b = self.backbone_feature_2(inputs) 102 | input_b = self.input_b_conv(input_b) 103 | 104 | tensor = tf.keras.layers.Concatenate(axis=-1)([input_a, input_b]) 105 | tensor = self.conv2(self.conv1(tensor)) 106 | 107 | tensor = self.otensor_upsampler_getter(tensor.shape)(tensor) 108 | return self.out_conv(tensor) 109 | -------------------------------------------------------------------------------- /deeplabv3plus/train.py: -------------------------------------------------------------------------------- 1 | """Module providing Trainer class for deeplabv3plus""" 2 | 3 | import os 4 | 5 | import tensorflow as tf 6 | 7 | import wandb 8 | from wandb.keras import WandbCallback 9 | 10 | from deeplabv3plus.datasets import GenericDataLoader 11 | from deeplabv3plus.model import DeeplabV3Plus 12 | 13 | 14 | class Trainer: 15 | """Class for managing DeeplabV3+ model training. 16 | 17 | Args: 18 | config: 19 | python dictionary containing training configuration for 20 | DeeplabV3Plus 21 | """ 22 | def __init__(self, config): 23 | self.config = config 24 | self._assert_config() 25 | 26 | # Train Dataset 27 | train_dataloader = GenericDataLoader(self.config[ 28 | 'train_dataset_config']) 29 | self.train_data_length = len(train_dataloader) 30 | print('[+] Data points in train dataset: {}'.format( 31 | self.train_data_length)) 32 | self.train_dataset = train_dataloader.get_dataset() 33 | print('Train Dataset:', self.train_dataset) 34 | 35 | # Validation Dataset 36 | val_dataloader = GenericDataLoader(self.config[ 37 | 'val_dataset_config']) 38 | self.val_data_length = len(val_dataloader) 39 | print('Data points in train dataset: {}'.format( 40 | self.val_data_length)) 41 | self.val_dataset = val_dataloader.get_dataset() 42 | print('Val Dataset:', self.val_dataset) 43 | 44 | self._model = None 45 | self._wandb_initialized = False 46 | 47 | @property 48 | def model(self): 49 | """Property returning model being trained.""" 50 | 51 | if self._model is not None: 52 | return self._model 53 | 54 | with self.config['strategy'].scope(): 55 | self._model = DeeplabV3Plus( 56 | num_classes=self.config['num_classes'], 57 | backbone=self.config['backbone'] 58 | ) 59 | 60 | self._model.compile( 61 | optimizer=tf.keras.optimizers.Adam( 62 | learning_rate=self.config['learning_rate'] 63 | ), 64 | loss=tf.keras.losses.SparseCategoricalCrossentropy(), 65 | metrics=['accuracy'] 66 | ) 67 | 68 | return self._model 69 | 70 | @staticmethod 71 | def _assert_dataset_config(dataset_config): 72 | assert 'images' in dataset_config and \ 73 | isinstance(dataset_config['images'], list) 74 | assert 'labels' in dataset_config and \ 75 | isinstance(dataset_config['labels'], list) 76 | 77 | assert 'height' in dataset_config and \ 78 | isinstance(dataset_config['height'], int) 79 | assert 'width' in dataset_config and \ 80 | isinstance(dataset_config['width'], int) 81 | 82 | assert 'batch_size' in dataset_config and \ 83 | isinstance(dataset_config['batch_size'], int) 84 | 85 | def _assert_config(self): 86 | assert 'project_name' in self.config and \ 87 | isinstance(self.config['project_name'], str) 88 | assert 'experiment_name' in self.config and \ 89 | isinstance(self.config['experiment_name'], str) 90 | 91 | assert 'train_dataset_config' in self.config 92 | Trainer._assert_dataset_config(self.config['train_dataset_config']) 93 | assert 'val_dataset_config' in self.config 94 | Trainer._assert_dataset_config(self.config['val_dataset_config']) 95 | 96 | assert 'strategy' in self.config and \ 97 | isinstance(self.config['strategy'], tf.distribute.Strategy) 98 | 99 | assert 'num_classes' in self.config and \ 100 | isinstance(self.config['num_classes'], int) 101 | assert 'backbone' in self.config and \ 102 | isinstance(self.config['backbone'], str) 103 | 104 | assert 'learning_rate' in self.config and \ 105 | isinstance(self.config['learning_rate'], float) 106 | 107 | assert 'checkpoint_dir' in self.config and \ 108 | isinstance(self.config['checkpoint_dir'], str) 109 | assert 'checkpoint_file_prefix' in self.config and \ 110 | isinstance(self.config['checkpoint_file_prefix'], str) 111 | 112 | assert 'epochs' in self.config and \ 113 | isinstance(self.config['epochs'], int) 114 | 115 | def connect_wandb(self): 116 | """Connects Trainer to wandb. 117 | 118 | Runs wandb.init() with the given wandb_api_key, project_name and 119 | experiment_name. 120 | """ 121 | if 'wandb_api_key' not in self.config: 122 | return 123 | 124 | os.environ['WANDB_API_KEY'] = self.config['wandb_api_key'] 125 | wandb.init( 126 | project=self.config['project_name'], 127 | name=self.config['experiment_name'] 128 | ) 129 | self._wandb_initialized = True 130 | 131 | def _get_checkpoint_filename_format(self): 132 | if self.config['checkpoint_dir'] == 'wandb://': 133 | if 'wandb_api_key' not in self.config: 134 | raise ValueError("Invalid configuration, wandb_api_key not " 135 | "provided!") 136 | if not self._wandb_initialized: 137 | raise ValueError("Wandb not intialized, " 138 | "checkpoint_filename_format is unusable.") 139 | 140 | return os.path.join(wandb.run.dir, 141 | self.config['checkpoint_file_prefix'] + 142 | "{epoch}") 143 | 144 | return os.path.join(self.config['checkpoint_dir'], 145 | self.config['checkpoint_file_prefix'] + 146 | "{epoch}") 147 | 148 | def _get_logger_callback(self): 149 | if 'wandb_api_key' not in self.config: 150 | return tf.keras.callbacks.TensorBoard() 151 | 152 | try: 153 | return WandbCallback(save_weights_only=True, save_model=False) 154 | except wandb.Error as error: 155 | if 'wandb_api_key' in self.config: 156 | raise error # rethrow 157 | 158 | print("[-] Defaulting to TensorBoard logging...") 159 | return tf.keras.callbacks.TensorBoard() 160 | 161 | def train(self): 162 | """Trainer entry point. 163 | 164 | Attempts to connect to wandb before starting training. Runs .fit() on 165 | loaded model. 166 | """ 167 | if not self._wandb_initialized: 168 | self.connect_wandb() 169 | 170 | callbacks = [ 171 | tf.keras.callbacks.ModelCheckpoint( 172 | filepath=self._get_checkpoint_filename_format(), 173 | monitor='val_loss', 174 | save_best_only=True, 175 | mode='min', 176 | save_weights_only=True 177 | ), 178 | 179 | self._get_logger_callback() 180 | ] 181 | 182 | history = self.model.fit( 183 | self.train_dataset, validation_data=self.val_dataset, 184 | 185 | steps_per_epoch=self.train_data_length // 186 | self.config['train_dataset_config']['batch_size'], 187 | 188 | validation_steps=self.val_data_length // 189 | self.config['val_dataset_config']['batch_size'], 190 | 191 | epochs=self.config['epochs'], callbacks=callbacks 192 | ) 193 | 194 | return history 195 | -------------------------------------------------------------------------------- /deeplabv3plus/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tensorflow as tf 4 | from matplotlib import pyplot as plt 5 | 6 | 7 | def plot_samples_matplotlib(display_list, figsize=(5, 3)): 8 | _, axes = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize) 9 | for i in range(len(display_list)): 10 | if display_list[i].shape[-1] == 3: 11 | axes[i].imshow(tf.keras.preprocessing.image.array_to_img(display_list[i])) 12 | else: 13 | axes[i].imshow(display_list[i]) 14 | plt.show() 15 | 16 | 17 | def decode_segmask(mask, colormap, n_classes): 18 | r = np.zeros_like(mask).astype(np.uint8) 19 | g = np.zeros_like(mask).astype(np.uint8) 20 | b = np.zeros_like(mask).astype(np.uint8) 21 | for l in range(0, n_classes): 22 | idx = mask == l 23 | r[idx] = colormap[l, 0] 24 | g[idx] = colormap[l, 1] 25 | b[idx] = colormap[l, 2] 26 | rgb = np.stack([r, g, b], axis=2) 27 | return rgb 28 | 29 | 30 | def get_overlay(image, colored_mask): 31 | image = tf.keras.preprocessing.image.array_to_img(image) 32 | image = np.array(image).astype(np.uint8) 33 | overlay = cv2.addWeighted(image, 0.35, colored_mask, 0.65, 0) 34 | return overlay 35 | -------------------------------------------------------------------------------- /models.md: -------------------------------------------------------------------------------- 1 | # Model Architectures 2 | 3 | 4 | ![](./assets/deeplabv3plus.svg) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | appnope==0.1.0 3 | astunparse==1.6.3 4 | attrs==19.3.0 5 | backcall==0.1.0 6 | bleach==3.1.5 7 | cachetools==4.1.0 8 | certifi==2020.4.5.1 9 | chardet==3.0.4 10 | click==7.1.2 11 | configparser==5.0.0 12 | cycler==0.10.0 13 | decorator==4.4.2 14 | defusedxml==0.6.0 15 | docker-pycreds==0.4.0 16 | entrypoints==0.3 17 | gast==0.3.3 18 | gitdb==4.0.5 19 | GitPython==3.1.2 20 | google-auth==1.14.3 21 | google-auth-oauthlib==0.4.1 22 | google-pasta==0.2.0 23 | gql==0.2.0 24 | graphql-core==1.1 25 | grpcio==1.29.0 26 | h5py==2.10.0 27 | idna==2.9 28 | importlib-metadata==1.6.0 29 | ipykernel==5.2.1 30 | ipython==7.14.0 31 | ipython-genutils==0.2.0 32 | ipywidgets==7.5.1 33 | jedi==0.17.0 34 | Jinja2==2.11.2 35 | jsonschema==3.2.0 36 | jupyter==1.0.0 37 | jupyter-client==6.1.3 38 | jupyter-console==6.1.0 39 | jupyter-core==4.6.3 40 | kaggle==1.5.6 41 | Keras-Preprocessing==1.1.2 42 | kiwisolver==1.2.0 43 | Markdown==3.2.2 44 | MarkupSafe==1.1.1 45 | matplotlib==3.2.1 46 | mistune==0.8.4 47 | nbconvert==5.6.1 48 | nbformat==5.0.6 49 | netron==4.1.8 50 | notebook==6.0.3 51 | numpy==1.18.4 52 | nvidia-ml-py3==7.352.0 53 | oauthlib==3.1.0 54 | opt-einsum==3.2.1 55 | packaging==20.3 56 | pandocfilters==1.4.2 57 | parso==0.7.0 58 | pathtools==0.1.2 59 | pexpect==4.8.0 60 | pickleshare==0.7.5 61 | Pillow==7.2.0 62 | prometheus-client==0.7.1 63 | promise==2.3 64 | prompt-toolkit==3.0.5 65 | protobuf==3.12.0 66 | psutil==5.7.0 67 | ptyprocess==0.6.0 68 | pyasn1==0.4.8 69 | pyasn1-modules==0.2.8 70 | Pygments==2.6.1 71 | pyparsing==2.4.7 72 | pyrsistent==0.16.0 73 | python-dateutil==2.8.1 74 | python-slugify==4.0.0 75 | PyYAML==5.3.1 76 | pyzmq==19.0.1 77 | qtconsole==4.7.4 78 | QtPy==1.9.0 79 | requests==2.23.0 80 | requests-oauthlib==1.3.0 81 | rsa==4.0 82 | scipy==1.4.1 83 | Send2Trash==1.5.0 84 | sentry-sdk==0.14.4 85 | shortuuid==1.0.1 86 | six==1.14.0 87 | smmap==3.0.4 88 | subprocess32==3.5.4 89 | tensorboard==2.2.1 90 | tensorboard-plugin-wit==1.6.0.post3 91 | tensorflow==2.2.1 92 | tensorflow-estimator==2.2.0 93 | termcolor==1.1.0 94 | terminado==0.8.3 95 | testpath==0.4.4 96 | text-unidecode==1.3 97 | tornado==6.0.4 98 | tqdm==4.46.0 99 | traitlets==4.3.3 100 | urllib3==1.24.3 101 | wandb==0.8.36 102 | watchdog==0.10.2 103 | wcwidth==0.1.9 104 | webencodings==0.5.1 105 | Werkzeug==1.0.1 106 | widgetsnbextension==3.5.1 107 | wrapt==1.12.1 108 | zipp==3.1.0 109 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Module for training deeplabv3plus on camvid dataset.""" 4 | 5 | # !pylint:disable=wrong-import-position 6 | 7 | import argparse 8 | from argparse import RawTextHelpFormatter 9 | 10 | print("[-] Importing tensorflow...") 11 | import tensorflow as tf # noqa: E402 12 | print(f"[+] Done! Tensorflow version: {tf.version.VERSION}") 13 | 14 | print("[-] Importing Deeplabv3plus Trainer class...") 15 | from deeplabv3plus.train import Trainer # noqa: E402 16 | 17 | print("[-] Importing config files...") 18 | from config import CONFIG_MAP # noqa: E402 19 | 20 | 21 | if __name__ == "__main__": 22 | REGISTERED_CONFIG_KEYS = "".join(map(lambda s: f" {s}\n", CONFIG_MAP.keys())) 23 | 24 | PARSER = argparse.ArgumentParser( 25 | description=f""" 26 | Runs DeeplabV3+ trainer with the given config setting. 27 | 28 | Registered config_key values: 29 | {REGISTERED_CONFIG_KEYS}""", 30 | formatter_class=RawTextHelpFormatter 31 | ) 32 | PARSER.add_argument('config_key', help="Key to use while looking up " 33 | "configuration from the CONFIG_MAP dictionary.") 34 | PARSER.add_argument("--wandb_api_key", 35 | help="""Wandb API Key for logging run on Wandb. 36 | If provided, checkpoint_dir is set to wandb:// 37 | (Model checkpoints are saved to wandb.)""", 38 | default=None) 39 | ARGS = PARSER.parse_args() 40 | 41 | CONFIG = CONFIG_MAP[ARGS.config_key] 42 | if ARGS.wandb_api_key is not None: 43 | CONFIG['wandb_api_key'] = ARGS.wandb_api_key 44 | CONFIG['checkpoint_dir'] = "wandb://" 45 | 46 | TRAINER = Trainer(CONFIG_MAP[ARGS.config_key]) 47 | HISTORY = TRAINER.train() 48 | --------------------------------------------------------------------------------