├── Hyper_opt.ipynb ├── LICENSE ├── README.md ├── data_functions.py ├── losses.py ├── model_functions.py ├── models.py ├── resnet.py ├── resnext.py └── training_functions.py /Hyper_opt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import skopt\n", 10 | "from skopt import gp_minimize, forest_minimize\n", 11 | "from skopt.space import Real, Categorical, Integer\n", 12 | "from skopt.plots import plot_convergence\n", 13 | "from skopt.plots import plot_objective, plot_evaluations\n", 14 | "from skopt.utils import use_named_args\n", 15 | "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", 16 | "\n", 17 | "from data_functions import build_loader\n", 18 | "from model_functions import build_segmentation_model\n", 19 | "from training_functions import train\n", 20 | "\n", 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "import torch.backends.cudnn as cudnn\n", 24 | "from torch.autograd import Variable\n", 25 | "from training_functions import *\n", 26 | "from tensorboardX import SummaryWriter\n", 27 | "import numpy as np\n", 28 | "from imgaug import augmenters as iaa\n", 29 | "import imgaug as ia\n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "Collecting scikit-optimize\n", 42 | " Using cached https://files.pythonhosted.org/packages/f4/44/60f82c97d1caa98752c7da2c1681cab5c7a390a0fdd3a55fac672b321cac/scikit_optimize-0.5.2-py2.py3-none-any.whl\n", 43 | "Requirement already satisfied: scipy>=0.14.0 in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from scikit-optimize) (1.1.0)\n", 44 | "Requirement already satisfied: numpy in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from scikit-optimize) (1.14.5)\n", 45 | "Requirement already satisfied: scikit-learn>=0.19.1 in /home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from scikit-optimize) (0.19.1)\n", 46 | "Installing collected packages: scikit-optimize\n", 47 | "Successfully installed scikit-optimize-0.5.2\n", 48 | "\u001b[33mYou are using pip version 10.0.1, however version 18.1 is available.\n", 49 | "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "!pip install scikit-optimize\n", 55 | "!pip install imgaug\n", 56 | "!pip install opencv-python\n", 57 | "!pip install tensorboardX" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "def dict_to_string(hyperdict,logdir=\"/\"):\n", 67 | " s=logdir\n", 68 | " for key,values in hyperdict.items():\n", 69 | " s=s+\"_\"+key+\"_{}_\".format(np.round(values, decimals=12))\n", 70 | " return s\n" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "wait_epoch=50\n", 80 | "max_epoch=500\n", 81 | "batch_size=12\n", 82 | "\n", 83 | "writer_name_list_eval=['valid/non_ema_loss_eval','valid/ema_loss_eval','valid/total_loss_eval','valid/lovasz_loss_eval','valid/focal_loss_eval','valid/lovasz_loss_ema_eval',\n", 84 | " 'valid/focal_loss_ema_eval','valid/unsupervised_loss_eval','valid/iou_score_eval']\n", 85 | "writer_name_list_train=['train/non_ema_loss','train/ema_loss','train/total_loss','train/lovasz_loss','train/focal_loss','train/lovasz_loss_ema',\n", 86 | " 'train/focal_loss_ema','train/unsupervised_loss','train/iou_score']\n", 87 | "\n", 88 | "#writer = SummaryWriter(log_dir=\"/home/ubuntu/Kaggle_Pytorch_TGS/plogs2\")\n" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "from skopt.space import Real, Integer\n", 98 | "from skopt.utils import use_named_args\n", 99 | "\n", 100 | "\n", 101 | "# The list of hyper-parameters we want to optimize. For each one we define the bounds,\n", 102 | "# the corresponding scikit-learn parameter name, as well as how to sample values\n", 103 | "# from that dimension (`'log-uniform'` for the learning rate)\n", 104 | "space = [Real(0.1, 0.5, \"uniform\", name='focal_scaling'),\n", 105 | " Integer(int(1), int(batch_size/2), name='second_batch_size'),\n", 106 | " Real(0.00000001, 0.0001, \"log-uniform\", name='decay'),\n", 107 | " Real(0.01, 0.3, \"uniform\", name='unsupervised_scaling'),\n", 108 | " Real(0.1, 0.5, \"uniform\", name='ema_scaling'),\n", 109 | " Real(0.1, 0.5, \"uniform\", name='non_ema_scaling'),\n", 110 | " Real(0.01, 0.5, \"log-uniform\", name='droppout'),\n", 111 | " Real(0.1, 0.5, \"uniform\", name='lovasz_scaling')]\n", 112 | "\n", 113 | "# this decorator allows your objective function to receive a the parameters as\n", 114 | "# keyword arguments. This is particularly convenient when you want to set scikit-learn\n", 115 | "# estimator parameters\n", 116 | "@use_named_args(space)\n", 117 | "def objective(**params):\n", 118 | " logdir = dict_to_string(params)\n", 119 | " logdir=\"/home/ubuntu/Kaggle_Pytorch_TGS/plogs2\"+logdir\n", 120 | " print(logdir)\n", 121 | " writer = SummaryWriter(log_dir=logdir)\n", 122 | " \n", 123 | " augs = iaa.Sequential([\n", 124 | " #iaa.Scale((512, 512)),\n", 125 | " iaa.Fliplr(0.5),\n", 126 | " iaa.Affine(rotate=(-25, 25),mode=\"reflect\",\n", 127 | " translate_percent={\"x\": (-0.01, 0.01), \"y\": (-0.01, 0.01)}),\n", 128 | " #iaa.Add((-40, 40), per_channel=0.5, name=\"color-jitter\") \n", 129 | " ])\n", 130 | " \n", 131 | " \n", 132 | " train_loader,valid_loader=build_loader(input_img_folder='data/train/images/'\n", 133 | " ,label_folder='data/train/masks/'\n", 134 | " ,test_img_folder='data/test/images/'\n", 135 | " ,second_batch_size=params[\"second_batch_size\"]\n", 136 | " ,batch_size=batch_size\n", 137 | " ,transform=augs\n", 138 | " ,show_image=False\n", 139 | " ,num_workers=4)\n", 140 | " \n", 141 | " \n", 142 | " #in the final training funciotn we will put them too. \n", 143 | " segmentation_module,segmentation_ema=build_segmentation_model(\n", 144 | " in_arch=\"resnet50_dilated8\",out_arch=\"upernet\" ,droppout=params[\"droppout\"])\n", 145 | " optimizer = torch.optim.SGD(\n", 146 | " group_weight(segmentation_module),\n", 147 | " lr=0.01,\n", 148 | " momentum=0.9,\n", 149 | " weight_decay=params[\"decay\"])\n", 150 | " scheduler = ReduceLROnPlateau(optimizer, 'max')\n", 151 | " \n", 152 | " best_metric=0\n", 153 | " wait=0\n", 154 | " n_iter=0\n", 155 | "\n", 156 | " \n", 157 | " for j in range(max_epoch):\n", 158 | " #Trains for one epoch\n", 159 | " train(train_loader,segmentation_module,segmentation_ema,optimizer\n", 160 | " ,writer=writer\n", 161 | " ,lovasz_scaling=params[\"lovasz_scaling\"]\n", 162 | " ,focal_scaling=params[\"focal_scaling\"]\n", 163 | " ,unsupervised_scaling=params[\"unsupervised_scaling\"]\n", 164 | " ,ema_scaling=params[\"ema_scaling\"]\n", 165 | " ,non_ema_scaling=params[\"non_ema_scaling\"]\n", 166 | " ,train=True\n", 167 | " #,test=True\n", 168 | " ,writer_name_list=writer_name_list_train\n", 169 | " ,second_batch_size=params[\"second_batch_size\"])\n", 170 | " \n", 171 | " # Does the Evaluation.\n", 172 | " metric=train(valid_loader,segmentation_module,segmentation_ema,optimizer\n", 173 | " ,writer=writer\n", 174 | " ,lovasz_scaling=params[\"lovasz_scaling\"]\n", 175 | " ,focal_scaling=params[\"focal_scaling\"]\n", 176 | " ,unsupervised_scaling=params[\"unsupervised_scaling\"]\n", 177 | " ,ema_scaling=params[\"ema_scaling\"]\n", 178 | " ,non_ema_scaling=params[\"non_ema_scaling\"]\n", 179 | " ,train=False\n", 180 | " #,test=True\n", 181 | " ,writer_name_list=writer_name_list_eval\n", 182 | " ,second_batch_size=params[\"second_batch_size\"])\n", 183 | " scheduler.step(metric)\n", 184 | " \n", 185 | " #Save best metric and do simple early stopping\n", 186 | " if metric > best_metric:\n", 187 | " best_metric=metric\n", 188 | " wait=0\n", 189 | " else:\n", 190 | " wait =wait+1\n", 191 | " if wait > wait_epoch:\n", 192 | " break\n", 193 | " print(best_metric)\n", 194 | " return -best_metric" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "/home/ubuntu/Kaggle_Pytorch_TGS/plogs2/_focal_scaling_0.385182128209__second_batch_size_3__decay_5.801488e-06__unsupervised_scaling_0.218553589945__ema_scaling_0.296447573373__non_ema_scaling_0.412011104765__droppout_0.049905473819__lovasz_scaling_0.331877718809_\n" 207 | ] 208 | }, 209 | { 210 | "name": "stderr", 211 | "output_type": "stream", 212 | "text": [ 213 | "/home/ubuntu/Kaggle_Pytorch_TGS/losses.py:163: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", 214 | " logpt = F.log_softmax(input)\n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "res_gp = gp_minimize(objective, space, n_calls=50, random_state=123)\n" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Environment (conda_pytorch_p36)", 233 | "language": "python", 234 | "name": "conda_pytorch_p36" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.6.5" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-Supervised-Segmentation-Pytorch 2 | A work in progress repository for semi supervised image segmentation using Mean Teacher it includes the following features: 3 | 4 | - Easy to train on new Train and Test sets using the provided notebook. 5 | - Different Pre trained Networks. 6 | - Many different Losses. 7 | - TensorboardX integration. 8 | - Hyperparameter tuning. 9 | - Data Loader with Image Augmentation. 10 | 11 | TBD: 12 | A learned GAN loss. Test on new dataset. 13 | 14 | Sources: 15 | 16 | - I took the pre trained networks and encoder/decoder code from: 17 | - https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/master/LICENSE 18 | - Much of the Mean Teacher Code was taken from: 19 | - https://github.com/CuriousAI/mean-teacher 20 | -------------------------------------------------------------------------------- /data_functions.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from skimage.io import imshow, imread 4 | import torchvision 5 | from torchvision import transforms 6 | from torchvision.utils import make_grid 7 | import torch.nn as nn 8 | 9 | import matplotlib.pyplot as plt 10 | import matplotlib as mpl 11 | mpl.rcParams['axes.grid'] = False 12 | mpl.rcParams['image.interpolation'] = 'nearest' 13 | mpl.rcParams['figure.figsize'] = 15, 10 14 | 15 | def show(img): 16 | npimg = img.numpy() 17 | plt.figure() 18 | plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest') 19 | 20 | from imgaug import augmenters as iaa 21 | import imgaug as ia 22 | 23 | import os 24 | import random 25 | 26 | import torch 27 | import torch.utils.data as data 28 | 29 | from PIL import Image 30 | from torch.utils.data.sampler import SubsetRandomSampler 31 | import torch.backends.cudnn as cudnn 32 | 33 | from random import sample,seed 34 | 35 | 36 | 37 | ## DATA LOADer ## 38 | 39 | #Basicaly how the smei supervised works: you let the sampler to all the hard ID work. Also we need to check 40 | #When loading to treat the Unsupervied ones differently, also we have to chck the IDs. 41 | def show(img): 42 | npimg = img.numpy() 43 | plt.figure() 44 | plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest') 45 | 46 | 47 | class SegmentationDatasetImgaug(data.Dataset): 48 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 49 | 50 | @staticmethod 51 | def _isimage(image, ends): 52 | return any(image.endswith(end) for end in ends) 53 | 54 | @staticmethod 55 | def _load_input_image(path): 56 | return imread(path, as_gray=True) 57 | 58 | @staticmethod 59 | def _load_target_image(path): 60 | return imread(path, as_gray=True)[..., np.newaxis] 61 | 62 | def __init__(self, input_root, target_root,test_root=None, transform=None,normalize=True,image_size=101, input_only=None): 63 | self.input_root = input_root 64 | self.target_root = target_root 65 | self.transform = transform 66 | self.input_only = input_only 67 | self.test_root = test_root 68 | self.image_size = image_size 69 | self.norm=normalize 70 | 71 | 72 | #With the IDs basically we will use the "first set of ids as the target IDs and the later ones as the label ids." 73 | self.input_ids = sorted(img for img in os.listdir(self.input_root) 74 | if self._isimage(img, self.IMG_EXTENSIONS)) 75 | 76 | self.target_ids = sorted(img for img in os.listdir(self.target_root) 77 | if self._isimage(img, self.IMG_EXTENSIONS)) 78 | if test_root: 79 | self.test_id = sorted(img for img in os.listdir(self.test_root) 80 | if self._isimage(img, self.IMG_EXTENSIONS)) 81 | self.input_ids=self.input_ids+self.test_id 82 | 83 | def _activator_masks(self, images, augmenter, parents, default): 84 | if self.input_only and augmenter.name in self.input_only: 85 | return False 86 | else: 87 | return default 88 | 89 | def __getitem__(self, idx): 90 | 91 | transform= self.transform 92 | 93 | if idx < len(self.target_ids): 94 | target_img = self._load_target_image( 95 | os.path.join(self.target_root, self.target_ids[idx])) 96 | input_img = self._load_input_image( 97 | os.path.join(self.input_root, self.input_ids[idx])) 98 | else : 99 | input_img = self._load_input_image( 100 | os.path.join(self.test_root, self.input_ids[idx])) 101 | target_img= torch.zeros([1,101, 101], dtype=torch.float32)-1 102 | transform = None 103 | if idx < len(self.target_ids): 104 | target_img=target_img.astype(np.uint8) 105 | 106 | input_img=input_img.astype(np.uint8) 107 | 108 | #This is a combined transformation for both Image and Label 109 | if transform: 110 | det_tf = self.transform.to_deterministic() 111 | input_img = det_tf.augment_image(input_img) 112 | target_img = det_tf.augment_image( 113 | target_img, 114 | hooks=ia.HooksImages(activator=self._activator_masks)) 115 | 116 | 117 | #npad = ( (14, 13), (14, 13),(0, 0)) 118 | #input_img = np.pad(input_img, pad_width=npad, mode='constant', constant_values=0) 119 | 120 | 121 | to_tensor = transforms.ToTensor() 122 | 123 | if idx < len(self.target_ids): 124 | target_img = to_tensor(target_img) 125 | 126 | 127 | 128 | input_img = to_tensor(input_img) 129 | if self.norm == True: 130 | trans=transforms.Compose([ 131 | transforms.Normalize(mean=[102.9801/255, 115.9465/255, 122.7717/255], std=[1., 1., 1.]) 132 | ]) 133 | input_img=trans(input_img) 134 | 135 | output = dict() 136 | output['img_data'] = input_img 137 | output['seg_label'] = target_img 138 | return output 139 | 140 | def __len__(self): 141 | return len(self.input_ids) 142 | 143 | from torch.utils.data.sampler import Sampler 144 | import itertools 145 | 146 | class TwoStreamBatchSampler(Sampler): 147 | """Iterate two sets of indices 148 | An 'epoch' is one iteration through the primary indices. 149 | During the epoch, the secondary indices are iterated through 150 | as many times as needed. 151 | """ 152 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 153 | self.primary_indices = primary_indices 154 | self.secondary_indices = secondary_indices 155 | self.secondary_batch_size = secondary_batch_size 156 | self.primary_batch_size = batch_size - secondary_batch_size 157 | 158 | assert len(self.primary_indices) >= self.primary_batch_size > 0 159 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 160 | 161 | def __iter__(self): 162 | primary_iter = iterate_once(self.primary_indices) 163 | secondary_iter = iterate_eternally(self.secondary_indices) 164 | return ( 165 | primary_batch + secondary_batch 166 | for (primary_batch, secondary_batch) 167 | in zip(grouper(primary_iter, self.primary_batch_size), 168 | grouper(secondary_iter, self.secondary_batch_size)) 169 | ) 170 | 171 | def __len__(self): 172 | return len(self.primary_indices) // self.primary_batch_size 173 | 174 | 175 | def iterate_once(iterable): 176 | return np.random.permutation(iterable) 177 | 178 | 179 | def iterate_eternally(indices): 180 | def infinite_shuffles(): 181 | while True: 182 | yield np.random.permutation(indices) 183 | return itertools.chain.from_iterable(infinite_shuffles()) 184 | 185 | def grouper(iterable, n): 186 | "Collect data into fixed-length chunks or blocks" 187 | # grouper('ABCDEFG', 3) --> ABC DEF" 188 | args = [iter(iterable)] * n 189 | return zip(*args) 190 | 191 | 192 | def build_loader(input_img_folder='data/train/images/' 193 | ,label_folder='data/train/masks/' 194 | ,test_img_folder='data/test/images/' 195 | ,second_batch_size=2 196 | ,show_image=True 197 | ,batch_size=8 198 | ,num_workers=4 199 | ,transform=None): 200 | ''' 201 | We build the datasets with augmentation and ultimately return the loaders. 202 | ''' 203 | if transform == None: 204 | augs = iaa.Sequential([ 205 | #iaa.Scale((512, 512)), 206 | iaa.Fliplr(0.5), 207 | iaa.Affine(rotate=(-25, 25),mode="reflect", 208 | translate_percent={"x": (-0.01, 0.01), "y": (-0.01, 0.01)}), 209 | #iaa.Add((-40, 40), per_channel=0.5, name="color-jitter") 210 | ]) 211 | 212 | else: 213 | augs=transform 214 | 215 | #Get correct indices 216 | num_train = len(sorted(img for img in os.listdir(input_img_folder))) 217 | indices = list(range(num_train)) 218 | seed(128381) 219 | indices=sample(indices,len(indices)) 220 | split = int(np.floor(0.05 * num_train)) 221 | 222 | train_idx, valid_idx = indices[split:], indices[:split] 223 | num_test = len(sorted(img for img in os.listdir(test_img_folder))) 224 | test_idx=list(range(num_train,num_train+18000)) 225 | 226 | train_sampler = TwoStreamBatchSampler(primary_indices=train_idx,secondary_indices=test_idx,batch_size=batch_size,secondary_batch_size=second_batch_size) 227 | 228 | #Set up datasets 229 | train_dataset = SegmentationDatasetImgaug( 230 | 'data/train/images/', 'data/train/masks//', 231 | transform=augs, 232 | test_root='data/test/images/', 233 | #input_only=['color-jitter'] 234 | ) 235 | 236 | 237 | 238 | valid_dataset = SegmentationDatasetImgaug( 239 | 'data/train/images/', 'data/train/masks//', 240 | #input_only=['color-jitter'] 241 | ) 242 | if show_image==True: 243 | train_dataset_show = SegmentationDatasetImgaug( 244 | 'data/train/images/', 'data/train/masks//', 245 | transform=augs, 246 | test_root='data/test/images/', 247 | normalize=False, 248 | #input_only=['color-jitter'] 249 | ) 250 | 251 | 252 | imgs = [train_dataset_show[i] for i in range(6)] 253 | 254 | show(torchvision.utils.make_grid(torch.stack([img["img_data"] for img in imgs]))) 255 | show(torchvision.utils.make_grid(torch.stack([img["seg_label"] for img in imgs]))) 256 | 257 | valid_sampler = SubsetRandomSampler(valid_idx) 258 | 259 | train_loader = torch.utils.data.DataLoader( 260 | train_dataset, batch_sampler=train_sampler, 261 | num_workers=num_workers, pin_memory=True 262 | ) 263 | 264 | valid_loader = torch.utils.data.DataLoader( 265 | valid_dataset, batch_size=batch_size, sampler=valid_sampler, 266 | num_workers=num_workers, pin_memory=True 267 | ) 268 | plt.show() 269 | return train_loader,valid_loader 270 | 271 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | #Here we want to have all the different Losses 2 | from __future__ import print_function, division 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import numpy as np 8 | try: 9 | from itertools import ifilterfalse 10 | except ImportError: # py3k 11 | from itertools import filterfalse 12 | 13 | import torch.nn as nn 14 | 15 | 16 | 17 | def lovasz_grad(gt_sorted): 18 | """ 19 | Computes gradient of the Lovasz extension w.r.t sorted errors 20 | See Alg. 1 in paper 21 | """ 22 | p = len(gt_sorted) 23 | gts = gt_sorted.sum() 24 | intersection = gts - gt_sorted.float().cumsum(0) 25 | union = gts + (1 - gt_sorted).float().cumsum(0) 26 | jaccard = 1. - intersection / union 27 | if p > 1: # cover 1-pixel case 28 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 29 | return jaccard 30 | 31 | 32 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 33 | """ 34 | IoU for foreground class 35 | binary: 1 foreground, 0 background 36 | """ 37 | if not per_image: 38 | preds, labels = (preds,), (labels,) 39 | ious = [] 40 | for pred, label in zip(preds, labels): 41 | intersection = ((label == 1) & (pred == 1)).sum() 42 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 43 | if not union: 44 | iou = EMPTY 45 | else: 46 | iou = float(intersection) / union 47 | ious.append(iou) 48 | iou = mean(ious) # mean accross images if per_image 49 | return 100 * iou 50 | 51 | 52 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 53 | """ 54 | Array of IoU for each (non ignored) class 55 | """ 56 | if not per_image: 57 | preds, labels = (preds,), (labels,) 58 | ious = [] 59 | for pred, label in zip(preds, labels): 60 | iou = [] 61 | for i in range(C): 62 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 63 | intersection = ((label == i) & (pred == i)).sum() 64 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 65 | if not union: 66 | iou.append(EMPTY) 67 | else: 68 | iou.append(float(intersection) / union) 69 | ious.append(iou) 70 | ious = map(mean, zip(*ious)) # mean accross images if per_image 71 | return 100 * np.array(ious) 72 | 73 | 74 | def mean(l, ignore_nan=False, empty=0): 75 | """ 76 | nanmean compatible with generators. 77 | """ 78 | l = iter(l) 79 | if ignore_nan: 80 | l = ifilterfalse(np.isnan, l) 81 | try: 82 | n = 1 83 | acc = next(l) 84 | except StopIteration: 85 | if empty == 'raise': 86 | raise ValueError('Empty mean') 87 | return empty 88 | for n, v in enumerate(l, 2): 89 | acc += v 90 | if n == 1: 91 | return acc 92 | return acc / n 93 | 94 | def lovasz_softmax_flat(probas, labels, only_present=False): 95 | """ 96 | Multi-class Lovasz-Softmax loss 97 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 98 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 99 | only_present: average only on classes present in ground truth 100 | """ 101 | C = probas.size(1) 102 | losses = [] 103 | for c in range(C): 104 | fg = (labels == c).float() # foreground for class c 105 | if only_present and fg.sum() == 0: 106 | continue 107 | errors = (Variable(fg) - probas[:, c]).abs() 108 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 109 | perm = perm.data 110 | fg_sorted = fg[perm] 111 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 112 | return mean(losses) 113 | 114 | def flatten_probas(probas, labels, ignore=None): 115 | """ 116 | Flattens predictions in the batch 117 | """ 118 | B, C, H, W = probas.size() 119 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 120 | labels = labels.view(-1) 121 | if ignore is None: 122 | return probas, labels 123 | valid = (labels != ignore) 124 | vprobas = probas[valid.nonzero().squeeze()] 125 | vlabels = labels[valid] 126 | return vprobas, vlabels 127 | 128 | def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None): 129 | """ 130 | Multi-class Lovasz-Softmax loss 131 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1) 132 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 133 | only_present: average only on classes present in ground truth 134 | per_image: compute the loss per image instead of per batch 135 | ignore: void class labels 136 | """ 137 | #Casue we want to retunr logits only 138 | probas = nn.functional.softmax(probas, dim=1) 139 | 140 | if per_image: 141 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present) 142 | for prob, lab in zip(probas, labels)) 143 | else: 144 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present) 145 | return loss 146 | 147 | class FocalLoss(nn.Module): 148 | def __init__(self, gamma=0, alpha=None, size_average=True): 149 | super(FocalLoss, self).__init__() 150 | self.gamma = gamma 151 | self.alpha = alpha 152 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 153 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 154 | self.size_average = size_average 155 | 156 | def forward(self, input, target): 157 | if input.dim()>2: 158 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 159 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 160 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 161 | target = target.view(-1,1) 162 | 163 | logpt = F.log_softmax(input) 164 | logpt = logpt.gather(1,target) 165 | logpt = logpt.view(-1) 166 | pt = Variable(logpt.data.exp()) 167 | 168 | if self.alpha is not None: 169 | if self.alpha.type()!=input.data.type(): 170 | self.alpha = self.alpha.type_as(input.data) 171 | at = self.alpha.gather(0,target.data.view(-1)) 172 | logpt = logpt * Variable(at) 173 | 174 | loss = -1 * (1-pt)**self.gamma * logpt 175 | if self.size_average: return loss.mean() 176 | else: return loss.sum() -------------------------------------------------------------------------------- /model_functions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | # the mock-0.3.1 dir contains testcase.py, testutils.py & mock.py 3 | sys.path.append('/home/ubuntu/Kaggle_Pytorch_TGS/semantic-segmentation-pytorch') 4 | from models import * 5 | from torchvision.transforms import Normalize 6 | 7 | def build_segmentation_model(in_arch="resnet101_dilated8",out_arch="upernet",droppout=0.1): 8 | 9 | ''' 10 | So we allow 3 versions of the model for now: A Resnet 50 101 with Upernet Decoder pre trained on Imagenet 11 | also a small model 12 | 13 | ''' 14 | #First we build two verison of the model for MEan Teacher or not. 15 | builder = ModelBuilder() 16 | #Define Encoder 17 | net_encoder = builder.build_encoder( 18 | arch=in_arch, 19 | #fc_dim=2048 20 | ) 21 | 22 | #Define Decoder 23 | net_decoder = builder.build_decoder( 24 | arch=out_arch, 25 | fc_dim=2048, 26 | #weights here lets us load our own weights neat 27 | num_class=2) 28 | 29 | net_encoder_ema = builder.build_encoder( 30 | arch=in_arch, 31 | #fc_dim=2048 32 | ) 33 | 34 | #Define Decoder 35 | net_decoder_ema = builder.build_decoder( 36 | arch=out_arch, 37 | fc_dim=2048, 38 | #weights here lets us load our own weights neat 39 | num_class=2) 40 | 41 | class SegmentationModule(SegmentationModuleBase): 42 | def __init__(self, net_enc, net_dec,drop=0,size=101): 43 | super(SegmentationModule, self).__init__() 44 | self.encoder = net_enc 45 | self.decoder = net_dec 46 | self.drop=drop 47 | self.size=size 48 | 49 | def forward(self, feed_dict, *, segSize=None): 50 | inpu=feed_dict['img_data'] 51 | 52 | encode= self.encoder(inpu, return_feature_maps=True) 53 | 54 | if self.drop>0: 55 | encode[0]=nn.Dropout(self.drop)(encode[0]) 56 | encode[1]=nn.Dropout(self.drop)(encode[1]) 57 | encode[2]=nn.Dropout(self.drop)(encode[2]) 58 | encode[3]=nn.Dropout(self.drop)(encode[3]) 59 | 60 | pred = self.decoder(encode) 61 | pred = nn.functional.upsample(pred, size=self.size, mode='bilinear', align_corners=True) 62 | #Lovasz Softmax needs Sonftmax inputs. 63 | #pred = nn.functional.softmax(pred, dim=1) 64 | 65 | return pred 66 | 67 | segmentation_ema=SegmentationModule( 68 | net_encoder_ema, net_decoder_ema,drop=droppout) 69 | 70 | 71 | 72 | segmentation_ema=segmentation_ema.cuda() 73 | 74 | #Set up the complete model 75 | segmentation_module = SegmentationModule( 76 | net_encoder, net_decoder,drop=droppout) 77 | segmentation_module=segmentation_module.cuda() 78 | 79 | for param,param2 in zip(segmentation_ema.parameters(),segmentation_module.parameters()): 80 | param.data=param2.data 81 | 82 | for param in segmentation_ema.parameters(): 83 | param.detach_() 84 | 85 | return segmentation_module,segmentation_ema 86 | 87 | 88 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import resnet, resnext 5 | from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d 6 | #from lib.nn import SynchronizedBatchNorm2d 7 | 8 | 9 | class SegmentationModuleBase(nn.Module): 10 | def __init__(self): 11 | super(SegmentationModuleBase, self).__init__() 12 | 13 | def pixel_acc(self, pred, label): 14 | _, preds = torch.max(pred, dim=1) 15 | valid = (label >= 0).long() 16 | acc_sum = torch.sum(valid * (preds == label).long()) 17 | pixel_sum = torch.sum(valid) 18 | acc = acc_sum.float() / (pixel_sum.float() + 1e-10) 19 | return acc 20 | 21 | 22 | class SegmentationModule(SegmentationModuleBase): 23 | def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): 24 | super(SegmentationModule, self).__init__() 25 | self.encoder = net_enc 26 | self.decoder = net_dec 27 | self.crit = crit 28 | self.deep_sup_scale = deep_sup_scale 29 | 30 | def forward(self, feed_dict, *, segSize=None): 31 | if segSize is None: # training 32 | if self.deep_sup_scale is not None: # use deep supervision technique 33 | (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) 34 | else: 35 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) 36 | 37 | loss = self.crit(pred, feed_dict['seg_label']) 38 | if self.deep_sup_scale is not None: 39 | loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) 40 | loss = loss + loss_deepsup * self.deep_sup_scale 41 | 42 | acc = self.pixel_acc(pred, feed_dict['seg_label']) 43 | return loss, acc 44 | else: # inference 45 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) 46 | return pred 47 | 48 | 49 | def conv3x3(in_planes, out_planes, stride=1, has_bias=False): 50 | "3x3 convolution with padding" 51 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 52 | padding=1, bias=has_bias) 53 | 54 | 55 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 56 | return nn.Sequential( 57 | conv3x3(in_planes, out_planes, stride), 58 | SynchronizedBatchNorm2d(out_planes), 59 | nn.ReLU(inplace=True), 60 | ) 61 | 62 | 63 | class ModelBuilder(): 64 | # custom weights initialization 65 | def weights_init(self, m): 66 | classname = m.__class__.__name__ 67 | if classname.find('Conv') != -1: 68 | nn.init.kaiming_normal_(m.weight.data) 69 | elif classname.find('BatchNorm') != -1: 70 | m.weight.data.fill_(1.) 71 | m.bias.data.fill_(1e-4) 72 | #elif classname.find('Linear') != -1: 73 | # m.weight.data.normal_(0.0, 0.0001) 74 | 75 | def build_encoder(self, arch='resnet50_dilated8', fc_dim=512, weights=''): 76 | pretrained = True if len(weights) == 0 else False 77 | if arch == 'resnet18': 78 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) 79 | net_encoder = Resnet(orig_resnet) 80 | elif arch == 'resnet18_dilated8': 81 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) 82 | net_encoder = ResnetDilated(orig_resnet, 83 | dilate_scale=8) 84 | elif arch == 'resnet18_dilated16': 85 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) 86 | net_encoder = ResnetDilated(orig_resnet, 87 | dilate_scale=16) 88 | elif arch == 'resnet34': 89 | raise NotImplementedError 90 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) 91 | net_encoder = Resnet(orig_resnet) 92 | elif arch == 'resnet34_dilated8': 93 | raise NotImplementedError 94 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) 95 | net_encoder = ResnetDilated(orig_resnet, 96 | dilate_scale=8) 97 | elif arch == 'resnet34_dilated16': 98 | raise NotImplementedError 99 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) 100 | net_encoder = ResnetDilated(orig_resnet, 101 | dilate_scale=16) 102 | elif arch == 'resnet50': 103 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) 104 | net_encoder = Resnet(orig_resnet) 105 | elif arch == 'resnet50_dilated8': 106 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) 107 | net_encoder = ResnetDilated(orig_resnet, 108 | dilate_scale=8) 109 | elif arch == 'resnet50_dilated16': 110 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) 111 | net_encoder = ResnetDilated(orig_resnet, 112 | dilate_scale=16) 113 | elif arch == 'resnet101': 114 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) 115 | net_encoder = Resnet(orig_resnet) 116 | elif arch == 'resnet101_dilated8': 117 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) 118 | net_encoder = ResnetDilated(orig_resnet, 119 | dilate_scale=8) 120 | elif arch == 'resnet101_dilated16': 121 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) 122 | net_encoder = ResnetDilated(orig_resnet, 123 | dilate_scale=16) 124 | elif arch == 'resnext101': 125 | orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained) 126 | net_encoder = Resnet(orig_resnext) # we can still use class Resnet 127 | else: 128 | raise Exception('Architecture undefined!') 129 | 130 | # net_encoder.apply(self.weights_init) 131 | if len(weights) > 0: 132 | print('Loading weights for net_encoder') 133 | net_encoder.load_state_dict( 134 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False) 135 | return net_encoder 136 | 137 | def build_decoder(self, arch='ppm_bilinear_deepsup', 138 | fc_dim=512, num_class=150, 139 | weights='', use_softmax=False): 140 | if arch == 'c1_bilinear_deepsup': 141 | net_decoder = C1BilinearDeepSup( 142 | num_class=num_class, 143 | fc_dim=fc_dim, 144 | use_softmax=use_softmax) 145 | elif arch == 'c1_bilinear': 146 | net_decoder = C1Bilinear( 147 | num_class=num_class, 148 | fc_dim=fc_dim, 149 | use_softmax=use_softmax) 150 | elif arch == 'ppm_bilinear': 151 | net_decoder = PPMBilinear( 152 | num_class=num_class, 153 | fc_dim=fc_dim, 154 | use_softmax=use_softmax) 155 | elif arch == 'ppm_bilinear_deepsup': 156 | net_decoder = PPMBilinearDeepsup( 157 | num_class=num_class, 158 | fc_dim=fc_dim, 159 | use_softmax=use_softmax) 160 | elif arch == 'upernet_lite': 161 | net_decoder = UPerNet( 162 | num_class=num_class, 163 | fc_dim=fc_dim, 164 | use_softmax=use_softmax, 165 | fpn_dim=256) 166 | elif arch == 'upernet': 167 | net_decoder = UPerNet( 168 | num_class=num_class, 169 | fc_dim=fc_dim, 170 | use_softmax=use_softmax, 171 | fpn_dim=512) 172 | elif arch == 'upernet_tmp': 173 | net_decoder = UPerNetTmp( 174 | num_class=num_class, 175 | fc_dim=fc_dim, 176 | use_softmax=use_softmax, 177 | fpn_dim=512) 178 | else: 179 | raise Exception('Architecture undefined!') 180 | 181 | net_decoder.apply(self.weights_init) 182 | if len(weights) > 0: 183 | print('Loading weights for net_decoder') 184 | net_decoder.load_state_dict( 185 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False) 186 | return net_decoder 187 | 188 | 189 | class Resnet(nn.Module): 190 | def __init__(self, orig_resnet): 191 | super(Resnet, self).__init__() 192 | 193 | # take pretrained resnet, except AvgPool and FC 194 | self.conv1 = orig_resnet.conv1 195 | self.bn1 = orig_resnet.bn1 196 | self.relu1 = orig_resnet.relu1 197 | self.conv2 = orig_resnet.conv2 198 | self.bn2 = orig_resnet.bn2 199 | self.relu2 = orig_resnet.relu2 200 | self.conv3 = orig_resnet.conv3 201 | self.bn3 = orig_resnet.bn3 202 | self.relu3 = orig_resnet.relu3 203 | self.maxpool = orig_resnet.maxpool 204 | self.layer1 = orig_resnet.layer1 205 | self.layer2 = orig_resnet.layer2 206 | self.layer3 = orig_resnet.layer3 207 | self.layer4 = orig_resnet.layer4 208 | 209 | def forward(self, x, return_feature_maps=False): 210 | conv_out = [] 211 | 212 | x = self.relu1(self.bn1(self.conv1(x))) 213 | x = self.relu2(self.bn2(self.conv2(x))) 214 | x = self.relu3(self.bn3(self.conv3(x))) 215 | x = self.maxpool(x) 216 | 217 | x = self.layer1(x); conv_out.append(x); 218 | x = self.layer2(x); conv_out.append(x); 219 | x = self.layer3(x); conv_out.append(x); 220 | x = self.layer4(x); conv_out.append(x); 221 | 222 | if return_feature_maps: 223 | return conv_out 224 | return [x] 225 | 226 | 227 | class ResnetDilated(nn.Module): 228 | def __init__(self, orig_resnet, dilate_scale=8): 229 | super(ResnetDilated, self).__init__() 230 | from functools import partial 231 | 232 | if dilate_scale == 8: 233 | orig_resnet.conv1.apply( 234 | partial(self._nostride_dilate, dilate=2)) 235 | #orig_resnet.conv2.apply( 236 | # partial(self._nostride_dilate, dilate=2)) 237 | #orig_resnet.layer3.apply( 238 | # partial(self._nostride_dilate, dilate=2)) 239 | #orig_resnet.layer4.apply( 240 | # partial(self._nostride_dilate, dilate=4)) 241 | elif dilate_scale == 16: 242 | orig_resnet.layer4.apply( 243 | partial(self._nostride_dilate, dilate=2)) 244 | 245 | # take pretrained resnet, except AvgPool and FC 246 | self.conv1 = orig_resnet.conv1 247 | self.bn1 = orig_resnet.bn1 248 | self.relu1 = orig_resnet.relu1 249 | self.conv2 = orig_resnet.conv2 250 | self.bn2 = orig_resnet.bn2 251 | self.relu2 = orig_resnet.relu2 252 | self.conv3 = orig_resnet.conv3 253 | self.bn3 = orig_resnet.bn3 254 | self.relu3 = orig_resnet.relu3 255 | self.maxpool = orig_resnet.maxpool 256 | self.layer1 = orig_resnet.layer1 257 | self.layer2 = orig_resnet.layer2 258 | self.layer3 = orig_resnet.layer3 259 | self.layer4 = orig_resnet.layer4 260 | 261 | def _nostride_dilate(self, m, dilate): 262 | classname = m.__class__.__name__ 263 | if classname.find('Conv') != -1: 264 | # the convolution with stride 265 | if m.stride == (2, 2): 266 | m.stride = (1, 1) 267 | if m.kernel_size == (3, 3): 268 | m.dilation = (dilate//2, dilate//2) 269 | m.padding = (dilate//2, dilate//2) 270 | # other convoluions 271 | else: 272 | if m.kernel_size == (3, 3): 273 | m.dilation = (dilate, dilate) 274 | m.padding = (dilate, dilate) 275 | 276 | def forward(self, x, return_feature_maps=False): 277 | conv_out = [] 278 | 279 | x = self.relu1(self.bn1(self.conv1(x))) 280 | x = self.relu2(self.bn2(self.conv2(x))) 281 | x = self.relu3(self.bn3(self.conv3(x))) 282 | x = self.maxpool(x) 283 | 284 | x = self.layer1(x); conv_out.append(x); 285 | x = self.layer2(x); conv_out.append(x); 286 | x = self.layer3(x); conv_out.append(x); 287 | x = self.layer4(x); conv_out.append(x); 288 | 289 | if return_feature_maps: 290 | return conv_out 291 | return [x] 292 | 293 | 294 | # last conv, bilinear upsample 295 | class C1BilinearDeepSup(nn.Module): 296 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): 297 | super(C1BilinearDeepSup, self).__init__() 298 | self.use_softmax = use_softmax 299 | 300 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) 301 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) 302 | 303 | # last conv 304 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 305 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 306 | 307 | def forward(self, conv_out, segSize=None): 308 | conv5 = conv_out[-1] 309 | 310 | x = self.cbr(conv5) 311 | x = self.conv_last(x) 312 | 313 | if self.use_softmax: # is True during inference 314 | x = nn.functional.upsample( 315 | x, size=segSize, mode='bilinear', align_corners=False) 316 | x = nn.functional.softmax(x, dim=1) 317 | return x 318 | 319 | # deep sup 320 | conv4 = conv_out[-2] 321 | _ = self.cbr_deepsup(conv4) 322 | _ = self.conv_last_deepsup(_) 323 | 324 | x = nn.functional.log_softmax(x, dim=1) 325 | _ = nn.functional.log_softmax(_, dim=1) 326 | 327 | return (x, _) 328 | 329 | 330 | # last conv, bilinear upsample 331 | class C1Bilinear(nn.Module): 332 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): 333 | super(C1Bilinear, self).__init__() 334 | self.use_softmax = use_softmax 335 | 336 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) 337 | 338 | # last conv 339 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 340 | 341 | def forward(self, conv_out, segSize=None): 342 | conv5 = conv_out[-1] 343 | x = self.cbr(conv5) 344 | x = self.conv_last(x) 345 | 346 | if self.use_softmax: # is True during inference 347 | x = nn.functional.upsample( 348 | x, size=segSize, mode='bilinear', align_corners=False) 349 | x = nn.functional.softmax(x, dim=1) 350 | else: 351 | x = nn.functional.log_softmax(x, dim=1) 352 | 353 | return x 354 | 355 | 356 | # pyramid pooling, bilinear upsample 357 | class PPMBilinear(nn.Module): 358 | def __init__(self, num_class=150, fc_dim=4096, 359 | use_softmax=False, pool_scales=(1, 2, 3, 6)): 360 | super(PPMBilinear, self).__init__() 361 | self.use_softmax = use_softmax 362 | 363 | self.ppm = [] 364 | for scale in pool_scales: 365 | self.ppm.append(nn.Sequential( 366 | nn.AdaptiveAvgPool2d(scale), 367 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 368 | SynchronizedBatchNorm2d(512), 369 | nn.ReLU(inplace=True) 370 | )) 371 | self.ppm = nn.ModuleList(self.ppm) 372 | 373 | self.conv_last = nn.Sequential( 374 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512, 375 | kernel_size=3, padding=1, bias=False), 376 | SynchronizedBatchNorm2d(512), 377 | nn.ReLU(inplace=True), 378 | nn.Dropout2d(0.1), 379 | nn.Conv2d(512, num_class, kernel_size=1) 380 | ) 381 | 382 | def forward(self, conv_out, segSize=None): 383 | conv5 = conv_out[-1] 384 | 385 | input_size = conv5.size() 386 | ppm_out = [conv5] 387 | for pool_scale in self.ppm: 388 | ppm_out.append(nn.functional.upsample( 389 | pool_scale(conv5), 390 | (input_size[2], input_size[3]), 391 | mode='bilinear', align_corners=False)) 392 | ppm_out = torch.cat(ppm_out, 1) 393 | 394 | x = self.conv_last(ppm_out) 395 | 396 | if self.use_softmax: # is True during inference 397 | x = nn.functional.upsample( 398 | x, size=segSize, mode='bilinear', align_corners=False) 399 | x = nn.functional.softmax(x, dim=1) 400 | else: 401 | x = nn.functional.log_softmax(x, dim=1) 402 | return x 403 | 404 | 405 | # pyramid pooling, bilinear upsample 406 | class PPMBilinearDeepsup(nn.Module): 407 | def __init__(self, num_class=150, fc_dim=4096, 408 | use_softmax=False, pool_scales=(1, 2, 3, 6)): 409 | super(PPMBilinearDeepsup, self).__init__() 410 | self.use_softmax = use_softmax 411 | 412 | self.ppm = [] 413 | for scale in pool_scales: 414 | self.ppm.append(nn.Sequential( 415 | nn.AdaptiveAvgPool2d(scale), 416 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 417 | SynchronizedBatchNorm2d(512), 418 | nn.ReLU(inplace=True) 419 | )) 420 | self.ppm = nn.ModuleList(self.ppm) 421 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) 422 | 423 | self.conv_last = nn.Sequential( 424 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512, 425 | kernel_size=3, padding=1, bias=False), 426 | SynchronizedBatchNorm2d(512), 427 | nn.ReLU(inplace=True), 428 | nn.Dropout2d(0.1), 429 | nn.Conv2d(512, num_class, kernel_size=1) 430 | ) 431 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 432 | self.dropout_deepsup = nn.Dropout2d(0.1) 433 | 434 | def forward(self, conv_out, segSize=None): 435 | conv5 = conv_out[-1] 436 | 437 | input_size = conv5.size() 438 | ppm_out = [conv5] 439 | for pool_scale in self.ppm: 440 | ppm_out.append(nn.functional.upsample( 441 | pool_scale(conv5), 442 | (input_size[2], input_size[3]), 443 | mode='bilinear', align_corners=False)) 444 | ppm_out = torch.cat(ppm_out, 1) 445 | 446 | x = self.conv_last(ppm_out) 447 | 448 | if self.use_softmax: # is True during inference 449 | x = nn.functional.upsample( 450 | x, size=segSize, mode='bilinear', align_corners=False) 451 | x = nn.functional.softmax(x, dim=1) 452 | return x 453 | 454 | # deep sup 455 | conv4 = conv_out[-2] 456 | _ = self.cbr_deepsup(conv4) 457 | _ = self.dropout_deepsup(_) 458 | _ = self.conv_last_deepsup(_) 459 | 460 | x = nn.functional.log_softmax(x, dim=1) 461 | _ = nn.functional.log_softmax(_, dim=1) 462 | 463 | return (x, _) 464 | 465 | 466 | # upernet 467 | class UPerNet(nn.Module): 468 | def __init__(self, num_class=150, fc_dim=4096, 469 | use_softmax=False, pool_scales=(1, 2, 3, 6), 470 | fpn_inplanes=(256,512,1024,2048), fpn_dim=256): 471 | super(UPerNet, self).__init__() 472 | self.use_softmax = use_softmax 473 | 474 | # PPM Module 475 | self.ppm_pooling = [] 476 | self.ppm_conv = [] 477 | 478 | for scale in pool_scales: 479 | self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) 480 | self.ppm_conv.append(nn.Sequential( 481 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 482 | SynchronizedBatchNorm2d(512), 483 | nn.ReLU(inplace=True) 484 | )) 485 | self.ppm_pooling = nn.ModuleList(self.ppm_pooling) 486 | self.ppm_conv = nn.ModuleList(self.ppm_conv) 487 | self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) 488 | 489 | # FPN Module 490 | self.fpn_in = [] 491 | for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer 492 | self.fpn_in.append(nn.Sequential( 493 | nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), 494 | SynchronizedBatchNorm2d(fpn_dim), 495 | nn.ReLU(inplace=True) 496 | )) 497 | self.fpn_in = nn.ModuleList(self.fpn_in) 498 | 499 | self.fpn_out = [] 500 | for i in range(len(fpn_inplanes) - 1): # skip the top layer 501 | self.fpn_out.append(nn.Sequential( 502 | conv3x3_bn_relu(fpn_dim, fpn_dim, 1), 503 | )) 504 | self.fpn_out = nn.ModuleList(self.fpn_out) 505 | 506 | self.conv_last = nn.Sequential( 507 | conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), 508 | nn.Conv2d(fpn_dim, num_class, kernel_size=1) 509 | ) 510 | 511 | def forward(self, conv_out, segSize=None): 512 | conv5 = conv_out[-1] 513 | 514 | input_size = conv5.size() 515 | ppm_out = [conv5] 516 | for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): 517 | ppm_out.append(pool_conv(nn.functional.upsample( 518 | pool_scale(conv5), 519 | (input_size[2], input_size[3]), 520 | mode='bilinear', align_corners=False))) 521 | ppm_out = torch.cat(ppm_out, 1) 522 | f = self.ppm_last_conv(ppm_out) 523 | 524 | fpn_feature_list = [f] 525 | for i in reversed(range(len(conv_out) - 1)): 526 | conv_x = conv_out[i] 527 | conv_x = self.fpn_in[i](conv_x) # lateral branch 528 | 529 | f = nn.functional.upsample( 530 | f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch 531 | f = conv_x + f 532 | 533 | fpn_feature_list.append(self.fpn_out[i](f)) 534 | 535 | fpn_feature_list.reverse() # [P2 - P5] 536 | output_size = fpn_feature_list[0].size()[2:] 537 | fusion_list = [fpn_feature_list[0]] 538 | for i in range(1, len(fpn_feature_list)): 539 | fusion_list.append(nn.functional.upsample( 540 | fpn_feature_list[i], 541 | output_size, 542 | mode='bilinear', align_corners=False)) 543 | fusion_out = torch.cat(fusion_list, 1) 544 | x = self.conv_last(fusion_out) 545 | 546 | if self.use_softmax: # is True during inference 547 | x = nn.functional.upsample( 548 | x, size=segSize, mode='bilinear', align_corners=False) 549 | x = nn.functional.softmax(x, dim=1) 550 | return x 551 | 552 | x = nn.functional.log_softmax(x, dim=1) 553 | 554 | return x 555 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d 7 | try: 8 | from urllib import urlretrieve 9 | except ImportError: 10 | from urllib.request import urlretrieve 11 | 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 14 | 15 | 16 | model_urls = { 17 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 18 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 19 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | "3x3 convolution with padding" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = SynchronizedBatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = SynchronizedBatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = SynchronizedBatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = SynchronizedBatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = SynchronizedBatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000): 103 | self.inplanes = 128 104 | super(ResNet, self).__init__() 105 | self.conv1 = conv3x3(3, 64, stride=2) 106 | self.bn1 = SynchronizedBatchNorm2d(64) 107 | self.relu1 = nn.ReLU(inplace=True) 108 | self.conv2 = conv3x3(64, 64) 109 | self.bn2 = SynchronizedBatchNorm2d(64) 110 | self.relu2 = nn.ReLU(inplace=True) 111 | self.conv3 = conv3x3(64, 128) 112 | self.bn3 = SynchronizedBatchNorm2d(128) 113 | self.relu3 = nn.ReLU(inplace=True) 114 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 115 | 116 | self.layer1 = self._make_layer(block, 64, layers[0]) 117 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 118 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 119 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 120 | self.avgpool = nn.AvgPool2d(7, stride=1) 121 | self.fc = nn.Linear(512 * block.expansion, num_classes) 122 | 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | elif isinstance(m, SynchronizedBatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | 131 | def _make_layer(self, block, planes, blocks, stride=1): 132 | downsample = None 133 | if stride != 1 or self.inplanes != planes * block.expansion: 134 | downsample = nn.Sequential( 135 | nn.Conv2d(self.inplanes, planes * block.expansion, 136 | kernel_size=1, stride=stride, bias=False), 137 | SynchronizedBatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = [] 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | x = self.relu1(self.bn1(self.conv1(x))) 150 | x = self.relu2(self.bn2(self.conv2(x))) 151 | x = self.relu3(self.bn3(self.conv3(x))) 152 | x = self.maxpool(x) 153 | 154 | x = self.layer1(x) 155 | x = self.layer2(x) 156 | x = self.layer3(x) 157 | x = self.layer4(x) 158 | 159 | x = self.avgpool(x) 160 | x = x.view(x.size(0), -1) 161 | x = self.fc(x) 162 | 163 | return x 164 | 165 | def resnet18(pretrained=False, **kwargs): 166 | """Constructs a ResNet-18 model. 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on Places 170 | """ 171 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 172 | if pretrained: 173 | model.load_state_dict(load_url(model_urls['resnet18'])) 174 | return model 175 | 176 | ''' 177 | def resnet34(pretrained=False, **kwargs): 178 | """Constructs a ResNet-34 model. 179 | 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on Places 182 | """ 183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | model.load_state_dict(load_url(model_urls['resnet34'])) 186 | return model 187 | ''' 188 | 189 | def resnet50(pretrained=False, **kwargs): 190 | """Constructs a ResNet-50 model. 191 | 192 | Args: 193 | pretrained (bool): If True, returns a model pre-trained on Places 194 | """ 195 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 196 | if pretrained: 197 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 198 | return model 199 | 200 | 201 | def resnet101(pretrained=False, **kwargs): 202 | """Constructs a ResNet-101 model. 203 | 204 | Args: 205 | pretrained (bool): If True, returns a model pre-trained on Places 206 | """ 207 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 208 | if pretrained: 209 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 210 | return model 211 | 212 | # def resnet152(pretrained=False, **kwargs): 213 | # """Constructs a ResNet-152 model. 214 | # 215 | # Args: 216 | # pretrained (bool): If True, returns a model pre-trained on Places 217 | # """ 218 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 219 | # if pretrained: 220 | # model.load_state_dict(load_url(model_urls['resnet152'])) 221 | # return model 222 | 223 | def load_url(url, model_dir='./pretrained', map_location=None): 224 | if not os.path.exists(model_dir): 225 | os.makedirs(model_dir) 226 | filename = url.split('/')[-1] 227 | cached_file = os.path.join(model_dir, filename) 228 | if not os.path.exists(cached_file): 229 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 230 | urlretrieve(url, cached_file) 231 | return torch.load(cached_file, map_location=map_location) 232 | -------------------------------------------------------------------------------- /resnext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from torch.nn import BatchNorm2d as SynchronizedBatchNorm2d 7 | try: 8 | from urllib import urlretrieve 9 | except ImportError: 10 | from urllib.request import urlretrieve 11 | 12 | 13 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 14 | 15 | 16 | model_urls = { 17 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 18 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | "3x3 convolution with padding" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class GroupBottleneck(nn.Module): 29 | expansion = 2 30 | 31 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 32 | super(GroupBottleneck, self).__init__() 33 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 34 | self.bn1 = SynchronizedBatchNorm2d(planes) 35 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 36 | padding=1, groups=groups, bias=False) 37 | self.bn2 = SynchronizedBatchNorm2d(planes) 38 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 39 | self.bn3 = SynchronizedBatchNorm2d(planes * 2) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv3(out) 56 | out = self.bn3(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class ResNeXt(nn.Module): 68 | 69 | def __init__(self, block, layers, groups=32, num_classes=1000): 70 | self.inplanes = 128 71 | super(ResNeXt, self).__init__() 72 | self.conv1 = conv3x3(3, 64, stride=2) 73 | self.bn1 = SynchronizedBatchNorm2d(64) 74 | self.relu1 = nn.ReLU(inplace=True) 75 | self.conv2 = conv3x3(64, 64) 76 | self.bn2 = SynchronizedBatchNorm2d(64) 77 | self.relu2 = nn.ReLU(inplace=True) 78 | self.conv3 = conv3x3(64, 128) 79 | self.bn3 = SynchronizedBatchNorm2d(128) 80 | self.relu3 = nn.ReLU(inplace=True) 81 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 82 | 83 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 84 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 85 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 86 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 87 | self.avgpool = nn.AvgPool2d(7, stride=1) 88 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 93 | m.weight.data.normal_(0, math.sqrt(2. / n)) 94 | elif isinstance(m, SynchronizedBatchNorm2d): 95 | m.weight.data.fill_(1) 96 | m.bias.data.zero_() 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | SynchronizedBatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes, groups=groups)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | x = self.relu1(self.bn1(self.conv1(x))) 117 | x = self.relu2(self.bn2(self.conv2(x))) 118 | x = self.relu3(self.bn3(self.conv3(x))) 119 | x = self.maxpool(x) 120 | 121 | x = self.layer1(x) 122 | x = self.layer2(x) 123 | x = self.layer3(x) 124 | x = self.layer4(x) 125 | 126 | x = self.avgpool(x) 127 | x = x.view(x.size(0), -1) 128 | x = self.fc(x) 129 | 130 | return x 131 | 132 | 133 | ''' 134 | def resnext50(pretrained=False, **kwargs): 135 | """Constructs a ResNet-50 model. 136 | 137 | Args: 138 | pretrained (bool): If True, returns a model pre-trained on Places 139 | """ 140 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 141 | if pretrained: 142 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 143 | return model 144 | ''' 145 | 146 | 147 | def resnext101(pretrained=False, **kwargs): 148 | """Constructs a ResNet-101 model. 149 | 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on Places 152 | """ 153 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 154 | if pretrained: 155 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 156 | return model 157 | 158 | 159 | # def resnext152(pretrained=False, **kwargs): 160 | # """Constructs a ResNeXt-152 model. 161 | # 162 | # Args: 163 | # pretrained (bool): If True, returns a model pre-trained on Places 164 | # """ 165 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 166 | # if pretrained: 167 | # model.load_state_dict(load_url(model_urls['resnext152'])) 168 | # return model 169 | 170 | 171 | def load_url(url, model_dir='./pretrained', map_location=None): 172 | if not os.path.exists(model_dir): 173 | os.makedirs(model_dir) 174 | filename = url.split('/')[-1] 175 | cached_file = os.path.join(model_dir, filename) 176 | if not os.path.exists(cached_file): 177 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 178 | urlretrieve(url, cached_file) 179 | return torch.load(cached_file, map_location=map_location) 180 | -------------------------------------------------------------------------------- /training_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.backends.cudnn as cudnn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | n_iter=0 7 | def group_weight(module): 8 | group_decay = [] 9 | group_no_decay = [] 10 | for m in module.modules(): 11 | if isinstance(m, nn.Linear): 12 | group_decay.append(m.weight) 13 | if m.bias is not None: 14 | group_no_decay.append(m.bias) 15 | elif isinstance(m, nn.modules.conv._ConvNd): 16 | group_decay.append(m.weight) 17 | if m.bias is not None: 18 | group_no_decay.append(m.bias) 19 | elif isinstance(m, nn.modules.batchnorm._BatchNorm): 20 | if m.weight is not None: 21 | group_no_decay.append(m.weight) 22 | if m.bias is not None: 23 | group_no_decay.append(m.bias) 24 | 25 | assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) 26 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] 27 | return groups 28 | 29 | def update_ema_variables(model, ema_model, alpha, global_step): 30 | # Use the true average until the exponential average is more correct 31 | alpha = min(1 - 1 / (global_step + 1), alpha) 32 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 33 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 34 | 35 | class AverageMeter(object): 36 | """Computes and stores the average and current value""" 37 | def __init__(self): 38 | self.initialized = False 39 | self.val = None 40 | self.avg = None 41 | self.sum = None 42 | self.count = None 43 | 44 | def initialize(self, val, weight): 45 | self.val = val 46 | self.avg = val 47 | self.sum = val * weight 48 | self.count = weight 49 | self.initialized = True 50 | 51 | def update(self, val, weight=1): 52 | if not self.initialized: 53 | self.initialize(val, weight) 54 | else: 55 | self.add(val, weight) 56 | 57 | def add(self, val, weight): 58 | self.val = val 59 | self.sum += val * weight 60 | self.count += weight 61 | self.avg = self.sum / self.count 62 | 63 | def value(self): 64 | return self.val 65 | 66 | def average(self): 67 | return self.avg 68 | 69 | def IOU_Score(y_pred,y_val): 70 | def IoUOld(a,b): 71 | intersection = ((a==1) & (a==b)).sum() 72 | union = ((a==1) | (b==1)).sum() 73 | if union > 0: 74 | return intersection / union 75 | elif union == 0 and intersection == 0: 76 | return 1 77 | else: 78 | return 0 79 | 80 | y_pred=y_pred[:,1,:,:]#.view(batch_size,1,101,101) 81 | 82 | t=0.5 83 | IOU_list=[] 84 | for j in range(y_pred.shape[0]): 85 | y_pred_ = np.array(y_pred[j,:,:] > t, dtype=bool) 86 | y_val_=np.array(y_val[j,:,:], dtype=bool) 87 | 88 | IOU = IoUOld(y_pred_, y_val_) 89 | 90 | IOU_list.append(IOU) 91 | #now we take different threshholds, these threshholds 92 | #basically determine if our IOU consitutes as a "true positiv" 93 | #or not 94 | prec_list=[] 95 | for IOU_t in np.arange(0.5, 1.0, 0.05): 96 | #get true positives, aka all examples where the IOU is larger than the threshhold 97 | TP=np.sum(np.asarray(IOU_list)>IOU_t) 98 | #calculate the current precision, by devididing by the total number of examples ( pretty sure this is correct :D) 99 | #they where writing the denominator as TP+FP+FN but that doesnt really make sens becasue there are no False postivies i think 100 | Prec=TP/len(IOU_list) 101 | prec_list.append(Prec) 102 | 103 | return np.mean(prec_list) 104 | 105 | 106 | 107 | #Main Training Function 108 | from losses import lovasz_softmax,FocalLoss 109 | from training_functions import IOU_Score 110 | focal=FocalLoss(size_average=True) 111 | 112 | def train(train_loader,segmentation_module,segmentation_ema,optimizer 113 | ,writer 114 | ,lovasz_scaling=0.1 115 | ,focal_scaling=0.9 116 | ,unsupervised_scaling=0.1 117 | ,ema_scaling=0.2 118 | ,non_ema_scaling=1 119 | ,second_batch_size=2 120 | ,train=True 121 | ,test=False 122 | ,writer_name_list=None 123 | ): 124 | 125 | global n_iter 126 | #Training Loop 127 | cudnn.benchmark = True 128 | 129 | lovasz_scaling=torch.tensor(lovasz_scaling).float().cuda() 130 | focal_scaling=torch.tensor(focal_scaling).float().cuda() 131 | unsupervised_scaling=torch.tensor(unsupervised_scaling).float().cuda() 132 | ema_scaling=torch.tensor(ema_scaling).float().cuda() 133 | non_ema_scaling=torch.tensor(non_ema_scaling).float().cuda() 134 | 135 | #average meter for all the losses we keep track of. 136 | ave_total_loss = AverageMeter() # Total Loss 137 | ave_non_ema_loss = AverageMeter() 138 | ave_ema_loss = AverageMeter() 139 | ave_total_loss = AverageMeter() 140 | ave_lovasz_loss = AverageMeter() 141 | ave_focal_loss = AverageMeter() 142 | ave_lovasz_loss_ema = AverageMeter() 143 | ave_focal_loss_ema = AverageMeter() 144 | ave_unsupervised_loss = AverageMeter() 145 | ave_iou_score = AverageMeter() 146 | if train==True: 147 | segmentation_module.train() 148 | segmentation_ema.train() 149 | else: 150 | segmentation_module.eval() 151 | segmentation_ema.eval() 152 | 153 | for batch_data in train_loader: 154 | 155 | batch_data["img_data"]=batch_data["img_data"].cuda() 156 | batch_data["seg_label"]=batch_data["seg_label"].cuda().long().squeeze() 157 | 158 | #Normal Pred and Pred from the self ensembeled model 159 | pred = segmentation_module(batch_data) 160 | pred_ema = segmentation_ema(batch_data) 161 | #We dont want to gradient descent into the EMA model 162 | pred_ema=Variable(pred_ema.detach().data, requires_grad=False) 163 | 164 | ### UNSUPVERVISED LOSS #### 165 | unsupervised_loss = torch.mean((pred - pred_ema)**2).cuda() 166 | 167 | ### SUPERVISED LOSS #### 168 | #We jsut get rid of the Unlabeled examples for the supervised loss! 169 | pred=pred[:-second_batch_size,:,:] 170 | pred_ema=pred_ema[:-second_batch_size,:,:] 171 | batch_data["seg_label"]=batch_data["seg_label"][:-second_batch_size,:,:] 172 | 173 | lovasz_loss=lovasz_softmax(pred, batch_data['seg_label'],ignore=-1,only_present=True).cuda() 174 | focal_loss=focal(pred, batch_data['seg_label'],) 175 | 176 | lovasz_loss_ema=lovasz_softmax(pred_ema, batch_data['seg_label'],ignore=-1,only_present=True).cuda() 177 | focal_loss_ema=focal(pred_ema, batch_data['seg_label'],) 178 | 179 | #### Loss Combinations ##### 180 | non_ema_loss=(lovasz_loss*lovasz_scaling+focal_loss*focal_scaling).cuda() 181 | ema_loss=(lovasz_loss_ema*lovasz_scaling+focal_loss_ema*focal_scaling).cuda() 182 | 183 | total_loss=(non_ema_loss*non_ema_scaling+ema_loss*ema_scaling+unsupervised_scaling*unsupervised_loss).cuda() 184 | #Need to give it as softmaxes 185 | pred = nn.functional.softmax(pred, dim=1) 186 | iou_score=IOU_Score(pred,batch_data["seg_label"]) 187 | 188 | ### BW #### 189 | if train==True: 190 | optimizer.zero_grad() 191 | total_loss.backward() 192 | 193 | optimizer.step() 194 | n_iter=n_iter+1 195 | 196 | 197 | update_ema_variables(segmentation_module, segmentation_ema, 0.999, n_iter) 198 | 199 | 200 | ### WRITING STUFF ######### 201 | 202 | ave_non_ema_loss.update(non_ema_loss.data.item()) 203 | ave_ema_loss.update(ema_loss.data.item()) 204 | ave_total_loss.update(total_loss.data.item()) 205 | ave_lovasz_loss.update(lovasz_loss.data.item()) 206 | ave_focal_loss.update(focal_loss.data.item()) 207 | ave_lovasz_loss_ema.update(lovasz_loss_ema.data.item()) 208 | ave_focal_loss_ema.update(focal_loss_ema.data.item()) 209 | ave_unsupervised_loss.update(unsupervised_loss.data.item()) 210 | ave_iou_score.update(iou_score.item()) 211 | 212 | if test==True: 213 | print(n_iter) 214 | break 215 | 216 | 217 | 218 | 219 | writer.add_scalar(writer_name_list[0], ave_non_ema_loss.average(), n_iter) 220 | writer.add_scalar(writer_name_list[1], ave_ema_loss.average(), n_iter) 221 | writer.add_scalar(writer_name_list[2], ave_total_loss.average(), n_iter) 222 | writer.add_scalar(writer_name_list[3], ave_lovasz_loss.average(), n_iter) 223 | writer.add_scalar(writer_name_list[4], ave_focal_loss.average(), n_iter) 224 | writer.add_scalar(writer_name_list[5], ave_lovasz_loss_ema.average(), n_iter) 225 | writer.add_scalar(writer_name_list[6], ave_focal_loss_ema.average(), n_iter) 226 | writer.add_scalar(writer_name_list[7], ave_unsupervised_loss.average(), n_iter) 227 | writer.add_scalar(writer_name_list[8], ave_iou_score.average(), n_iter) 228 | 229 | if train==False: 230 | return np.mean(ave_iou_score.average()) 231 | 232 | 233 | 234 | 235 | --------------------------------------------------------------------------------