├── .gitignore
├── LICENSE
├── README.md
├── docs
├── PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images.pdf
├── arch.png
├── matching.png
├── matrix.png
├── outputs.png
└── vertex-detection.png
├── models
├── backbone.py
└── matching.py
├── train.ipynb
└── utils
├── coco_IoU_cIoU.py
├── coco_to_shp.py
├── dataset.py
├── loss.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/
2 | /trained_weights/
3 |
4 |
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
115 | .pdm.toml
116 | .pdm-python
117 | .pdm-build/
118 |
119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120 | __pypackages__/
121 |
122 | # Celery stuff
123 | celerybeat-schedule
124 | celerybeat.pid
125 |
126 | # SageMath parsed files
127 | *.sage.py
128 |
129 | # Environments
130 | .env
131 | .venv
132 | env/
133 | venv/
134 | ENV/
135 | env.bak/
136 | venv.bak/
137 |
138 | # Spyder project settings
139 | .spyderproject
140 | .spyproject
141 |
142 | # Rope project settings
143 | .ropeproject
144 |
145 | # mkdocs documentation
146 | /site
147 |
148 | # mypy
149 | .mypy_cache/
150 | .dmypy.json
151 | dmypy.json
152 |
153 | # Pyre type checker
154 | .pyre/
155 |
156 | # pytype static type analyzer
157 | .pytype/
158 |
159 | # Cython debug symbols
160 | cython_debug/
161 |
162 | # PyCharm
163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165 | # and can be added to the global gitignore or merged into this file. For a more nuclear
166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167 | #.idea/
168 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | SOFTWARE LICENSE AGREEMENT
2 |
3 | BlackShark.ai Software – 2022, all rights reserved, hereinafter "the Software".
4 |
5 | This software has been developed by researchers of ICG and BlackShark.ai.
6 |
7 | Institute of Computer Graphics and Vision (ICG), Inffeldgasse 16/II,
8 | 8010 Graz, Austria
9 |
10 | BlackShark.ai, Am Eisernen Tor 1/3
11 | 8010 Graz, Austria
12 |
13 | BlackShark.ai holds all the ownership rights on the Software.
14 |
15 | The Software is still being currently developed. It is the BlackShark.ai's aim for the Software
16 | to be used by the scientific community so as to test it and, evaluate it so that BlackShark.ai may improve it.
17 |
18 | For these reasons BlackShark.ai has decided to distribute the Software.
19 |
20 | BlackShark.ai grants to the academic user, a free of charge, without right to sub license non-exclusive right
21 | to use the Software for research purposes for a period of one (1) year from the date of the download
22 | of the source code. Any other use without of prior consent of BlackShark.ai is prohibited.
23 |
24 | The academic user explicitly acknowledges having received from BlackShark.ai all information allowing him
25 | to appreciate the adequacy between of the Software and his needs and to undertake all necessary
26 | precautions for his execution and use.
27 |
28 | The Software is provided only as a source.
29 |
30 | In case of using the Software for a publication or other results obtained through the use of the Software,
31 | user should cite the Software as follows :
32 |
33 | @article{zorzi2021polyworld,
34 | title={PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images},
35 | author={Zorzi, Stefano and Bazrafkan, Shabab and Habenschuss, Stefan and Fraundorfer, Friedrich},
36 | journal={arXiv preprint arXiv:2111.15491},
37 | year={2021}
38 | }
39 |
40 | Every user of the Software could communicate to the developers [stefano.zorzi@icg.tugraz.at]
41 | his or her remarks as to the use of the Software.
42 |
43 | THE USER CANNOT USE, EXPLOIT OR COMMERCIALLY DISTRIBUTE THE SOFTWARE WITHOUT PRIOR AND EXPLICIT CONSENT
44 | OF BlackShark.ai (sbazrafkan@blackshark.ai). ANY SUCH ACTION WILL CONSTITUTE A FORGERY.
45 |
46 | THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES,
47 | WITH REGARDS TO COMMERCIAL USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALIZATION OR ADAPTATION.
48 |
49 | UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL BlackShark.ai OR THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
50 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
51 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
52 | WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM,
53 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
54 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images
2 | 
3 |
4 | PolyWorld is a research project conducted by the Institute of Computer Graphics and Vision of TUGraz, in collaboration with BlackShark.ai. PolyWorld is a neural network that extracts polygonal objects from an image in an end-to-end fashion. The model detects vertex candidates and predicts the connection strenght between each pair of vertices using a Graph Neural Network.
5 |
6 | This repo tries to train the network from scratch
7 |
8 | - Official Repository: [PolyWorld](https://github.com/zorzi-s/PolyWorldPretrainedNetwork)
9 |
10 | - Paper PDF: [PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images](https://arxiv.org/abs/2111.15491)
11 |
12 | - Authors: Stefano Zorzi, Shabab Bazrafkan, Stefan Habenschuss, Friedrich Fraundorfer
13 |
14 | - Video: [YouTube link](https://youtu.be/C80dojBosLQ)
15 |
16 | - Poster: [Seafile link](https://files.icg.tugraz.at/f/6a044f133c0d4dd992c5/)
17 |
18 | ---
19 |
20 | You can start training the network in **train.ipynb** notebook
21 | - We can load the pre-trained weights for the backbone and the vertex detection network from here (https://github.com/zorzi-s/PolyWorldPretrainedNetwork) and freeze them for training to only train the Matching Network
22 |
23 | ## Architecture
24 |
25 |
26 |
27 | ## Main Components
28 | - Backbone CNN (Pre-trained)
29 | - Vertex Detection (1x1 Conv) (Pre-trained)
30 |
31 |
32 | - NMS (Selects top 256 points) (Not Trainable)
33 | - Optimal Matching (Attentional GNN) (Our main challenge for training)
34 |
35 |
36 | - Polygon Reconstruction (Not Trainable) (Using predicted adjacency matrix and top 256 points) (Our main challenge for training)
37 |
38 | ## Contributions
39 | - Applied random permutations on groundtruth permutation matrix (line 31-32-98-105-106 in dataset.py)
40 | - Applied Sinkhorn algorithm in the matching step (Thanks to https://github.com/henokyen/henokyen_polyworld)
41 | - Dataloader for CrowdAI dataset
42 |
43 | ## Issues
44 | - Everything is ok till we want to create the polygons from the points (Top 256 predictions) and the predicted adjacency matrix (the adjacenecy matrix is predicted correctly based on the ground truth), I guess the main problem is the way we reconstruct the polygons; we have the coordinates of the points and the adjacency matrix, but we don't know for example which vertex is the point (x,y); it's v1, v2 or vn? Probably we need to assign an ID to each point
--------------------------------------------------------------------------------
/docs/PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images.pdf
--------------------------------------------------------------------------------
/docs/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/arch.png
--------------------------------------------------------------------------------
/docs/matching.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/matching.png
--------------------------------------------------------------------------------
/docs/matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/matrix.png
--------------------------------------------------------------------------------
/docs/outputs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/outputs.png
--------------------------------------------------------------------------------
/docs/vertex-detection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/vertex-detection.png
--------------------------------------------------------------------------------
/models/backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import init
5 |
6 |
7 | class DetectionBranch(nn.Module):
8 | def __init__(self):
9 | super(DetectionBranch,self).__init__()
10 | self.conv = nn.Sequential(
11 | nn.Conv2d(64, 64, kernel_size=1,stride=1,padding=0,bias=True),
12 | nn.BatchNorm2d(64),
13 | nn.ReLU(inplace=True),
14 | nn.Conv2d(64, 1, kernel_size=1,stride=1,padding=0,bias=True)
15 | )
16 |
17 | def forward(self,x):
18 | x = self.conv(x)
19 | return x
20 |
21 |
22 | class up_conv(nn.Module):
23 | def __init__(self,ch_in,ch_out):
24 | super(up_conv,self).__init__()
25 | self.up = nn.Sequential(
26 | nn.Upsample(scale_factor=2),
27 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
28 | nn.BatchNorm2d(ch_out),
29 | nn.ReLU(inplace=True)
30 | )
31 |
32 | def forward(self,x):
33 | x = self.up(x)
34 | return x
35 |
36 |
37 | class Recurrent_block(nn.Module):
38 | def __init__(self,ch_out,t=2):
39 | super(Recurrent_block,self).__init__()
40 | self.t = t
41 | self.ch_out = ch_out
42 | self.conv = nn.Sequential(
43 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
44 | nn.BatchNorm2d(ch_out),
45 | nn.ReLU(inplace=True)
46 | )
47 |
48 | def forward(self,x):
49 | for i in range(self.t):
50 |
51 | if i==0:
52 | x1 = self.conv(x)
53 |
54 | x1 = self.conv(x+x1)
55 | return x1
56 |
57 |
58 | class RRCNN_block(nn.Module):
59 | def __init__(self,ch_in,ch_out,t=2):
60 | super(RRCNN_block,self).__init__()
61 | self.RCNN = nn.Sequential(
62 | Recurrent_block(ch_out,t=t),
63 | Recurrent_block(ch_out,t=t)
64 | )
65 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
66 |
67 | def forward(self,x):
68 | x = self.Conv_1x1(x)
69 | x1 = self.RCNN(x)
70 | return x+x1
71 |
72 |
73 | class R2U_Net(nn.Module):
74 | def __init__(self,img_ch=3,t=1):
75 | super(R2U_Net,self).__init__()
76 |
77 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
78 | self.Upsample = nn.Upsample(scale_factor=2)
79 |
80 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t)
81 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t)
82 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t)
83 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t)
84 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t)
85 |
86 | self.Up5 = up_conv(ch_in=1024,ch_out=512)
87 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t)
88 |
89 | self.Up4 = up_conv(ch_in=512,ch_out=256)
90 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t)
91 |
92 | self.Up3 = up_conv(ch_in=256,ch_out=128)
93 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t)
94 |
95 | self.Up2 = up_conv(ch_in=128,ch_out=64)
96 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t)
97 |
98 |
99 | def forward(self,x):
100 | # encoding path
101 | x1 = self.RRCNN1(x)
102 |
103 | x2 = self.Maxpool(x1)
104 | x2 = self.RRCNN2(x2)
105 |
106 | x3 = self.Maxpool(x2)
107 | x3 = self.RRCNN3(x3)
108 |
109 | x4 = self.Maxpool(x3)
110 | x4 = self.RRCNN4(x4)
111 |
112 | x5 = self.Maxpool(x4)
113 | x5 = self.RRCNN5(x5)
114 |
115 | # decoding + concat path
116 | d5 = self.Up5(x5)
117 | d5 = torch.cat((x4,d5),dim=1)
118 | d5 = self.Up_RRCNN5(d5)
119 |
120 | d4 = self.Up4(d5)
121 | d4 = torch.cat((x3,d4),dim=1)
122 | d4 = self.Up_RRCNN4(d4)
123 |
124 | d3 = self.Up3(d4)
125 | d3 = torch.cat((x2,d3),dim=1)
126 | d3 = self.Up_RRCNN3(d3)
127 |
128 | d2 = self.Up2(d3)
129 | d2 = torch.cat((x1,d2),dim=1)
130 | d2 = self.Up_RRCNN2(d2)
131 |
132 | return d2
133 |
134 |
135 | class NonMaxSuppression(nn.Module):
136 | def __init__(self, n_peaks=256):
137 | super(NonMaxSuppression,self).__init__()
138 | self.k = 3 # kernel
139 | self.p = 1 # padding
140 | self.s = 1 # stride
141 | self.center_idx = self.k**2//2
142 | self.sigmoid = nn.Sigmoid()
143 | self.unfold = nn.Unfold(kernel_size=self.k, padding=self.p, stride=self.s)
144 | self.n_peaks = n_peaks
145 |
146 | def sample_peaks(self, x):
147 | B, _, H, W = x.shape
148 | for b in range(B):
149 | x_b = x[b,0]
150 | idx = torch.topk(x_b.flatten(), self.n_peaks).indices
151 | idx_i = torch.div(idx, W, rounding_mode='floor')
152 | idx_j = idx % W
153 | idx = torch.cat((idx_i.unsqueeze(1), idx_j.unsqueeze(1)), dim=1)
154 | idx = idx.unsqueeze(0)
155 |
156 | if b == 0:
157 | graph = idx
158 | else:
159 | graph = torch.cat((graph, idx), dim=0)
160 |
161 | return graph
162 |
163 | def forward(self, feat):
164 | B, C, H, W = feat.shape
165 |
166 | x = self.sigmoid(feat)
167 |
168 | # Prepare filter
169 | f = self.unfold(x).view(B, self.k**2, H, W)
170 | f = torch.argmax(f, dim=1).unsqueeze(1)
171 | f = (f == self.center_idx).float()
172 |
173 | # Apply filter
174 | x = x * f
175 |
176 | # Sample top peaks
177 | graph = self.sample_peaks(x)
178 | return x, graph
--------------------------------------------------------------------------------
/models/matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from copy import deepcopy
4 | from utils.utils import scores_to_permutations, permutations_to_polygons
5 |
6 |
7 | def MultiLayerPerceptron(channels: list, batch_norm=True):
8 | n_layers = len(channels)
9 |
10 | layers = []
11 | for i in range(1, n_layers):
12 | layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
13 |
14 | if i < (n_layers - 1):
15 | if batch_norm:
16 | layers.append(nn.BatchNorm1d(channels[i]))
17 | layers.append(nn.ReLU())
18 |
19 | return nn.Sequential(*layers)
20 |
21 |
22 | class Attention(nn.Module):
23 |
24 | def __init__(self, n_heads: int, d_model: int):
25 | super().__init__()
26 | assert d_model % n_heads == 0
27 | self.dim = d_model // n_heads
28 | self.n_heads = n_heads
29 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
30 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
31 |
32 | def forward(self, query, key, value):
33 | b = query.size(0)
34 | query, key, value = [l(x).view(b, self.dim, self.n_heads, -1)
35 | for l, x in zip(self.proj, (query, key, value))]
36 |
37 | b, d, h, n = query.shape
38 | scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / d**.5
39 | attn = torch.einsum('bhnm,bdhm->bdhn', torch.nn.functional.softmax(scores, dim=-1), value)
40 |
41 | return self.merge(attn.contiguous().view(b, self.dim*self.n_heads, -1))
42 |
43 |
44 | class AttentionalPropagation(nn.Module):
45 |
46 | def __init__(self, feature_dim: int, n_heads: int):
47 | super().__init__()
48 | self.attn = Attention(n_heads, feature_dim)
49 | self.mlp = MultiLayerPerceptron([feature_dim*2, feature_dim*2, feature_dim])
50 | nn.init.constant_(self.mlp[-1].bias, 0.0)
51 |
52 | def forward(self, x):
53 | message = self.attn(x, x, x)
54 | return self.mlp(torch.cat([x, message], dim=1))
55 |
56 |
57 | class AttentionalGNN(nn.Module):
58 |
59 | def __init__(self, feature_dim: int, num_layers: int):
60 | super().__init__()
61 | self.conv_init = nn.Sequential(
62 | nn.Conv1d(feature_dim + 2, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
63 | nn.BatchNorm1d(feature_dim),
64 | nn.ReLU(inplace=True),
65 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
66 | nn.BatchNorm1d(feature_dim),
67 | nn.ReLU(inplace=True),
68 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True)
69 | )
70 |
71 | self.layers = nn.ModuleList([
72 | AttentionalPropagation(feature_dim, 4)
73 | for _ in range(num_layers)])
74 |
75 | self.conv_desc = nn.Sequential(
76 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
77 | nn.BatchNorm1d(feature_dim),
78 | nn.ReLU(inplace=True),
79 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
80 | nn.BatchNorm1d(feature_dim),
81 | nn.ReLU(inplace=True)
82 | )
83 |
84 | self.conv_offset = nn.Sequential(
85 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
86 | nn.BatchNorm1d(feature_dim),
87 | nn.ReLU(inplace=True),
88 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True),
89 | nn.BatchNorm1d(feature_dim),
90 | nn.ReLU(inplace=True),
91 | nn.Conv1d(feature_dim, 2, kernel_size=1,stride=1,padding=0,bias=True),
92 | nn.Hardtanh()
93 | )
94 |
95 | def forward(self, feat, graph):
96 | graph = graph.permute(0,2,1)
97 | feat = torch.cat((feat, graph), dim=1)
98 | feat = self.conv_init(feat)
99 |
100 | for layer in self.layers:
101 | feat = feat + layer(feat)
102 |
103 | desc = self.conv_desc(feat)
104 | offset = self.conv_offset(feat).permute(0,2,1)
105 | return desc, offset
106 |
107 |
108 | class ScoreNet(nn.Module):
109 |
110 | def __init__(self, in_ch):
111 | super().__init__()
112 | self.relu = nn.ReLU(inplace=True)
113 | self.conv1 = nn.Conv2d(in_ch, 256, kernel_size=1, stride=1, padding=0, bias=True)
114 | self.bn1 = nn.BatchNorm2d(256)
115 | self.conv2 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True)
116 | self.bn2 = nn.BatchNorm2d(128)
117 | self.conv3 = nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=True)
118 | self.bn3 = nn.BatchNorm2d(64)
119 | self.conv4 = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0, bias=True)
120 |
121 | def forward(self, x):
122 | n_points = x.shape[-1]
123 |
124 | x = x.unsqueeze(-1)
125 | x = x.repeat(1,1,1,n_points)
126 | t = torch.transpose(x, 2, 3)
127 | x = torch.cat((x, t), dim=1)
128 |
129 | x = self.conv1(x)
130 | x = self.bn1(x)
131 | x = self.relu(x)
132 |
133 | x = self.conv2(x)
134 | x = self.bn2(x)
135 | x = self.relu(x)
136 |
137 | x = self.conv3(x)
138 | x = self.bn3(x)
139 | x = self.relu(x)
140 |
141 | x = self.conv4(x)
142 | return x[:,0]
143 |
144 |
145 | class OptimalMatching(nn.Module):
146 |
147 | def __init__(self):
148 | super(OptimalMatching, self).__init__()
149 |
150 | # Default configuration settings
151 | self.descriptor_dim = 64
152 | self.sinkhorn_iterations = 100
153 | self.attention_layers = 4
154 | self.correction_radius = 0.05
155 |
156 | # Modules
157 | self.scorenet1 = ScoreNet(self.descriptor_dim * 2)
158 | self.scorenet2 = ScoreNet(self.descriptor_dim * 2)
159 | self.gnn = AttentionalGNN(self.descriptor_dim, self.attention_layers)
160 |
161 |
162 | def normalize_coordinates(self, graph, ws, input):
163 | if input == 'global':
164 | graph = (graph * 2 / ws - 1)
165 | elif input == 'normalized':
166 | graph = ((graph + 1) * ws / 2)
167 | graph = torch.round(graph).long()
168 | graph[graph < 0] = 0
169 | graph[graph >= ws] = ws - 1
170 | return graph
171 |
172 |
173 | def predict(self, image, descriptors, graph):
174 | B, _, H, W = image.shape
175 | B, N, _ = graph.shape
176 |
177 | for b in range(B):
178 | b_desc = descriptors[b]
179 | b_graph = graph[b]
180 |
181 | # Extract descriptors
182 | b_desc = b_desc[:, b_graph[:,0], b_graph[:,1]]
183 |
184 | # Concatenate descriptors in batches
185 | if b == 0:
186 | sel_desc = b_desc.unsqueeze(0)
187 | else:
188 | sel_desc = torch.cat((sel_desc, b_desc.unsqueeze(0)), dim=0)
189 |
190 | # Multi-layer Transformer network.
191 | norm_graph = self.normalize_coordinates(graph, W, input="global") #out: normalized coordinate system [-1, 1]
192 | sel_desc, offset = self.gnn(sel_desc, norm_graph)
193 |
194 | # Correct points coordinates
195 | norm_graph = norm_graph + offset * self.correction_radius
196 | graph = self.normalize_coordinates(norm_graph, W, input="normalized") # out: global coordinate system [0, W]
197 |
198 | # Compute scores
199 | scores_1 = self.scorenet1(sel_desc)
200 | scores_2 = self.scorenet2(sel_desc)
201 | scores = scores_1 + torch.transpose(scores_2, 1, 2)
202 |
203 | scores = scores_to_permutations(scores)
204 | poly = permutations_to_polygons(scores, graph, out='coco')
205 |
206 | return poly
207 |
--------------------------------------------------------------------------------
/utils/coco_IoU_cIoU.py:
--------------------------------------------------------------------------------
1 | from pycocotools.coco import COCO
2 | from pycocotools import mask as cocomask
3 | import numpy as np
4 | import json
5 | from tqdm import tqdm
6 |
7 | def calc_IoU(a, b):
8 | i = np.logical_and(a, b)
9 | u = np.logical_or(a, b)
10 | I = np.sum(i)
11 | U = np.sum(u)
12 |
13 | iou = I/(U + 1e-9)
14 |
15 | is_void = U == 0
16 | if is_void:
17 | return 1.0
18 | else:
19 | return iou
20 |
21 | def compute_IoU_cIoU(input_json, gti_annotations):
22 | # Ground truth annotations
23 | coco_gti = COCO(gti_annotations)
24 |
25 | # Predictions annotations
26 | submission_file = json.loads(open(input_json).read())
27 | coco = COCO(gti_annotations)
28 | coco = coco.loadRes(submission_file)
29 |
30 |
31 | image_ids = coco.getImgIds(catIds=coco.getCatIds())
32 | bar = tqdm(image_ids)
33 |
34 | list_iou = []
35 | list_ciou = []
36 | for image_id in bar:
37 |
38 | img = coco.loadImgs(image_id)[0]
39 |
40 | annotation_ids = coco.getAnnIds(imgIds=img['id'])
41 | annotations = coco.loadAnns(annotation_ids)
42 | N = 0
43 | for _idx, annotation in enumerate(annotations):
44 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width'])
45 | m = cocomask.decode(rle)
46 | if _idx == 0:
47 | mask = m.reshape((img['height'], img['width']))
48 | N = len(annotation['segmentation'][0]) // 2
49 | else:
50 | mask = mask + m.reshape((img['height'], img['width']))
51 | N = N + len(annotation['segmentation'][0]) // 2
52 |
53 | mask = mask != 0
54 |
55 |
56 | annotation_ids = coco_gti.getAnnIds(imgIds=img['id'])
57 | annotations = coco_gti.loadAnns(annotation_ids)
58 | N_GT = 0
59 | for _idx, annotation in enumerate(annotations):
60 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width'])
61 | m = cocomask.decode(rle)
62 | if _idx == 0:
63 | mask_gti = m.reshape((img['height'], img['width']))
64 | N_GT = len(annotation['segmentation'][0]) // 2
65 | else:
66 | mask_gti = mask_gti + m.reshape((img['height'], img['width']))
67 | N_GT = N_GT + len(annotation['segmentation'][0]) // 2
68 |
69 | mask_gti = mask_gti != 0
70 |
71 | ps = 1 - np.abs(N - N_GT) / (N + N_GT + 1e-9)
72 | iou = calc_IoU(mask, mask_gti)
73 | list_iou.append(iou)
74 | list_ciou.append(iou * ps)
75 |
76 | bar.set_description("iou: %2.4f, c-iou: %2.4f" % (np.mean(list_iou), np.mean(list_ciou)))
77 | bar.refresh()
78 |
79 | print("Done!")
80 | print("Mean IoU: ", np.mean(list_iou))
81 | print("Mean C-IoU: ", np.mean(list_ciou))
82 |
83 |
84 |
85 | if __name__ == "__main__":
86 | compute_IoU_cIoU(input_json="./predictions.json",
87 | gti_annotations="/home/stefano/Workspace/data/mapping_challenge_dataset/raw/val/annotation.json")
88 |
--------------------------------------------------------------------------------
/utils/coco_to_shp.py:
--------------------------------------------------------------------------------
1 | from pycocotools.coco import COCO
2 | import numpy as np
3 | import json
4 | from tqdm import tqdm
5 | import shapefile
6 |
7 |
8 | def cocojson_to_shapefiles(input_json, gti_annotations, output_folder):
9 |
10 | submission_file = json.loads(open(input_json).read())
11 | coco = COCO(gti_annotations)
12 | coco = coco.loadRes(submission_file)
13 |
14 | image_ids = coco.getImgIds(catIds=coco.getCatIds())
15 |
16 | for image_id in tqdm(image_ids):
17 |
18 | img = coco.loadImgs(image_id)[0]
19 |
20 | annotation_ids = coco.getAnnIds(imgIds=img['id'])
21 | annotations = coco.loadAnns(annotation_ids)
22 |
23 | list_poly = []
24 | for _idx, annotation in enumerate(annotations):
25 | poly = annotation['segmentation'][0]
26 | poly = np.array(poly)
27 | poly = poly.reshape((-1,2))
28 |
29 | poly[:,1] = -poly[:,1]
30 | list_poly.append(poly.tolist())
31 |
32 | number_str = str(image_id).zfill(12)
33 | w = shapefile.Writer(output_folder + '%s.shp' % number_str)
34 | w.field('name', 'C')
35 | w.poly(list_poly)
36 | w.record("polygon")
37 | w.close()
38 |
39 | print("Done!")
40 |
41 |
42 | if __name__ == "__main__":
43 | cocojson_to_shapefiles(input_json="./predictions.json",
44 | gti_annotations="data/val/annotation.json",
45 | output_folder="./shapefiles/")
46 |
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | from skimage import io
4 | from skimage.transform import resize
5 | import torch
6 | from torch.utils.data import Dataset
7 | import json
8 | import pandas as pd
9 | import cv2
10 |
11 | class CrowdAI(Dataset):
12 |
13 |
14 | def __init__(self, images_directory, annotations_path, window_size=320):
15 |
16 | self.IMAGES_DIRECTORY = images_directory
17 | self.ANNOTATIONS_PATH = annotations_path
18 |
19 | self.window_size = window_size
20 | self.max_points = 256
21 |
22 | # load annotation json
23 | with open(self.ANNOTATIONS_PATH) as f:
24 | self.annotations = json.load(f)
25 |
26 | self.images = pd.DataFrame(self.annotations['images'])
27 | self.labels = pd.DataFrame(self.annotations['annotations'])
28 |
29 |
30 | # Generate the Permutation
31 | # torch.manual_seed(0)
32 | # self.permutation = torch.randperm(self.max_points)
33 |
34 |
35 |
36 |
37 |
38 | def _shuffle_adjacency_matrix(self, A):
39 | """
40 | generates a new permutation for each sample (or batch)
41 | and shuffles the adjacency matrix A accordingly
42 | """
43 |
44 | n = A.shape[0]
45 | torch.manual_seed(0)
46 | permutation = torch.randperm(n)
47 |
48 | shuffled_A = A[permutation, :][:, permutation]
49 |
50 | return shuffled_A, permutation
51 |
52 |
53 |
54 | '''
55 | def _shuffle_adjacency_matrix(self, A):
56 | """
57 | generates a new permutation for each sample (or batch)
58 | and shuffles the adjacency matrix A accordingly
59 | """
60 | n = A.shape[0]
61 | shuffled_A = A[self.permutation, :][:, self.permutation]
62 | return shuffled_A, self.permutation[:n]
63 | '''
64 |
65 |
66 | def _shuffle_vector(self, v, permutation):
67 | return v[permutation]
68 |
69 |
70 | def _create_adjacency_matrix(self, segmentations, N=256):
71 |
72 | adjacency_matrix = torch.zeros((N, N), dtype=torch.uint8)
73 | dic = {}
74 |
75 | n, m = 0, 0
76 | graph = torch.zeros((N, 2), dtype=torch.float32) # to create the seg mask with permutation_to_polygon()
77 | vertices = torch.zeros((N, 2), dtype=torch.float32) # to sort the nms points in the training
78 | for i, polygon in enumerate(segmentations):
79 | for v, point in enumerate(polygon):
80 | if v != len(polygon) - 1:
81 | adjacency_matrix[n, n+1] = 1
82 | else:
83 | adjacency_matrix[n, n-v] = 1
84 |
85 |
86 | # We just use it to create the segmentation mask with permutation_to_polygon(), no other use
87 | graph[n] = torch.tensor(point)
88 | n += 1 # n must be incremented in this way due to the functionality of the permutation_to_polygon() function, so we use m for the dictionary
89 |
90 | if tuple(map(float, point)) not in dic:
91 | vertices[m] = torch.tensor(point)
92 | dic[tuple(map(float, point))] = m
93 | m += 1
94 |
95 |
96 |
97 | # Permute the adjacency matrix
98 | # adjacency_matrix[:n,:n], permutation = self._shuffle_adjacency_matrix(adjacency_matrix[:n, :n])
99 |
100 | # Fill the diagonal with 1s
101 | for i in range(n, N):
102 | adjacency_matrix[i, i] = 1
103 |
104 | # Permute the graph
105 | # graph[:n] = self._shuffle_vector(graph, permutation)
106 | # graph = graph[:n]
107 |
108 | return adjacency_matrix, graph, vertices[:m], dic
109 |
110 |
111 |
112 |
113 |
114 | def _create_segmentation_mask(self, polygons, image_size):
115 | mask = np.zeros((image_size, image_size), dtype=np.uint8)
116 | for polygon in polygons:
117 | cv2.fillPoly(mask, [polygon], 1)
118 | return torch.tensor(mask, dtype=torch.uint8)
119 |
120 |
121 |
122 | def _create_vertex_mask(self, polygons, image_shape=(320, 320)):
123 | mask = torch.zeros(image_shape, dtype=torch.uint8)
124 |
125 | for poly in polygons:
126 | for p in poly:
127 | mask[p[1], p[0]] = 1
128 |
129 | return mask
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 | def __len__(self):
139 | return len(self.images)
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 | def __getitem__(self, idx):
148 |
149 | image = io.imread(self.IMAGES_DIRECTORY + self.images['file_name'][idx])
150 | image = resize(image, (self.window_size, self.window_size), anti_aliasing=True)
151 | image = torch.from_numpy(image)
152 | width, height = self.images['width'][idx], self.images['height'][idx]
153 | ratio = self.window_size / max(width, height)
154 |
155 |
156 |
157 | # Get the image ID
158 | image_id = self.images['id'][idx]
159 | # Get all annotations for this image
160 | image_annotations = self.labels[self.labels['image_id'] == image_id]
161 | # get all polygons for the image
162 | segmentations = image_annotations['segmentation'].values
163 | segmentations = [e[0] for e in segmentations]
164 | for i, poly in enumerate(segmentations):
165 | # rescale the polygon
166 | poly = [int(e * ratio) for e in poly]
167 | # out of bounds check
168 | for j, e in enumerate(poly):
169 | if j % 2 == 0:
170 | poly[j] = min(max(0, e), self.window_size - 1)
171 | else:
172 | poly[j] = min(max(0, e), self.window_size - 1)
173 | segmentations[i] = poly
174 |
175 |
176 |
177 |
178 | # print(segmentations)
179 | segmentations = [np.array(poly, dtype=int).reshape(-1, 2) for poly in segmentations] # convert a list of polygons to a list of numpy arrays of points
180 | # print(segmentations)
181 |
182 |
183 |
184 |
185 | # create permutation matrix
186 | # permutation_matrix, graph, permutation = self._create_adjacency_matrix(segmentations, N=self.max_points)
187 | # create the simple adjacency matrix
188 | permutation_matrix, graph, vertices, dic = self._create_adjacency_matrix(segmentations, N=self.max_points)
189 | # create vertex mask
190 | vertex_mask = self._create_vertex_mask(segmentations, image_shape=(self.window_size, self.window_size))
191 | # create segmentation mask
192 | seg_mask = self._create_segmentation_mask(segmentations, image_size=self.window_size)
193 |
194 |
195 | segmentations = [torch.from_numpy(poly) for poly in segmentations]
196 |
197 |
198 |
199 |
200 |
201 | return image, vertex_mask, seg_mask, permutation_matrix, segmentations, graph, vertices, dic
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn.functional as F
3 | import torch
4 | from utils.utils import angle_between_points, prepare_gt_vertices
5 |
6 |
7 |
8 |
9 | def compute_l_angle_loss(gt_permutation_mask,
10 | vertices,
11 | pred_permutation_mat,
12 | graph, device='cuda'):
13 | v_gt = prepare_gt_vertices(vertices, device=device)
14 |
15 | v_gt_1 = torch.matmul(gt_permutation_mask, v_gt)
16 | v_gt_2 = torch.matmul(gt_permutation_mask, v_gt_1)
17 | gt_angle = angle_between_points(v_gt, v_gt_1, v_gt_2)
18 | #torch.isnan(gt_angle).any()
19 |
20 | pred_permutation_mat = pred_permutation_mat.to(device)
21 | v_pred_1 = torch.matmul(pred_permutation_mat, graph)
22 | v_pred_2 = torch.matmul(pred_permutation_mat, v_pred_1)
23 | pred_angle = angle_between_points(graph, v_pred_1.float(), v_pred_2.float())
24 | #torch.isnan(pred_angle).any()
25 |
26 | return pred_angle, gt_angle
27 |
28 |
29 |
30 | def cross_entropy_loss(sinkhorn_results, gt_permutation):
31 | '''
32 | It only considers the positive matches
33 | and tries to minimize the value of positive matches
34 |
35 |
36 | One considereation is the distribution of the vertices in the GT permutation matrix
37 | to overcome the overfitting and help the model to learn the
38 | correct permutation matrix
39 |
40 |
41 | One solution is:
42 | 1- order the polygons, based on the number of vertices
43 | 2- select first vertex of each polygon and assign it a unique ID
44 | 3- repeat 2
45 |
46 |
47 | Another Solution:
48 | - Using a random permutation matrix/vector
49 |
50 | '''
51 | loss_match = -torch.mean(torch.masked_select(sinkhorn_results, gt_permutation == 1))
52 | return loss_match
53 |
54 |
55 |
56 |
57 |
58 | # def iou_loss_function(pred, gt):
59 | # B, H, W = gt.shape
60 | # iou = 0
61 | # for batch in range(B):
62 | # K = len(pred[batch]) # Number of polygons
63 | # batch_tensor = np.zeros((K, H, W), dtype=np.uint8)
64 | # for i, poly in enumerate(pred[batch]):
65 | # cv2.fillPoly(batch_tensor[i], [poly.detach().cpu().numpy()], 1)
66 |
67 | # batch_pred_mask = torch.sum(torch.tensor(batch_tensor), dim=0).permute(1,0)
68 |
69 | # # plt.imshow(batch_pred_mask)
70 | # # plt.show()
71 | # # plt.imshow(gt[batch])
72 | # # plt.show()
73 |
74 | # intersection = torch.min(batch_pred_mask, gt[batch])
75 | # union = torch.max(batch_pred_mask, gt[batch])
76 | # batch_iou = torch.sum(intersection) / torch.sum(union)
77 | # iou += batch_iou
78 |
79 | # return torch.tensor(1 - iou, requires_grad=True)
80 |
81 |
82 |
83 | # def iou_loss_function(pred, gt):
84 | # B, H, W = gt.shape
85 | # total_iou = 0
86 |
87 | # for batch in range(B):
88 | # batch_pred = torch.zeros((H, W), device=gt.device)
89 | # for poly in pred[batch]:
90 | # # Convert polygon to mask
91 | # mask = torch.zeros((H, W), device=gt.device)
92 | # poly_tensor = poly.long() # Ensure integer coordinates
93 | # mask[poly_tensor[:, 1], poly_tensor[:, 0]] = 1
94 | # mask = F.max_pool2d(mask.unsqueeze(0).float(), kernel_size=3, stride=1, padding=1).squeeze(0)
95 | # batch_pred = torch.max(batch_pred, mask)
96 |
97 | # plt.imshow(mask)
98 | # plt.show()
99 | # plt.imshow(gt[batch])
100 | # plt.show()
101 |
102 | # intersection = torch.sum(torch.min(batch_pred, gt[batch]))
103 | # union = torch.sum(torch.max(batch_pred, gt[batch]))
104 | # batch_iou = intersection / (union + 1e-6) # Add small epsilon to avoid division by zero
105 | # total_iou += batch_iou
106 |
107 | # avg_iou = total_iou / B
108 | # return 1 - avg_iou
109 |
110 |
111 |
112 | # def iou_loss_function(pred_mask, gt_mask):
113 | # '''
114 | # pred_mask: (B, H, W)
115 | # gt_mask: (B, H, W)
116 | # '''
117 | # pred_mask = F.sigmoid(pred_mask)
118 | # intersection = torch.sum(torch.min(pred_mask, gt_mask))
119 | # union = torch.sum(torch.max(pred_mask, gt_mask))
120 | # iou = intersection / (union + 1e-6) # Add small epsilon to avoid division by zero
121 | # loss = 1 - iou
122 | # # loss.requires_grad = True
123 |
124 | # return loss # Return 1 - IoU to minimize
125 |
126 |
127 |
128 | def iou_loss_function(pred_mask, target_maks):
129 | pred_mask = F.sigmoid(pred_mask)
130 |
131 | intersection = (pred_mask * target_maks).sum()
132 | union = ((pred_mask + target_maks) - (pred_mask * target_maks)).sum()
133 | iou = intersection / union
134 | iou_dual = pred_mask.size(0) - iou
135 |
136 | #iou_dual = iou_dual / pred_mask.size(0)
137 | iou_dual.requires_grad = True
138 | return torch.mean(iou_dual)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.optimize import linear_sum_assignment
4 | import torch.nn.functional as F
5 | import math
6 |
7 | def scores_to_permutations(scores):
8 | """
9 | Input a batched array of scores and returns the hungarian optimized
10 | permutation matrices.
11 | """
12 | B, N, N = scores.shape
13 |
14 | scores = scores.detach().cpu().numpy()
15 | perm = np.zeros_like(scores)
16 | for b in range(B):
17 | r, c = linear_sum_assignment(-scores[b])
18 | perm[b,r,c] = 1
19 | return torch.tensor(perm)
20 |
21 |
22 |
23 |
24 | def permutations_to_polygons(perm, graph, out='torch'):
25 | B, N, N = perm.shape
26 |
27 | def bubble_merge(poly):
28 | s = 0
29 | P = len(poly)
30 | while s < P:
31 | head = poly[s][-1]
32 |
33 | t = s+1
34 | while t < P:
35 | tail = poly[t][0]
36 | if head == tail:
37 | poly[s] = poly[s] + poly[t][1:]
38 | del poly[t]
39 | poly = bubble_merge(poly)
40 | P = len(poly)
41 | t += 1
42 | s += 1
43 | return poly
44 |
45 | diag = torch.logical_not(perm[:,range(N),range(N)])
46 | batch = []
47 | for b in range(B):
48 | b_perm = perm[b]
49 | b_graph = graph[b]
50 | b_diag = diag[b]
51 |
52 | idx = torch.arange(N)[b_diag]
53 |
54 | if idx.shape[0] > 0:
55 | # If there are vertices in the batch
56 |
57 | b_perm = b_perm[idx,:]
58 | b_graph = b_graph[idx,:]
59 | b_perm = b_perm[:,idx]
60 |
61 | first = torch.arange(idx.shape[0]).unsqueeze(1)
62 | second = torch.argmax(b_perm, dim=1).unsqueeze(1).cpu()
63 |
64 | polygons_idx = torch.cat((first, second), dim=1).tolist()
65 | polygons_idx = bubble_merge(polygons_idx)
66 |
67 | batch_poly = []
68 | for p_idx in polygons_idx:
69 | if out == 'torch':
70 | batch_poly.append(b_graph[p_idx,:])
71 | elif out == 'numpy':
72 | batch_poly.append(b_graph[p_idx,:].numpy())
73 | elif out == 'list':
74 | g = b_graph[p_idx,:] * 300 / 320
75 | g[:,0] = -g[:,0]
76 | g = torch.fliplr(g)
77 | batch_poly.append(g.tolist())
78 | elif out == 'coco':
79 | g = b_graph[p_idx,:] * 300 / 320
80 | g = torch.fliplr(g)
81 | batch_poly.append(g.view(-1).tolist())
82 | else:
83 | print("Indicate a valid output polygon format")
84 | exit()
85 |
86 | batch.append(batch_poly)
87 |
88 | else:
89 | # If the batch has no vertices
90 | batch.append([])
91 |
92 | return batch
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 | def graph_to_vertex_mask(points, image):
102 | B, _, H, W = image.shape
103 |
104 | mask = torch.zeros((B, H, W), dtype=torch.uint8)
105 |
106 | # Loop style
107 | # for batch in range(B):
108 | # mask[batch, points[batch, :, 0], points[batch, :, 1]] = 1
109 |
110 | # Vectorized Style
111 | batch_indices = np.arange(B)[:, None]
112 | mask[batch_indices, points[:, :, 0], points[:, :, 1]] = 1
113 |
114 | return mask
115 |
116 |
117 |
118 |
119 |
120 | def polygon_to_vertex_mask(polygons: list):
121 | B = len(polygons)
122 | mask = torch.zeros((B, 320, 320), dtype=torch.uint8)
123 |
124 | for batch in range(B):
125 | batch_polygons = [np.array(poly, dtype=int) for poly in polygons[batch]]
126 | for poly in batch_polygons:
127 | for point in poly:
128 | mask[batch, point[1], point[0]] = 1
129 |
130 |
131 | return mask
132 |
133 |
134 |
135 |
136 |
137 | def tensor_to_numpy(input: list):
138 | ''' convert a list of tensors to a list of numpy arrays '''
139 | numpy = []
140 | for batch in range(len(input)):
141 | batch_polygons = [tensor.cpu().numpy() for tensor in input[batch]]
142 | numpy.append(batch_polygons)
143 | return numpy
144 |
145 |
146 |
147 |
148 |
149 | def point_to_polygon(points: list):
150 |
151 | B = len(points)
152 | polygons = []
153 |
154 | for batch in range(B):
155 | batch_polygons = [np.array(poly, dtype=int).reshape(-1, 2) for poly in points[batch]]
156 | polygons.append(batch_polygons)
157 |
158 | return polygons
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 | def polygon_to_seg_mask(polygons, image_size):
169 | B = len(polygons)
170 | mask = np.zeros((B, image_size, image_size), dtype=np.uint8)
171 | for batch in range(B):
172 | for polygon in polygons[batch]:
173 | # for polygon in polygons[0]:
174 | cv2.fillPoly(mask[batch], [polygon.detach().cpu().numpy()], 1)
175 | return torch.tensor(mask, dtype=torch.float32, device=device, requires_grad=True)
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 | def sort_sync_nsm_points(nms_graph, vertices, gt_index):
191 | '''
192 | nms graph: (B, N, 2)
193 | Vertices: numpy array of points (B, N, 2) --> N: number of unique points
194 | GT_Index: A Dictinary with points as its keys and their index as values (It's the same as vertices but with index values) (B, N)
195 | '''
196 | B, N, D = nms_graph.shape # 1, 256, 2
197 | sorted_nsm_points = np.zeros((B, N, D), dtype=int)
198 | nms_graph = nms_graph.detach().cpu().numpy()
199 |
200 | for b in range(B):
201 | sorted_nsm = np.zeros((N, D), dtype=int)
202 |
203 | n = vertices[b].shape[0] # Number of Viertices (Points)
204 | m = nms_graph[b].shape[0] # Number of Predicted Vertices (Points)
205 | # print(vertices[b].shape, nms_graph[b].shape)
206 | distances = np.linalg.norm(vertices[b][:, None] - nms_graph[b], axis=2) # Adds a new axis to vertices[b], changing its shape from (M, D) to (M, 1, D).
207 | distances = distances.reshape(n*m, 1)
208 | # print(distances.shape)
209 | distances = np.hstack((distances, np.repeat(vertices[b], m, axis=0),np.tile(nms_graph[b], (n, 1))))
210 | sorted_distance = distances[np.argsort(distances[:,0])]
211 |
212 | # Sort distances by the first column (distance)
213 | sorted_distances = distances[np.argsort(distances[:,0])]
214 |
215 | cndd_used = set()
216 | gt_used = set()
217 | cndd_mapped = {tuple(cndd):0 for cndd in nms_graph[b]}
218 |
219 | for d_p in sorted_distances:
220 | gt_p = tuple((d_p[1], d_p[2]))
221 | cndd_p = tuple((d_p[3], d_p[4]))
222 | if gt_p not in gt_used and cndd_p not in cndd_used:
223 | #print('we have a match ..', gt_p ,'->', cndd_p, ' with distance of ', d_p[0])
224 | # print(gt_p)
225 | sorted_nsm[gt_index[b][gt_p]]= list(cndd_p)
226 | gt_used.add(gt_p)
227 | cndd_used.add(cndd_p)
228 | cndd_mapped[cndd_p] = 1
229 |
230 | restart_index = n
231 | for k, v in cndd_mapped.items():
232 | if v ==0:
233 | sorted_nsm[restart_index] = list(k)
234 | restart_index +=1
235 | sorted_nsm_points[b] = sorted_nsm
236 | return torch.from_numpy(sorted_nsm_points)
237 |
238 |
239 |
240 |
241 |
242 | def prepare_gt_vertices(vertices, device='cuda', MAX_POINTS=256):
243 | B = len(vertices)
244 | v_gt = torch.empty((B, MAX_POINTS, 2), dtype=torch.float64)
245 | for b in range(B):
246 | gt_size = vertices[b].shape[0]
247 | extra = torch.full((MAX_POINTS - gt_size, 2), 0, dtype=torch.float64)
248 | extra_gt = torch.cat((vertices[b], extra), dim=0).to(device)
249 | v_gt[b] = extra_gt
250 | return v_gt.to(device)
251 |
252 |
253 |
254 |
255 | def angle_between_points(A, B, C, batch=False):
256 | d = 2 if batch else 1
257 | AB = A - B
258 | BC = C - B
259 | epsilon = 3e-8
260 |
261 | AB_mag = torch.norm(AB, dim=d) + epsilon
262 | BC_mag = torch.norm(BC, dim=d) + epsilon
263 |
264 | dot_product = torch.sum(AB * BC, dim=d)
265 | cos_theta = dot_product / (AB_mag * BC_mag)
266 |
267 | zero_mask = (AB_mag == 0) | (BC_mag == 0)
268 | cos_theta[zero_mask] = 0
269 | theta = torch.acos(torch.clamp(cos_theta, -1 + epsilon, 1 - epsilon))
270 | theta[zero_mask] = 0
271 | return theta * 180 / math.pi
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 | def soft_winding_number(pred_polys, lam=1000, img_size=320, device='cuda'):
285 |
286 | B = len(pred_polys)
287 | IMG_SIZE = img_size
288 | pred_mask = torch.zeros((B, IMG_SIZE, IMG_SIZE)).to(device)
289 |
290 | x = torch.arange(IMG_SIZE)
291 | y = torch.arange(IMG_SIZE)
292 | xx, yy = torch.meshgrid(x,y)
293 |
294 | pixel_coords = torch.stack([yy, xx], dim=-1).float()
295 |
296 | for b in range(B):
297 | vertices = torch.vstack(pred_polys[b]).float()
298 | #vertices = vertices.detach().cpu()
299 | #vertices.requires_grad=True
300 | #vertices = vertices.unfold(dimension = 0,size = 2, step = 1)
301 | #vertices_repeated = vertices.repeat_interleave(IMG_SIZE*IMG_SIZE, dim=0)
302 |
303 | pairs = vertices[:-1].unsqueeze(1).repeat(1, 2, 1)
304 | pairs[:, 1, :] = vertices[1:]
305 |
306 | pairs_repeated = pairs.repeat_interleave(IMG_SIZE*IMG_SIZE, dim=0)
307 |
308 | #pixel_coords_angle = pixel_coords.repeat(vertices.shape[0],1,1).view(vertices.shape[0] *IMG_SIZE*IMG_SIZE,2)
309 | #pixel_coords_det = pixel_coords.repeat(vertices.shape[0],1,1).view(vertices.shape[0] *IMG_SIZE*IMG_SIZE ,2,1)
310 | pixel_coords_angle = pixel_coords.repeat(pairs.shape[0],1,1).view(pairs.shape[0] *IMG_SIZE*IMG_SIZE ,1 , 2).to(device)
311 |
312 | concatenated = torch.cat([pairs_repeated, pixel_coords_angle], dim=1)
313 |
314 | #ones = torch.ones(IMG_SIZE*IMG_SIZE*vertices.shape[0], 3).reshape(IMG_SIZE*IMG_SIZE*vertices.shape[0],1, 3)
315 |
316 | ones = torch.ones(pairs.shape[0] *IMG_SIZE*IMG_SIZE, 3, 1).to(device) #.reshape(vertices.shape[0]-1 *IMG_SIZE*IMG_SIZE,3, 1)
317 | output = torch.cat((concatenated, ones), dim=2)
318 |
319 | det = torch.det(output)
320 |
321 | # compute angle
322 | angles = angle_between_points(pairs_repeated[:, 0], pixel_coords_angle.view(pairs.shape[0] *IMG_SIZE*IMG_SIZE, 2), pairs_repeated[:, 1], batch=False)
323 |
324 | #Compute the soft winding number using equation 13
325 | w = (lam * det) / (1 + torch.abs(det *lam))
326 | w = w * angles
327 |
328 | w = w.view(pairs.shape[0], IMG_SIZE, IMG_SIZE)
329 | # Sum over all pairs of adjacent vertices to get the winding number
330 | w = w.sum(dim=0)
331 |
332 | pred_mask[b] = w
333 |
334 | return pred_mask
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 | # def sinkhorn_knopp(cost_matrix, epsilon=0.05, iterations=100):
354 | # """
355 | # Sinkhorn-Knopp algorithm to approximate optimal transport
356 | # """
357 | # B, N, _ = cost_matrix.shape
358 | # log_mu = torch.zeros_like(cost_matrix)
359 | # log_nu = torch.zeros_like(cost_matrix)
360 |
361 | # u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
362 | # for _ in range(iterations):
363 | # u = epsilon * (torch.log(torch.ones(B, N, 1, device=cost_matrix.device)) -
364 | # torch.logsumexp((-cost_matrix + v) / epsilon, dim=-1, keepdim=True)) + log_mu
365 | # v = epsilon * (torch.log(torch.ones(B, 1, N, device=cost_matrix.device)) -
366 | # torch.logsumexp((-cost_matrix + u) / epsilon, dim=-2, keepdim=True)) + log_nu
367 |
368 | # return torch.exp((-cost_matrix + u + v) / epsilon)
369 |
370 | # def scores_to_permutations(scores, temperature=0.1):
371 | # """
372 | # Input a batched array of scores and returns the approximate
373 | # permutation matrices using Sinkhorn-Knopp algorithm.
374 | # Preserves gradients for backpropagation.
375 | # """
376 | # B, N, _ = scores.shape
377 |
378 | # # Normalize scores to be non-negative
379 | # scores_normalized = scores - scores.min(dim=-1, keepdim=True)[0]
380 |
381 | # # Use Sinkhorn-Knopp to approximate permutation matrices
382 | # perm_soft = sinkhorn_knopp(-scores_normalized)
383 |
384 | # # Use a differentiable approximation of argmax
385 | # perm_hard = F.gumbel_softmax(torch.log(perm_soft + 1e-8), tau=temperature, hard=True, dim=-1)
386 |
387 | # return perm_hard
--------------------------------------------------------------------------------