├── .gitignore ├── LICENSE ├── README.md ├── docs ├── PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images.pdf ├── arch.png ├── matching.png ├── matrix.png ├── outputs.png └── vertex-detection.png ├── models ├── backbone.py └── matching.py ├── train.ipynb └── utils ├── coco_IoU_cIoU.py ├── coco_to_shp.py ├── dataset.py ├── loss.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /trained_weights/ 3 | 4 | 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 115 | .pdm.toml 116 | .pdm-python 117 | .pdm-build/ 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | SOFTWARE LICENSE AGREEMENT 2 | 3 | BlackShark.ai Software – 2022, all rights reserved, hereinafter "the Software". 4 | 5 | This software has been developed by researchers of ICG and BlackShark.ai. 6 | 7 | Institute of Computer Graphics and Vision (ICG), Inffeldgasse 16/II, 8 | 8010 Graz, Austria 9 | 10 | BlackShark.ai, Am Eisernen Tor 1/3 11 | 8010 Graz, Austria 12 | 13 | BlackShark.ai holds all the ownership rights on the Software. 14 | 15 | The Software is still being currently developed. It is the BlackShark.ai's aim for the Software 16 | to be used by the scientific community so as to test it and, evaluate it so that BlackShark.ai may improve it. 17 | 18 | For these reasons BlackShark.ai has decided to distribute the Software. 19 | 20 | BlackShark.ai grants to the academic user, a free of charge, without right to sub license non-exclusive right 21 | to use the Software for research purposes for a period of one (1) year from the date of the download 22 | of the source code. Any other use without of prior consent of BlackShark.ai is prohibited. 23 | 24 | The academic user explicitly acknowledges having received from BlackShark.ai all information allowing him 25 | to appreciate the adequacy between of the Software and his needs and to undertake all necessary 26 | precautions for his execution and use. 27 | 28 | The Software is provided only as a source. 29 | 30 | In case of using the Software for a publication or other results obtained through the use of the Software, 31 | user should cite the Software as follows : 32 | 33 | @article{zorzi2021polyworld, 34 | title={PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images}, 35 | author={Zorzi, Stefano and Bazrafkan, Shabab and Habenschuss, Stefan and Fraundorfer, Friedrich}, 36 | journal={arXiv preprint arXiv:2111.15491}, 37 | year={2021} 38 | } 39 | 40 | Every user of the Software could communicate to the developers [stefano.zorzi@icg.tugraz.at] 41 | his or her remarks as to the use of the Software. 42 | 43 | THE USER CANNOT USE, EXPLOIT OR COMMERCIALLY DISTRIBUTE THE SOFTWARE WITHOUT PRIOR AND EXPLICIT CONSENT 44 | OF BlackShark.ai (sbazrafkan@blackshark.ai). ANY SUCH ACTION WILL CONSTITUTE A FORGERY. 45 | 46 | THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, 47 | WITH REGARDS TO COMMERCIAL USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALIZATION OR ADAPTATION. 48 | 49 | UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL BlackShark.ai OR THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 50 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 51 | GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 52 | WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, 53 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images 2 | ![](docs/outputs.png) 3 | 4 | PolyWorld is a research project conducted by the Institute of Computer Graphics and Vision of TUGraz, in collaboration with BlackShark.ai. PolyWorld is a neural network that extracts polygonal objects from an image in an end-to-end fashion. The model detects vertex candidates and predicts the connection strenght between each pair of vertices using a Graph Neural Network. 5 | 6 | This repo tries to train the network from scratch 7 | 8 | - Official Repository: [PolyWorld](https://github.com/zorzi-s/PolyWorldPretrainedNetwork) 9 | 10 | - Paper PDF: [PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images](https://arxiv.org/abs/2111.15491) 11 | 12 | - Authors: Stefano Zorzi, Shabab Bazrafkan, Stefan Habenschuss, Friedrich Fraundorfer 13 | 14 | - Video: [YouTube link](https://youtu.be/C80dojBosLQ) 15 | 16 | - Poster: [Seafile link](https://files.icg.tugraz.at/f/6a044f133c0d4dd992c5/) 17 | 18 | --- 19 | 20 | You can start training the network in **train.ipynb** notebook 21 | - We can load the pre-trained weights for the backbone and the vertex detection network from here (https://github.com/zorzi-s/PolyWorldPretrainedNetwork) and freeze them for training to only train the Matching Network 22 | 23 | ## Architecture 24 | Architecture 25 | 26 | 27 | ## Main Components 28 | - Backbone CNN (Pre-trained) 29 | - Vertex Detection (1x1 Conv) (Pre-trained) 30 | Vertex Detection 31 | 32 | - NMS (Selects top 256 points) (Not Trainable) 33 | - Optimal Matching (Attentional GNN) (Our main challenge for training) 34 | Matching 35 | 36 | - Polygon Reconstruction (Not Trainable) (Using predicted adjacency matrix and top 256 points) (Our main challenge for training) 37 | 38 | ## Contributions 39 | - Applied random permutations on groundtruth permutation matrix (line 31-32-98-105-106 in dataset.py) 40 | - Applied Sinkhorn algorithm in the matching step (Thanks to https://github.com/henokyen/henokyen_polyworld) 41 | - Dataloader for CrowdAI dataset 42 | 43 | ## Issues 44 | - Everything is ok till we want to create the polygons from the points (Top 256 predictions) and the predicted adjacency matrix (the adjacenecy matrix is predicted correctly based on the ground truth), I guess the main problem is the way we reconstruct the polygons; we have the coordinates of the points and the adjacency matrix, but we don't know for example which vertex is the point (x,y); it's v1, v2 or vn? Probably we need to assign an ID to each point -------------------------------------------------------------------------------- /docs/PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/PolyWorld: Polygonal Building Extraction with Graph Neural Networks in Satellite Images.pdf -------------------------------------------------------------------------------- /docs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/arch.png -------------------------------------------------------------------------------- /docs/matching.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/matching.png -------------------------------------------------------------------------------- /docs/matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/matrix.png -------------------------------------------------------------------------------- /docs/outputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/outputs.png -------------------------------------------------------------------------------- /docs/vertex-detection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amirafshari/polyworld/a8e0d72d32312e56c8c3f1bbc123060d45b40373/docs/vertex-detection.png -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | 6 | 7 | class DetectionBranch(nn.Module): 8 | def __init__(self): 9 | super(DetectionBranch,self).__init__() 10 | self.conv = nn.Sequential( 11 | nn.Conv2d(64, 64, kernel_size=1,stride=1,padding=0,bias=True), 12 | nn.BatchNorm2d(64), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(64, 1, kernel_size=1,stride=1,padding=0,bias=True) 15 | ) 16 | 17 | def forward(self,x): 18 | x = self.conv(x) 19 | return x 20 | 21 | 22 | class up_conv(nn.Module): 23 | def __init__(self,ch_in,ch_out): 24 | super(up_conv,self).__init__() 25 | self.up = nn.Sequential( 26 | nn.Upsample(scale_factor=2), 27 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 28 | nn.BatchNorm2d(ch_out), 29 | nn.ReLU(inplace=True) 30 | ) 31 | 32 | def forward(self,x): 33 | x = self.up(x) 34 | return x 35 | 36 | 37 | class Recurrent_block(nn.Module): 38 | def __init__(self,ch_out,t=2): 39 | super(Recurrent_block,self).__init__() 40 | self.t = t 41 | self.ch_out = ch_out 42 | self.conv = nn.Sequential( 43 | nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 44 | nn.BatchNorm2d(ch_out), 45 | nn.ReLU(inplace=True) 46 | ) 47 | 48 | def forward(self,x): 49 | for i in range(self.t): 50 | 51 | if i==0: 52 | x1 = self.conv(x) 53 | 54 | x1 = self.conv(x+x1) 55 | return x1 56 | 57 | 58 | class RRCNN_block(nn.Module): 59 | def __init__(self,ch_in,ch_out,t=2): 60 | super(RRCNN_block,self).__init__() 61 | self.RCNN = nn.Sequential( 62 | Recurrent_block(ch_out,t=t), 63 | Recurrent_block(ch_out,t=t) 64 | ) 65 | self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0) 66 | 67 | def forward(self,x): 68 | x = self.Conv_1x1(x) 69 | x1 = self.RCNN(x) 70 | return x+x1 71 | 72 | 73 | class R2U_Net(nn.Module): 74 | def __init__(self,img_ch=3,t=1): 75 | super(R2U_Net,self).__init__() 76 | 77 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 78 | self.Upsample = nn.Upsample(scale_factor=2) 79 | 80 | self.RRCNN1 = RRCNN_block(ch_in=img_ch,ch_out=64,t=t) 81 | self.RRCNN2 = RRCNN_block(ch_in=64,ch_out=128,t=t) 82 | self.RRCNN3 = RRCNN_block(ch_in=128,ch_out=256,t=t) 83 | self.RRCNN4 = RRCNN_block(ch_in=256,ch_out=512,t=t) 84 | self.RRCNN5 = RRCNN_block(ch_in=512,ch_out=1024,t=t) 85 | 86 | self.Up5 = up_conv(ch_in=1024,ch_out=512) 87 | self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512,t=t) 88 | 89 | self.Up4 = up_conv(ch_in=512,ch_out=256) 90 | self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256,t=t) 91 | 92 | self.Up3 = up_conv(ch_in=256,ch_out=128) 93 | self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128,t=t) 94 | 95 | self.Up2 = up_conv(ch_in=128,ch_out=64) 96 | self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64,t=t) 97 | 98 | 99 | def forward(self,x): 100 | # encoding path 101 | x1 = self.RRCNN1(x) 102 | 103 | x2 = self.Maxpool(x1) 104 | x2 = self.RRCNN2(x2) 105 | 106 | x3 = self.Maxpool(x2) 107 | x3 = self.RRCNN3(x3) 108 | 109 | x4 = self.Maxpool(x3) 110 | x4 = self.RRCNN4(x4) 111 | 112 | x5 = self.Maxpool(x4) 113 | x5 = self.RRCNN5(x5) 114 | 115 | # decoding + concat path 116 | d5 = self.Up5(x5) 117 | d5 = torch.cat((x4,d5),dim=1) 118 | d5 = self.Up_RRCNN5(d5) 119 | 120 | d4 = self.Up4(d5) 121 | d4 = torch.cat((x3,d4),dim=1) 122 | d4 = self.Up_RRCNN4(d4) 123 | 124 | d3 = self.Up3(d4) 125 | d3 = torch.cat((x2,d3),dim=1) 126 | d3 = self.Up_RRCNN3(d3) 127 | 128 | d2 = self.Up2(d3) 129 | d2 = torch.cat((x1,d2),dim=1) 130 | d2 = self.Up_RRCNN2(d2) 131 | 132 | return d2 133 | 134 | 135 | class NonMaxSuppression(nn.Module): 136 | def __init__(self, n_peaks=256): 137 | super(NonMaxSuppression,self).__init__() 138 | self.k = 3 # kernel 139 | self.p = 1 # padding 140 | self.s = 1 # stride 141 | self.center_idx = self.k**2//2 142 | self.sigmoid = nn.Sigmoid() 143 | self.unfold = nn.Unfold(kernel_size=self.k, padding=self.p, stride=self.s) 144 | self.n_peaks = n_peaks 145 | 146 | def sample_peaks(self, x): 147 | B, _, H, W = x.shape 148 | for b in range(B): 149 | x_b = x[b,0] 150 | idx = torch.topk(x_b.flatten(), self.n_peaks).indices 151 | idx_i = torch.div(idx, W, rounding_mode='floor') 152 | idx_j = idx % W 153 | idx = torch.cat((idx_i.unsqueeze(1), idx_j.unsqueeze(1)), dim=1) 154 | idx = idx.unsqueeze(0) 155 | 156 | if b == 0: 157 | graph = idx 158 | else: 159 | graph = torch.cat((graph, idx), dim=0) 160 | 161 | return graph 162 | 163 | def forward(self, feat): 164 | B, C, H, W = feat.shape 165 | 166 | x = self.sigmoid(feat) 167 | 168 | # Prepare filter 169 | f = self.unfold(x).view(B, self.k**2, H, W) 170 | f = torch.argmax(f, dim=1).unsqueeze(1) 171 | f = (f == self.center_idx).float() 172 | 173 | # Apply filter 174 | x = x * f 175 | 176 | # Sample top peaks 177 | graph = self.sample_peaks(x) 178 | return x, graph -------------------------------------------------------------------------------- /models/matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | from utils.utils import scores_to_permutations, permutations_to_polygons 5 | 6 | 7 | def MultiLayerPerceptron(channels: list, batch_norm=True): 8 | n_layers = len(channels) 9 | 10 | layers = [] 11 | for i in range(1, n_layers): 12 | layers.append(nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) 13 | 14 | if i < (n_layers - 1): 15 | if batch_norm: 16 | layers.append(nn.BatchNorm1d(channels[i])) 17 | layers.append(nn.ReLU()) 18 | 19 | return nn.Sequential(*layers) 20 | 21 | 22 | class Attention(nn.Module): 23 | 24 | def __init__(self, n_heads: int, d_model: int): 25 | super().__init__() 26 | assert d_model % n_heads == 0 27 | self.dim = d_model // n_heads 28 | self.n_heads = n_heads 29 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) 30 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) 31 | 32 | def forward(self, query, key, value): 33 | b = query.size(0) 34 | query, key, value = [l(x).view(b, self.dim, self.n_heads, -1) 35 | for l, x in zip(self.proj, (query, key, value))] 36 | 37 | b, d, h, n = query.shape 38 | scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / d**.5 39 | attn = torch.einsum('bhnm,bdhm->bdhn', torch.nn.functional.softmax(scores, dim=-1), value) 40 | 41 | return self.merge(attn.contiguous().view(b, self.dim*self.n_heads, -1)) 42 | 43 | 44 | class AttentionalPropagation(nn.Module): 45 | 46 | def __init__(self, feature_dim: int, n_heads: int): 47 | super().__init__() 48 | self.attn = Attention(n_heads, feature_dim) 49 | self.mlp = MultiLayerPerceptron([feature_dim*2, feature_dim*2, feature_dim]) 50 | nn.init.constant_(self.mlp[-1].bias, 0.0) 51 | 52 | def forward(self, x): 53 | message = self.attn(x, x, x) 54 | return self.mlp(torch.cat([x, message], dim=1)) 55 | 56 | 57 | class AttentionalGNN(nn.Module): 58 | 59 | def __init__(self, feature_dim: int, num_layers: int): 60 | super().__init__() 61 | self.conv_init = nn.Sequential( 62 | nn.Conv1d(feature_dim + 2, feature_dim, kernel_size=1,stride=1,padding=0,bias=True), 63 | nn.BatchNorm1d(feature_dim), 64 | nn.ReLU(inplace=True), 65 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True), 66 | nn.BatchNorm1d(feature_dim), 67 | nn.ReLU(inplace=True), 68 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True) 69 | ) 70 | 71 | self.layers = nn.ModuleList([ 72 | AttentionalPropagation(feature_dim, 4) 73 | for _ in range(num_layers)]) 74 | 75 | self.conv_desc = nn.Sequential( 76 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True), 77 | nn.BatchNorm1d(feature_dim), 78 | nn.ReLU(inplace=True), 79 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True), 80 | nn.BatchNorm1d(feature_dim), 81 | nn.ReLU(inplace=True) 82 | ) 83 | 84 | self.conv_offset = nn.Sequential( 85 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True), 86 | nn.BatchNorm1d(feature_dim), 87 | nn.ReLU(inplace=True), 88 | nn.Conv1d(feature_dim, feature_dim, kernel_size=1,stride=1,padding=0,bias=True), 89 | nn.BatchNorm1d(feature_dim), 90 | nn.ReLU(inplace=True), 91 | nn.Conv1d(feature_dim, 2, kernel_size=1,stride=1,padding=0,bias=True), 92 | nn.Hardtanh() 93 | ) 94 | 95 | def forward(self, feat, graph): 96 | graph = graph.permute(0,2,1) 97 | feat = torch.cat((feat, graph), dim=1) 98 | feat = self.conv_init(feat) 99 | 100 | for layer in self.layers: 101 | feat = feat + layer(feat) 102 | 103 | desc = self.conv_desc(feat) 104 | offset = self.conv_offset(feat).permute(0,2,1) 105 | return desc, offset 106 | 107 | 108 | class ScoreNet(nn.Module): 109 | 110 | def __init__(self, in_ch): 111 | super().__init__() 112 | self.relu = nn.ReLU(inplace=True) 113 | self.conv1 = nn.Conv2d(in_ch, 256, kernel_size=1, stride=1, padding=0, bias=True) 114 | self.bn1 = nn.BatchNorm2d(256) 115 | self.conv2 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=True) 116 | self.bn2 = nn.BatchNorm2d(128) 117 | self.conv3 = nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=True) 118 | self.bn3 = nn.BatchNorm2d(64) 119 | self.conv4 = nn.Conv2d(64, 1, kernel_size=1, stride=1, padding=0, bias=True) 120 | 121 | def forward(self, x): 122 | n_points = x.shape[-1] 123 | 124 | x = x.unsqueeze(-1) 125 | x = x.repeat(1,1,1,n_points) 126 | t = torch.transpose(x, 2, 3) 127 | x = torch.cat((x, t), dim=1) 128 | 129 | x = self.conv1(x) 130 | x = self.bn1(x) 131 | x = self.relu(x) 132 | 133 | x = self.conv2(x) 134 | x = self.bn2(x) 135 | x = self.relu(x) 136 | 137 | x = self.conv3(x) 138 | x = self.bn3(x) 139 | x = self.relu(x) 140 | 141 | x = self.conv4(x) 142 | return x[:,0] 143 | 144 | 145 | class OptimalMatching(nn.Module): 146 | 147 | def __init__(self): 148 | super(OptimalMatching, self).__init__() 149 | 150 | # Default configuration settings 151 | self.descriptor_dim = 64 152 | self.sinkhorn_iterations = 100 153 | self.attention_layers = 4 154 | self.correction_radius = 0.05 155 | 156 | # Modules 157 | self.scorenet1 = ScoreNet(self.descriptor_dim * 2) 158 | self.scorenet2 = ScoreNet(self.descriptor_dim * 2) 159 | self.gnn = AttentionalGNN(self.descriptor_dim, self.attention_layers) 160 | 161 | 162 | def normalize_coordinates(self, graph, ws, input): 163 | if input == 'global': 164 | graph = (graph * 2 / ws - 1) 165 | elif input == 'normalized': 166 | graph = ((graph + 1) * ws / 2) 167 | graph = torch.round(graph).long() 168 | graph[graph < 0] = 0 169 | graph[graph >= ws] = ws - 1 170 | return graph 171 | 172 | 173 | def predict(self, image, descriptors, graph): 174 | B, _, H, W = image.shape 175 | B, N, _ = graph.shape 176 | 177 | for b in range(B): 178 | b_desc = descriptors[b] 179 | b_graph = graph[b] 180 | 181 | # Extract descriptors 182 | b_desc = b_desc[:, b_graph[:,0], b_graph[:,1]] 183 | 184 | # Concatenate descriptors in batches 185 | if b == 0: 186 | sel_desc = b_desc.unsqueeze(0) 187 | else: 188 | sel_desc = torch.cat((sel_desc, b_desc.unsqueeze(0)), dim=0) 189 | 190 | # Multi-layer Transformer network. 191 | norm_graph = self.normalize_coordinates(graph, W, input="global") #out: normalized coordinate system [-1, 1] 192 | sel_desc, offset = self.gnn(sel_desc, norm_graph) 193 | 194 | # Correct points coordinates 195 | norm_graph = norm_graph + offset * self.correction_radius 196 | graph = self.normalize_coordinates(norm_graph, W, input="normalized") # out: global coordinate system [0, W] 197 | 198 | # Compute scores 199 | scores_1 = self.scorenet1(sel_desc) 200 | scores_2 = self.scorenet2(sel_desc) 201 | scores = scores_1 + torch.transpose(scores_2, 1, 2) 202 | 203 | scores = scores_to_permutations(scores) 204 | poly = permutations_to_polygons(scores, graph, out='coco') 205 | 206 | return poly 207 | -------------------------------------------------------------------------------- /utils/coco_IoU_cIoU.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | from pycocotools import mask as cocomask 3 | import numpy as np 4 | import json 5 | from tqdm import tqdm 6 | 7 | def calc_IoU(a, b): 8 | i = np.logical_and(a, b) 9 | u = np.logical_or(a, b) 10 | I = np.sum(i) 11 | U = np.sum(u) 12 | 13 | iou = I/(U + 1e-9) 14 | 15 | is_void = U == 0 16 | if is_void: 17 | return 1.0 18 | else: 19 | return iou 20 | 21 | def compute_IoU_cIoU(input_json, gti_annotations): 22 | # Ground truth annotations 23 | coco_gti = COCO(gti_annotations) 24 | 25 | # Predictions annotations 26 | submission_file = json.loads(open(input_json).read()) 27 | coco = COCO(gti_annotations) 28 | coco = coco.loadRes(submission_file) 29 | 30 | 31 | image_ids = coco.getImgIds(catIds=coco.getCatIds()) 32 | bar = tqdm(image_ids) 33 | 34 | list_iou = [] 35 | list_ciou = [] 36 | for image_id in bar: 37 | 38 | img = coco.loadImgs(image_id)[0] 39 | 40 | annotation_ids = coco.getAnnIds(imgIds=img['id']) 41 | annotations = coco.loadAnns(annotation_ids) 42 | N = 0 43 | for _idx, annotation in enumerate(annotations): 44 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width']) 45 | m = cocomask.decode(rle) 46 | if _idx == 0: 47 | mask = m.reshape((img['height'], img['width'])) 48 | N = len(annotation['segmentation'][0]) // 2 49 | else: 50 | mask = mask + m.reshape((img['height'], img['width'])) 51 | N = N + len(annotation['segmentation'][0]) // 2 52 | 53 | mask = mask != 0 54 | 55 | 56 | annotation_ids = coco_gti.getAnnIds(imgIds=img['id']) 57 | annotations = coco_gti.loadAnns(annotation_ids) 58 | N_GT = 0 59 | for _idx, annotation in enumerate(annotations): 60 | rle = cocomask.frPyObjects(annotation['segmentation'], img['height'], img['width']) 61 | m = cocomask.decode(rle) 62 | if _idx == 0: 63 | mask_gti = m.reshape((img['height'], img['width'])) 64 | N_GT = len(annotation['segmentation'][0]) // 2 65 | else: 66 | mask_gti = mask_gti + m.reshape((img['height'], img['width'])) 67 | N_GT = N_GT + len(annotation['segmentation'][0]) // 2 68 | 69 | mask_gti = mask_gti != 0 70 | 71 | ps = 1 - np.abs(N - N_GT) / (N + N_GT + 1e-9) 72 | iou = calc_IoU(mask, mask_gti) 73 | list_iou.append(iou) 74 | list_ciou.append(iou * ps) 75 | 76 | bar.set_description("iou: %2.4f, c-iou: %2.4f" % (np.mean(list_iou), np.mean(list_ciou))) 77 | bar.refresh() 78 | 79 | print("Done!") 80 | print("Mean IoU: ", np.mean(list_iou)) 81 | print("Mean C-IoU: ", np.mean(list_ciou)) 82 | 83 | 84 | 85 | if __name__ == "__main__": 86 | compute_IoU_cIoU(input_json="./predictions.json", 87 | gti_annotations="/home/stefano/Workspace/data/mapping_challenge_dataset/raw/val/annotation.json") 88 | -------------------------------------------------------------------------------- /utils/coco_to_shp.py: -------------------------------------------------------------------------------- 1 | from pycocotools.coco import COCO 2 | import numpy as np 3 | import json 4 | from tqdm import tqdm 5 | import shapefile 6 | 7 | 8 | def cocojson_to_shapefiles(input_json, gti_annotations, output_folder): 9 | 10 | submission_file = json.loads(open(input_json).read()) 11 | coco = COCO(gti_annotations) 12 | coco = coco.loadRes(submission_file) 13 | 14 | image_ids = coco.getImgIds(catIds=coco.getCatIds()) 15 | 16 | for image_id in tqdm(image_ids): 17 | 18 | img = coco.loadImgs(image_id)[0] 19 | 20 | annotation_ids = coco.getAnnIds(imgIds=img['id']) 21 | annotations = coco.loadAnns(annotation_ids) 22 | 23 | list_poly = [] 24 | for _idx, annotation in enumerate(annotations): 25 | poly = annotation['segmentation'][0] 26 | poly = np.array(poly) 27 | poly = poly.reshape((-1,2)) 28 | 29 | poly[:,1] = -poly[:,1] 30 | list_poly.append(poly.tolist()) 31 | 32 | number_str = str(image_id).zfill(12) 33 | w = shapefile.Writer(output_folder + '%s.shp' % number_str) 34 | w.field('name', 'C') 35 | w.poly(list_poly) 36 | w.record("polygon") 37 | w.close() 38 | 39 | print("Done!") 40 | 41 | 42 | if __name__ == "__main__": 43 | cocojson_to_shapefiles(input_json="./predictions.json", 44 | gti_annotations="data/val/annotation.json", 45 | output_folder="./shapefiles/") 46 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from skimage import io 4 | from skimage.transform import resize 5 | import torch 6 | from torch.utils.data import Dataset 7 | import json 8 | import pandas as pd 9 | import cv2 10 | 11 | class CrowdAI(Dataset): 12 | 13 | 14 | def __init__(self, images_directory, annotations_path, window_size=320): 15 | 16 | self.IMAGES_DIRECTORY = images_directory 17 | self.ANNOTATIONS_PATH = annotations_path 18 | 19 | self.window_size = window_size 20 | self.max_points = 256 21 | 22 | # load annotation json 23 | with open(self.ANNOTATIONS_PATH) as f: 24 | self.annotations = json.load(f) 25 | 26 | self.images = pd.DataFrame(self.annotations['images']) 27 | self.labels = pd.DataFrame(self.annotations['annotations']) 28 | 29 | 30 | # Generate the Permutation 31 | # torch.manual_seed(0) 32 | # self.permutation = torch.randperm(self.max_points) 33 | 34 | 35 | 36 | 37 | 38 | def _shuffle_adjacency_matrix(self, A): 39 | """ 40 | generates a new permutation for each sample (or batch) 41 | and shuffles the adjacency matrix A accordingly 42 | """ 43 | 44 | n = A.shape[0] 45 | torch.manual_seed(0) 46 | permutation = torch.randperm(n) 47 | 48 | shuffled_A = A[permutation, :][:, permutation] 49 | 50 | return shuffled_A, permutation 51 | 52 | 53 | 54 | ''' 55 | def _shuffle_adjacency_matrix(self, A): 56 | """ 57 | generates a new permutation for each sample (or batch) 58 | and shuffles the adjacency matrix A accordingly 59 | """ 60 | n = A.shape[0] 61 | shuffled_A = A[self.permutation, :][:, self.permutation] 62 | return shuffled_A, self.permutation[:n] 63 | ''' 64 | 65 | 66 | def _shuffle_vector(self, v, permutation): 67 | return v[permutation] 68 | 69 | 70 | def _create_adjacency_matrix(self, segmentations, N=256): 71 | 72 | adjacency_matrix = torch.zeros((N, N), dtype=torch.uint8) 73 | dic = {} 74 | 75 | n, m = 0, 0 76 | graph = torch.zeros((N, 2), dtype=torch.float32) # to create the seg mask with permutation_to_polygon() 77 | vertices = torch.zeros((N, 2), dtype=torch.float32) # to sort the nms points in the training 78 | for i, polygon in enumerate(segmentations): 79 | for v, point in enumerate(polygon): 80 | if v != len(polygon) - 1: 81 | adjacency_matrix[n, n+1] = 1 82 | else: 83 | adjacency_matrix[n, n-v] = 1 84 | 85 | 86 | # We just use it to create the segmentation mask with permutation_to_polygon(), no other use 87 | graph[n] = torch.tensor(point) 88 | n += 1 # n must be incremented in this way due to the functionality of the permutation_to_polygon() function, so we use m for the dictionary 89 | 90 | if tuple(map(float, point)) not in dic: 91 | vertices[m] = torch.tensor(point) 92 | dic[tuple(map(float, point))] = m 93 | m += 1 94 | 95 | 96 | 97 | # Permute the adjacency matrix 98 | # adjacency_matrix[:n,:n], permutation = self._shuffle_adjacency_matrix(adjacency_matrix[:n, :n]) 99 | 100 | # Fill the diagonal with 1s 101 | for i in range(n, N): 102 | adjacency_matrix[i, i] = 1 103 | 104 | # Permute the graph 105 | # graph[:n] = self._shuffle_vector(graph, permutation) 106 | # graph = graph[:n] 107 | 108 | return adjacency_matrix, graph, vertices[:m], dic 109 | 110 | 111 | 112 | 113 | 114 | def _create_segmentation_mask(self, polygons, image_size): 115 | mask = np.zeros((image_size, image_size), dtype=np.uint8) 116 | for polygon in polygons: 117 | cv2.fillPoly(mask, [polygon], 1) 118 | return torch.tensor(mask, dtype=torch.uint8) 119 | 120 | 121 | 122 | def _create_vertex_mask(self, polygons, image_shape=(320, 320)): 123 | mask = torch.zeros(image_shape, dtype=torch.uint8) 124 | 125 | for poly in polygons: 126 | for p in poly: 127 | mask[p[1], p[0]] = 1 128 | 129 | return mask 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | def __len__(self): 139 | return len(self.images) 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | def __getitem__(self, idx): 148 | 149 | image = io.imread(self.IMAGES_DIRECTORY + self.images['file_name'][idx]) 150 | image = resize(image, (self.window_size, self.window_size), anti_aliasing=True) 151 | image = torch.from_numpy(image) 152 | width, height = self.images['width'][idx], self.images['height'][idx] 153 | ratio = self.window_size / max(width, height) 154 | 155 | 156 | 157 | # Get the image ID 158 | image_id = self.images['id'][idx] 159 | # Get all annotations for this image 160 | image_annotations = self.labels[self.labels['image_id'] == image_id] 161 | # get all polygons for the image 162 | segmentations = image_annotations['segmentation'].values 163 | segmentations = [e[0] for e in segmentations] 164 | for i, poly in enumerate(segmentations): 165 | # rescale the polygon 166 | poly = [int(e * ratio) for e in poly] 167 | # out of bounds check 168 | for j, e in enumerate(poly): 169 | if j % 2 == 0: 170 | poly[j] = min(max(0, e), self.window_size - 1) 171 | else: 172 | poly[j] = min(max(0, e), self.window_size - 1) 173 | segmentations[i] = poly 174 | 175 | 176 | 177 | 178 | # print(segmentations) 179 | segmentations = [np.array(poly, dtype=int).reshape(-1, 2) for poly in segmentations] # convert a list of polygons to a list of numpy arrays of points 180 | # print(segmentations) 181 | 182 | 183 | 184 | 185 | # create permutation matrix 186 | # permutation_matrix, graph, permutation = self._create_adjacency_matrix(segmentations, N=self.max_points) 187 | # create the simple adjacency matrix 188 | permutation_matrix, graph, vertices, dic = self._create_adjacency_matrix(segmentations, N=self.max_points) 189 | # create vertex mask 190 | vertex_mask = self._create_vertex_mask(segmentations, image_shape=(self.window_size, self.window_size)) 191 | # create segmentation mask 192 | seg_mask = self._create_segmentation_mask(segmentations, image_size=self.window_size) 193 | 194 | 195 | segmentations = [torch.from_numpy(poly) for poly in segmentations] 196 | 197 | 198 | 199 | 200 | 201 | return image, vertex_mask, seg_mask, permutation_matrix, segmentations, graph, vertices, dic -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn.functional as F 3 | import torch 4 | from utils.utils import angle_between_points, prepare_gt_vertices 5 | 6 | 7 | 8 | 9 | def compute_l_angle_loss(gt_permutation_mask, 10 | vertices, 11 | pred_permutation_mat, 12 | graph, device='cuda'): 13 | v_gt = prepare_gt_vertices(vertices, device=device) 14 | 15 | v_gt_1 = torch.matmul(gt_permutation_mask, v_gt) 16 | v_gt_2 = torch.matmul(gt_permutation_mask, v_gt_1) 17 | gt_angle = angle_between_points(v_gt, v_gt_1, v_gt_2) 18 | #torch.isnan(gt_angle).any() 19 | 20 | pred_permutation_mat = pred_permutation_mat.to(device) 21 | v_pred_1 = torch.matmul(pred_permutation_mat, graph) 22 | v_pred_2 = torch.matmul(pred_permutation_mat, v_pred_1) 23 | pred_angle = angle_between_points(graph, v_pred_1.float(), v_pred_2.float()) 24 | #torch.isnan(pred_angle).any() 25 | 26 | return pred_angle, gt_angle 27 | 28 | 29 | 30 | def cross_entropy_loss(sinkhorn_results, gt_permutation): 31 | ''' 32 | It only considers the positive matches 33 | and tries to minimize the value of positive matches 34 | 35 | 36 | One considereation is the distribution of the vertices in the GT permutation matrix 37 | to overcome the overfitting and help the model to learn the 38 | correct permutation matrix 39 | 40 | 41 | One solution is: 42 | 1- order the polygons, based on the number of vertices 43 | 2- select first vertex of each polygon and assign it a unique ID 44 | 3- repeat 2 45 | 46 | 47 | Another Solution: 48 | - Using a random permutation matrix/vector 49 | 50 | ''' 51 | loss_match = -torch.mean(torch.masked_select(sinkhorn_results, gt_permutation == 1)) 52 | return loss_match 53 | 54 | 55 | 56 | 57 | 58 | # def iou_loss_function(pred, gt): 59 | # B, H, W = gt.shape 60 | # iou = 0 61 | # for batch in range(B): 62 | # K = len(pred[batch]) # Number of polygons 63 | # batch_tensor = np.zeros((K, H, W), dtype=np.uint8) 64 | # for i, poly in enumerate(pred[batch]): 65 | # cv2.fillPoly(batch_tensor[i], [poly.detach().cpu().numpy()], 1) 66 | 67 | # batch_pred_mask = torch.sum(torch.tensor(batch_tensor), dim=0).permute(1,0) 68 | 69 | # # plt.imshow(batch_pred_mask) 70 | # # plt.show() 71 | # # plt.imshow(gt[batch]) 72 | # # plt.show() 73 | 74 | # intersection = torch.min(batch_pred_mask, gt[batch]) 75 | # union = torch.max(batch_pred_mask, gt[batch]) 76 | # batch_iou = torch.sum(intersection) / torch.sum(union) 77 | # iou += batch_iou 78 | 79 | # return torch.tensor(1 - iou, requires_grad=True) 80 | 81 | 82 | 83 | # def iou_loss_function(pred, gt): 84 | # B, H, W = gt.shape 85 | # total_iou = 0 86 | 87 | # for batch in range(B): 88 | # batch_pred = torch.zeros((H, W), device=gt.device) 89 | # for poly in pred[batch]: 90 | # # Convert polygon to mask 91 | # mask = torch.zeros((H, W), device=gt.device) 92 | # poly_tensor = poly.long() # Ensure integer coordinates 93 | # mask[poly_tensor[:, 1], poly_tensor[:, 0]] = 1 94 | # mask = F.max_pool2d(mask.unsqueeze(0).float(), kernel_size=3, stride=1, padding=1).squeeze(0) 95 | # batch_pred = torch.max(batch_pred, mask) 96 | 97 | # plt.imshow(mask) 98 | # plt.show() 99 | # plt.imshow(gt[batch]) 100 | # plt.show() 101 | 102 | # intersection = torch.sum(torch.min(batch_pred, gt[batch])) 103 | # union = torch.sum(torch.max(batch_pred, gt[batch])) 104 | # batch_iou = intersection / (union + 1e-6) # Add small epsilon to avoid division by zero 105 | # total_iou += batch_iou 106 | 107 | # avg_iou = total_iou / B 108 | # return 1 - avg_iou 109 | 110 | 111 | 112 | # def iou_loss_function(pred_mask, gt_mask): 113 | # ''' 114 | # pred_mask: (B, H, W) 115 | # gt_mask: (B, H, W) 116 | # ''' 117 | # pred_mask = F.sigmoid(pred_mask) 118 | # intersection = torch.sum(torch.min(pred_mask, gt_mask)) 119 | # union = torch.sum(torch.max(pred_mask, gt_mask)) 120 | # iou = intersection / (union + 1e-6) # Add small epsilon to avoid division by zero 121 | # loss = 1 - iou 122 | # # loss.requires_grad = True 123 | 124 | # return loss # Return 1 - IoU to minimize 125 | 126 | 127 | 128 | def iou_loss_function(pred_mask, target_maks): 129 | pred_mask = F.sigmoid(pred_mask) 130 | 131 | intersection = (pred_mask * target_maks).sum() 132 | union = ((pred_mask + target_maks) - (pred_mask * target_maks)).sum() 133 | iou = intersection / union 134 | iou_dual = pred_mask.size(0) - iou 135 | 136 | #iou_dual = iou_dual / pred_mask.size(0) 137 | iou_dual.requires_grad = True 138 | return torch.mean(iou_dual) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.optimize import linear_sum_assignment 4 | import torch.nn.functional as F 5 | import math 6 | 7 | def scores_to_permutations(scores): 8 | """ 9 | Input a batched array of scores and returns the hungarian optimized 10 | permutation matrices. 11 | """ 12 | B, N, N = scores.shape 13 | 14 | scores = scores.detach().cpu().numpy() 15 | perm = np.zeros_like(scores) 16 | for b in range(B): 17 | r, c = linear_sum_assignment(-scores[b]) 18 | perm[b,r,c] = 1 19 | return torch.tensor(perm) 20 | 21 | 22 | 23 | 24 | def permutations_to_polygons(perm, graph, out='torch'): 25 | B, N, N = perm.shape 26 | 27 | def bubble_merge(poly): 28 | s = 0 29 | P = len(poly) 30 | while s < P: 31 | head = poly[s][-1] 32 | 33 | t = s+1 34 | while t < P: 35 | tail = poly[t][0] 36 | if head == tail: 37 | poly[s] = poly[s] + poly[t][1:] 38 | del poly[t] 39 | poly = bubble_merge(poly) 40 | P = len(poly) 41 | t += 1 42 | s += 1 43 | return poly 44 | 45 | diag = torch.logical_not(perm[:,range(N),range(N)]) 46 | batch = [] 47 | for b in range(B): 48 | b_perm = perm[b] 49 | b_graph = graph[b] 50 | b_diag = diag[b] 51 | 52 | idx = torch.arange(N)[b_diag] 53 | 54 | if idx.shape[0] > 0: 55 | # If there are vertices in the batch 56 | 57 | b_perm = b_perm[idx,:] 58 | b_graph = b_graph[idx,:] 59 | b_perm = b_perm[:,idx] 60 | 61 | first = torch.arange(idx.shape[0]).unsqueeze(1) 62 | second = torch.argmax(b_perm, dim=1).unsqueeze(1).cpu() 63 | 64 | polygons_idx = torch.cat((first, second), dim=1).tolist() 65 | polygons_idx = bubble_merge(polygons_idx) 66 | 67 | batch_poly = [] 68 | for p_idx in polygons_idx: 69 | if out == 'torch': 70 | batch_poly.append(b_graph[p_idx,:]) 71 | elif out == 'numpy': 72 | batch_poly.append(b_graph[p_idx,:].numpy()) 73 | elif out == 'list': 74 | g = b_graph[p_idx,:] * 300 / 320 75 | g[:,0] = -g[:,0] 76 | g = torch.fliplr(g) 77 | batch_poly.append(g.tolist()) 78 | elif out == 'coco': 79 | g = b_graph[p_idx,:] * 300 / 320 80 | g = torch.fliplr(g) 81 | batch_poly.append(g.view(-1).tolist()) 82 | else: 83 | print("Indicate a valid output polygon format") 84 | exit() 85 | 86 | batch.append(batch_poly) 87 | 88 | else: 89 | # If the batch has no vertices 90 | batch.append([]) 91 | 92 | return batch 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | def graph_to_vertex_mask(points, image): 102 | B, _, H, W = image.shape 103 | 104 | mask = torch.zeros((B, H, W), dtype=torch.uint8) 105 | 106 | # Loop style 107 | # for batch in range(B): 108 | # mask[batch, points[batch, :, 0], points[batch, :, 1]] = 1 109 | 110 | # Vectorized Style 111 | batch_indices = np.arange(B)[:, None] 112 | mask[batch_indices, points[:, :, 0], points[:, :, 1]] = 1 113 | 114 | return mask 115 | 116 | 117 | 118 | 119 | 120 | def polygon_to_vertex_mask(polygons: list): 121 | B = len(polygons) 122 | mask = torch.zeros((B, 320, 320), dtype=torch.uint8) 123 | 124 | for batch in range(B): 125 | batch_polygons = [np.array(poly, dtype=int) for poly in polygons[batch]] 126 | for poly in batch_polygons: 127 | for point in poly: 128 | mask[batch, point[1], point[0]] = 1 129 | 130 | 131 | return mask 132 | 133 | 134 | 135 | 136 | 137 | def tensor_to_numpy(input: list): 138 | ''' convert a list of tensors to a list of numpy arrays ''' 139 | numpy = [] 140 | for batch in range(len(input)): 141 | batch_polygons = [tensor.cpu().numpy() for tensor in input[batch]] 142 | numpy.append(batch_polygons) 143 | return numpy 144 | 145 | 146 | 147 | 148 | 149 | def point_to_polygon(points: list): 150 | 151 | B = len(points) 152 | polygons = [] 153 | 154 | for batch in range(B): 155 | batch_polygons = [np.array(poly, dtype=int).reshape(-1, 2) for poly in points[batch]] 156 | polygons.append(batch_polygons) 157 | 158 | return polygons 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | def polygon_to_seg_mask(polygons, image_size): 169 | B = len(polygons) 170 | mask = np.zeros((B, image_size, image_size), dtype=np.uint8) 171 | for batch in range(B): 172 | for polygon in polygons[batch]: 173 | # for polygon in polygons[0]: 174 | cv2.fillPoly(mask[batch], [polygon.detach().cpu().numpy()], 1) 175 | return torch.tensor(mask, dtype=torch.float32, device=device, requires_grad=True) 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | def sort_sync_nsm_points(nms_graph, vertices, gt_index): 191 | ''' 192 | nms graph: (B, N, 2) 193 | Vertices: numpy array of points (B, N, 2) --> N: number of unique points 194 | GT_Index: A Dictinary with points as its keys and their index as values (It's the same as vertices but with index values) (B, N) 195 | ''' 196 | B, N, D = nms_graph.shape # 1, 256, 2 197 | sorted_nsm_points = np.zeros((B, N, D), dtype=int) 198 | nms_graph = nms_graph.detach().cpu().numpy() 199 | 200 | for b in range(B): 201 | sorted_nsm = np.zeros((N, D), dtype=int) 202 | 203 | n = vertices[b].shape[0] # Number of Viertices (Points) 204 | m = nms_graph[b].shape[0] # Number of Predicted Vertices (Points) 205 | # print(vertices[b].shape, nms_graph[b].shape) 206 | distances = np.linalg.norm(vertices[b][:, None] - nms_graph[b], axis=2) # Adds a new axis to vertices[b], changing its shape from (M, D) to (M, 1, D). 207 | distances = distances.reshape(n*m, 1) 208 | # print(distances.shape) 209 | distances = np.hstack((distances, np.repeat(vertices[b], m, axis=0),np.tile(nms_graph[b], (n, 1)))) 210 | sorted_distance = distances[np.argsort(distances[:,0])] 211 | 212 | # Sort distances by the first column (distance) 213 | sorted_distances = distances[np.argsort(distances[:,0])] 214 | 215 | cndd_used = set() 216 | gt_used = set() 217 | cndd_mapped = {tuple(cndd):0 for cndd in nms_graph[b]} 218 | 219 | for d_p in sorted_distances: 220 | gt_p = tuple((d_p[1], d_p[2])) 221 | cndd_p = tuple((d_p[3], d_p[4])) 222 | if gt_p not in gt_used and cndd_p not in cndd_used: 223 | #print('we have a match ..', gt_p ,'->', cndd_p, ' with distance of ', d_p[0]) 224 | # print(gt_p) 225 | sorted_nsm[gt_index[b][gt_p]]= list(cndd_p) 226 | gt_used.add(gt_p) 227 | cndd_used.add(cndd_p) 228 | cndd_mapped[cndd_p] = 1 229 | 230 | restart_index = n 231 | for k, v in cndd_mapped.items(): 232 | if v ==0: 233 | sorted_nsm[restart_index] = list(k) 234 | restart_index +=1 235 | sorted_nsm_points[b] = sorted_nsm 236 | return torch.from_numpy(sorted_nsm_points) 237 | 238 | 239 | 240 | 241 | 242 | def prepare_gt_vertices(vertices, device='cuda', MAX_POINTS=256): 243 | B = len(vertices) 244 | v_gt = torch.empty((B, MAX_POINTS, 2), dtype=torch.float64) 245 | for b in range(B): 246 | gt_size = vertices[b].shape[0] 247 | extra = torch.full((MAX_POINTS - gt_size, 2), 0, dtype=torch.float64) 248 | extra_gt = torch.cat((vertices[b], extra), dim=0).to(device) 249 | v_gt[b] = extra_gt 250 | return v_gt.to(device) 251 | 252 | 253 | 254 | 255 | def angle_between_points(A, B, C, batch=False): 256 | d = 2 if batch else 1 257 | AB = A - B 258 | BC = C - B 259 | epsilon = 3e-8 260 | 261 | AB_mag = torch.norm(AB, dim=d) + epsilon 262 | BC_mag = torch.norm(BC, dim=d) + epsilon 263 | 264 | dot_product = torch.sum(AB * BC, dim=d) 265 | cos_theta = dot_product / (AB_mag * BC_mag) 266 | 267 | zero_mask = (AB_mag == 0) | (BC_mag == 0) 268 | cos_theta[zero_mask] = 0 269 | theta = torch.acos(torch.clamp(cos_theta, -1 + epsilon, 1 - epsilon)) 270 | theta[zero_mask] = 0 271 | return theta * 180 / math.pi 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | def soft_winding_number(pred_polys, lam=1000, img_size=320, device='cuda'): 285 | 286 | B = len(pred_polys) 287 | IMG_SIZE = img_size 288 | pred_mask = torch.zeros((B, IMG_SIZE, IMG_SIZE)).to(device) 289 | 290 | x = torch.arange(IMG_SIZE) 291 | y = torch.arange(IMG_SIZE) 292 | xx, yy = torch.meshgrid(x,y) 293 | 294 | pixel_coords = torch.stack([yy, xx], dim=-1).float() 295 | 296 | for b in range(B): 297 | vertices = torch.vstack(pred_polys[b]).float() 298 | #vertices = vertices.detach().cpu() 299 | #vertices.requires_grad=True 300 | #vertices = vertices.unfold(dimension = 0,size = 2, step = 1) 301 | #vertices_repeated = vertices.repeat_interleave(IMG_SIZE*IMG_SIZE, dim=0) 302 | 303 | pairs = vertices[:-1].unsqueeze(1).repeat(1, 2, 1) 304 | pairs[:, 1, :] = vertices[1:] 305 | 306 | pairs_repeated = pairs.repeat_interleave(IMG_SIZE*IMG_SIZE, dim=0) 307 | 308 | #pixel_coords_angle = pixel_coords.repeat(vertices.shape[0],1,1).view(vertices.shape[0] *IMG_SIZE*IMG_SIZE,2) 309 | #pixel_coords_det = pixel_coords.repeat(vertices.shape[0],1,1).view(vertices.shape[0] *IMG_SIZE*IMG_SIZE ,2,1) 310 | pixel_coords_angle = pixel_coords.repeat(pairs.shape[0],1,1).view(pairs.shape[0] *IMG_SIZE*IMG_SIZE ,1 , 2).to(device) 311 | 312 | concatenated = torch.cat([pairs_repeated, pixel_coords_angle], dim=1) 313 | 314 | #ones = torch.ones(IMG_SIZE*IMG_SIZE*vertices.shape[0], 3).reshape(IMG_SIZE*IMG_SIZE*vertices.shape[0],1, 3) 315 | 316 | ones = torch.ones(pairs.shape[0] *IMG_SIZE*IMG_SIZE, 3, 1).to(device) #.reshape(vertices.shape[0]-1 *IMG_SIZE*IMG_SIZE,3, 1) 317 | output = torch.cat((concatenated, ones), dim=2) 318 | 319 | det = torch.det(output) 320 | 321 | # compute angle 322 | angles = angle_between_points(pairs_repeated[:, 0], pixel_coords_angle.view(pairs.shape[0] *IMG_SIZE*IMG_SIZE, 2), pairs_repeated[:, 1], batch=False) 323 | 324 | #Compute the soft winding number using equation 13 325 | w = (lam * det) / (1 + torch.abs(det *lam)) 326 | w = w * angles 327 | 328 | w = w.view(pairs.shape[0], IMG_SIZE, IMG_SIZE) 329 | # Sum over all pairs of adjacent vertices to get the winding number 330 | w = w.sum(dim=0) 331 | 332 | pred_mask[b] = w 333 | 334 | return pred_mask 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | # def sinkhorn_knopp(cost_matrix, epsilon=0.05, iterations=100): 354 | # """ 355 | # Sinkhorn-Knopp algorithm to approximate optimal transport 356 | # """ 357 | # B, N, _ = cost_matrix.shape 358 | # log_mu = torch.zeros_like(cost_matrix) 359 | # log_nu = torch.zeros_like(cost_matrix) 360 | 361 | # u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 362 | # for _ in range(iterations): 363 | # u = epsilon * (torch.log(torch.ones(B, N, 1, device=cost_matrix.device)) - 364 | # torch.logsumexp((-cost_matrix + v) / epsilon, dim=-1, keepdim=True)) + log_mu 365 | # v = epsilon * (torch.log(torch.ones(B, 1, N, device=cost_matrix.device)) - 366 | # torch.logsumexp((-cost_matrix + u) / epsilon, dim=-2, keepdim=True)) + log_nu 367 | 368 | # return torch.exp((-cost_matrix + u + v) / epsilon) 369 | 370 | # def scores_to_permutations(scores, temperature=0.1): 371 | # """ 372 | # Input a batched array of scores and returns the approximate 373 | # permutation matrices using Sinkhorn-Knopp algorithm. 374 | # Preserves gradients for backpropagation. 375 | # """ 376 | # B, N, _ = scores.shape 377 | 378 | # # Normalize scores to be non-negative 379 | # scores_normalized = scores - scores.min(dim=-1, keepdim=True)[0] 380 | 381 | # # Use Sinkhorn-Knopp to approximate permutation matrices 382 | # perm_soft = sinkhorn_knopp(-scores_normalized) 383 | 384 | # # Use a differentiable approximation of argmax 385 | # perm_hard = F.gumbel_softmax(torch.log(perm_soft + 1e-8), tau=temperature, hard=True, dim=-1) 386 | 387 | # return perm_hard --------------------------------------------------------------------------------