├── .gitignore ├── LICENSE ├── README.md ├── examples ├── imgs │ └── VOC │ │ ├── 2007_004189.jpg │ │ ├── 2007_004189.png │ │ ├── 2007_006449.jpg │ │ ├── 2007_006449.png │ │ ├── 2008_002536.jpg │ │ ├── 2008_002536.png │ │ ├── 2010_005582.jpg │ │ └── 2010_005582.png └── notebooks │ └── VOC.ipynb ├── models ├── __init__.py └── resnet.py ├── requirements3.txt └── utils ├── __init__.py ├── cmap.npy ├── cs_cmap.npy ├── helpers.py └── layer_factory.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .flake8 6 | .idea 7 | 8 | experiments 9 | pretrained 10 | ckpt 11 | weights/ 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | env/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | tmp/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # dotenv 90 | .env 91 | 92 | # virtualenv 93 | .venv 94 | venv/ 95 | ENV/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | search-*/ 110 | eval-*/ 111 | runs/ 112 | 113 | # Swap 114 | [._]*.s[a-v][a-z] 115 | [._]*.sw[a-p] 116 | [._]s[a-v][a-z] 117 | [._]sw[a-p] 118 | 119 | # Session 120 | Session.vim 121 | 122 | # Temporary 123 | .netrwhist 124 | *~ 125 | # Auto-generated tag files 126 | tags 127 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | RefineNet for non-commercial purposes 2 | 3 | Copyright (c) 2018, Vladimir Nekrasov 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RefineNet (in PyTorch) 2 | 3 | This repository provides the ResNet-101-based model trained on PASCAL VOC from the paper `RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation` (the provided weights achieve **80.5**% mean IoU on the validation set in the single scale setting) 4 | 5 | ``` 6 | RefineNet: Multi-Path Refinement Networks for High-Resolution Semantic Segmentation 7 | Guosheng Lin, Anton Milan, Chunhua Shen, Ian Reid 8 | In CVPR 2017 9 | ``` 10 | 11 | ## Getting Started 12 | 13 | For flawless reproduction of our results, the Ubuntu OS is recommended. The model have been tested using Python 3.6. 14 | 15 | ### Dependencies 16 | 17 | ``` 18 | pip3 19 | torch>=0.4.0 20 | ``` 21 | To install required Python packages, please run `pip3 install -r requirements3.txt` (Python3) - use the flag `-u` for local installation. 22 | The given examples can be run with, or without GPU. 23 | 24 | ## Running examples 25 | 26 | For the ease of reproduction, we have embedded all our examples inside Jupyter notebooks. One can either download them from this repository and proceed working with them on his/her local machine/server, or can resort to online version supported by the Google Colab service. 27 | 28 | ### Jupyter Notebooks [Local] 29 | 30 | If all the installation steps have been smoothly executed, you can proceed with running any of the notebooks provided in the `examples/notebooks` folder. 31 | To start the Jupyter Notebook server, on your local machine run `jupyter notebook`. This will open a web page inside your browser. If it did not open automatically, find the port number from the command's output and paste it into your browser manually. 32 | After that, navigate to the repository folder and choose any of the examples given. 33 | 34 | Inside the notebook, one can try out their own images, write loops to iterate over videos / whole datasets / streams (e.g., from webcam). Feel free to contribute your cool use cases of the notebooks! 35 | 36 | ### Colab Notebooks [Web] 37 | 38 | *Coming soon* 39 | 40 | ## Training scripts 41 | 42 | Please refer to the training scripts for [Light-Weight-RefineNet](https://github.com/DrSleep/light-weight-refinenet) 43 | 44 | 45 | ## More projects to check out 46 | 47 | [Light-Weight-RefineNet](https://github.com/DrSleep/light-weight-refinenet) - compact version of RefineNet running in real-time with minimal decrease in accuracy (3x decrease in the number of parameters, 5x decrease in the number of FLOPs) 48 | 49 | ## License 50 | 51 | For academic usage, this project is licensed under the 2-clause BSD License - see the [LICENSE](LICENSE) file for details. For commercial usage, please contact the authors. 52 | -------------------------------------------------------------------------------- /examples/imgs/VOC/2007_004189.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/examples/imgs/VOC/2007_004189.jpg -------------------------------------------------------------------------------- /examples/imgs/VOC/2007_004189.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/examples/imgs/VOC/2007_004189.png -------------------------------------------------------------------------------- /examples/imgs/VOC/2007_006449.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/examples/imgs/VOC/2007_006449.jpg -------------------------------------------------------------------------------- /examples/imgs/VOC/2007_006449.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/examples/imgs/VOC/2007_006449.png -------------------------------------------------------------------------------- /examples/imgs/VOC/2008_002536.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/examples/imgs/VOC/2008_002536.jpg -------------------------------------------------------------------------------- /examples/imgs/VOC/2008_002536.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/examples/imgs/VOC/2008_002536.png -------------------------------------------------------------------------------- /examples/imgs/VOC/2010_005582.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/examples/imgs/VOC/2010_005582.jpg -------------------------------------------------------------------------------- /examples/imgs/VOC/2010_005582.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/examples/imgs/VOC/2010_005582.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/models/__init__.py -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """RefineNet 2 | 3 | RefineNet PyTorch for non-commercial purposes 4 | 5 | Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | """ 29 | 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | import torch 33 | 34 | import numpy as np 35 | 36 | from utils.helpers import maybe_download 37 | from utils.layer_factory import conv1x1, conv3x3, CRPBlock, RCUBlock 38 | 39 | data_info = { 40 | 21: 'VOC', 41 | } 42 | 43 | models_urls = { 44 | '101_voc' : 'https://cloudstor.aarnet.edu.au/plus/s/Owmttk9bdPROwc6/download', 45 | 46 | '101_imagenet': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 47 | } 48 | 49 | class BasicBlock(nn.Module): 50 | expansion = 1 51 | 52 | def __init__(self, inplanes, planes, stride=1, downsample=None): 53 | super(BasicBlock, self).__init__() 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None): 85 | super(Bottleneck, self).__init__() 86 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 87 | self.bn1 = nn.BatchNorm2d(planes) 88 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 89 | padding=1, bias=False) 90 | self.bn2 = nn.BatchNorm2d(planes) 91 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 92 | self.bn3 = nn.BatchNorm2d(planes * 4) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | residual = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | if self.downsample is not None: 112 | residual = self.downsample(x) 113 | 114 | out += residual 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class RefineNet(nn.Module): 121 | 122 | def __init__(self, block, layers, num_classes=21): 123 | self.inplanes = 64 124 | super(RefineNet, self).__init__() 125 | self.do = nn.Dropout(p=0.5) 126 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 127 | bias=False) 128 | self.bn1 = nn.BatchNorm2d(64) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0]) 132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 135 | self.p_ims1d2_outl1_dimred = conv3x3(2048, 512, bias=False) 136 | self.adapt_stage1_b = self._make_rcu(512, 512, 2, 2) 137 | self.mflow_conv_g1_pool = self._make_crp(512, 512, 4) 138 | self.mflow_conv_g1_b = self._make_rcu(512, 512, 3, 2) 139 | self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(512, 256, bias=False) 140 | self.p_ims1d2_outl2_dimred = conv3x3(1024, 256, bias=False) 141 | self.adapt_stage2_b = self._make_rcu(256, 256, 2, 2) 142 | self.adapt_stage2_b2_joint_varout_dimred = conv3x3(256, 256, bias=False) 143 | self.mflow_conv_g2_pool = self._make_crp(256, 256, 4) 144 | self.mflow_conv_g2_b = self._make_rcu(256, 256, 3, 2) 145 | self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(256, 256, bias=False) 146 | 147 | self.p_ims1d2_outl3_dimred = conv3x3(512, 256, bias=False) 148 | self.adapt_stage3_b = self._make_rcu(256, 256, 2, 2) 149 | self.adapt_stage3_b2_joint_varout_dimred = conv3x3(256, 256, bias=False) 150 | self.mflow_conv_g3_pool = self._make_crp(256, 256, 4) 151 | self.mflow_conv_g3_b = self._make_rcu(256, 256, 3, 2) 152 | self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(256, 256, bias=False) 153 | 154 | self.p_ims1d2_outl4_dimred = conv3x3(256, 256, bias=False) 155 | self.adapt_stage4_b = self._make_rcu(256, 256, 2, 2) 156 | self.adapt_stage4_b2_joint_varout_dimred = conv3x3(256, 256, bias=False) 157 | self.mflow_conv_g4_pool = self._make_crp(256, 256, 4) 158 | self.mflow_conv_g4_b = self._make_rcu(256, 256, 3, 2) 159 | 160 | self.clf_conv = nn.Conv2d(256, num_classes, kernel_size=3, stride=1, 161 | padding=1, bias=True) 162 | 163 | def _make_crp(self, in_planes, out_planes, stages): 164 | layers = [CRPBlock(in_planes, out_planes,stages)] 165 | return nn.Sequential(*layers) 166 | 167 | def _make_rcu(self, in_planes, out_planes, blocks, stages): 168 | layers = [RCUBlock(in_planes, out_planes, blocks, stages)] 169 | return nn.Sequential(*layers) 170 | 171 | def _make_layer(self, block, planes, blocks, stride=1): 172 | downsample = None 173 | if stride != 1 or self.inplanes != planes * block.expansion: 174 | downsample = nn.Sequential( 175 | nn.Conv2d(self.inplanes, planes * block.expansion, 176 | kernel_size=1, stride=stride, bias=False), 177 | nn.BatchNorm2d(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample)) 182 | self.inplanes = planes * block.expansion 183 | for i in range(1, blocks): 184 | layers.append(block(self.inplanes, planes)) 185 | 186 | return nn.Sequential(*layers) 187 | 188 | def forward(self, x): 189 | x = self.conv1(x) 190 | x = self.bn1(x) 191 | x = self.relu(x) 192 | x = self.maxpool(x) 193 | 194 | l1 = self.layer1(x) 195 | l2 = self.layer2(l1) 196 | l3 = self.layer3(l2) 197 | l4 = self.layer4(l3) 198 | 199 | l4 = self.do(l4) 200 | l3 = self.do(l3) 201 | 202 | x4 = self.p_ims1d2_outl1_dimred(l4) 203 | x4 = self.adapt_stage1_b(x4) 204 | x4 = self.relu(x4) 205 | x4 = self.mflow_conv_g1_pool(x4) 206 | x4 = self.mflow_conv_g1_b(x4) 207 | x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) 208 | x4 = nn.Upsample(size=l3.size()[2:], mode='bilinear', align_corners=True)(x4) 209 | 210 | x3 = self.p_ims1d2_outl2_dimred(l3) 211 | x3 = self.adapt_stage2_b(x3) 212 | x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) 213 | x3 = x3 + x4 214 | x3 = F.relu(x3) 215 | x3 = self.mflow_conv_g2_pool(x3) 216 | x3 = self.mflow_conv_g2_b(x3) 217 | x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) 218 | x3 = nn.Upsample(size=l2.size()[2:], mode='bilinear', align_corners=True)(x3) 219 | 220 | x2 = self.p_ims1d2_outl3_dimred(l2) 221 | x2 = self.adapt_stage3_b(x2) 222 | x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) 223 | x2 = x2 + x3 224 | x2 = F.relu(x2) 225 | x2 = self.mflow_conv_g3_pool(x2) 226 | x2 = self.mflow_conv_g3_b(x2) 227 | x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) 228 | x2 = nn.Upsample(size=l1.size()[2:], mode='bilinear', align_corners=True)(x2) 229 | 230 | x1 = self.p_ims1d2_outl4_dimred(l1) 231 | x1 = self.adapt_stage4_b(x1) 232 | x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) 233 | x1 = x1 + x2 234 | x1 = F.relu(x1) 235 | x1 = self.mflow_conv_g4_pool(x1) 236 | x1 = self.mflow_conv_g4_b(x1) 237 | x1 = self.do(x1) 238 | 239 | out = self.clf_conv(x1) 240 | return out 241 | 242 | 243 | def rf101(num_classes, imagenet=False, pretrained=True, **kwargs): 244 | model = RefineNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, **kwargs) 245 | if imagenet: 246 | key = '101_imagenet' 247 | url = models_urls[key] 248 | model.load_state_dict(maybe_download(key, url), strict=False) 249 | elif pretrained: 250 | dataset = data_info.get(num_classes, None) 251 | if dataset: 252 | bname = '101_' + dataset.lower() 253 | key = 'rf' + bname 254 | url = models_urls[bname] 255 | model.load_state_dict(maybe_download(key, url), strict=False) 256 | return model -------------------------------------------------------------------------------- /requirements3.txt: -------------------------------------------------------------------------------- 1 | ipykernel==4.8.2 2 | ipython==6.4.0 3 | ipython-genutils==0.2.0 4 | jupyter==1.0.0 5 | jupyter-client==5.2.3 6 | jupyter-console==5.2.0 7 | jupyter-core==4.4.0 8 | matplotlib==2.2.2 9 | notebook>=5.7.2 10 | numpy==1.14.5 11 | opencv-python==3.4.1.15 12 | Pillow==6.2.0 13 | scikit-image==0.14.0 14 | six==1.11.0 15 | torch==0.4.1 16 | tqdm==4.23.4 17 | urllib3==1.24.2 18 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/utils/__init__.py -------------------------------------------------------------------------------- /utils/cmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/utils/cmap.npy -------------------------------------------------------------------------------- /utils/cs_cmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DrSleep/refinenet-pytorch/8f25c076016e61a835551493aae303e81cf36c53/utils/cs_cmap.npy -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | IMG_SCALE = 1./255 5 | IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) 6 | IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) 7 | 8 | def maybe_download(model_name, model_url, model_dir=None, map_location=None): 9 | import os, sys 10 | from six.moves import urllib 11 | if model_dir is None: 12 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) 13 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) 14 | if not os.path.exists(model_dir): 15 | os.makedirs(model_dir) 16 | filename = '{}.pth.tar'.format(model_name) 17 | cached_file = os.path.join(model_dir, filename) 18 | if not os.path.exists(cached_file): 19 | url = model_url 20 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 21 | urllib.request.urlretrieve(url, cached_file) 22 | return torch.load(cached_file, map_location=map_location) 23 | 24 | def prepare_img(img): 25 | return (img * IMG_SCALE - IMG_MEAN) / IMG_STD -------------------------------------------------------------------------------- /utils/layer_factory.py: -------------------------------------------------------------------------------- 1 | """RefineNet-CRP-RCU-blocks in PyTorch 2 | 3 | RefineNet-PyTorch for non-commercial purposes 4 | 5 | Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | """ 29 | 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | 33 | def batchnorm(in_planes): 34 | "batch norm 2d" 35 | return nn.BatchNorm2d(in_planes, affine=True, eps=1e-5, momentum=0.1) 36 | 37 | def conv3x3(in_planes, out_planes, stride=1, bias=False): 38 | "3x3 convolution with padding" 39 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 40 | padding=1, bias=bias) 41 | 42 | def conv1x1(in_planes, out_planes, stride=1, bias=False): 43 | "1x1 convolution" 44 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 45 | padding=0, bias=bias) 46 | 47 | def convbnrelu(in_planes, out_planes, kernel_size, stride=1, groups=1, act=True): 48 | "conv-batchnorm-relu" 49 | if act: 50 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size, stride=stride, padding=int(kernel_size / 2.), groups=groups, bias=False), 51 | batchnorm(out_planes), 52 | nn.ReLU6(inplace=True)) 53 | else: 54 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size, stride=stride, padding=int(kernel_size / 2.), groups=groups, bias=False), 55 | batchnorm(out_planes)) 56 | 57 | class CRPBlock(nn.Module): 58 | 59 | def __init__(self, in_planes, out_planes, n_stages): 60 | super(CRPBlock, self).__init__() 61 | for i in range(n_stages): 62 | setattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'), 63 | conv3x3(in_planes if (i == 0) else out_planes, 64 | out_planes, stride=1, 65 | bias=False)) 66 | self.stride = 1 67 | self.n_stages = n_stages 68 | self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) 69 | 70 | def forward(self, x): 71 | top = x 72 | for i in range(self.n_stages): 73 | top = self.maxpool(top) 74 | top = getattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'))(top) 75 | x = top + x 76 | return x 77 | 78 | stages_suffixes = {0 : '_conv', 79 | 1 : '_conv_relu_varout_dimred'} 80 | 81 | class RCUBlock(nn.Module): 82 | 83 | def __init__(self, in_planes, out_planes, n_blocks, n_stages): 84 | super(RCUBlock, self).__init__() 85 | for i in range(n_blocks): 86 | for j in range(n_stages): 87 | setattr(self, '{}{}'.format(i + 1, stages_suffixes[j]), 88 | conv3x3(in_planes if (i == 0) and (j == 0) else out_planes, 89 | out_planes, stride=1, 90 | bias=(j == 0))) 91 | self.stride = 1 92 | self.n_blocks = n_blocks 93 | self.n_stages = n_stages 94 | 95 | def forward(self, x): 96 | for i in range(self.n_blocks): 97 | residual = x 98 | for j in range(self.n_stages): 99 | x = F.relu(x) 100 | x = getattr(self, '{}{}'.format(i + 1, stages_suffixes[j]))(x) 101 | x += residual 102 | return x 103 | 104 | 105 | --------------------------------------------------------------------------------