├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── crfasrnn ├── __init__.py ├── crfasrnn_model.py ├── crfrnn.py ├── fcn8s.py ├── filters.py ├── params.py ├── permuto.cpp ├── permutohedral.cpp ├── permutohedral.h ├── setup.py └── util.py ├── image.jpg ├── quick_run.py ├── requirements.txt ├── run_demo.py └── sample.png /.gitattributes: -------------------------------------------------------------------------------- 1 | crfasrnn/permutohedral.cpp linguist-vendored 2 | crfasrnn/permutohedral.h linguist-vendored 3 | 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | .pyc 4 | 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Sadeep Jayasumana 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CRF-RNN for Semantic Image Segmentation - PyTorch version 2 | ![sample](sample.png) 3 | 4 | Live demo:                           [http://crfasrnn.torr.vision](http://crfasrnn.torr.vision)
5 | Caffe version:                      [http://github.com/torrvision/crfasrnn](http://github.com/torrvision/crfasrnn)
6 | Tensorflow/Keras version: [http://github.com/sadeepj/crfasrnn_keras](http://github.com/sadeepj/crfasrnn_keras)
7 | 8 | This repository contains the official PyTorch implementation of the "CRF-RNN" semantic image segmentation method, published in the ICCV 2015 paper [Conditional Random Fields as Recurrent Neural Networks](http://www.robots.ox.ac.uk/~szheng/papers/CRFasRNN.pdf). The [online demo](http://crfasrnn.torr.vision) of this project won the Best Demo Prize at ICCV 2015. Results of this PyTorch code are identical to that of the Caffe and Tensorflow/Keras based versions above. 9 | 10 | If you use this code/model for your research, please cite the following paper: 11 | ``` 12 | @inproceedings{crfasrnn_ICCV2015, 13 | author = {Shuai Zheng and Sadeep Jayasumana and Bernardino Romera-Paredes and Vibhav Vineet and 14 | Zhizhong Su and Dalong Du and Chang Huang and Philip H. S. Torr}, 15 | title = {Conditional Random Fields as Recurrent Neural Networks}, 16 | booktitle = {International Conference on Computer Vision (ICCV)}, 17 | year = {2015} 18 | } 19 | ``` 20 | 21 | ## Installation Guide 22 | 23 | _Note_: If you are using a Python virtualenv, make sure it is activated before running each command in this guide. 24 | 25 | ### Step 1: Clone the repository 26 | ``` 27 | $ git clone https://github.com/sadeepj/crfasrnn_pytorch.git 28 | ``` 29 | The root directory of the clone will be referred to as `crfasrnn_pytorch` hereafter. 30 | 31 | ### Step 2: Install dependencies 32 | 33 | 34 | Use the `requirements.txt` file in this repository to install all the dependencies via `pip`: 35 | ``` 36 | $ cd crfasrnn_pytorch 37 | $ pip install -r requirements.txt 38 | ``` 39 | 40 | After installing the dependencies, run the following commands to make sure they are properly installed: 41 | ``` 42 | $ python 43 | >>> import torch 44 | ``` 45 | You should not see any errors while importing `torch` above. 46 | 47 | ### Step 3: Build CRF-RNN custom op 48 | 49 | Run `setup.py` inside the `crfasrnn_pytorch/crfasrnn` directory: 50 | ``` 51 | $ cd crfasrnn_pytorch/crfasrnn 52 | $ python setup.py install 53 | ``` 54 | Note that the `python` command in the console should refer to the Python interpreter associated with your PyTorch installation. 55 | 56 | ### Step 4: Download the pre-trained model weights 57 | 58 | Download the model weights from [here](https://github.com/sadeepj/crfasrnn_pytorch/releases/download/0.0.1/crfasrnn_weights.pth) and place it in the `crfasrnn_pytorch` directory with the file name `crfasrnn_weights.pth`. 59 | 60 | ### Step 5: Run the demo 61 | ``` 62 | $ cd crfasrnn_pytorch 63 | $ python run_demo.py 64 | ``` 65 | If all goes well, you will see the segmentation results in a file named "labels.png". 66 | 67 | ## Contributors 68 | * Sadeep Jayasumana ([sadeepj](https://github.com/sadeepj)) 69 | * Harsha Ranasinghe ([HarshaPrabhath](https://github.com/HarshaPrabhath)) 70 | 71 | -------------------------------------------------------------------------------- /crfasrnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sadeepj/crfasrnn_pytorch/24899c528981dfbc14f1212869a3eae328dc6570/crfasrnn/__init__.py -------------------------------------------------------------------------------- /crfasrnn/crfasrnn_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Sadeep Jayasumana 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | from crfasrnn.crfrnn import CrfRnn 26 | from crfasrnn.fcn8s import Fcn8s 27 | 28 | 29 | class CrfRnnNet(Fcn8s): 30 | """ 31 | The full CRF-RNN network with the FCN-8s backbone as described in the paper: 32 | 33 | Conditional Random Fields as Recurrent Neural Networks, 34 | S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr, 35 | ICCV 2015 (https://arxiv.org/abs/1502.03240). 36 | """ 37 | 38 | def __init__(self): 39 | super(CrfRnnNet, self).__init__() 40 | self.crfrnn = CrfRnn(num_labels=21, num_iterations=10) 41 | 42 | def forward(self, image): 43 | out = super(CrfRnnNet, self).forward(image) 44 | # Plug the CRF-RNN module at the end 45 | return self.crfrnn(image, out) 46 | -------------------------------------------------------------------------------- /crfasrnn/crfrnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Sadeep Jayasumana 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import torch 26 | import torch.nn as nn 27 | 28 | from crfasrnn.filters import SpatialFilter, BilateralFilter 29 | from crfasrnn.params import DenseCRFParams 30 | 31 | 32 | class CrfRnn(nn.Module): 33 | """ 34 | PyTorch implementation of the CRF-RNN module described in the paper: 35 | 36 | Conditional Random Fields as Recurrent Neural Networks, 37 | S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr, 38 | ICCV 2015 (https://arxiv.org/abs/1502.03240). 39 | """ 40 | 41 | def __init__(self, num_labels, num_iterations=5, crf_init_params=None): 42 | """ 43 | Create a new instance of the CRF-RNN layer. 44 | 45 | Args: 46 | num_labels: Number of semantic labels in the dataset 47 | num_iterations: Number of mean-field iterations to perform 48 | crf_init_params: CRF initialization parameters 49 | """ 50 | super(CrfRnn, self).__init__() 51 | 52 | if crf_init_params is None: 53 | crf_init_params = DenseCRFParams() 54 | 55 | self.params = crf_init_params 56 | self.num_iterations = num_iterations 57 | 58 | self._softmax = torch.nn.Softmax(dim=0) 59 | 60 | self.num_labels = num_labels 61 | 62 | # -------------------------------------------------------------------------------------------- 63 | # --------------------------------- Trainable Parameters ------------------------------------- 64 | # -------------------------------------------------------------------------------------------- 65 | 66 | # Spatial kernel weights 67 | self.spatial_ker_weights = nn.Parameter( 68 | crf_init_params.spatial_ker_weight 69 | * torch.eye(num_labels, dtype=torch.float32) 70 | ) 71 | 72 | # Bilateral kernel weights 73 | self.bilateral_ker_weights = nn.Parameter( 74 | crf_init_params.bilateral_ker_weight 75 | * torch.eye(num_labels, dtype=torch.float32) 76 | ) 77 | 78 | # Compatibility transform matrix 79 | self.compatibility_matrix = nn.Parameter( 80 | torch.eye(num_labels, dtype=torch.float32) 81 | ) 82 | 83 | def forward(self, image, logits): 84 | """ 85 | Perform CRF inference. 86 | 87 | Args: 88 | image: Tensor of shape (3, h, w) containing the RGB image 89 | logits: Tensor of shape (num_classes, h, w) containing the unary logits 90 | Returns: 91 | log-Q distributions (logits) after CRF inference 92 | """ 93 | if logits.shape[0] != 1: 94 | raise ValueError("Only batch size 1 is currently supported!") 95 | 96 | image = image[0] 97 | logits = logits[0] 98 | 99 | spatial_filter = SpatialFilter(image, gamma=self.params.gamma) 100 | bilateral_filter = BilateralFilter( 101 | image, alpha=self.params.alpha, beta=self.params.beta 102 | ) 103 | _, h, w = image.shape 104 | cur_logits = logits 105 | 106 | for _ in range(self.num_iterations): 107 | # Normalization 108 | q_values = self._softmax(cur_logits) 109 | 110 | # Spatial filtering 111 | spatial_out = torch.mm( 112 | self.spatial_ker_weights, 113 | spatial_filter.apply(q_values).view(self.num_labels, -1), 114 | ) 115 | 116 | # Bilateral filtering 117 | bilateral_out = torch.mm( 118 | self.bilateral_ker_weights, 119 | bilateral_filter.apply(q_values).view(self.num_labels, -1), 120 | ) 121 | 122 | # Compatibility transform 123 | msg_passing_out = ( 124 | spatial_out + bilateral_out 125 | ) # Shape: (self.num_labels, -1) 126 | msg_passing_out = torch.mm(self.compatibility_matrix, msg_passing_out).view( 127 | self.num_labels, h, w 128 | ) 129 | 130 | # Adding unary potentials 131 | cur_logits = msg_passing_out + logits 132 | 133 | return torch.unsqueeze(cur_logits, 0) 134 | -------------------------------------------------------------------------------- /crfasrnn/fcn8s.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains a modified version of the FCN-8s code available in https://github.com/wkentaro/pytorch-fcn 3 | The original copyright notice from that repository is included below: 4 | 5 | Copyright (c) 2017 - 2019 Kentaro Wada. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | """ 25 | 26 | import numpy as np 27 | import torch 28 | import torch.nn as nn 29 | 30 | 31 | def _upsampling_weights(in_channels, out_channels, kernel_size): 32 | factor = (kernel_size + 1) // 2 33 | if kernel_size % 2 == 1: 34 | center = factor - 1 35 | else: 36 | center = factor - 0.5 37 | og = np.ogrid[:kernel_size, :kernel_size] 38 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 39 | weight = np.zeros( 40 | (in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64 41 | ) 42 | weight[range(in_channels), range(out_channels), :, :] = filt 43 | return torch.from_numpy(weight).float() 44 | 45 | 46 | class Fcn8s(nn.Module): 47 | def __init__(self, n_class=21): 48 | """ 49 | Create the FCN-8s network the the given number of classes. 50 | 51 | Args: 52 | n_class: The number of semantic classes. 53 | """ 54 | 55 | super(Fcn8s, self).__init__() 56 | 57 | # conv1 58 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) 59 | self.relu1_1 = nn.ReLU(inplace=True) 60 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 61 | self.relu1_2 = nn.ReLU(inplace=True) 62 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 63 | 64 | # conv2 65 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 66 | self.relu2_1 = nn.ReLU(inplace=True) 67 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 68 | self.relu2_2 = nn.ReLU(inplace=True) 69 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 70 | 71 | # conv3 72 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 73 | self.relu3_1 = nn.ReLU(inplace=True) 74 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 75 | self.relu3_2 = nn.ReLU(inplace=True) 76 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 77 | self.relu3_3 = nn.ReLU(inplace=True) 78 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 79 | 80 | # conv4 81 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 82 | self.relu4_1 = nn.ReLU(inplace=True) 83 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 84 | self.relu4_2 = nn.ReLU(inplace=True) 85 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 86 | self.relu4_3 = nn.ReLU(inplace=True) 87 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 88 | 89 | # conv5 90 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 91 | self.relu5_1 = nn.ReLU(inplace=True) 92 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 93 | self.relu5_2 = nn.ReLU(inplace=True) 94 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 95 | self.relu5_3 = nn.ReLU(inplace=True) 96 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 97 | 98 | # fc6 99 | self.fc6 = nn.Conv2d(512, 4096, 7) 100 | self.relu6 = nn.ReLU(inplace=True) 101 | self.drop6 = nn.Dropout2d() 102 | 103 | # fc7 104 | self.fc7 = nn.Conv2d(4096, 4096, 1) 105 | self.relu7 = nn.ReLU(inplace=True) 106 | self.drop7 = nn.Dropout2d() 107 | 108 | self.score_fr = nn.Conv2d(4096, n_class, 1) 109 | self.score_pool3 = nn.Conv2d(256, n_class, 1) 110 | self.score_pool4 = nn.Conv2d(512, n_class, 1) 111 | 112 | self.upscore2 = nn.ConvTranspose2d(n_class, n_class, 4, stride=2, bias=True) 113 | self.upscore8 = nn.ConvTranspose2d(n_class, n_class, 16, stride=8, bias=False) 114 | self.upscore_pool4 = nn.ConvTranspose2d( 115 | n_class, n_class, 4, stride=2, bias=False 116 | ) 117 | 118 | self._initialize_weights() 119 | 120 | def _initialize_weights(self): 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | m.weight.data.zero_() 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | if isinstance(m, nn.ConvTranspose2d): 127 | assert m.kernel_size[0] == m.kernel_size[1] 128 | initial_weight = _upsampling_weights( 129 | m.in_channels, m.out_channels, m.kernel_size[0] 130 | ) 131 | m.weight.data.copy_(initial_weight) 132 | 133 | def forward(self, image): 134 | h = self.relu1_1(self.conv1_1(image)) 135 | h = self.relu1_2(self.conv1_2(h)) 136 | h = self.pool1(h) 137 | 138 | h = self.relu2_1(self.conv2_1(h)) 139 | h = self.relu2_2(self.conv2_2(h)) 140 | h = self.pool2(h) 141 | 142 | h = self.relu3_1(self.conv3_1(h)) 143 | h = self.relu3_2(self.conv3_2(h)) 144 | h = self.relu3_3(self.conv3_3(h)) 145 | h = self.pool3(h) 146 | pool3 = h # 1/8 147 | 148 | h = self.relu4_1(self.conv4_1(h)) 149 | h = self.relu4_2(self.conv4_2(h)) 150 | h = self.relu4_3(self.conv4_3(h)) 151 | h = self.pool4(h) 152 | pool4 = h # 1/16 153 | 154 | h = self.relu5_1(self.conv5_1(h)) 155 | h = self.relu5_2(self.conv5_2(h)) 156 | h = self.relu5_3(self.conv5_3(h)) 157 | h = self.pool5(h) 158 | 159 | h = self.relu6(self.fc6(h)) 160 | h = self.drop6(h) 161 | 162 | h = self.relu7(self.fc7(h)) 163 | h = self.drop7(h) 164 | 165 | h = self.score_fr(h) 166 | h = self.upscore2(h) 167 | upscore2 = h # 1/16 168 | 169 | h = self.score_pool4(pool4) 170 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]] 171 | score_pool4c = h # 1/16 172 | 173 | h = upscore2 + score_pool4c # 1/16 174 | h = self.upscore_pool4(h) 175 | upscore_pool4 = h # 1/8 176 | 177 | h = self.score_pool3(pool3) 178 | h = h[:, :, 9:9 + upscore_pool4.size()[2], 9:9 + upscore_pool4.size()[3]] 179 | score_pool3c = h # 1/8 180 | 181 | h = upscore_pool4 + score_pool3c # 1/8 182 | 183 | h = self.upscore8(h) 184 | h = h[:, :, 31:31 + image.size()[2], 31:31 + image.size()[3]].contiguous() 185 | 186 | return h 187 | -------------------------------------------------------------------------------- /crfasrnn/filters.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Sadeep Jayasumana 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | from abc import ABC, abstractmethod 26 | 27 | import numpy as np 28 | import torch 29 | 30 | try: 31 | import permuto_cpp 32 | except ImportError as e: 33 | raise (e, "Did you import `torch` first?") 34 | 35 | _CPU = torch.device("cpu") 36 | _EPS = np.finfo("float").eps 37 | 38 | 39 | class PermutoFunction(torch.autograd.Function): 40 | 41 | @staticmethod 42 | def forward(ctx, q_in, features): 43 | q_out = permuto_cpp.forward(q_in, features)[0] 44 | ctx.save_for_backward(features) 45 | return q_out 46 | 47 | @staticmethod 48 | def backward(ctx, grad_q_out): 49 | feature_saved = ctx.saved_tensors[0] 50 | grad_q_back = permuto_cpp.backward( 51 | grad_q_out.contiguous(), feature_saved.contiguous() 52 | )[0] 53 | return grad_q_back, None # No need of grads w.r.t. features 54 | 55 | 56 | def _spatial_features(image, sigma): 57 | """ 58 | Return the spatial features as a Tensor 59 | 60 | Args: 61 | image: Image as a Tensor of shape (channels, height, wight) 62 | sigma: Bandwidth parameter 63 | 64 | Returns: 65 | Tensor of shape [h, w, 2] with spatial features 66 | """ 67 | sigma = float(sigma) 68 | _, h, w = image.size() 69 | x = torch.arange(start=0, end=w, dtype=torch.float32, device=_CPU) 70 | xx = x.repeat([h, 1]) / sigma 71 | 72 | y = torch.arange( 73 | start=0, end=h, dtype=torch.float32, device=torch.device("cpu") 74 | ).view(-1, 1) 75 | yy = y.repeat([1, w]) / sigma 76 | 77 | return torch.stack([xx, yy], dim=2) 78 | 79 | 80 | class AbstractFilter(ABC): 81 | """ 82 | Super-class for permutohedral-based Gaussian filters 83 | """ 84 | 85 | def __init__(self, image): 86 | self.features = self._calc_features(image) 87 | self.norm = self._calc_norm(image) 88 | 89 | def apply(self, input_): 90 | output = PermutoFunction.apply(input_, self.features) 91 | return output * self.norm 92 | 93 | @abstractmethod 94 | def _calc_features(self, image): 95 | pass 96 | 97 | def _calc_norm(self, image): 98 | _, h, w = image.size() 99 | all_ones = torch.ones((1, h, w), dtype=torch.float32, device=_CPU) 100 | norm = PermutoFunction.apply(all_ones, self.features) 101 | return 1.0 / (norm + _EPS) 102 | 103 | 104 | class SpatialFilter(AbstractFilter): 105 | """ 106 | Gaussian filter in the spatial ([x, y]) domain 107 | """ 108 | 109 | def __init__(self, image, gamma): 110 | """ 111 | Create new instance 112 | 113 | Args: 114 | image: Image tensor of shape (3, height, width) 115 | gamma: Standard deviation 116 | """ 117 | self.gamma = gamma 118 | super(SpatialFilter, self).__init__(image) 119 | 120 | def _calc_features(self, image): 121 | return _spatial_features(image, self.gamma) 122 | 123 | 124 | class BilateralFilter(AbstractFilter): 125 | """ 126 | Gaussian filter in the bilateral ([r, g, b, x, y]) domain 127 | """ 128 | 129 | def __init__(self, image, alpha, beta): 130 | """ 131 | Create new instance 132 | 133 | Args: 134 | image: Image tensor of shape (3, height, width) 135 | alpha: Smoothness (spatial) sigma 136 | beta: Appearance (color) sigma 137 | """ 138 | self.alpha = alpha 139 | self.beta = beta 140 | super(BilateralFilter, self).__init__(image) 141 | 142 | def _calc_features(self, image): 143 | xy = _spatial_features( 144 | image, self.alpha 145 | ) # TODO Possible optimisation, was calculated in the spatial kernel 146 | rgb = (image / float(self.beta)).permute(1, 2, 0) # Channel last order 147 | return torch.cat([xy, rgb], dim=2) 148 | -------------------------------------------------------------------------------- /crfasrnn/params.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Sadeep Jayasumana 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | class DenseCRFParams(object): 27 | """ 28 | Parameters for the DenseCRF model 29 | """ 30 | 31 | def __init__( 32 | self, 33 | alpha=160.0, 34 | beta=3.0, 35 | gamma=3.0, 36 | spatial_ker_weight=3.0, 37 | bilateral_ker_weight=5.0, 38 | ): 39 | """ 40 | Default values were taken from https://github.com/sadeepj/crfasrnn_keras. More details about these parameters 41 | can be found in https://arxiv.org/pdf/1210.5644.pdf 42 | 43 | Args: 44 | alpha: Bandwidth for the spatial component of the bilateral filter 45 | beta: Bandwidth for the color component of the bilateral filter 46 | gamma: Bandwidth for the spatial filter 47 | spatial_ker_weight: Spatial kernel weight 48 | bilateral_ker_weight: Bilateral kernel weight 49 | """ 50 | self.alpha = alpha 51 | self.beta = beta 52 | self.gamma = gamma 53 | self.spatial_ker_weight = spatial_ker_weight 54 | self.bilateral_ker_weight = bilateral_ker_weight 55 | -------------------------------------------------------------------------------- /crfasrnn/permuto.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "permutohedral.h" 6 | 7 | /** 8 | * 9 | * @param input_values Input values to filter (e.g. Q distributions). Has shape (channels, height, width) 10 | * @param features Features for the permutohedral lattice. Has shape (height, width, feature_channels). Note that 11 | * channels are at the end! 12 | * @return Filtered values with shape (channels, height, width) 13 | */ 14 | std::vector permuto_forward(torch::Tensor input_values, torch::Tensor features) { 15 | 16 | auto input_sizes = input_values.sizes(); // (channels, height, width) 17 | auto feature_sizes = features.sizes(); // (height, width, num_features) 18 | 19 | auto h = feature_sizes[0]; 20 | auto w = feature_sizes[1]; 21 | auto n_feature_dims = static_cast(feature_sizes[2]); 22 | auto n_pixels = static_cast(h * w); 23 | auto n_channels = static_cast(input_sizes[0]); 24 | 25 | // Validate the arguments 26 | if (input_sizes[1] != h || input_sizes[2] != w) { 27 | throw std::runtime_error("Sizes of `input_values` and `features` do not match!"); 28 | } 29 | 30 | if (!(input_values.dtype() == torch::kFloat32)) { 31 | throw std::runtime_error("`input_values` must have float32 type."); 32 | } 33 | 34 | if (!(features.dtype() == torch::kFloat32)) { 35 | throw std::runtime_error("`features` must have float32 type."); 36 | } 37 | 38 | // Create the output tensor 39 | auto options = torch::TensorOptions() 40 | .dtype(torch::kFloat32) 41 | .layout(torch::kStrided) 42 | .device(torch::kCPU) 43 | .requires_grad(false); 44 | 45 | auto output_values = torch::empty(input_sizes, options); 46 | output_values = output_values.contiguous(); 47 | 48 | Permutohedral p; 49 | p.init(features.contiguous().data(), n_feature_dims, n_pixels); 50 | p.compute(output_values.data(), input_values.contiguous().data(), n_channels); 51 | 52 | return {output_values}; 53 | } 54 | 55 | 56 | std::vector permuto_backward(torch::Tensor grads, torch::Tensor features) { 57 | 58 | auto grad_sizes = grads.sizes(); // (channels, height, width) 59 | auto feature_sizes = features.sizes(); // (height, width, num_features) 60 | 61 | auto h = feature_sizes[0]; 62 | auto w = feature_sizes[1]; 63 | auto n_feature_dims = static_cast(feature_sizes[2]); 64 | auto n_pixels = static_cast(h * w); 65 | auto n_channels = static_cast(grad_sizes[0]); 66 | 67 | // Validate the arguments 68 | if (grad_sizes[1] != h || grad_sizes[2] != w) { 69 | throw std::runtime_error("Sizes of `grad_values` and `features` do not match!"); 70 | } 71 | 72 | if (!(grads.dtype() == torch::kFloat32)) { 73 | throw std::runtime_error("`input_values` must have float32 type."); 74 | } 75 | 76 | if (!(features.dtype() == torch::kFloat32)) { 77 | throw std::runtime_error("`features` must have float32 type."); 78 | } 79 | 80 | // Create the output tensor 81 | auto options = torch::TensorOptions() 82 | .dtype(torch::kFloat32) 83 | .layout(torch::kStrided) 84 | .device(torch::kCPU) 85 | .requires_grad(false); 86 | 87 | auto grads_back = torch::empty(grad_sizes, options); 88 | grads_back = grads_back.contiguous(); 89 | 90 | Permutohedral p; 91 | p.init(features.contiguous().data(), n_feature_dims, n_pixels); 92 | p.compute(grads_back.data(), grads.contiguous().data(), n_channels, true); 93 | 94 | return {grads_back}; 95 | } 96 | 97 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 98 | m.def("forward", &permuto_forward, "PERMUTO forward"); 99 | m.def("backward", &permuto_backward, "PERMUTO backward"); 100 | } 101 | -------------------------------------------------------------------------------- /crfasrnn/permutohedral.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | This file contains a modified version of the "permutohedral.cpp" code 3 | available at http://graphics.stanford.edu/projects/drf/. Copyright notice of 4 | the original file is included below: 5 | 6 | Copyright (c) 2013, Philipp Krähenbühl 7 | All rights reserved. 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions are met: 11 | * Redistributions of source code must retain the above copyright 12 | notice, this list of conditions and the following disclaimer. 13 | * Redistributions in binary form must reproduce the above copyright 14 | notice, this list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | * Neither the name of the Stanford University nor the 17 | names of its contributors may be used to endorse or promote products 18 | derived from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY 21 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 22 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL Philipp Krähenbühl BE LIABLE FOR ANY 24 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 25 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 26 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 27 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 29 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | */ 31 | 32 | //#include "stdafx.h" 33 | #include "permutohedral.h" 34 | 35 | #ifdef __SSE__ 36 | // SSE Permutoheral lattice 37 | # define SSE_PERMUTOHEDRAL 38 | #endif 39 | 40 | #if defined(SSE_PERMUTOHEDRAL) 41 | # include 42 | # include 43 | # ifdef __SSE4_1__ 44 | # include 45 | # endif 46 | #endif 47 | 48 | 49 | /************************************************/ 50 | /*** Hash Table ***/ 51 | /************************************************/ 52 | 53 | class HashTable{ 54 | protected: 55 | size_t key_size_, filled_, capacity_; 56 | std::vector< short > keys_; 57 | std::vector< int > table_; 58 | void grow(){ 59 | // Create the new memory and copy the values in 60 | int old_capacity = capacity_; 61 | capacity_ *= 2; 62 | std::vector old_keys( (old_capacity+10)*key_size_ ); 63 | std::copy( keys_.begin(), keys_.end(), old_keys.begin() ); 64 | std::vector old_table( capacity_, -1 ); 65 | 66 | // Swap the memory 67 | table_.swap( old_table ); 68 | keys_.swap( old_keys ); 69 | 70 | // Reinsert each element 71 | for( int i=0; i= 0){ 73 | int e = old_table[i]; 74 | size_t h = hash( getKey(e) ) % capacity_; 75 | for(; table_[h] >= 0; h = h= capacity_) grow(); 99 | // Get the hash value 100 | size_t h = hash( k ) % capacity_; 101 | // Find the element with he right key, using linear probing 102 | while(1){ 103 | int e = table_[h]; 104 | if (e==-1){ 105 | if (create){ 106 | // Insert a new key and return the new id 107 | for( size_t i=0; i0; j-- ){ 202 | __m128 cf = f[j-1]*scale_factor[j-1]; 203 | elevated[j] = sm - _mm_set1_ps(j)*cf; 204 | sm += cf; 205 | } 206 | elevated[0] = sm; 207 | 208 | // Find the closest 0-colored simplex through rounding 209 | __m128 sum = Zero; 210 | for( int i=0; i<=d_; i++ ){ 211 | __m128 v = invdplus1 * elevated[i]; 212 | #ifdef __SSE4_1__ 213 | v = _mm_round_ps( v, _MM_FROUND_TO_NEAREST_INT ); 214 | #else 215 | v = _mm_cvtepi32_ps( _mm_cvtps_epi32( v ) ); 216 | #endif 217 | rem0[i] = v*dplus1; 218 | sum += v; 219 | } 220 | 221 | // Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values) 222 | for( int i=0; i<=d_; i++ ) 223 | rank[i] = Zero; 224 | for( int i=0; i0; j-- ){ 366 | float cf = f[j-1]*scale_factor[j-1]; 367 | elevated[j] = sm - j*cf; 368 | sm += cf; 369 | } 370 | elevated[0] = sm; 371 | 372 | // Find the closest 0-colored simplex through rounding 373 | float down_factor = 1.0f / (d_+1); 374 | float up_factor = (d_+1); 375 | int sum = 0; 376 | for( int i=0; i<=d_; i++ ){ 377 | //int rd1 = round( down_factor * elevated[i]); 378 | int rd2; 379 | float v = down_factor * elevated[i]; 380 | float up = ceilf(v)*up_factor; 381 | float down = floorf(v)*up_factor; 382 | if (up - elevated[i] < elevated[i] - down) rd2 = (short)up; 383 | else rd2 = (short)down; 384 | 385 | //if(rd1!=rd2) 386 | // break; 387 | 388 | rem0[i] = rd2; 389 | sum += rd2*down_factor; 390 | } 391 | 392 | // Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values) 393 | for( int i=0; i<=d_; i++ ) 394 | rank[i] = 0; 395 | for( int i=0; i d_ ){ 412 | rank[i] -= d_+1; 413 | rem0[i] -= d_+1; 414 | } 415 | } 416 | 417 | // Compute the barycentric coordinates (p.10 in [Adams etal 2010]) 418 | for( int i=0; i<=d_+1; i++ ) 419 | barycentric[i] = 0; 420 | for( int i=0; i<=d_; i++ ){ 421 | float v = (elevated[i] - rem0[i])*down_factor; 422 | barycentric[d_-rank[i] ] += v; 423 | barycentric[d_-rank[i]+1] -= v; 424 | } 425 | // Wrap around 426 | barycentric[0] += 1.0 + barycentric[d_+1]; 427 | 428 | // Compute all vertices and their offset 429 | for( int remainder=0; remainder<=d_; remainder++ ){ 430 | for( int i=0; i 0 (used for blurring) 479 | float * values = new float[ (M_+2)*value_size ]; 480 | float * new_values = new float[ (M_+2)*value_size ]; 481 | 482 | for( int i=0; i<(M_+2)*value_size; i++ ) 483 | values[i] = new_values[i] = 0; 484 | 485 | // Splatting 486 | for( int i=0; i=0; reverse?j--:j++ ){ 496 | for( int i=0; i 0 (used for blurring) 536 | __m128 * sse_val = (__m128*) _mm_malloc( sse_value_size*sizeof(__m128), 16 ); 537 | __m128 * values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 ); 538 | __m128 * new_values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 ); 539 | 540 | __m128 Zero = _mm_set1_ps( 0 ); 541 | 542 | for( int i=0; i<(M_+2)*sse_value_size; i++ ) 543 | values[i] = new_values[i] = Zero; 544 | for( int i=0; i=0; reverse?j--:j++ ){ 568 | for( int i=0; i 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | 39 | 40 | /************************************************/ 41 | /*** Permutohedral Lattice ***/ 42 | /************************************************/ 43 | class Permutohedral { 44 | protected: 45 | struct Neighbors { 46 | int n1, n2; 47 | 48 | Neighbors(int n1 = 0, int n2 = 0) : n1(n1), n2(n2) { 49 | } 50 | }; 51 | 52 | std::vector offset_, rank_; 53 | std::vector barycentric_; 54 | std::vector blur_neighbors_; 55 | // Number of elements, size of sparse discretized space, dimension of features 56 | int N_, M_, d_; 57 | 58 | void sseCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const; 59 | 60 | void seqCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const; 61 | 62 | public: 63 | Permutohedral(); 64 | 65 | void init(const float *features, int num_dimensions, int num_points); 66 | 67 | void compute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const; 68 | }; 69 | -------------------------------------------------------------------------------- /crfasrnn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from torch.utils import cpp_extension 3 | 4 | setup(name='permuto_cpp', 5 | ext_modules=[cpp_extension.CppExtension('permuto_cpp', ['permuto.cpp', 'permutohedral.cpp'])], 6 | cmdclass={'build_ext': cpp_extension.BuildExtension}) 7 | -------------------------------------------------------------------------------- /crfasrnn/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Sadeep Jayasumana 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import numpy as np 26 | from PIL import Image 27 | 28 | # Pascal VOC color palette for labels 29 | _PALETTE = [0, 0, 0, 30 | 128, 0, 0, 31 | 0, 128, 0, 32 | 128, 128, 0, 33 | 0, 0, 128, 34 | 128, 0, 128, 35 | 0, 128, 128, 36 | 128, 128, 128, 37 | 64, 0, 0, 38 | 192, 0, 0, 39 | 64, 128, 0, 40 | 192, 128, 0, 41 | 64, 0, 128, 42 | 192, 0, 128, 43 | 64, 128, 128, 44 | 192, 128, 128, 45 | 0, 64, 0, 46 | 128, 64, 0, 47 | 0, 192, 0, 48 | 128, 192, 0, 49 | 0, 64, 128, 50 | 128, 64, 128, 51 | 0, 192, 128, 52 | 128, 192, 128, 53 | 64, 64, 0, 54 | 192, 64, 0, 55 | 64, 192, 0, 56 | 192, 192, 0] 57 | 58 | _IMAGENET_MEANS = np.array([123.68, 116.779, 103.939], dtype=np.float32) # RGB mean values 59 | 60 | 61 | def get_preprocessed_image(file_name): 62 | """ 63 | Reads an image from the disk, pre-processes it by subtracting mean etc. and 64 | returns a numpy array that's ready to be fed into the PyTorch model. 65 | 66 | Args: 67 | file_name: File to read the image from 68 | 69 | Returns: 70 | A tuple containing: 71 | 72 | (preprocessed image, img_h, img_w, original width & height) 73 | """ 74 | 75 | image = Image.open(file_name) 76 | original_size = image.size 77 | w, h = original_size 78 | ratio = min(500.0 / w, 500.0 / h) 79 | image = image.resize((int(w * ratio), int(h * ratio)), resample=Image.BILINEAR) 80 | im = np.array(image).astype(np.float32) 81 | assert im.ndim == 3, 'Only RGB images are supported.' 82 | im = im[:, :, :3] 83 | im = im - _IMAGENET_MEANS 84 | im = im[:, :, ::-1] # Convert to BGR 85 | img_h, img_w, _ = im.shape 86 | 87 | pad_h = 500 - img_h 88 | pad_w = 500 - img_w 89 | im = np.pad(im, pad_width=((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0) 90 | return np.expand_dims(im.transpose([2, 0, 1]), 0), img_h, img_w, original_size 91 | 92 | 93 | def get_label_image(probs, img_h, img_w, original_size): 94 | """ 95 | Returns the label image (PNG with Pascal VOC colormap) given the probabilities. 96 | 97 | Args: 98 | probs: Probability output of shape (num_labels, height, width) 99 | img_h: Image height 100 | img_w: Image width 101 | original_size: Original image size (width, height) 102 | 103 | Returns: 104 | Label image as a PIL Image 105 | """ 106 | 107 | labels = probs.argmax(axis=0).astype('uint8')[:img_h, :img_w] 108 | label_im = Image.fromarray(labels, 'P') 109 | label_im.putpalette(_PALETTE) 110 | label_im = label_im.resize(original_size) 111 | return label_im 112 | -------------------------------------------------------------------------------- /image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sadeepj/crfasrnn_pytorch/24899c528981dfbc14f1212869a3eae328dc6570/image.jpg -------------------------------------------------------------------------------- /quick_run.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Sadeep Jayasumana 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import argparse 26 | 27 | import torch 28 | 29 | from crfasrnn import util 30 | from crfasrnn.crfasrnn_model import CrfRnnNet 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument( 37 | "--weights", 38 | help="Path to the .pth file (download from https://tinyurl.com/crfasrnn-weights-pth)", 39 | required=True, 40 | ) 41 | parser.add_argument("--image", help="Path to the input image", required=True) 42 | parser.add_argument("--output", help="Path to the output label image", default=None) 43 | args = parser.parse_args() 44 | 45 | img_data, img_h, img_w, size = util.get_preprocessed_image(args.image) 46 | 47 | output_file = args.output or args.imaage + "_labels.png" 48 | 49 | model = CrfRnnNet() 50 | model.load_state_dict(torch.load(args.weights)) 51 | model.eval() 52 | out = model.forward(torch.from_numpy(img_data)) 53 | 54 | probs = out.detach().numpy()[0] 55 | label_im = util.get_label_image(probs, img_h, img_w, size) 56 | label_im.save(output_file) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | Pillow 4 | 5 | -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Sadeep Jayasumana 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | import torch 25 | 26 | from crfasrnn import util 27 | from crfasrnn.crfasrnn_model import CrfRnnNet 28 | 29 | 30 | def main(): 31 | input_file = "image.jpg" 32 | output_file = "labels.png" 33 | 34 | # Read the image 35 | img_data, img_h, img_w, size = util.get_preprocessed_image(input_file) 36 | 37 | # Download the model from https://tinyurl.com/crfasrnn-weights-pth 38 | saved_weights_path = "crfasrnn_weights.pth" 39 | 40 | model = CrfRnnNet() 41 | model.load_state_dict(torch.load(saved_weights_path)) 42 | model.eval() 43 | out = model.forward(torch.from_numpy(img_data)) 44 | 45 | probs = out.detach().numpy()[0] 46 | label_im = util.get_label_image(probs, img_h, img_w, size) 47 | label_im.save(output_file) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sadeepj/crfasrnn_pytorch/24899c528981dfbc14f1212869a3eae328dc6570/sample.png --------------------------------------------------------------------------------