├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── crfasrnn
├── __init__.py
├── crfasrnn_model.py
├── crfrnn.py
├── fcn8s.py
├── filters.py
├── params.py
├── permuto.cpp
├── permutohedral.cpp
├── permutohedral.h
├── setup.py
└── util.py
├── image.jpg
├── quick_run.py
├── requirements.txt
├── run_demo.py
└── sample.png
/.gitattributes:
--------------------------------------------------------------------------------
1 | crfasrnn/permutohedral.cpp linguist-vendored
2 | crfasrnn/permutohedral.h linguist-vendored
3 |
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 | .pyc
4 |
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Sadeep Jayasumana
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CRF-RNN for Semantic Image Segmentation - PyTorch version
2 | 
3 |
4 | Live demo: [http://crfasrnn.torr.vision](http://crfasrnn.torr.vision)
5 | Caffe version: [http://github.com/torrvision/crfasrnn](http://github.com/torrvision/crfasrnn)
6 | Tensorflow/Keras version: [http://github.com/sadeepj/crfasrnn_keras](http://github.com/sadeepj/crfasrnn_keras)
7 |
8 | This repository contains the official PyTorch implementation of the "CRF-RNN" semantic image segmentation method, published in the ICCV 2015 paper [Conditional Random Fields as Recurrent Neural Networks](http://www.robots.ox.ac.uk/~szheng/papers/CRFasRNN.pdf). The [online demo](http://crfasrnn.torr.vision) of this project won the Best Demo Prize at ICCV 2015. Results of this PyTorch code are identical to that of the Caffe and Tensorflow/Keras based versions above.
9 |
10 | If you use this code/model for your research, please cite the following paper:
11 | ```
12 | @inproceedings{crfasrnn_ICCV2015,
13 | author = {Shuai Zheng and Sadeep Jayasumana and Bernardino Romera-Paredes and Vibhav Vineet and
14 | Zhizhong Su and Dalong Du and Chang Huang and Philip H. S. Torr},
15 | title = {Conditional Random Fields as Recurrent Neural Networks},
16 | booktitle = {International Conference on Computer Vision (ICCV)},
17 | year = {2015}
18 | }
19 | ```
20 |
21 | ## Installation Guide
22 |
23 | _Note_: If you are using a Python virtualenv, make sure it is activated before running each command in this guide.
24 |
25 | ### Step 1: Clone the repository
26 | ```
27 | $ git clone https://github.com/sadeepj/crfasrnn_pytorch.git
28 | ```
29 | The root directory of the clone will be referred to as `crfasrnn_pytorch` hereafter.
30 |
31 | ### Step 2: Install dependencies
32 |
33 |
34 | Use the `requirements.txt` file in this repository to install all the dependencies via `pip`:
35 | ```
36 | $ cd crfasrnn_pytorch
37 | $ pip install -r requirements.txt
38 | ```
39 |
40 | After installing the dependencies, run the following commands to make sure they are properly installed:
41 | ```
42 | $ python
43 | >>> import torch
44 | ```
45 | You should not see any errors while importing `torch` above.
46 |
47 | ### Step 3: Build CRF-RNN custom op
48 |
49 | Run `setup.py` inside the `crfasrnn_pytorch/crfasrnn` directory:
50 | ```
51 | $ cd crfasrnn_pytorch/crfasrnn
52 | $ python setup.py install
53 | ```
54 | Note that the `python` command in the console should refer to the Python interpreter associated with your PyTorch installation.
55 |
56 | ### Step 4: Download the pre-trained model weights
57 |
58 | Download the model weights from [here](https://github.com/sadeepj/crfasrnn_pytorch/releases/download/0.0.1/crfasrnn_weights.pth) and place it in the `crfasrnn_pytorch` directory with the file name `crfasrnn_weights.pth`.
59 |
60 | ### Step 5: Run the demo
61 | ```
62 | $ cd crfasrnn_pytorch
63 | $ python run_demo.py
64 | ```
65 | If all goes well, you will see the segmentation results in a file named "labels.png".
66 |
67 | ## Contributors
68 | * Sadeep Jayasumana ([sadeepj](https://github.com/sadeepj))
69 | * Harsha Ranasinghe ([HarshaPrabhath](https://github.com/HarshaPrabhath))
70 |
71 |
--------------------------------------------------------------------------------
/crfasrnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sadeepj/crfasrnn_pytorch/24899c528981dfbc14f1212869a3eae328dc6570/crfasrnn/__init__.py
--------------------------------------------------------------------------------
/crfasrnn/crfasrnn_model.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Sadeep Jayasumana
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | from crfasrnn.crfrnn import CrfRnn
26 | from crfasrnn.fcn8s import Fcn8s
27 |
28 |
29 | class CrfRnnNet(Fcn8s):
30 | """
31 | The full CRF-RNN network with the FCN-8s backbone as described in the paper:
32 |
33 | Conditional Random Fields as Recurrent Neural Networks,
34 | S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr,
35 | ICCV 2015 (https://arxiv.org/abs/1502.03240).
36 | """
37 |
38 | def __init__(self):
39 | super(CrfRnnNet, self).__init__()
40 | self.crfrnn = CrfRnn(num_labels=21, num_iterations=10)
41 |
42 | def forward(self, image):
43 | out = super(CrfRnnNet, self).forward(image)
44 | # Plug the CRF-RNN module at the end
45 | return self.crfrnn(image, out)
46 |
--------------------------------------------------------------------------------
/crfasrnn/crfrnn.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Sadeep Jayasumana
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | import torch
26 | import torch.nn as nn
27 |
28 | from crfasrnn.filters import SpatialFilter, BilateralFilter
29 | from crfasrnn.params import DenseCRFParams
30 |
31 |
32 | class CrfRnn(nn.Module):
33 | """
34 | PyTorch implementation of the CRF-RNN module described in the paper:
35 |
36 | Conditional Random Fields as Recurrent Neural Networks,
37 | S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, C. Huang and P. Torr,
38 | ICCV 2015 (https://arxiv.org/abs/1502.03240).
39 | """
40 |
41 | def __init__(self, num_labels, num_iterations=5, crf_init_params=None):
42 | """
43 | Create a new instance of the CRF-RNN layer.
44 |
45 | Args:
46 | num_labels: Number of semantic labels in the dataset
47 | num_iterations: Number of mean-field iterations to perform
48 | crf_init_params: CRF initialization parameters
49 | """
50 | super(CrfRnn, self).__init__()
51 |
52 | if crf_init_params is None:
53 | crf_init_params = DenseCRFParams()
54 |
55 | self.params = crf_init_params
56 | self.num_iterations = num_iterations
57 |
58 | self._softmax = torch.nn.Softmax(dim=0)
59 |
60 | self.num_labels = num_labels
61 |
62 | # --------------------------------------------------------------------------------------------
63 | # --------------------------------- Trainable Parameters -------------------------------------
64 | # --------------------------------------------------------------------------------------------
65 |
66 | # Spatial kernel weights
67 | self.spatial_ker_weights = nn.Parameter(
68 | crf_init_params.spatial_ker_weight
69 | * torch.eye(num_labels, dtype=torch.float32)
70 | )
71 |
72 | # Bilateral kernel weights
73 | self.bilateral_ker_weights = nn.Parameter(
74 | crf_init_params.bilateral_ker_weight
75 | * torch.eye(num_labels, dtype=torch.float32)
76 | )
77 |
78 | # Compatibility transform matrix
79 | self.compatibility_matrix = nn.Parameter(
80 | torch.eye(num_labels, dtype=torch.float32)
81 | )
82 |
83 | def forward(self, image, logits):
84 | """
85 | Perform CRF inference.
86 |
87 | Args:
88 | image: Tensor of shape (3, h, w) containing the RGB image
89 | logits: Tensor of shape (num_classes, h, w) containing the unary logits
90 | Returns:
91 | log-Q distributions (logits) after CRF inference
92 | """
93 | if logits.shape[0] != 1:
94 | raise ValueError("Only batch size 1 is currently supported!")
95 |
96 | image = image[0]
97 | logits = logits[0]
98 |
99 | spatial_filter = SpatialFilter(image, gamma=self.params.gamma)
100 | bilateral_filter = BilateralFilter(
101 | image, alpha=self.params.alpha, beta=self.params.beta
102 | )
103 | _, h, w = image.shape
104 | cur_logits = logits
105 |
106 | for _ in range(self.num_iterations):
107 | # Normalization
108 | q_values = self._softmax(cur_logits)
109 |
110 | # Spatial filtering
111 | spatial_out = torch.mm(
112 | self.spatial_ker_weights,
113 | spatial_filter.apply(q_values).view(self.num_labels, -1),
114 | )
115 |
116 | # Bilateral filtering
117 | bilateral_out = torch.mm(
118 | self.bilateral_ker_weights,
119 | bilateral_filter.apply(q_values).view(self.num_labels, -1),
120 | )
121 |
122 | # Compatibility transform
123 | msg_passing_out = (
124 | spatial_out + bilateral_out
125 | ) # Shape: (self.num_labels, -1)
126 | msg_passing_out = torch.mm(self.compatibility_matrix, msg_passing_out).view(
127 | self.num_labels, h, w
128 | )
129 |
130 | # Adding unary potentials
131 | cur_logits = msg_passing_out + logits
132 |
133 | return torch.unsqueeze(cur_logits, 0)
134 |
--------------------------------------------------------------------------------
/crfasrnn/fcn8s.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains a modified version of the FCN-8s code available in https://github.com/wkentaro/pytorch-fcn
3 | The original copyright notice from that repository is included below:
4 |
5 | Copyright (c) 2017 - 2019 Kentaro Wada.
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in
15 | all copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 | THE SOFTWARE.
24 | """
25 |
26 | import numpy as np
27 | import torch
28 | import torch.nn as nn
29 |
30 |
31 | def _upsampling_weights(in_channels, out_channels, kernel_size):
32 | factor = (kernel_size + 1) // 2
33 | if kernel_size % 2 == 1:
34 | center = factor - 1
35 | else:
36 | center = factor - 0.5
37 | og = np.ogrid[:kernel_size, :kernel_size]
38 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
39 | weight = np.zeros(
40 | (in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64
41 | )
42 | weight[range(in_channels), range(out_channels), :, :] = filt
43 | return torch.from_numpy(weight).float()
44 |
45 |
46 | class Fcn8s(nn.Module):
47 | def __init__(self, n_class=21):
48 | """
49 | Create the FCN-8s network the the given number of classes.
50 |
51 | Args:
52 | n_class: The number of semantic classes.
53 | """
54 |
55 | super(Fcn8s, self).__init__()
56 |
57 | # conv1
58 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
59 | self.relu1_1 = nn.ReLU(inplace=True)
60 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
61 | self.relu1_2 = nn.ReLU(inplace=True)
62 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
63 |
64 | # conv2
65 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
66 | self.relu2_1 = nn.ReLU(inplace=True)
67 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
68 | self.relu2_2 = nn.ReLU(inplace=True)
69 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
70 |
71 | # conv3
72 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
73 | self.relu3_1 = nn.ReLU(inplace=True)
74 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
75 | self.relu3_2 = nn.ReLU(inplace=True)
76 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
77 | self.relu3_3 = nn.ReLU(inplace=True)
78 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
79 |
80 | # conv4
81 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
82 | self.relu4_1 = nn.ReLU(inplace=True)
83 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
84 | self.relu4_2 = nn.ReLU(inplace=True)
85 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
86 | self.relu4_3 = nn.ReLU(inplace=True)
87 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
88 |
89 | # conv5
90 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
91 | self.relu5_1 = nn.ReLU(inplace=True)
92 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
93 | self.relu5_2 = nn.ReLU(inplace=True)
94 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
95 | self.relu5_3 = nn.ReLU(inplace=True)
96 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
97 |
98 | # fc6
99 | self.fc6 = nn.Conv2d(512, 4096, 7)
100 | self.relu6 = nn.ReLU(inplace=True)
101 | self.drop6 = nn.Dropout2d()
102 |
103 | # fc7
104 | self.fc7 = nn.Conv2d(4096, 4096, 1)
105 | self.relu7 = nn.ReLU(inplace=True)
106 | self.drop7 = nn.Dropout2d()
107 |
108 | self.score_fr = nn.Conv2d(4096, n_class, 1)
109 | self.score_pool3 = nn.Conv2d(256, n_class, 1)
110 | self.score_pool4 = nn.Conv2d(512, n_class, 1)
111 |
112 | self.upscore2 = nn.ConvTranspose2d(n_class, n_class, 4, stride=2, bias=True)
113 | self.upscore8 = nn.ConvTranspose2d(n_class, n_class, 16, stride=8, bias=False)
114 | self.upscore_pool4 = nn.ConvTranspose2d(
115 | n_class, n_class, 4, stride=2, bias=False
116 | )
117 |
118 | self._initialize_weights()
119 |
120 | def _initialize_weights(self):
121 | for m in self.modules():
122 | if isinstance(m, nn.Conv2d):
123 | m.weight.data.zero_()
124 | if m.bias is not None:
125 | m.bias.data.zero_()
126 | if isinstance(m, nn.ConvTranspose2d):
127 | assert m.kernel_size[0] == m.kernel_size[1]
128 | initial_weight = _upsampling_weights(
129 | m.in_channels, m.out_channels, m.kernel_size[0]
130 | )
131 | m.weight.data.copy_(initial_weight)
132 |
133 | def forward(self, image):
134 | h = self.relu1_1(self.conv1_1(image))
135 | h = self.relu1_2(self.conv1_2(h))
136 | h = self.pool1(h)
137 |
138 | h = self.relu2_1(self.conv2_1(h))
139 | h = self.relu2_2(self.conv2_2(h))
140 | h = self.pool2(h)
141 |
142 | h = self.relu3_1(self.conv3_1(h))
143 | h = self.relu3_2(self.conv3_2(h))
144 | h = self.relu3_3(self.conv3_3(h))
145 | h = self.pool3(h)
146 | pool3 = h # 1/8
147 |
148 | h = self.relu4_1(self.conv4_1(h))
149 | h = self.relu4_2(self.conv4_2(h))
150 | h = self.relu4_3(self.conv4_3(h))
151 | h = self.pool4(h)
152 | pool4 = h # 1/16
153 |
154 | h = self.relu5_1(self.conv5_1(h))
155 | h = self.relu5_2(self.conv5_2(h))
156 | h = self.relu5_3(self.conv5_3(h))
157 | h = self.pool5(h)
158 |
159 | h = self.relu6(self.fc6(h))
160 | h = self.drop6(h)
161 |
162 | h = self.relu7(self.fc7(h))
163 | h = self.drop7(h)
164 |
165 | h = self.score_fr(h)
166 | h = self.upscore2(h)
167 | upscore2 = h # 1/16
168 |
169 | h = self.score_pool4(pool4)
170 | h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
171 | score_pool4c = h # 1/16
172 |
173 | h = upscore2 + score_pool4c # 1/16
174 | h = self.upscore_pool4(h)
175 | upscore_pool4 = h # 1/8
176 |
177 | h = self.score_pool3(pool3)
178 | h = h[:, :, 9:9 + upscore_pool4.size()[2], 9:9 + upscore_pool4.size()[3]]
179 | score_pool3c = h # 1/8
180 |
181 | h = upscore_pool4 + score_pool3c # 1/8
182 |
183 | h = self.upscore8(h)
184 | h = h[:, :, 31:31 + image.size()[2], 31:31 + image.size()[3]].contiguous()
185 |
186 | return h
187 |
--------------------------------------------------------------------------------
/crfasrnn/filters.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Sadeep Jayasumana
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | from abc import ABC, abstractmethod
26 |
27 | import numpy as np
28 | import torch
29 |
30 | try:
31 | import permuto_cpp
32 | except ImportError as e:
33 | raise (e, "Did you import `torch` first?")
34 |
35 | _CPU = torch.device("cpu")
36 | _EPS = np.finfo("float").eps
37 |
38 |
39 | class PermutoFunction(torch.autograd.Function):
40 |
41 | @staticmethod
42 | def forward(ctx, q_in, features):
43 | q_out = permuto_cpp.forward(q_in, features)[0]
44 | ctx.save_for_backward(features)
45 | return q_out
46 |
47 | @staticmethod
48 | def backward(ctx, grad_q_out):
49 | feature_saved = ctx.saved_tensors[0]
50 | grad_q_back = permuto_cpp.backward(
51 | grad_q_out.contiguous(), feature_saved.contiguous()
52 | )[0]
53 | return grad_q_back, None # No need of grads w.r.t. features
54 |
55 |
56 | def _spatial_features(image, sigma):
57 | """
58 | Return the spatial features as a Tensor
59 |
60 | Args:
61 | image: Image as a Tensor of shape (channels, height, wight)
62 | sigma: Bandwidth parameter
63 |
64 | Returns:
65 | Tensor of shape [h, w, 2] with spatial features
66 | """
67 | sigma = float(sigma)
68 | _, h, w = image.size()
69 | x = torch.arange(start=0, end=w, dtype=torch.float32, device=_CPU)
70 | xx = x.repeat([h, 1]) / sigma
71 |
72 | y = torch.arange(
73 | start=0, end=h, dtype=torch.float32, device=torch.device("cpu")
74 | ).view(-1, 1)
75 | yy = y.repeat([1, w]) / sigma
76 |
77 | return torch.stack([xx, yy], dim=2)
78 |
79 |
80 | class AbstractFilter(ABC):
81 | """
82 | Super-class for permutohedral-based Gaussian filters
83 | """
84 |
85 | def __init__(self, image):
86 | self.features = self._calc_features(image)
87 | self.norm = self._calc_norm(image)
88 |
89 | def apply(self, input_):
90 | output = PermutoFunction.apply(input_, self.features)
91 | return output * self.norm
92 |
93 | @abstractmethod
94 | def _calc_features(self, image):
95 | pass
96 |
97 | def _calc_norm(self, image):
98 | _, h, w = image.size()
99 | all_ones = torch.ones((1, h, w), dtype=torch.float32, device=_CPU)
100 | norm = PermutoFunction.apply(all_ones, self.features)
101 | return 1.0 / (norm + _EPS)
102 |
103 |
104 | class SpatialFilter(AbstractFilter):
105 | """
106 | Gaussian filter in the spatial ([x, y]) domain
107 | """
108 |
109 | def __init__(self, image, gamma):
110 | """
111 | Create new instance
112 |
113 | Args:
114 | image: Image tensor of shape (3, height, width)
115 | gamma: Standard deviation
116 | """
117 | self.gamma = gamma
118 | super(SpatialFilter, self).__init__(image)
119 |
120 | def _calc_features(self, image):
121 | return _spatial_features(image, self.gamma)
122 |
123 |
124 | class BilateralFilter(AbstractFilter):
125 | """
126 | Gaussian filter in the bilateral ([r, g, b, x, y]) domain
127 | """
128 |
129 | def __init__(self, image, alpha, beta):
130 | """
131 | Create new instance
132 |
133 | Args:
134 | image: Image tensor of shape (3, height, width)
135 | alpha: Smoothness (spatial) sigma
136 | beta: Appearance (color) sigma
137 | """
138 | self.alpha = alpha
139 | self.beta = beta
140 | super(BilateralFilter, self).__init__(image)
141 |
142 | def _calc_features(self, image):
143 | xy = _spatial_features(
144 | image, self.alpha
145 | ) # TODO Possible optimisation, was calculated in the spatial kernel
146 | rgb = (image / float(self.beta)).permute(1, 2, 0) # Channel last order
147 | return torch.cat([xy, rgb], dim=2)
148 |
--------------------------------------------------------------------------------
/crfasrnn/params.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Sadeep Jayasumana
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 |
26 | class DenseCRFParams(object):
27 | """
28 | Parameters for the DenseCRF model
29 | """
30 |
31 | def __init__(
32 | self,
33 | alpha=160.0,
34 | beta=3.0,
35 | gamma=3.0,
36 | spatial_ker_weight=3.0,
37 | bilateral_ker_weight=5.0,
38 | ):
39 | """
40 | Default values were taken from https://github.com/sadeepj/crfasrnn_keras. More details about these parameters
41 | can be found in https://arxiv.org/pdf/1210.5644.pdf
42 |
43 | Args:
44 | alpha: Bandwidth for the spatial component of the bilateral filter
45 | beta: Bandwidth for the color component of the bilateral filter
46 | gamma: Bandwidth for the spatial filter
47 | spatial_ker_weight: Spatial kernel weight
48 | bilateral_ker_weight: Bilateral kernel weight
49 | """
50 | self.alpha = alpha
51 | self.beta = beta
52 | self.gamma = gamma
53 | self.spatial_ker_weight = spatial_ker_weight
54 | self.bilateral_ker_weight = bilateral_ker_weight
55 |
--------------------------------------------------------------------------------
/crfasrnn/permuto.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "permutohedral.h"
6 |
7 | /**
8 | *
9 | * @param input_values Input values to filter (e.g. Q distributions). Has shape (channels, height, width)
10 | * @param features Features for the permutohedral lattice. Has shape (height, width, feature_channels). Note that
11 | * channels are at the end!
12 | * @return Filtered values with shape (channels, height, width)
13 | */
14 | std::vector permuto_forward(torch::Tensor input_values, torch::Tensor features) {
15 |
16 | auto input_sizes = input_values.sizes(); // (channels, height, width)
17 | auto feature_sizes = features.sizes(); // (height, width, num_features)
18 |
19 | auto h = feature_sizes[0];
20 | auto w = feature_sizes[1];
21 | auto n_feature_dims = static_cast(feature_sizes[2]);
22 | auto n_pixels = static_cast(h * w);
23 | auto n_channels = static_cast(input_sizes[0]);
24 |
25 | // Validate the arguments
26 | if (input_sizes[1] != h || input_sizes[2] != w) {
27 | throw std::runtime_error("Sizes of `input_values` and `features` do not match!");
28 | }
29 |
30 | if (!(input_values.dtype() == torch::kFloat32)) {
31 | throw std::runtime_error("`input_values` must have float32 type.");
32 | }
33 |
34 | if (!(features.dtype() == torch::kFloat32)) {
35 | throw std::runtime_error("`features` must have float32 type.");
36 | }
37 |
38 | // Create the output tensor
39 | auto options = torch::TensorOptions()
40 | .dtype(torch::kFloat32)
41 | .layout(torch::kStrided)
42 | .device(torch::kCPU)
43 | .requires_grad(false);
44 |
45 | auto output_values = torch::empty(input_sizes, options);
46 | output_values = output_values.contiguous();
47 |
48 | Permutohedral p;
49 | p.init(features.contiguous().data(), n_feature_dims, n_pixels);
50 | p.compute(output_values.data(), input_values.contiguous().data(), n_channels);
51 |
52 | return {output_values};
53 | }
54 |
55 |
56 | std::vector permuto_backward(torch::Tensor grads, torch::Tensor features) {
57 |
58 | auto grad_sizes = grads.sizes(); // (channels, height, width)
59 | auto feature_sizes = features.sizes(); // (height, width, num_features)
60 |
61 | auto h = feature_sizes[0];
62 | auto w = feature_sizes[1];
63 | auto n_feature_dims = static_cast(feature_sizes[2]);
64 | auto n_pixels = static_cast(h * w);
65 | auto n_channels = static_cast(grad_sizes[0]);
66 |
67 | // Validate the arguments
68 | if (grad_sizes[1] != h || grad_sizes[2] != w) {
69 | throw std::runtime_error("Sizes of `grad_values` and `features` do not match!");
70 | }
71 |
72 | if (!(grads.dtype() == torch::kFloat32)) {
73 | throw std::runtime_error("`input_values` must have float32 type.");
74 | }
75 |
76 | if (!(features.dtype() == torch::kFloat32)) {
77 | throw std::runtime_error("`features` must have float32 type.");
78 | }
79 |
80 | // Create the output tensor
81 | auto options = torch::TensorOptions()
82 | .dtype(torch::kFloat32)
83 | .layout(torch::kStrided)
84 | .device(torch::kCPU)
85 | .requires_grad(false);
86 |
87 | auto grads_back = torch::empty(grad_sizes, options);
88 | grads_back = grads_back.contiguous();
89 |
90 | Permutohedral p;
91 | p.init(features.contiguous().data(), n_feature_dims, n_pixels);
92 | p.compute(grads_back.data(), grads.contiguous().data(), n_channels, true);
93 |
94 | return {grads_back};
95 | }
96 |
97 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
98 | m.def("forward", &permuto_forward, "PERMUTO forward");
99 | m.def("backward", &permuto_backward, "PERMUTO backward");
100 | }
101 |
--------------------------------------------------------------------------------
/crfasrnn/permutohedral.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | This file contains a modified version of the "permutohedral.cpp" code
3 | available at http://graphics.stanford.edu/projects/drf/. Copyright notice of
4 | the original file is included below:
5 |
6 | Copyright (c) 2013, Philipp Krähenbühl
7 | All rights reserved.
8 |
9 | Redistribution and use in source and binary forms, with or without
10 | modification, are permitted provided that the following conditions are met:
11 | * Redistributions of source code must retain the above copyright
12 | notice, this list of conditions and the following disclaimer.
13 | * Redistributions in binary form must reproduce the above copyright
14 | notice, this list of conditions and the following disclaimer in the
15 | documentation and/or other materials provided with the distribution.
16 | * Neither the name of the Stanford University nor the
17 | names of its contributors may be used to endorse or promote products
18 | derived from this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY
21 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
22 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL Philipp Krähenbühl BE LIABLE FOR ANY
24 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
25 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
26 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
27 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 | */
31 |
32 | //#include "stdafx.h"
33 | #include "permutohedral.h"
34 |
35 | #ifdef __SSE__
36 | // SSE Permutoheral lattice
37 | # define SSE_PERMUTOHEDRAL
38 | #endif
39 |
40 | #if defined(SSE_PERMUTOHEDRAL)
41 | # include
42 | # include
43 | # ifdef __SSE4_1__
44 | # include
45 | # endif
46 | #endif
47 |
48 |
49 | /************************************************/
50 | /*** Hash Table ***/
51 | /************************************************/
52 |
53 | class HashTable{
54 | protected:
55 | size_t key_size_, filled_, capacity_;
56 | std::vector< short > keys_;
57 | std::vector< int > table_;
58 | void grow(){
59 | // Create the new memory and copy the values in
60 | int old_capacity = capacity_;
61 | capacity_ *= 2;
62 | std::vector old_keys( (old_capacity+10)*key_size_ );
63 | std::copy( keys_.begin(), keys_.end(), old_keys.begin() );
64 | std::vector old_table( capacity_, -1 );
65 |
66 | // Swap the memory
67 | table_.swap( old_table );
68 | keys_.swap( old_keys );
69 |
70 | // Reinsert each element
71 | for( int i=0; i= 0){
73 | int e = old_table[i];
74 | size_t h = hash( getKey(e) ) % capacity_;
75 | for(; table_[h] >= 0; h = h= capacity_) grow();
99 | // Get the hash value
100 | size_t h = hash( k ) % capacity_;
101 | // Find the element with he right key, using linear probing
102 | while(1){
103 | int e = table_[h];
104 | if (e==-1){
105 | if (create){
106 | // Insert a new key and return the new id
107 | for( size_t i=0; i0; j-- ){
202 | __m128 cf = f[j-1]*scale_factor[j-1];
203 | elevated[j] = sm - _mm_set1_ps(j)*cf;
204 | sm += cf;
205 | }
206 | elevated[0] = sm;
207 |
208 | // Find the closest 0-colored simplex through rounding
209 | __m128 sum = Zero;
210 | for( int i=0; i<=d_; i++ ){
211 | __m128 v = invdplus1 * elevated[i];
212 | #ifdef __SSE4_1__
213 | v = _mm_round_ps( v, _MM_FROUND_TO_NEAREST_INT );
214 | #else
215 | v = _mm_cvtepi32_ps( _mm_cvtps_epi32( v ) );
216 | #endif
217 | rem0[i] = v*dplus1;
218 | sum += v;
219 | }
220 |
221 | // Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values)
222 | for( int i=0; i<=d_; i++ )
223 | rank[i] = Zero;
224 | for( int i=0; i0; j-- ){
366 | float cf = f[j-1]*scale_factor[j-1];
367 | elevated[j] = sm - j*cf;
368 | sm += cf;
369 | }
370 | elevated[0] = sm;
371 |
372 | // Find the closest 0-colored simplex through rounding
373 | float down_factor = 1.0f / (d_+1);
374 | float up_factor = (d_+1);
375 | int sum = 0;
376 | for( int i=0; i<=d_; i++ ){
377 | //int rd1 = round( down_factor * elevated[i]);
378 | int rd2;
379 | float v = down_factor * elevated[i];
380 | float up = ceilf(v)*up_factor;
381 | float down = floorf(v)*up_factor;
382 | if (up - elevated[i] < elevated[i] - down) rd2 = (short)up;
383 | else rd2 = (short)down;
384 |
385 | //if(rd1!=rd2)
386 | // break;
387 |
388 | rem0[i] = rd2;
389 | sum += rd2*down_factor;
390 | }
391 |
392 | // Find the simplex we are in and store it in rank (where rank describes what position coorinate i has in the sorted order of the features values)
393 | for( int i=0; i<=d_; i++ )
394 | rank[i] = 0;
395 | for( int i=0; i d_ ){
412 | rank[i] -= d_+1;
413 | rem0[i] -= d_+1;
414 | }
415 | }
416 |
417 | // Compute the barycentric coordinates (p.10 in [Adams etal 2010])
418 | for( int i=0; i<=d_+1; i++ )
419 | barycentric[i] = 0;
420 | for( int i=0; i<=d_; i++ ){
421 | float v = (elevated[i] - rem0[i])*down_factor;
422 | barycentric[d_-rank[i] ] += v;
423 | barycentric[d_-rank[i]+1] -= v;
424 | }
425 | // Wrap around
426 | barycentric[0] += 1.0 + barycentric[d_+1];
427 |
428 | // Compute all vertices and their offset
429 | for( int remainder=0; remainder<=d_; remainder++ ){
430 | for( int i=0; i 0 (used for blurring)
479 | float * values = new float[ (M_+2)*value_size ];
480 | float * new_values = new float[ (M_+2)*value_size ];
481 |
482 | for( int i=0; i<(M_+2)*value_size; i++ )
483 | values[i] = new_values[i] = 0;
484 |
485 | // Splatting
486 | for( int i=0; i=0; reverse?j--:j++ ){
496 | for( int i=0; i 0 (used for blurring)
536 | __m128 * sse_val = (__m128*) _mm_malloc( sse_value_size*sizeof(__m128), 16 );
537 | __m128 * values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 );
538 | __m128 * new_values = (__m128*) _mm_malloc( (M_+2)*sse_value_size*sizeof(__m128), 16 );
539 |
540 | __m128 Zero = _mm_set1_ps( 0 );
541 |
542 | for( int i=0; i<(M_+2)*sse_value_size; i++ )
543 | values[i] = new_values[i] = Zero;
544 | for( int i=0; i=0; reverse?j--:j++ ){
568 | for( int i=0; i
33 | #include
34 | #include
35 | #include
36 | #include
37 | #include
38 |
39 |
40 | /************************************************/
41 | /*** Permutohedral Lattice ***/
42 | /************************************************/
43 | class Permutohedral {
44 | protected:
45 | struct Neighbors {
46 | int n1, n2;
47 |
48 | Neighbors(int n1 = 0, int n2 = 0) : n1(n1), n2(n2) {
49 | }
50 | };
51 |
52 | std::vector offset_, rank_;
53 | std::vector barycentric_;
54 | std::vector blur_neighbors_;
55 | // Number of elements, size of sparse discretized space, dimension of features
56 | int N_, M_, d_;
57 |
58 | void sseCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;
59 |
60 | void seqCompute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;
61 |
62 | public:
63 | Permutohedral();
64 |
65 | void init(const float *features, int num_dimensions, int num_points);
66 |
67 | void compute(float *out, const float *in, int value_size, bool reverse = false, bool add = false) const;
68 | };
69 |
--------------------------------------------------------------------------------
/crfasrnn/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, Extension
2 | from torch.utils import cpp_extension
3 |
4 | setup(name='permuto_cpp',
5 | ext_modules=[cpp_extension.CppExtension('permuto_cpp', ['permuto.cpp', 'permutohedral.cpp'])],
6 | cmdclass={'build_ext': cpp_extension.BuildExtension})
7 |
--------------------------------------------------------------------------------
/crfasrnn/util.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Sadeep Jayasumana
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | import numpy as np
26 | from PIL import Image
27 |
28 | # Pascal VOC color palette for labels
29 | _PALETTE = [0, 0, 0,
30 | 128, 0, 0,
31 | 0, 128, 0,
32 | 128, 128, 0,
33 | 0, 0, 128,
34 | 128, 0, 128,
35 | 0, 128, 128,
36 | 128, 128, 128,
37 | 64, 0, 0,
38 | 192, 0, 0,
39 | 64, 128, 0,
40 | 192, 128, 0,
41 | 64, 0, 128,
42 | 192, 0, 128,
43 | 64, 128, 128,
44 | 192, 128, 128,
45 | 0, 64, 0,
46 | 128, 64, 0,
47 | 0, 192, 0,
48 | 128, 192, 0,
49 | 0, 64, 128,
50 | 128, 64, 128,
51 | 0, 192, 128,
52 | 128, 192, 128,
53 | 64, 64, 0,
54 | 192, 64, 0,
55 | 64, 192, 0,
56 | 192, 192, 0]
57 |
58 | _IMAGENET_MEANS = np.array([123.68, 116.779, 103.939], dtype=np.float32) # RGB mean values
59 |
60 |
61 | def get_preprocessed_image(file_name):
62 | """
63 | Reads an image from the disk, pre-processes it by subtracting mean etc. and
64 | returns a numpy array that's ready to be fed into the PyTorch model.
65 |
66 | Args:
67 | file_name: File to read the image from
68 |
69 | Returns:
70 | A tuple containing:
71 |
72 | (preprocessed image, img_h, img_w, original width & height)
73 | """
74 |
75 | image = Image.open(file_name)
76 | original_size = image.size
77 | w, h = original_size
78 | ratio = min(500.0 / w, 500.0 / h)
79 | image = image.resize((int(w * ratio), int(h * ratio)), resample=Image.BILINEAR)
80 | im = np.array(image).astype(np.float32)
81 | assert im.ndim == 3, 'Only RGB images are supported.'
82 | im = im[:, :, :3]
83 | im = im - _IMAGENET_MEANS
84 | im = im[:, :, ::-1] # Convert to BGR
85 | img_h, img_w, _ = im.shape
86 |
87 | pad_h = 500 - img_h
88 | pad_w = 500 - img_w
89 | im = np.pad(im, pad_width=((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
90 | return np.expand_dims(im.transpose([2, 0, 1]), 0), img_h, img_w, original_size
91 |
92 |
93 | def get_label_image(probs, img_h, img_w, original_size):
94 | """
95 | Returns the label image (PNG with Pascal VOC colormap) given the probabilities.
96 |
97 | Args:
98 | probs: Probability output of shape (num_labels, height, width)
99 | img_h: Image height
100 | img_w: Image width
101 | original_size: Original image size (width, height)
102 |
103 | Returns:
104 | Label image as a PIL Image
105 | """
106 |
107 | labels = probs.argmax(axis=0).astype('uint8')[:img_h, :img_w]
108 | label_im = Image.fromarray(labels, 'P')
109 | label_im.putpalette(_PALETTE)
110 | label_im = label_im.resize(original_size)
111 | return label_im
112 |
--------------------------------------------------------------------------------
/image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sadeepj/crfasrnn_pytorch/24899c528981dfbc14f1212869a3eae328dc6570/image.jpg
--------------------------------------------------------------------------------
/quick_run.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Sadeep Jayasumana
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 |
25 | import argparse
26 |
27 | import torch
28 |
29 | from crfasrnn import util
30 | from crfasrnn.crfasrnn_model import CrfRnnNet
31 |
32 |
33 | def main():
34 | parser = argparse.ArgumentParser()
35 |
36 | parser.add_argument(
37 | "--weights",
38 | help="Path to the .pth file (download from https://tinyurl.com/crfasrnn-weights-pth)",
39 | required=True,
40 | )
41 | parser.add_argument("--image", help="Path to the input image", required=True)
42 | parser.add_argument("--output", help="Path to the output label image", default=None)
43 | args = parser.parse_args()
44 |
45 | img_data, img_h, img_w, size = util.get_preprocessed_image(args.image)
46 |
47 | output_file = args.output or args.imaage + "_labels.png"
48 |
49 | model = CrfRnnNet()
50 | model.load_state_dict(torch.load(args.weights))
51 | model.eval()
52 | out = model.forward(torch.from_numpy(img_data))
53 |
54 | probs = out.detach().numpy()[0]
55 | label_im = util.get_label_image(probs, img_h, img_w, size)
56 | label_im.save(output_file)
57 |
58 |
59 | if __name__ == "__main__":
60 | main()
61 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | Pillow
4 |
5 |
--------------------------------------------------------------------------------
/run_demo.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2019 Sadeep Jayasumana
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 | """
24 | import torch
25 |
26 | from crfasrnn import util
27 | from crfasrnn.crfasrnn_model import CrfRnnNet
28 |
29 |
30 | def main():
31 | input_file = "image.jpg"
32 | output_file = "labels.png"
33 |
34 | # Read the image
35 | img_data, img_h, img_w, size = util.get_preprocessed_image(input_file)
36 |
37 | # Download the model from https://tinyurl.com/crfasrnn-weights-pth
38 | saved_weights_path = "crfasrnn_weights.pth"
39 |
40 | model = CrfRnnNet()
41 | model.load_state_dict(torch.load(saved_weights_path))
42 | model.eval()
43 | out = model.forward(torch.from_numpy(img_data))
44 |
45 | probs = out.detach().numpy()[0]
46 | label_im = util.get_label_image(probs, img_h, img_w, size)
47 | label_im.save(output_file)
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sadeepj/crfasrnn_pytorch/24899c528981dfbc14f1212869a3eae328dc6570/sample.png
--------------------------------------------------------------------------------