├── .gitignore ├── LICENSE ├── README.md ├── benchmark.py ├── convcrf ├── __init__.py └── convcrf.py ├── data ├── .directory ├── 2007_000033_0img.png ├── 2007_000033_5labels.png ├── 2007_000129_0img.png ├── 2007_000129_5labels.png ├── 2007_000332_0img.png ├── 2007_000332_5labels.png ├── 2007_000346_0img.png ├── 2007_000346_5labels.png ├── 2007_000847_0img.png ├── 2007_000847_5labels.png ├── 2007_001284_0img.png ├── 2007_001284_5labels.png ├── 2007_001288_0img.png ├── 2007_001288_5labels.png └── output │ ├── Res1.png │ ├── Res2.pdf │ ├── Res2.png │ └── Res_1.png ├── demo.py ├── fullcrf ├── __init__.py └── fullcrf.py ├── requirements.txt ├── setup.py └── utils ├── __init__.py ├── pascal_visualizer.py ├── synthetic.py ├── test_utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | out.png 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Marvin Teichmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ConvCRF 2 | ======== 3 | This repository contains the reference implementation for our proposed [Convolutional CRFs][4] in PyTorch (Tensorflow planned). The two main entry-points are [demo.py](demo.py) and [benchmark.py](benchmark.py). Demo.py performs ConvCRF inference on a single input image while benchmark.py compares ConvCRF with FullCRF. Both scripts output plots similar to the one shown below. 4 | 5 | ![Example Output](data/output/Res2.png) 6 | 7 | Requirements 8 | ------------- 9 | 10 | **Plattform**: *Linux, python3 >= 3.4 (or python2 >= 2.7), pytorch 0.4 (or pytorch 0.3 + pyinn), cuda, cudnn* 11 | 12 | **Python Packages**: *numpy, imageio, cython, scikit-image, matplotlib* 13 | 14 | To install those python packages run `pip install -r requirements.txt` or `pip install numpy imageio cython scikit-image matplotlib`. I recommand using a [python virtualenv][1]. 15 | 16 | ### Optional Packages: pyinn, pydensecrf 17 | 18 | [**Pydensecrf**][2] is required to run FullCRF, which is only needed for the benchmark. To install pydensecrf, follow the instructions [here][2] or simply run `pip install git+https://github.com/lucasb-eyer/pydensecrf.git`. **Warning** Running `pip install git+` downloads and installs external code from the internet. 19 | 20 | [**PyINN**][3] allows us to write native cuda operations and compile them on-the-fly during runtime. PyINN is used for our initial ConvCRF implementation and required for PyTorch 0.3 users. PyTorch 0.4 introduces an Im2Col layer, making it possible to implement ConvCRFs entirely in PyTorch. PyINN can be used as alternative backend. Run `pip install git+https://github.com/szagoruyko/pyinn.git@master` to install PyINN. 21 | 22 | 23 | Execute 24 | -------- 25 | 26 | **Demo**: Run `python demo.py data/2007_001288_0img.png data/2007_001288_5labels.png` to perform ConvCRF inference on a single image. Try `python demo.py --help` to see more options. 27 | 28 | **Benchmark**: Run `python benchmark.py data/2007_001288_0img.png data/2007_001288_5labels.png` to compare the performance of ConvCRFs to FullCRFs. This script will also tell you how much faster ConvCRFs are. On my system ConvCRF7 is more then **40** and ConvCRF5 more then **60** times faster. 29 | 30 | 31 | Citation 32 | -------- 33 | If you benefit from this project, please consider citing our [paper][4]. 34 | 35 | TODO 36 | ----- 37 | 38 | - [x] Build a native PyTorch 0.4 implementation of ConvCRF 39 | - [x] Provide python 2 implementation 40 | - [ ] Build a Tensorflow implementation of ConvCRF 41 | 42 | 43 | 44 | [1]: https://virtualenvwrapper.readthedocs.io/en/latest/ 45 | [2]: https://github.com/lucasb-eyer/pydensecrf 46 | [3]: https://github.com/szagoruyko/pyinn 47 | [4]: https://arxiv.org/abs/1805.04777 48 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 Marvin Teichmann 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import sys 13 | 14 | import numpy as np 15 | import imageio 16 | # import scipy as scp 17 | # import scipy.misc 18 | 19 | import argparse 20 | 21 | import logging 22 | 23 | from convcrf import convcrf 24 | from fullcrf import fullcrf 25 | 26 | import torch 27 | from torch.autograd import Variable 28 | 29 | from utils import pascal_visualizer as vis 30 | from utils import synthetic 31 | 32 | import time 33 | 34 | try: 35 | import matplotlib.pyplot as plt 36 | matplotlib = True 37 | figure = plt.figure() 38 | plt.close(figure) 39 | except: 40 | matplotlib = False 41 | pass 42 | 43 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 44 | level=logging.INFO, 45 | stream=sys.stdout) 46 | 47 | 48 | def do_crf_inference(image, unary, args): 49 | 50 | if args.pyinn or not hasattr(torch.nn.functional, 'unfold'): 51 | # pytorch 0.3 or older requires pyinn. 52 | args.pyinn = True 53 | # Cheap and easy trick to make sure that pyinn is loadable. 54 | import pyinn 55 | 56 | # get basic hyperparameters 57 | num_classes = unary.shape[2] 58 | shape = image.shape[0:2] 59 | config = convcrf.default_conf 60 | config['filter_size'] = 7 61 | config['pyinn'] = args.pyinn 62 | 63 | if args.normalize: 64 | # Warning, applying image normalization affects CRF computation. 65 | # The parameter 'col_feats::schan' needs to be adapted. 66 | 67 | # Normalize image range 68 | # This changes the image features and influences CRF output 69 | image = image / 255 70 | # mean substraction 71 | # CRF is invariant to mean subtraction, output is NOT affected 72 | image = image - 0.5 73 | # std normalization 74 | # Affect CRF computation 75 | image = image / 0.3 76 | 77 | # schan = 0.1 is a good starting value for normalized images. 78 | # The relation is f_i = image / schan 79 | config['col_feats']['schan'] = 0.1 80 | 81 | # make input pytorch compatible 82 | img = image.transpose(2, 0, 1) # shape: [3, hight, width] 83 | # Add batch dimension to image: [1, 3, height, width] 84 | img = img.reshape([1, 3, shape[0], shape[1]]) 85 | img_var = Variable(torch.Tensor(img)) 86 | 87 | un = unary.transpose(2, 0, 1) # shape: [3, hight, width] 88 | # Add batch dimension to unary: [1, 21, height, width] 89 | un = un.reshape([1, num_classes, shape[0], shape[1]]) 90 | unary_var = Variable(torch.Tensor(un)) 91 | 92 | logging.debug("Build ConvCRF.") 93 | ## 94 | # Create CRF module 95 | gausscrf = convcrf.GaussCRF(conf=config, shape=shape, nclasses=num_classes, 96 | use_gpu=not args.cpu) 97 | 98 | # move to GPU if requested 99 | if not args.cpu: 100 | img_var = img_var.cuda() 101 | unary_var = unary_var.cuda() 102 | gausscrf.cuda() 103 | 104 | 105 | # Perform ConvCRF inference 106 | """ 107 | 'Warm up': Our implementation compiles cuda kernels during runtime. 108 | The first inference call thus comes with some overhead. 109 | """ 110 | logging.info("Start Computation.") 111 | prediction = gausscrf.forward(unary=unary_var, img=img_var) 112 | 113 | if args.nospeed: 114 | 115 | logging.info("Doing speed benchmark with filter size: {}" 116 | .format(config['filter_size'])) 117 | logging.info("Running multiple iteration. This may take a while.") 118 | 119 | # Our implementation compiles cuda kernels during runtime. 120 | # The first inference run is those much slower. 121 | # prediction = gausscrf.forward(unary=unary_var, img=img_var) 122 | 123 | start_time = time.time() 124 | for i in range(10): 125 | # Running ConvCRF 10 times and report average total time 126 | prediction = gausscrf.forward(unary=unary_var, img=img_var) 127 | 128 | prediction.cpu() # wait for all GPU computations to finish 129 | duration = (time.time() - start_time) * 1000 / 10 130 | 131 | logging.debug("Finished running 10 predictions.") 132 | logging.debug("Avg Computation time: {} ms".format(duration)) 133 | 134 | # Perform FullCRF inference 135 | myfullcrf = fullcrf.FullCRF(config, shape, num_classes) 136 | fullprediction = myfullcrf.compute(unary, image, softmax=False) 137 | 138 | if args.nospeed: 139 | 140 | start_time = time.time() 141 | for i in range(5): 142 | # Running FullCRF 5 times and report average total time 143 | fullprediction = myfullcrf.compute(unary, image, softmax=False) 144 | 145 | fullduration = (time.time() - start_time) * 1000 / 5 146 | 147 | logging.debug("Finished running 5 predictions.") 148 | logging.debug("Avg Computation time: {} ms".format(fullduration)) 149 | 150 | logging.info("Using FullCRF took {:4.0f} ms ({:2.2f} s)".format( 151 | fullduration, fullduration / 1000)) 152 | 153 | logging.info("Using ConvCRF took {:4.0f} ms ({:2.2f} s)".format( 154 | duration, duration / 1000)) 155 | 156 | logging.info("Congratulation. Using ConvCRF provids a speed-up" 157 | " of {:.0f}.".format(fullduration / duration)) 158 | 159 | logging.info("") 160 | 161 | return prediction.data.cpu().numpy(), fullprediction 162 | 163 | 164 | def plot_results(image, unary, conv_out, full_out, label, args): 165 | 166 | logging.debug("Plot results.") 167 | 168 | # Create visualizer 169 | myvis = vis.PascalVisualizer() 170 | 171 | # Transform id image to coloured labels 172 | coloured_label = myvis.id2color(id_image=label) 173 | 174 | unary_hard = np.argmax(unary, axis=2) 175 | coloured_unary = myvis.id2color(id_image=unary_hard) 176 | 177 | conv_out = conv_out[0] # Remove Batch dimension 178 | conv_hard = np.argmax(conv_out, axis=0) 179 | coloured_conv = myvis.id2color(id_image=conv_hard) 180 | 181 | full_hard = np.argmax(full_out, axis=2) 182 | coloured_full = myvis.id2color(id_image=full_hard) 183 | 184 | if matplotlib: 185 | # Plot results using matplotlib 186 | figure = plt.figure() 187 | figure.tight_layout() 188 | # Plot parameters 189 | num_rows = 2 190 | num_cols = 3 191 | off = 0 192 | 193 | ax = figure.add_subplot(num_rows, num_cols, 1) 194 | # img_name = os.path.basename(args.image) 195 | ax.set_title('Image ') 196 | ax.axis('off') 197 | ax.imshow(image) 198 | 199 | ax = figure.add_subplot(num_rows, num_cols, 2) 200 | ax.set_title('Label') 201 | ax.axis('off') 202 | ax.imshow(coloured_label.astype(np.uint8)) 203 | 204 | ax = figure.add_subplot(num_rows, num_cols, 3 - off) 205 | ax.set_title('Unary') 206 | ax.axis('off') 207 | ax.imshow(coloured_unary.astype(np.uint8)) 208 | 209 | ax = figure.add_subplot(num_rows, num_cols, 4 - off) 210 | ax.set_title('ConvCRF Output') 211 | ax.axis('off') 212 | ax.imshow(coloured_conv.astype(np.uint8)) 213 | 214 | ax = figure.add_subplot(num_rows, num_cols, 5 - off) 215 | ax.set_title('FullCRF Output') 216 | ax.axis('off') 217 | ax.imshow(coloured_full.astype(np.uint8)) 218 | 219 | # plt.subplots_adjust(left=0.02, right=0.98, 220 | # wspace=0.15, hspace=0.15) 221 | 222 | plt.show() 223 | else: 224 | if args.output is None: 225 | args.output = "out.png" 226 | 227 | logging.warning("Matplotlib not found.") 228 | logging.info("Saving output to {} instead".format(args.output)) 229 | 230 | if args.output is not None: 231 | # Save results to disk 232 | out_img = np.concatenate( 233 | (image, coloured_label, coloured_unary, coloured_conv), 234 | axis=1) 235 | 236 | imageio.imwrite(args.output, out_img.astype(np.uint8)) 237 | 238 | logging.info("Plot has been saved to {}".format(args.output)) 239 | 240 | return 241 | 242 | 243 | def get_parser(): 244 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 245 | parser = ArgumentParser(description=__doc__, 246 | formatter_class=ArgumentDefaultsHelpFormatter) 247 | 248 | parser.add_argument("image", type=str, 249 | help="input image") 250 | 251 | parser.add_argument("label", type=str, 252 | help="Label file.") 253 | 254 | parser.add_argument("--gpu", type=str, default='0', 255 | help="which gpu to use") 256 | 257 | parser.add_argument('--output', type=str, 258 | help="Optionally save output as img.") 259 | 260 | parser.add_argument('--nospeed', action='store_false', 261 | help="Skip speed evaluation.") 262 | 263 | parser.add_argument('--normalize', action='store_true', 264 | help="Normalize input image before inference.") 265 | 266 | parser.add_argument('--pyinn', action='store_true', 267 | help="Use pyinn based Cuda implementation" 268 | "for message passing.") 269 | 270 | parser.add_argument('--cpu', action='store_true', 271 | help="Run on CPU instead of GPU.") 272 | 273 | return parser 274 | 275 | 276 | if __name__ == '__main__': 277 | parser = get_parser() 278 | args = parser.parse_args() 279 | 280 | # Load data 281 | image = imageio.imread(args.image) 282 | label = imageio.imread(args.label) 283 | 284 | # Produce unary by adding noise to label 285 | unary = synthetic.augment_label(label, num_classes=21) 286 | # Compute CRF inference 287 | 288 | conv_out, full_out = do_crf_inference(image, unary, args) 289 | plot_results(image, unary, conv_out, full_out, label, args) 290 | logging.info("Thank you for trying ConvCRFs.") 291 | -------------------------------------------------------------------------------- /convcrf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/convcrf/__init__.py -------------------------------------------------------------------------------- /convcrf/convcrf.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 Marvin Teichmann 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import sys 13 | 14 | import numpy as np 15 | import scipy as scp 16 | import math 17 | 18 | import logging 19 | import warnings 20 | 21 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 22 | level=logging.INFO, 23 | stream=sys.stdout) 24 | 25 | try: 26 | import pyinn as P 27 | has_pyinn = True 28 | except ImportError: 29 | # PyInn is required to use our cuda based message-passing implementation 30 | # Torch 0.4 provides a im2col operation, which will be used instead. 31 | # It is ~15% slower. 32 | has_pyinn = False 33 | pass 34 | 35 | from utils import test_utils 36 | 37 | import torch 38 | import torch.nn as nn 39 | from torch.nn import functional as nnfun 40 | from torch.autograd import Variable 41 | from torch.nn.parameter import Parameter 42 | 43 | import torch.nn.functional as F 44 | 45 | import gc 46 | 47 | 48 | # Default config as proposed by Philipp Kraehenbuehl and Vladlen Koltun, 49 | default_conf = { 50 | 'filter_size': 11, 51 | 'blur': 4, 52 | 'merge': True, 53 | 'norm': 'none', 54 | 'weight': 'vector', 55 | "unary_weight": 1, 56 | "weight_init": 0.2, 57 | 58 | 'trainable': False, 59 | 'convcomp': False, 60 | 'logsoftmax': True, # use logsoftmax for numerical stability 61 | 'softmax': True, 62 | 'final_softmax': False, 63 | 64 | 'pos_feats': { 65 | 'sdims': 3, 66 | 'compat': 3, 67 | }, 68 | 'col_feats': { 69 | 'sdims': 80, 70 | 'schan': 13, # schan depend on the input scale. 71 | # use schan = 13 for images in [0, 255] 72 | # for normalized images in [-0.5, 0.5] try schan = 0.1 73 | 'compat': 10, 74 | 'use_bias': False 75 | }, 76 | "trainable_bias": False, 77 | 78 | "pyinn": False 79 | } 80 | 81 | # Config used for test cases on 10 x 10 pixel greyscale inpu 82 | test_config = { 83 | 'filter_size': 5, 84 | 'blur': 1, 85 | 'merge': False, 86 | 'norm': 'sym', 87 | 'trainable': False, 88 | 'weight': 'scalar', 89 | "unary_weight": 1, 90 | "weight_init": 0.5, 91 | 'convcomp': False, 92 | 93 | 'trainable': False, 94 | 'convcomp': False, 95 | "logsoftmax": True, # use logsoftmax for numerical stability 96 | "softmax": True, 97 | 98 | 'pos_feats': { 99 | 'sdims': 1.5, 100 | 'compat': 3, 101 | }, 102 | 103 | 'col_feats': { 104 | 'sdims': 2, 105 | 'schan': 2, 106 | 'compat': 3, 107 | 'use_bias': True 108 | }, 109 | "trainable_bias": False, 110 | } 111 | 112 | 113 | class GaussCRF(nn.Module): 114 | """ Implements ConvCRF with hand-crafted features. 115 | 116 | It uses the more generic ConvCRF class as basis and utilizes a config 117 | dict to easily set hyperparameters and follows the design choices of: 118 | Philipp Kraehenbuehl and Vladlen Koltun, "Efficient Inference in Fully 119 | "Connected CRFs with Gaussian Edge Pots" (arxiv.org/abs/1210.5644) 120 | """ 121 | 122 | def __init__(self, conf, shape, nclasses=None, use_gpu=True): 123 | super(GaussCRF, self).__init__() 124 | 125 | self.conf = conf 126 | self.shape = shape 127 | self.nclasses = nclasses 128 | 129 | self.trainable = conf['trainable'] 130 | 131 | if not conf['trainable_bias']: 132 | self.register_buffer('mesh', self._create_mesh()) 133 | else: 134 | self.register_parameter('mesh', Parameter(self._create_mesh())) 135 | 136 | if self.trainable: 137 | def register(name, tensor): 138 | self.register_parameter(name, Parameter(tensor)) 139 | else: 140 | def register(name, tensor): 141 | self.register_buffer(name, Variable(tensor)) 142 | 143 | register('pos_sdims', torch.Tensor([1 / conf['pos_feats']['sdims']])) 144 | 145 | if conf['col_feats']['use_bias']: 146 | register('col_sdims', 147 | torch.Tensor([1 / conf['col_feats']['sdims']])) 148 | else: 149 | self.col_sdims = None 150 | 151 | register('col_schan', torch.Tensor([1 / conf['col_feats']['schan']])) 152 | register('col_compat', torch.Tensor([conf['col_feats']['compat']])) 153 | register('pos_compat', torch.Tensor([conf['pos_feats']['compat']])) 154 | 155 | if conf['weight'] is None: 156 | weight = None 157 | elif conf['weight'] == 'scalar': 158 | val = conf['weight_init'] 159 | weight = torch.Tensor([val]) 160 | elif conf['weight'] == 'vector': 161 | val = conf['weight_init'] 162 | weight = val * torch.ones(1, nclasses, 1, 1) 163 | 164 | self.CRF = ConvCRF( 165 | shape, nclasses, mode="col", conf=conf, 166 | use_gpu=use_gpu, filter_size=conf['filter_size'], 167 | norm=conf['norm'], blur=conf['blur'], trainable=conf['trainable'], 168 | convcomp=conf['convcomp'], weight=weight, 169 | final_softmax=conf['final_softmax'], 170 | unary_weight=conf['unary_weight'], 171 | pyinn=conf['pyinn']) 172 | 173 | return 174 | 175 | def forward(self, unary, img, num_iter=5): 176 | """ Run a forward pass through ConvCRF. 177 | 178 | Arguments: 179 | unary: torch.Tensor with shape [bs, num_classes, height, width]. 180 | The unary predictions. Logsoftmax is applied to the unaries 181 | during inference. When using CNNs don't apply softmax, 182 | use unnormalized output (logits) instead. 183 | 184 | img: torch.Tensor with shape [bs, 3, height, width] 185 | The input image. Default config assumes image 186 | data in [0, 255]. For normalized images adapt 187 | `schan`. Try schan = 0.1 for images in [-0.5, 0.5] 188 | """ 189 | 190 | conf = self.conf 191 | 192 | bs, c, x, y = img.shape 193 | 194 | pos_feats = self.create_position_feats(sdims=self.pos_sdims, bs=bs) 195 | col_feats = self.create_colour_feats( 196 | img, sdims=self.col_sdims, schan=self.col_schan, 197 | bias=conf['col_feats']['use_bias'], bs=bs) 198 | 199 | compats = [self.pos_compat, self.col_compat] 200 | 201 | self.CRF.add_pairwise_energies([pos_feats, col_feats], 202 | compats, conf['merge']) 203 | 204 | prediction = self.CRF.inference(unary, num_iter=num_iter) 205 | 206 | self.CRF.clean_filters() 207 | return prediction 208 | 209 | def _create_mesh(self, requires_grad=False): 210 | hcord_range = [range(s) for s in self.shape] 211 | mesh = np.array(np.meshgrid(*hcord_range, indexing='ij'), 212 | dtype=np.float32) 213 | 214 | return torch.from_numpy(mesh) 215 | 216 | def create_colour_feats(self, img, schan, sdims=0.0, bias=True, bs=1): 217 | norm_img = img * schan 218 | 219 | if bias: 220 | norm_mesh = self.create_position_feats(sdims=sdims, bs=bs) 221 | feats = torch.cat([norm_mesh, norm_img], dim=1) 222 | else: 223 | feats = norm_img 224 | return feats 225 | 226 | def create_position_feats(self, sdims, bs=1): 227 | if type(self.mesh) is Parameter: 228 | return torch.stack(bs * [self.mesh * sdims]) 229 | else: 230 | return torch.stack(bs * [Variable(self.mesh) * sdims]) 231 | 232 | 233 | def show_memusage(device=0, name=""): 234 | import gpustat 235 | gc.collect() 236 | gpu_stats = gpustat.GPUStatCollection.new_query() 237 | item = gpu_stats.jsonify()["gpus"][device] 238 | 239 | logging.info("{:>5}/{:>5} MB Usage at {}".format( 240 | item["memory.used"], item["memory.total"], name)) 241 | 242 | 243 | def exp_and_normalize(features, dim=0): 244 | """ 245 | Aka "softmax" in deep learning literature 246 | """ 247 | normalized = torch.nn.functional.softmax(features, dim=dim) 248 | return normalized 249 | 250 | 251 | def _get_ind(dz): 252 | if dz == 0: 253 | return 0, 0 254 | if dz < 0: 255 | return 0, -dz 256 | if dz > 0: 257 | return dz, 0 258 | 259 | 260 | def _negative(dz): 261 | """ 262 | Computes -dz for numpy indexing. Goal is to use as in array[i:-dz]. 263 | 264 | However, if dz=0 this indexing does not work. 265 | None needs to be used instead. 266 | """ 267 | if dz == 0: 268 | return None 269 | else: 270 | return -dz 271 | 272 | 273 | class MessagePassingCol(): 274 | """ Perform the Message passing of ConvCRFs. 275 | 276 | The main magic happens here. 277 | """ 278 | 279 | def __init__(self, feat_list, compat_list, merge, npixels, nclasses, 280 | norm="sym", 281 | filter_size=5, clip_edges=0, use_gpu=False, 282 | blur=1, matmul=False, verbose=False, pyinn=False): 283 | 284 | if not norm == "sym" and not norm == "none": 285 | raise NotImplementedError 286 | 287 | span = filter_size // 2 288 | assert(filter_size % 2 == 1) 289 | self.span = span 290 | self.filter_size = filter_size 291 | self.use_gpu = use_gpu 292 | self.verbose = verbose 293 | self.blur = blur 294 | self.pyinn = pyinn 295 | 296 | self.merge = merge 297 | 298 | self.npixels = npixels 299 | 300 | if not self.blur == 1 and self.blur % 2: 301 | raise NotImplementedError 302 | 303 | self.matmul = matmul 304 | 305 | self._gaus_list = [] 306 | self._norm_list = [] 307 | 308 | for feats, compat in zip(feat_list, compat_list): 309 | gaussian = self._create_convolutional_filters(feats) 310 | if not norm == "none": 311 | mynorm = self._get_norm(gaussian) 312 | self._norm_list.append(mynorm) 313 | else: 314 | self._norm_list.append(None) 315 | 316 | gaussian = compat * gaussian 317 | self._gaus_list.append(gaussian) 318 | 319 | if merge: 320 | self.gaussian = sum(self._gaus_list) 321 | if not norm == 'none': 322 | raise NotImplementedError 323 | 324 | def _get_norm(self, gaus): 325 | norm_tensor = torch.ones([1, 1, self.npixels[0], self.npixels[1]]) 326 | normalization_feats = torch.autograd.Variable(norm_tensor) 327 | if self.use_gpu: 328 | normalization_feats = normalization_feats.cuda() 329 | 330 | norm_out = self._compute_gaussian(normalization_feats, gaussian=gaus) 331 | return 1 / torch.sqrt(norm_out + 1e-20) 332 | 333 | def _create_convolutional_filters(self, features): 334 | 335 | span = self.span 336 | 337 | bs = features.shape[0] 338 | 339 | if self.blur > 1: 340 | off_0 = (self.blur - self.npixels[0] % self.blur) % self.blur 341 | off_1 = (self.blur - self.npixels[1] % self.blur) % self.blur 342 | pad_0 = math.ceil(off_0 / 2) 343 | pad_1 = math.ceil(off_1 / 2) 344 | if self.blur == 2: 345 | assert(pad_0 == self.npixels[0] % 2) 346 | assert(pad_1 == self.npixels[1] % 2) 347 | 348 | features = torch.nn.functional.avg_pool2d(features, 349 | kernel_size=self.blur, 350 | padding=(pad_0, pad_1), 351 | count_include_pad=False) 352 | 353 | npixels = [math.ceil(self.npixels[0] / self.blur), 354 | math.ceil(self.npixels[1] / self.blur)] 355 | assert(npixels[0] == features.shape[2]) 356 | assert(npixels[1] == features.shape[3]) 357 | else: 358 | npixels = self.npixels 359 | 360 | gaussian_tensor = features.data.new( 361 | bs, self.filter_size, self.filter_size, 362 | npixels[0], npixels[1]).fill_(0) 363 | 364 | gaussian = Variable(gaussian_tensor) 365 | 366 | for dx in range(-span, span + 1): 367 | for dy in range(-span, span + 1): 368 | 369 | dx1, dx2 = _get_ind(dx) 370 | dy1, dy2 = _get_ind(dy) 371 | 372 | feat_t = features[:, :, dx1:_negative(dx2), dy1:_negative(dy2)] 373 | feat_t2 = features[:, :, dx2:_negative(dx1), dy2:_negative(dy1)] # NOQA 374 | 375 | diff = feat_t - feat_t2 376 | diff_sq = diff * diff 377 | exp_diff = torch.exp(torch.sum(-0.5 * diff_sq, dim=1)) 378 | 379 | gaussian[:, dx + span, dy + span, 380 | dx2:_negative(dx1), dy2:_negative(dy1)] = exp_diff 381 | 382 | return gaussian.view( 383 | bs, 1, self.filter_size, self.filter_size, 384 | npixels[0], npixels[1]) 385 | 386 | def compute(self, input): 387 | if self.merge: 388 | pred = self._compute_gaussian(input, self.gaussian) 389 | else: 390 | assert(len(self._gaus_list) == len(self._norm_list)) 391 | pred = 0 392 | for gaus, norm in zip(self._gaus_list, self._norm_list): 393 | pred += self._compute_gaussian(input, gaus, norm) 394 | 395 | return pred 396 | 397 | def _compute_gaussian(self, input, gaussian, norm=None): 398 | 399 | if norm is not None: 400 | input = input * norm 401 | 402 | shape = input.shape 403 | num_channels = shape[1] 404 | bs = shape[0] 405 | 406 | if self.blur > 1: 407 | off_0 = (self.blur - self.npixels[0] % self.blur) % self.blur 408 | off_1 = (self.blur - self.npixels[1] % self.blur) % self.blur 409 | pad_0 = int(math.ceil(off_0 / 2)) 410 | pad_1 = int(math.ceil(off_1 / 2)) 411 | input = torch.nn.functional.avg_pool2d(input, 412 | kernel_size=self.blur, 413 | padding=(pad_0, pad_1), 414 | count_include_pad=False) 415 | npixels = [math.ceil(self.npixels[0] / self.blur), 416 | math.ceil(self.npixels[1] / self.blur)] 417 | assert(npixels[0] == input.shape[2]) 418 | assert(npixels[1] == input.shape[3]) 419 | else: 420 | npixels = self.npixels 421 | 422 | if self.verbose: 423 | show_memusage(name="Init") 424 | 425 | if self.pyinn: 426 | input_col = P.im2col(input, self.filter_size, 1, self.span) 427 | else: 428 | # An alternative implementation of num2col. 429 | # 430 | # This has implementation uses the torch 0.4 im2col operation. 431 | # This implementation was not avaible when we did the experiments 432 | # published in our paper. So less "testing" has been done. 433 | # 434 | # It is around ~20% slower then the pyinn implementation but 435 | # easier to use as it removes a dependency. 436 | input_unfold = F.unfold(input, self.filter_size, 1, self.span) 437 | input_unfold = input_unfold.view( 438 | bs, num_channels, self.filter_size, self.filter_size, 439 | npixels[0], npixels[1]) 440 | input_col = input_unfold 441 | 442 | k_sqr = self.filter_size * self.filter_size 443 | 444 | if self.verbose: 445 | show_memusage(name="Im2Col") 446 | 447 | product = gaussian * input_col 448 | if self.verbose: 449 | show_memusage(name="Product") 450 | 451 | product = product.view([bs, num_channels, 452 | k_sqr, npixels[0], npixels[1]]) 453 | 454 | message = product.sum(2) 455 | 456 | if self.verbose: 457 | show_memusage(name="FinalNorm") 458 | 459 | if self.blur > 1: 460 | in_0 = self.npixels[0] 461 | in_1 = self.npixels[1] 462 | message = message.view(bs, num_channels, npixels[0], npixels[1]) 463 | with warnings.catch_warnings(): 464 | warnings.simplefilter("ignore") 465 | # Suppress warning regarding corner alignment 466 | message = torch.nn.functional.upsample(message, 467 | scale_factor=self.blur, 468 | mode='bilinear') 469 | 470 | message = message[:, :, pad_0:pad_0 + in_0, pad_1:in_1 + pad_1] 471 | message = message.contiguous() 472 | 473 | message = message.view(shape) 474 | assert(message.shape == shape) 475 | 476 | if norm is not None: 477 | message = norm * message 478 | 479 | return message 480 | 481 | 482 | class ConvCRF(nn.Module): 483 | """ 484 | Implements a generic CRF class. 485 | 486 | This class provides tools to build 487 | your own ConvCRF based model. 488 | """ 489 | 490 | def __init__(self, npixels, nclasses, conf, 491 | mode="conv", filter_size=5, 492 | clip_edges=0, blur=1, use_gpu=False, 493 | norm='sym', merge=False, 494 | verbose=False, trainable=False, 495 | convcomp=False, weight=None, 496 | final_softmax=True, unary_weight=10, 497 | pyinn=False): 498 | 499 | super(ConvCRF, self).__init__() 500 | self.nclasses = nclasses 501 | 502 | self.filter_size = filter_size 503 | self.clip_edges = clip_edges 504 | self.use_gpu = use_gpu 505 | self.mode = mode 506 | self.norm = norm 507 | self.merge = merge 508 | self.kernel = None 509 | self.verbose = verbose 510 | self.blur = blur 511 | self.final_softmax = final_softmax 512 | self.pyinn = pyinn 513 | 514 | self.conf = conf 515 | 516 | self.unary_weight = unary_weight 517 | 518 | if self.use_gpu: 519 | if not torch.cuda.is_available(): 520 | logging.error("GPU mode requested but not avaible.") 521 | logging.error("Please run using use_gpu=False.") 522 | raise ValueError 523 | 524 | self.npixels = npixels 525 | 526 | if type(npixels) is tuple or type(npixels) is list: 527 | self.height = npixels[0] 528 | self.width = npixels[1] 529 | else: 530 | self.npixels = npixels 531 | 532 | if trainable: 533 | def register(name, tensor): 534 | self.register_parameter(name, Parameter(tensor)) 535 | else: 536 | def register(name, tensor): 537 | self.register_buffer(name, Variable(tensor)) 538 | 539 | if weight is None: 540 | self.weight = None 541 | else: 542 | register('weight', weight) 543 | 544 | if convcomp: 545 | self.comp = nn.Conv2d(nclasses, nclasses, 546 | kernel_size=1, stride=1, padding=0, 547 | bias=False) 548 | 549 | self.comp.weight.data.fill_(0.1 * math.sqrt(2.0 / nclasses)) 550 | else: 551 | self.comp = None 552 | 553 | def clean_filters(self): 554 | self.kernel = None 555 | 556 | def add_pairwise_energies(self, feat_list, compat_list, merge): 557 | assert(len(feat_list) == len(compat_list)) 558 | 559 | self.kernel = MessagePassingCol( 560 | feat_list=feat_list, 561 | compat_list=compat_list, 562 | merge=merge, 563 | npixels=self.npixels, 564 | filter_size=self.filter_size, 565 | nclasses=self.nclasses, 566 | use_gpu=self.use_gpu, 567 | norm=self.norm, 568 | verbose=self.verbose, 569 | blur=self.blur, 570 | pyinn=self.pyinn) 571 | 572 | def inference(self, unary, num_iter=5): 573 | 574 | if not self.conf['logsoftmax']: 575 | lg_unary = torch.log(unary) 576 | prediction = exp_and_normalize(lg_unary, dim=1) 577 | else: 578 | lg_unary = nnfun.log_softmax(unary, dim=1, _stacklevel=5) 579 | if self.conf['softmax'] and False: 580 | prediction = exp_and_normalize(lg_unary, dim=1) 581 | else: 582 | prediction = lg_unary 583 | 584 | for i in range(num_iter): 585 | message = self.kernel.compute(prediction) 586 | 587 | if self.comp is not None: 588 | # message_r = message.view(tuple([1]) + message.shape) 589 | comp = self.comp(message) 590 | message = message + comp 591 | 592 | if self.weight is None: 593 | prediction = lg_unary + message 594 | else: 595 | prediction = (self.unary_weight - self.weight) * lg_unary + \ 596 | self.weight * message 597 | 598 | if not i == num_iter - 1 or self.final_softmax: 599 | if self.conf['softmax']: 600 | prediction = exp_and_normalize(prediction, dim=1) 601 | 602 | return prediction 603 | 604 | def start_inference(self): 605 | pass 606 | 607 | def step_inference(self): 608 | pass 609 | 610 | 611 | def get_test_conf(): 612 | return test_config.copy() 613 | 614 | 615 | def get_default_conf(): 616 | return default_conf.copy() 617 | 618 | if __name__ == "__main__": 619 | conf = get_test_conf() 620 | tcrf = GaussCRF(conf, [10, 10], None).cuda() 621 | 622 | unary = test_utils._get_simple_unary() 623 | img = test_utils._get_simple_img() 624 | 625 | img = np.transpose(img, [2, 0, 1]) 626 | img_torch = Variable(torch.Tensor(img), requires_grad=False).cuda() 627 | 628 | unary_var = Variable(torch.Tensor(unary)).cuda() 629 | unary_var = unary_var.view(2, 10, 10) 630 | img_var = Variable(torch.Tensor(img)).cuda() 631 | 632 | prediction = tcrf.forward(unary_var, img_var).cpu().data.numpy() 633 | res = np.argmax(prediction, axis=0) 634 | import scipy.misc 635 | scp.misc.imsave("out.png", res) 636 | # d.addPairwiseBilateral(2, 2, img, 3) 637 | -------------------------------------------------------------------------------- /data/.directory: -------------------------------------------------------------------------------- 1 | [Dolphin] 2 | PreviewsShown=true 3 | Timestamp=2018,5,11,21,58,17 4 | Version=3 5 | -------------------------------------------------------------------------------- /data/2007_000033_0img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000033_0img.png -------------------------------------------------------------------------------- /data/2007_000033_5labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000033_5labels.png -------------------------------------------------------------------------------- /data/2007_000129_0img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000129_0img.png -------------------------------------------------------------------------------- /data/2007_000129_5labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000129_5labels.png -------------------------------------------------------------------------------- /data/2007_000332_0img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000332_0img.png -------------------------------------------------------------------------------- /data/2007_000332_5labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000332_5labels.png -------------------------------------------------------------------------------- /data/2007_000346_0img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000346_0img.png -------------------------------------------------------------------------------- /data/2007_000346_5labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000346_5labels.png -------------------------------------------------------------------------------- /data/2007_000847_0img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000847_0img.png -------------------------------------------------------------------------------- /data/2007_000847_5labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_000847_5labels.png -------------------------------------------------------------------------------- /data/2007_001284_0img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_001284_0img.png -------------------------------------------------------------------------------- /data/2007_001284_5labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_001284_5labels.png -------------------------------------------------------------------------------- /data/2007_001288_0img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_001288_0img.png -------------------------------------------------------------------------------- /data/2007_001288_5labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/2007_001288_5labels.png -------------------------------------------------------------------------------- /data/output/Res1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/output/Res1.png -------------------------------------------------------------------------------- /data/output/Res2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/output/Res2.pdf -------------------------------------------------------------------------------- /data/output/Res2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/output/Res2.png -------------------------------------------------------------------------------- /data/output/Res_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/data/output/Res_1.png -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 Marvin Teichmann 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import sys 13 | 14 | import numpy as np 15 | import imageio 16 | # import scipy as scp 17 | # import scipy.misc 18 | 19 | import argparse 20 | 21 | import logging 22 | import time 23 | 24 | from convcrf import convcrf 25 | 26 | import torch 27 | from torch.autograd import Variable 28 | 29 | from utils import pascal_visualizer as vis 30 | from utils import synthetic 31 | 32 | try: 33 | import matplotlib.pyplot as plt 34 | figure = plt.figure() 35 | matplotlib = True 36 | plt.close(figure) 37 | except: 38 | matplotlib = False 39 | pass 40 | 41 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 42 | level=logging.INFO, 43 | stream=sys.stdout) 44 | 45 | 46 | def do_crf_inference(image, unary, args): 47 | 48 | if args.pyinn or not hasattr(torch.nn.functional, 'unfold'): 49 | # pytorch 0.3 or older requires pyinn. 50 | args.pyinn = True 51 | # Cheap and easy trick to make sure that pyinn is loadable. 52 | import pyinn 53 | 54 | # get basic hyperparameters 55 | num_classes = unary.shape[2] 56 | shape = image.shape[0:2] 57 | config = convcrf.default_conf 58 | config['filter_size'] = 7 59 | config['pyinn'] = args.pyinn 60 | 61 | if args.normalize: 62 | # Warning, applying image normalization affects CRF computation. 63 | # The parameter 'col_feats::schan' needs to be adapted. 64 | 65 | # Normalize image range 66 | # This changes the image features and influences CRF output 67 | image = image / 255 68 | # mean substraction 69 | # CRF is invariant to mean subtraction, output is NOT affected 70 | image = image - 0.5 71 | # std normalization 72 | # Affect CRF computation 73 | image = image / 0.3 74 | 75 | # schan = 0.1 is a good starting value for normalized images. 76 | # The relation is f_i = image * schan 77 | config['col_feats']['schan'] = 0.1 78 | 79 | # make input pytorch compatible 80 | image = image.transpose(2, 0, 1) # shape: [3, hight, width] 81 | # Add batch dimension to image: [1, 3, height, width] 82 | image = image.reshape([1, 3, shape[0], shape[1]]) 83 | img_var = Variable(torch.Tensor(image)) 84 | 85 | unary = unary.transpose(2, 0, 1) # shape: [3, hight, width] 86 | # Add batch dimension to unary: [1, 21, height, width] 87 | unary = unary.reshape([1, num_classes, shape[0], shape[1]]) 88 | unary_var = Variable(torch.Tensor(unary)) 89 | 90 | logging.info("Build ConvCRF.") 91 | ## 92 | # Create CRF module 93 | gausscrf = convcrf.GaussCRF(conf=config, shape=shape, nclasses=num_classes, 94 | use_gpu=not args.cpu) 95 | 96 | # move to GPU if requested 97 | if not args.cpu: 98 | img_var = img_var.cuda() 99 | unary_var = unary_var.cuda() 100 | gausscrf.cuda() 101 | 102 | logging.info("Start Computation.") 103 | # Perform CRF inference 104 | prediction = gausscrf.forward(unary=unary_var, img=img_var) 105 | 106 | if args.nospeed: 107 | # Evaluate inference speed 108 | logging.info("Doing speed evaluation.") 109 | start_time = time.time() 110 | for i in range(10): 111 | # Running ConvCRF 10 times and average total time 112 | pred = gausscrf.forward(unary=unary_var, img=img_var) 113 | 114 | pred.cpu() # wait for all GPU computations to finish 115 | 116 | duration = (time.time() - start_time) * 1000 / 10 117 | 118 | logging.info("Finished running 10 predictions.") 119 | logging.info("Avg. Computation time: {} ms".format(duration)) 120 | 121 | return prediction.data.cpu().numpy() 122 | 123 | 124 | def plot_results(image, unary, prediction, label, args): 125 | 126 | logging.info("Plot results.") 127 | 128 | # Create visualizer 129 | myvis = vis.PascalVisualizer() 130 | 131 | # Transform id image to coloured labels 132 | coloured_label = myvis.id2color(id_image=label) 133 | 134 | unary_hard = np.argmax(unary, axis=2) 135 | coloured_unary = myvis.id2color(id_image=unary_hard) 136 | 137 | prediction = prediction[0] # Remove Batch dimension 138 | prediction_hard = np.argmax(prediction, axis=0) 139 | coloured_crf = myvis.id2color(id_image=prediction_hard) 140 | 141 | if matplotlib: 142 | # Plot results using matplotlib 143 | figure = plt.figure() 144 | figure.tight_layout() 145 | 146 | # Plot parameters 147 | num_rows = 2 148 | num_cols = 2 149 | 150 | ax = figure.add_subplot(num_rows, num_cols, 1) 151 | # img_name = os.path.basename(args.image) 152 | ax.set_title('Image ') 153 | ax.axis('off') 154 | ax.imshow(image) 155 | 156 | ax = figure.add_subplot(num_rows, num_cols, 2) 157 | ax.set_title('Label') 158 | ax.axis('off') 159 | ax.imshow(coloured_label.astype(np.uint8)) 160 | 161 | ax = figure.add_subplot(num_rows, num_cols, 3) 162 | ax.set_title('Unary') 163 | ax.axis('off') 164 | ax.imshow(coloured_unary.astype(np.uint8)) 165 | 166 | ax = figure.add_subplot(num_rows, num_cols, 4) 167 | ax.set_title('CRF Output') 168 | ax.axis('off') 169 | ax.imshow(coloured_crf.astype(np.uint8)) 170 | 171 | plt.show() 172 | else: 173 | if args.output is None: 174 | args.output = "out.png" 175 | 176 | logging.warning("Matplotlib not found.") 177 | logging.info("Saving output to {} instead".format(args.output)) 178 | 179 | if args.output is not None: 180 | # Save results to disk 181 | out_img = np.concatenate( 182 | (image, coloured_label, coloured_unary, coloured_crf), 183 | axis=1) 184 | 185 | imageio.imwrite(args.output, out_img.astype(np.uint8)) 186 | 187 | logging.info("Plot has been saved to {}".format(args.output)) 188 | 189 | return 190 | 191 | 192 | def get_parser(): 193 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 194 | parser = ArgumentParser(description=__doc__, 195 | formatter_class=ArgumentDefaultsHelpFormatter) 196 | 197 | parser.add_argument("image", type=str, 198 | help="input image") 199 | 200 | parser.add_argument("label", type=str, 201 | help="Label file.") 202 | 203 | parser.add_argument("--gpu", type=str, default='0', 204 | help="which gpu to use") 205 | 206 | parser.add_argument('--output', type=str, 207 | help="Optionally save output as img.") 208 | 209 | parser.add_argument('--nospeed', action='store_false', 210 | help="Skip speed evaluation.") 211 | 212 | parser.add_argument('--normalize', action='store_true', 213 | help="Normalize input image before inference.") 214 | 215 | parser.add_argument('--pyinn', action='store_true', 216 | help="Use pyinn based Cuda implementation" 217 | "for message passing.") 218 | 219 | parser.add_argument('--cpu', action='store_true', 220 | help="Run on CPU instead of GPU.") 221 | 222 | # parser.add_argument('--compare', action='store_true') 223 | # parser.add_argument('--embed', action='store_true') 224 | 225 | # args = parser.parse_args() 226 | 227 | return parser 228 | 229 | 230 | if __name__ == '__main__': 231 | parser = get_parser() 232 | args = parser.parse_args() 233 | 234 | # Load data 235 | image = imageio.imread(args.image) 236 | label = imageio.imread(args.label) 237 | 238 | # Produce unary by adding noise to label 239 | unary = synthetic.augment_label(label, num_classes=21) 240 | # Compute CRF inference 241 | prediction = do_crf_inference(image, unary, args) 242 | # Plot output 243 | plot_results(image, unary, prediction, label, args) 244 | logging.info("Thank you for trying ConvCRFs.") 245 | -------------------------------------------------------------------------------- /fullcrf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/fullcrf/__init__.py -------------------------------------------------------------------------------- /fullcrf/fullcrf.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 Marvin Teichmann 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import sys 13 | 14 | import numpy as np 15 | import scipy as scp 16 | import math 17 | 18 | import logging 19 | 20 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 21 | level=logging.INFO, 22 | stream=sys.stdout) 23 | 24 | 25 | from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral 26 | from pydensecrf.utils import create_pairwise_gaussian 27 | 28 | import torch 29 | import torch.nn as nn 30 | from torch.nn import functional as nnfun 31 | from torch.autograd import Variable 32 | from torch.nn.parameter import Parameter 33 | 34 | import gc 35 | 36 | from pydensecrf import densecrf as dcrf 37 | from pydensecrf import utils 38 | 39 | 40 | default_conf = { 41 | 'blur': 4, 42 | 'merge': False, 43 | 'norm': 'none', 44 | 'trainable': False, 45 | 'weight': 'scalar', 46 | 'weight_init': 0.2, 47 | 'convcomp': False, 48 | 49 | 'pos_feats': { 50 | 'sdims': 3, 51 | 'compat': 3, 52 | }, 53 | 'col_feats': { 54 | 'sdims': 80, 55 | 'schan': 13, 56 | 'compat': 10, 57 | 'use_bias': False 58 | }, 59 | "trainable_bias": False, 60 | } 61 | 62 | 63 | test_config = { 64 | 'filter_size': 5, 65 | 'blur': 1, 66 | 'merge': False, 67 | 'norm': 'sym', 68 | 'trainable': False, 69 | 'weight': None, 70 | 'weight_init': 5, 71 | 'convcomp': False, 72 | 73 | 'pos_feats': { 74 | 'sdims': 3, 75 | 'compat': 3, 76 | }, 77 | 78 | 'col_feats': { 79 | 'sdims': 80, 80 | 'schan': 13, 81 | 'compat': 10, 82 | 'use_bias': True 83 | }, 84 | "trainable_bias": False, 85 | } 86 | 87 | 88 | class FullCRF(): 89 | """ Implements FullCRF with hand-crafted features. 90 | 91 | This class uses pydensecrf to implement the CRF model proposed by: 92 | Philipp Kraehenbuehl and Vladlen Koltun, "Efficient Inference in Fully 93 | "Connected CRFs with Gaussian Edge Pots" (arxiv.org/abs/1210.5644) 94 | """ 95 | 96 | def __init__(self, conf, shape, num_classes=None): 97 | self.crf = None 98 | self.conf = conf 99 | self.num_classes = num_classes 100 | self.shape = shape 101 | 102 | def compute_lattice(self, img, num_classes=None): 103 | """ 104 | Compute indices for the lattice approximation. 105 | 106 | Arguments: 107 | img: np.array with shape [height, width, 3] 108 | The input image. Default config assumes image 109 | data in [0, 255]. For normalized images adapt 110 | `schan`. Try schan = 0.1 for images in [-0.5, 0.5] 111 | """ 112 | 113 | if num_classes is not None: 114 | self.num_classes = num_classes 115 | 116 | assert self.num_classes is not None 117 | 118 | npixels = self.shape[0] * self.shape[1] 119 | crf = dcrf.DenseCRF(npixels, self.num_classes) 120 | 121 | sdims = self.conf['pos_feats']['sdims'] 122 | 123 | feats = utils.create_pairwise_gaussian( 124 | sdims=(sdims, sdims), 125 | shape=img.shape[:2]) 126 | 127 | self.smooth_feats = feats 128 | 129 | self.crf = crf 130 | 131 | self.crf.addPairwiseEnergy( 132 | self.smooth_feats, compat=self.conf['pos_feats']['compat']) 133 | 134 | sdims = self.conf['col_feats']['sdims'] 135 | schan = self.conf['col_feats']['schan'] 136 | 137 | feats = utils.create_pairwise_bilateral(sdims=(sdims, sdims), 138 | schan=(schan, schan, schan), 139 | img=img, chdim=2) 140 | 141 | self.appear_feats = feats 142 | 143 | self.crf.addPairwiseEnergy( 144 | self.appear_feats, compat=self.conf['pos_feats']['compat']) 145 | 146 | def compute_dcrf(self, unary): 147 | """ 148 | Compute dcrf assuming compute_lattice was called. 149 | 150 | Arguments: 151 | unary: np.array with shape [height, width, num_classes] 152 | The unary predictions. 153 | """ 154 | 155 | eps = 1e-20 156 | unary = unary + eps 157 | unary = unary.reshape(-1, self.num_classes) 158 | unary = np.transpose(unary) 159 | unary = np.ascontiguousarray(unary, dtype=np.float32) 160 | self.crf.setUnaryEnergy(-np.log(unary)) 161 | 162 | # Run five inference steps. 163 | crfout = self.crf.inference(5) 164 | crfout = np.transpose(crfout) 165 | crfout = crfout.reshape(self.shape[0], self.shape[1], -1) 166 | 167 | return crfout 168 | 169 | def compute(self, unary, img, softmax=False): 170 | """ 171 | Full forward pass on numpy arrays. 172 | 173 | This function calls `compute_lattice` followed by compute_dcrf 174 | 175 | Arguments: 176 | unary: np.array with shape [height, width, num_classes] 177 | The unary predictions. 178 | img: np.array with shape [height, width, 3] 179 | The input image. Default config assumes image 180 | data in [0, 255]. For normalized images adapt 181 | `schan`. Try schan = 0.1 for images in [-0.5, 0.5] 182 | 183 | softmax: bool 184 | Whether to apply softmax. Unaries need to be normalized. 185 | """ 186 | if softmax: 187 | unary = torch.nn.functional.softmax( 188 | Variable(torch.Tensor(unary)), dim=2) 189 | unary = unary.data.numpy() 190 | self.compute_lattice(img) 191 | return self.compute_dcrf(unary) 192 | 193 | def batched_compute(self, unary, img, softmax=False): 194 | """ 195 | Perform compute on batched torch.tensors. 196 | 197 | Arguments: 198 | unary: torch.Tensor with shape [bs, num_classes, height, width]. 199 | The unary predictions. 200 | 201 | img: torch.Tensor with shape [bs, 3, height, width] 202 | The input image. Default config assumes image 203 | data in [0, 255]. For normalized images adapt 204 | `schan`. Try schan = 0.1 for images in [-0.5, 0.5] 205 | 206 | softmax: bool 207 | Whether to apply softmax. Unaries need to be normalized. 208 | """ 209 | 210 | img = img.data.cpu().numpy() 211 | unary = unary.data.cpu().numpy() 212 | 213 | img = img.transpose(0, 2, 3, 1) 214 | unary = unary.transpose(0, 2, 3, 1) 215 | 216 | results = [] 217 | 218 | for d in range(img.shape[0]): 219 | img_d = img[d] 220 | unary_d = unary[d] 221 | res = self.compute(unary_d, img_d, softmax) 222 | results.append(res) 223 | 224 | return results 225 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=1.5.1 2 | numpy>=1.11.1 3 | cython>=0.27.1 4 | imageio 5 | scikit-image 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='ConvCRF', 6 | version='1.0', 7 | description='Reference Implementation of ConvCRF.', 8 | author='Marvin Teichmann', 9 | author_email=('marvin.teichmann@googlemail.com'), 10 | packages=find_packages(), 11 | package_data={'': ['*.lst']} 12 | ) 13 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarvinTeichmann/ConvCRF/09306378ebc76e38f91aeaf57be5ce36ec2b873c/utils/__init__.py -------------------------------------------------------------------------------- /utils/pascal_visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | from collections import OrderedDict 4 | import json 5 | import logging 6 | import sys 7 | import random 8 | 9 | import numpy as np 10 | import scipy as scp 11 | import scipy.misc 12 | 13 | try: 14 | import matplotlib.pyplot as plt 15 | except ImportError: 16 | pass 17 | 18 | from . import visualization as vis 19 | 20 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 21 | level=logging.INFO, 22 | stream=sys.stdout) 23 | 24 | voc_names = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 25 | 'bottle', 'bus', 'car', 'cat', 26 | 'chair', 'cow', 'diningtable', 'dog', 27 | 'horse', 'motorbike', 'person', 'potted-plant', 28 | 'sheep', 'sofa', 'train', 'tv/monitor'] 29 | 30 | color_list = [(0, 0, 0), 31 | (128, 0, 0), 32 | (0, 128, 0), 33 | (128, 128, 0), 34 | (0, 0, 128), 35 | (128, 0, 128), 36 | (0, 128, 128), 37 | (128, 128, 128), 38 | (64, 0, 0), 39 | (192, 0, 0), 40 | (64, 128, 0), 41 | (192, 128, 0), 42 | (64, 0, 128), 43 | (192, 0, 128), 44 | (64, 128, 128), 45 | (192, 128, 128), 46 | (0, 64, 0), 47 | (128, 64, 0), 48 | (0, 192, 0), 49 | (128, 192, 0), 50 | (0, 64, 128)] 51 | 52 | 53 | class PascalVisualizer(vis.SegmentationVisualizer): 54 | 55 | def __init__(self): 56 | super(PascalVisualizer, self).__init__( 57 | color_list=color_list, name_list=voc_names) 58 | 59 | def plot_sample(self, sample): 60 | 61 | image = sample['image'].transpose(1, 2, 0) 62 | label = sample['label'] 63 | mask = label != -100 64 | 65 | idx = eval(sample['load_dict'])['idx'] 66 | 67 | coloured_label = self.id2color(id_image=label, 68 | mask=mask) 69 | 70 | figure = plt.figure() 71 | figure.tight_layout() 72 | 73 | ax = figure.add_subplot(1, 2, 1) 74 | ax.set_title('Image #{}'.format(idx)) 75 | ax.axis('off') 76 | ax.imshow(image) 77 | 78 | ax = figure.add_subplot(1, 2, 2) 79 | ax.set_title('Label') 80 | ax.axis('off') 81 | ax.imshow(coloured_label.astype(np.uint8)) 82 | 83 | return figure 84 | 85 | def plot_segmentation_batch(self, sample_batch, prediction): 86 | figure = plt.figure() 87 | figure.tight_layout() 88 | 89 | batch_size = len(sample_batch['load_dict']) 90 | figure.set_size_inches(12, 3 * batch_size) 91 | 92 | for d in range(batch_size): 93 | image = sample_batch['image'][d].numpy().transpose(1, 2, 0) 94 | label = sample_batch['label'][d].numpy() 95 | 96 | mask = label != -100 97 | 98 | pred = prediction[d].cpu().data.numpy().transpose(1, 2, 0) 99 | pred_hard = np.argmax(pred, axis=2) 100 | 101 | idx = eval(sample_batch['load_dict'][d])['idx'] 102 | 103 | coloured_label = self.id2color(id_image=label, 104 | mask=mask) 105 | 106 | coloured_prediction = self.pred2color(pred_image=pred, 107 | mask=mask) 108 | 109 | coloured_hard = self.id2color(id_image=pred_hard, 110 | mask=mask) 111 | 112 | ax = figure.add_subplot(batch_size, 4, batch_size * d + 1) 113 | ax.set_title('Image #{}'.format(idx)) 114 | ax.axis('off') 115 | ax.imshow(image) 116 | 117 | ax = figure.add_subplot(batch_size, 4, batch_size * d + 2) 118 | ax.set_title('Label') 119 | ax.axis('off') 120 | ax.imshow(coloured_label.astype(np.uint8)) 121 | 122 | ax = figure.add_subplot(batch_size, 4, batch_size * d + 3) 123 | ax.set_title('Prediction (hard)') 124 | ax.axis('off') 125 | ax.imshow(coloured_hard.astype(np.uint8)) 126 | 127 | ax = figure.add_subplot(batch_size, 4, batch_size * d + 4) 128 | ax.set_title('Prediction (soft)') 129 | ax.axis('off') 130 | ax.imshow(coloured_prediction.astype(np.uint8)) 131 | 132 | return figure 133 | 134 | def plot_batch(self, sample_batch): 135 | 136 | figure = plt.figure() 137 | figure.tight_layout() 138 | 139 | batch_size = len(sample_batch['load_dict']) 140 | 141 | for d in range(batch_size): 142 | 143 | image = sample_batch['image'][d].numpy().transpose(1, 2, 0) 144 | label = sample_batch['label'][d].numpy() 145 | mask = label != -100 146 | 147 | idx = eval(sample_batch['load_dict'][d])['idx'] 148 | 149 | coloured_label = self.id2color(id_image=label, 150 | mask=mask) 151 | 152 | ax = figure.add_subplot(2, batch_size, d + 1) 153 | ax.set_title('Image #{}'.format(idx)) 154 | ax.axis('off') 155 | ax.imshow(image) 156 | 157 | ax = figure.add_subplot(2, batch_size, d + batch_size + 1) 158 | ax.set_title('Label') 159 | ax.axis('off') 160 | ax.imshow(coloured_label.astype(np.uint8)) 161 | 162 | return figure 163 | -------------------------------------------------------------------------------- /utils/synthetic.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 Marvin Teichmann 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import sys 13 | 14 | import numpy as np 15 | import scipy as scp 16 | 17 | import logging 18 | 19 | import skimage 20 | import skimage.transform 21 | 22 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 23 | level=logging.INFO, 24 | stream=sys.stdout) 25 | 26 | 27 | def np_onehot(label, num_classes): 28 | return np.eye(num_classes)[label] 29 | 30 | 31 | def augment_label(label, num_classes, scale=8, keep_prop=0.8): 32 | """ 33 | Add noise to label for synthetic benchmark. 34 | """ 35 | 36 | shape = label.shape 37 | label = label.reshape(shape[0], shape[1]) 38 | 39 | onehot = np_onehot(label, num_classes) 40 | lower_shape = (shape[0] // scale, shape[1] // scale) 41 | 42 | label_down = skimage.transform.resize( 43 | onehot, (lower_shape[0], lower_shape[1], num_classes), 44 | order=1, preserve_range=True, mode='constant') 45 | 46 | onehot = skimage.transform.resize(label_down, 47 | (shape[0], shape[1], num_classes), 48 | order=1, preserve_range=True, 49 | mode='constant') 50 | 51 | noise = np.random.randint(0, num_classes, lower_shape) 52 | 53 | noise = np_onehot(noise, num_classes) 54 | 55 | noise_up = skimage.transform.resize(noise, 56 | (shape[0], shape[1], num_classes), 57 | order=1, preserve_range=True, 58 | mode='constant') 59 | 60 | mask = np.floor(keep_prop + np.random.rand(*lower_shape)) 61 | mask_up = skimage.transform.resize(mask, (shape[0], shape[1], 1), 62 | order=1, preserve_range=True, 63 | mode='constant') 64 | 65 | noised_label = mask_up * onehot + (1 - mask_up) * noise_up 66 | 67 | return noised_label 68 | 69 | 70 | if __name__ == '__main__': 71 | logging.info("Hello World.") 72 | -------------------------------------------------------------------------------- /utils/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 Marvin Teichmann 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import sys 13 | 14 | import numpy as np 15 | import scipy as scp 16 | 17 | import logging 18 | 19 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 20 | level=logging.INFO, 21 | stream=sys.stdout) 22 | 23 | 24 | if __name__ == '__main__': 25 | logging.info("Hello World.") 26 | 27 | 28 | def _get_simple_unary(batched=False): 29 | unary1 = np.zeros((10, 10), dtype=np.float32) 30 | unary1[:, [0, -1]] = unary1[[0, -1], :] = 1 31 | 32 | unary2 = np.zeros((10, 10), dtype=np.float32) 33 | unary2[4:7, 4:7] = 1 34 | 35 | unary = np.vstack([unary1.flat, unary2.flat]) 36 | unary = (unary + 1) / (np.sum(unary, axis=0) + 2) 37 | 38 | if batched: 39 | unary = unary.reshape(tuple([1]) + unary) 40 | 41 | return unary 42 | 43 | 44 | def _get_simple_img(batched=False): 45 | 46 | img = np.zeros((10, 10, 3), dtype=np.uint8) 47 | img[2:8, 2:8, :] = 255 48 | 49 | if batched: 50 | img = img.reshape(tuple([1]) + img) 51 | 52 | return img 53 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 Marvin Teichmann 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import sys 13 | 14 | import numpy as np 15 | import scipy as scp 16 | 17 | import logging 18 | 19 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 20 | level=logging.INFO, 21 | stream=sys.stdout) 22 | 23 | 24 | class SegmentationVisualizer(object): 25 | """docstring for label_converter""" 26 | def __init__(self, color_list=None, name_list=None, 27 | mode='RGB'): 28 | super(SegmentationVisualizer, self).__init__() 29 | self.color_list = color_list 30 | self.name_list = name_list 31 | 32 | self.mask_color = [0, 0, 0] 33 | 34 | if mode == 'RGB': 35 | self.chan = 3 36 | 37 | def id2color(self, id_image, mask=None, ignore_idx=-100): 38 | """ 39 | Input: Int Array of shape [height, width] 40 | Containing Integers 0 <= i <= num_classes. 41 | """ 42 | 43 | if mask is None: 44 | if np.any(id_image != ignore_idx): 45 | mask = id_image != ignore_idx 46 | 47 | shape = id_image.shape 48 | gt_out = np.zeros([shape[0], shape[1], self.chan], dtype=np.int32) 49 | id_image 50 | 51 | for train_id, color in enumerate(self.color_list): 52 | c_mask = id_image == train_id 53 | c_mask = c_mask.reshape(c_mask.shape + tuple([1])) 54 | gt_out = gt_out + color * c_mask 55 | 56 | if mask is not None: 57 | mask = mask.reshape(mask.shape + tuple([1])) 58 | bg_color = [0, 0, 0] 59 | mask2 = np.all(gt_out == bg_color, axis=2) 60 | mask2 = mask2.reshape(mask2.shape + tuple([1])) 61 | gt_out = gt_out + mask2 * (self.mask_color * (1 - mask)) 62 | 63 | return gt_out 64 | 65 | def pred2color(self, pred_image, mask=None): 66 | 67 | color_image = np.dot(pred_image, self.color_list) 68 | 69 | if mask is not None: 70 | 71 | if len(mask.shape) == 2: 72 | mask = mask.reshape(mask.shape + tuple([1])) 73 | 74 | color_image = mask * color_image + (1 - mask) * self.mask_color 75 | 76 | return color_image 77 | 78 | def color2id(self, color_gt): 79 | assert(False) 80 | shape = color_gt.shape 81 | gt_reshaped = np.zeros([shape[0], shape[1]], dtype=np.int32) 82 | mask = np.zeros([shape[0], shape[1]], dtype=np.int32) 83 | 84 | for train_id, color in enumerate(self.color_list): 85 | gt_label = np.all(color_gt == color, axis=2) 86 | mask = mask + gt_label 87 | gt_reshaped = gt_reshaped + 10 * train_id * gt_label 88 | 89 | assert(np.max(mask) == 1) 90 | np.unique(gt_reshaped) 91 | assert(np.max(gt_reshaped) <= 200) 92 | 93 | gt_reshaped = gt_reshaped + 255 * (1 - mask) 94 | return gt_reshaped 95 | 96 | def underlay2(self, image, gt_image, labels): 97 | # TODO 98 | color_img = self.id2color(gt_image) 99 | color_labels = self.id2color(labels) 100 | 101 | output = np.concatenate((image, color_img, color_labels), axis=0) 102 | 103 | return output 104 | 105 | def overlay(self, image, gt_image): 106 | # TODO 107 | color_img = self.id2color((gt_image)) 108 | output = 0.4 * color_img[:, :] + 0.6 * image 109 | 110 | return output 111 | --------------------------------------------------------------------------------