├── LICENSE ├── README.md ├── img ├── BE_module.png ├── Boundary_Enhancement_Semantic_Segmentation.pdf ├── HED_unit.png ├── Tail_part.png └── readme.md ├── nets ├── __init__.py ├── _torch_losses.py ├── assembly_block.py ├── callbacks.py ├── datagen.py ├── infer.py ├── losses.py ├── model_io.py ├── optimizers.py ├── torch_callbacks.py ├── train.py ├── transform.py └── zoo │ ├── __init__.py │ ├── brrnet.py │ ├── brrnet_BE.py │ ├── denet.py │ ├── enru.py │ ├── enru_BE.py │ ├── hrnet.yml │ ├── hrnet_config.py │ ├── hrnetv2.py │ ├── hrnetv2_BE.py │ ├── resunet.py │ ├── resunet_BE.py │ ├── ternaus.py │ ├── ternaus_BE.py │ ├── unet.py │ ├── unet_BE.py │ ├── uspp.py │ └── uspp_BE.py ├── notebooks ├── __init__.py ├── data_prep.ipynb └── get_mask_eval.ipynb ├── src ├── __init__.py ├── inference.py └── train.py ├── utils ├── __init__.py ├── config.py ├── core.py ├── data.py ├── io.py └── log.py └── yml ├── infer.yml └── train.yml /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Boundary Enhancement Semantic Segmentation for Building Extraction from Remote Sensed Image 2 | ## Introduction 3 | This repository includes implementations for binary semantic segmentation, especially for building extraction in satellite images.[Link](https://ieeexplore.ieee.org/document/9527893) [pdf](./img/Boundary_Enhancement_Semantic_Segmentation.pdf) 4 | Furthermore, the boundary enhanced methods (BE module) are also contained in ```/net/zoo/```. 5 | 6 | ![HED_unit](./img/HED_unit.png) 7 | ![BE_module](./img/BE_module.png) 8 | ![Tail_part](./img/Tail_part.png) 9 | 10 | ## Requirements 11 | 12 | ``` 13 | Python >= 3.7.0 14 | 15 | Pytorch > =1.9.0 16 | 17 | skimage >= 0.18.2 18 | 19 | cuda >= 10.1 20 | ``` 21 | 22 | ## Data prep 23 | 24 | ### Urban3D Dataset example 25 | 26 | - The experiments were conducted with cropped images as 512 X 512 size, and splitted 2,912 and 672 for training and test subset, respectively. 27 | - The original dataset can be downloaded [Urban3D](https://github.com/topcoderinc/Urban3d). 28 | 29 | - The data should be arranged like this. 30 | 31 | ``` 32 | |-- Test 33 | | `-- Urban3D_Test 34 | | |-- RGB 35 | | `-- masks 36 | `-- Train 37 | `-- Urban3D_Train 38 | |-- RGB 39 | `-- masks 40 | ``` 41 | 42 | - Open ```/notebook/data_prep.py``` and make dataframes for train and test set. 43 | ```Urban3D_Train_df.csv``` and ```Urban3D_Test_df.csv``` would be made in ```/csv/```. 44 | 45 | ## Train 46 | 47 | - Check and set hyperparameters in ```/yml/train.yml```. 48 | - Choose model refrered to ```/net/zoo/__init__py/```. 49 | - Choose area of interest. ```6``` is default for Urban3D dataset. 50 | - Set ```num_stage``` as following the number of backbone architecture's stage. 51 | - Set training hyperparameters ; epochs, optimizer, lr, loss functions. 52 | - If you want to train *Boundary Enhancement* model, set ```boundary``` as ```True```. 53 | - Run ```/src/train.py```. 54 | - ```result``` directory and ```/result/models_weight``` directory would be created automatically. 55 | - Model weights will be saved in ```/result/models_weight/{DATASET_NAME}_{MODEL_NAME}_{TRAINING_ID}```. ```TRAINING_ID``` is an UNIX time when the training was started. 56 | 57 | ## Inference 58 | 59 | - Check and setup parameters in ```/yml/infer.yml/```. 60 | 61 | - ```model_name``` and ```aoi``` should be same with those in ```train.yml```. 62 | 63 | - If you want to train *Boundary Enhancement* model, set ```boundary``` as ```True```. 64 | 65 | - Set training_date same as ```TRAINING_ID```. 66 | 67 | 68 | 69 | - Run ```/src/infer.py```. 70 | 71 | - Inferred images will be saved in ```/result/infer/```. 72 | 73 | ## Evaluation 74 | 75 | - Open ```/notebook/get_mask_eval.ipynb```. 76 | - Check ```aois``` and ```training date```. ``training date`` is ```TRAINING_ID``` in training procedure. 77 | - Running all cells will create mask image from inferred image. 78 | - Evaluation result will show up comparing ground truth and predicted mask. The result will be saved in ```/result/eval_result/```. 79 | 80 | 81 | 82 | ## Implemented model and dataset 83 | 84 | ### Model 85 | 86 | - U-Net 87 | - ResUNet++ 88 | - TernausNet 89 | - BRR-Net 90 | - USPP 91 | - DE-Net 92 | 93 | ### Dataset 94 | 95 | - DeepGlobe Dataset(Vegas, Paris, Shanghai, Khartoum) 96 | - Urban3D Dataset 97 | - WHU Dataset(aerial and satellite) 98 | - Massachusetts Dataset 99 | 100 | ## File tree 101 | 102 | ``` 103 | |-- data 104 | | |-- Test 105 | | `-- Train 106 | |-- nets 107 | | |-- __init__.py 108 | | |-- _torch_losses.py 109 | | |-- assembly_block.py 110 | | |-- callbacks.py 111 | | |-- datagen.py 112 | | |-- infer.py 113 | | |-- losses.py 114 | | |-- model_io.py 115 | | |-- optimizers.py 116 | | |-- torch_callbacks.py 117 | | |-- train.py 118 | | |-- transform.py 119 | | |-- weights 120 | | `-- zoo 121 | |-- notebooks 122 | | |-- __init__.py 123 | | |-- data_prep.ipynb 124 | | `-- get_mask_eval.ipynb 125 | |-- result 126 | | |-- infer 127 | | |-- infer_masks 128 | | `-- models_weight 129 | |-- src 130 | | |-- __init__.py 131 | | |-- inference.py 132 | | `-- train.py 133 | |-- utils 134 | | |-- __init__.py 135 | | |-- config.py 136 | | |-- core.py 137 | | |-- data.py 138 | | |-- io.py 139 | | `-- log.py 140 | `-- yml 141 | |-- infer.yml 142 | `-- train.yml 143 | ``` 144 | 145 | ## Contribution 146 | 147 | This codes are modified and simplified version of [Solaris](https://github.com/CosmiQ/solaris) for my own research. 148 | 149 | -------------------------------------------------------------------------------- /img/BE_module.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/d4ed2c0f95bbee435abc97389df1357393b7e570/img/BE_module.png -------------------------------------------------------------------------------- /img/Boundary_Enhancement_Semantic_Segmentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/d4ed2c0f95bbee435abc97389df1357393b7e570/img/Boundary_Enhancement_Semantic_Segmentation.pdf -------------------------------------------------------------------------------- /img/HED_unit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/d4ed2c0f95bbee435abc97389df1357393b7e570/img/HED_unit.png -------------------------------------------------------------------------------- /img/Tail_part.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/d4ed2c0f95bbee435abc97389df1357393b7e570/img/Tail_part.png -------------------------------------------------------------------------------- /img/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | weights_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 4 | 'weights') 5 | 6 | from . import callbacks, datagen, infer, losses, model_io 7 | from . import optimizers , losses, model_io, train 8 | 9 | if not os.path.isdir(weights_dir): 10 | os.mkdir(weights_dir) 11 | -------------------------------------------------------------------------------- /nets/_torch_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch import nn 6 | from math import exp 7 | try: 8 | from itertools import ifilterfalse 9 | except ImportError: # py3k 10 | from itertools import filterfalse as ifilterfalse 11 | 12 | epsilon_ = 1e-15 13 | 14 | class TorchDiceLoss(nn.Module): 15 | def __init__(self, weight=None, size_average=True, 16 | per_image=False, logits=False): 17 | super().__init__() 18 | self.size_average = size_average 19 | self.register_buffer('weight', weight) 20 | self.per_image = per_image 21 | self.logits = logits 22 | 23 | def forward(self, input, target): 24 | if self.logits: 25 | input = torch.sigmoid(input) 26 | return soft_dice_loss(input, target, per_image=self.per_image) 27 | 28 | 29 | class TorchFocalLoss(nn.Module): 30 | """Implementation of Focal Loss[1]_ modified from Catalyst [2]_ . 31 | 32 | Arguments 33 | --------- 34 | gamma : :class:`int` or :class:`float` 35 | Focusing parameter. See [1]_ . 36 | alpha : :class:`int` or :class:`float` 37 | Normalization factor. See [1]_ . 38 | 39 | References 40 | ---------- 41 | .. [1] https://arxiv.org/pdf/1708.02002.pdf 42 | .. [2] https://catalyst-team.github.io/catalyst/ 43 | """ 44 | 45 | def __init__(self, gamma=2, reduce=True, logits=False): 46 | super().__init__() 47 | self.gamma = gamma 48 | self.reduce = reduce 49 | self.logits = logits 50 | 51 | # TODO refactor 52 | def forward(self, outputs, targets): 53 | """Calculate the loss function between `outputs` and `targets`. 54 | 55 | Arguments 56 | --------- 57 | outputs : :class:`torch.Tensor` 58 | The output tensor from a model. 59 | targets : :class:`torch.Tensor` 60 | The training target. 61 | 62 | Returns 63 | ------- 64 | loss : :class:`torch.Variable` 65 | The loss value. 66 | """ 67 | 68 | if self.logits: 69 | BCE_loss = F.binary_cross_entropy_with_logits(outputs, targets, 70 | reduction='none') 71 | else: 72 | BCE_loss = F.binary_cross_entropy(outputs, targets, 73 | reduction='none') 74 | pt = torch.exp(-BCE_loss) 75 | F_loss = (1-pt)**self.gamma * BCE_loss 76 | if self.reduce: 77 | return torch.mean(F_loss) 78 | else: 79 | return F_loss 80 | 81 | 82 | def torch_lovasz_hinge(logits, labels, per_image=False, ignore=None): 83 | """Lovasz Hinge Loss. Implementation edited from Maxim Berman's GitHub. 84 | 85 | References 86 | ---------- 87 | https://github.com/bermanmaxim/LovaszSoftmax/ 88 | https://arxiv.org/abs/1705.08790 89 | 90 | Arguments 91 | --------- 92 | logits: :class:`torch.Variable` 93 | logits at each pixel (between -inf and +inf) 94 | labels: :class:`torch.Tensor` 95 | binary ground truth masks (0 or 1) 96 | per_image: bool, optional 97 | compute the loss per image instead of per batch. Defaults to ``False``. 98 | ignore: optional void class id. 99 | 100 | Returns 101 | ------- 102 | loss : :class:`torch.Variable` 103 | Lovasz loss value for the input logits and labels. Compatible with 104 | ``loss.backward()`` as its a :class:`torch.Variable` . 105 | """ 106 | # TODO: Restructure into a class like TorchFocalLoss for compatibility 107 | if per_image: 108 | loss = mean( 109 | lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), 110 | lab.unsqueeze(0), 111 | ignore)) 112 | for log, lab in zip(logits, labels)) 113 | else: 114 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, 115 | labels, 116 | ignore)) 117 | return loss 118 | 119 | 120 | def lovasz_hinge_flat(logits, labels): 121 | """Binary Lovasz hinge loss. 122 | 123 | Arguments 124 | --------- 125 | logits: :class:`torch.Variable` 126 | Logits at each prediction (between -inf and +inf) 127 | labels: :class:`torch.Tensor` 128 | binary ground truth labels (0 or 1) 129 | 130 | Returns 131 | ------- 132 | loss : :class:`torch.Variable` 133 | Lovasz loss value for the input logits and labels. 134 | """ 135 | if len(labels) == 0: 136 | # only void pixels, the gradients should be 0 137 | return logits.sum() * 0. 138 | signs = 2. * labels.float() - 1. 139 | errors = (1. - logits * Variable(signs)) 140 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 141 | perm = perm.data 142 | gt_sorted = labels[perm] 143 | grad = lovasz_grad(gt_sorted) 144 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 145 | return loss 146 | 147 | 148 | def flatten_binary_scores(scores, labels, ignore=None): 149 | """ 150 | Flattens predictions in the batch (binary case) 151 | Remove labels equal to 'ignore' 152 | """ 153 | scores = scores.view(-1) 154 | labels = labels.view(-1) 155 | if ignore is None: 156 | return scores, labels 157 | valid = (labels != ignore) 158 | vscores = scores[valid] 159 | vlabels = labels[valid] 160 | return vscores, vlabels 161 | 162 | 163 | class TorchJaccardLoss(torch.nn.modules.Module): 164 | # modified from XD_XD's implementation 165 | def __init__(self): 166 | super(TorchJaccardLoss, self).__init__() 167 | 168 | def forward(self, outputs, targets): 169 | eps = 1e-15 170 | 171 | jaccard_target = (targets == 1).float() 172 | jaccard_output = torch.sigmoid(outputs) 173 | #jaccard_output = outputs # bear's modif part 174 | intersection = (jaccard_output * jaccard_target).sum() 175 | union = jaccard_output.sum() + jaccard_target.sum() 176 | jaccard_score = ((intersection + eps) / (union - intersection + eps)) 177 | self._stash_jaccard = jaccard_score 178 | loss = 1. - jaccard_score 179 | 180 | return loss 181 | 182 | 183 | class TorchStableBCELoss(torch.nn.modules.Module): 184 | def __init__(self): 185 | super(TorchStableBCELoss, self).__init__() 186 | 187 | def forward(self, input, target): 188 | neg_abs = - input.abs() 189 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 190 | return loss.mean() 191 | 192 | 193 | def binary_xloss(logits, labels, ignore=None): 194 | """ 195 | Binary Cross entropy loss 196 | logits: [B, H, W] Variable, logits at each pixel (between -inf and +inf) 197 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 198 | ignore: void class id 199 | """ 200 | logits, labels = flatten_binary_scores(logits, labels, ignore) 201 | loss = TorchStableBCELoss()(logits, Variable(labels.float())) 202 | return loss 203 | 204 | 205 | def lovasz_grad(gt_sorted): 206 | """ 207 | Computes gradient of the Lovasz extension w.r.t sorted errors 208 | See Alg. 1 in paper 209 | """ 210 | p = len(gt_sorted) 211 | gts = gt_sorted.sum() 212 | intersection = gts - gt_sorted.float().cumsum(0) 213 | union = gts + (1 - gt_sorted).float().cumsum(0) 214 | jaccard = 1. - intersection / union 215 | if p > 1: # cover 1 - pixel case 216 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 217 | return jaccard 218 | 219 | 220 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 221 | """ 222 | IoU for foreground class 223 | binary: 1 foreground, 0 background 224 | """ 225 | if not per_image: 226 | preds, labels = (preds,), (labels,) 227 | ious = [] 228 | for pred, label in zip(preds, labels): 229 | intersection = ((label == 1) & (pred == 1)).sum() 230 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 231 | if not union: 232 | iou = EMPTY 233 | else: 234 | iou = float(intersection) / float(union) 235 | ious.append(iou) 236 | iou = mean(ious) # mean accross images if per_image 237 | return 100 * iou 238 | 239 | 240 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 241 | """ 242 | Array of IoU for each (non ignored) class 243 | """ 244 | if not per_image: 245 | preds, labels = (preds,), (labels,) 246 | ious = [] 247 | for pred, label in zip(preds, labels): 248 | iou = [] 249 | for i in range(C): 250 | if i != ignore: 251 | intersection = ((label == i) & (pred == i)).sum() 252 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 253 | if not union: 254 | iou.append(EMPTY) 255 | else: 256 | iou.append(float(intersection) / float(union)) 257 | ious.append(iou) 258 | ious = [mean(iou) for iou in zip(*ious)] # mean across images if per_image 259 | return 100 * np.array(ious) 260 | 261 | 262 | # helper functions 263 | def isnan(x): 264 | return x != x 265 | 266 | 267 | def mean(l, ignore_nan=False, empty=0): 268 | """ 269 | nanmean compatible with generators. 270 | """ 271 | l = iter(l) 272 | if ignore_nan: 273 | l = ifilterfalse(isnan, l) 274 | try: 275 | n = 1 276 | acc = next(l) 277 | except StopIteration: 278 | if empty == 'raise': 279 | raise ValueError('Empty mean') 280 | return empty 281 | for n, v in enumerate(l, 2): 282 | acc += v 283 | if n == 1: 284 | return acc 285 | return acc / n 286 | 287 | 288 | def dice_round(preds, trues): 289 | preds = preds.float() 290 | return soft_dice_loss(preds, trues) 291 | 292 | 293 | def soft_dice_loss(outputs, targets, per_image=False): 294 | batch_size = outputs.size()[0] 295 | eps = 1e-5 296 | if not per_image: 297 | batch_size = 1 298 | dice_target = targets.contiguous().view(batch_size, -1).float() 299 | dice_output = outputs.contiguous().view(batch_size, -1) 300 | intersection = torch.sum(dice_output * dice_target, dim=1) 301 | union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) + eps 302 | loss = (1 - (2 * intersection + eps) / union).mean() 303 | 304 | return loss 305 | class MSSSIM(torch.nn.Module): 306 | def __init__(self, window_size=11, size_average=True, channel=1): 307 | super(MSSSIM, self).__init__() 308 | self.window_size = window_size 309 | self.size_average = size_average 310 | self.channel = channel 311 | 312 | def forward(self, img1, img2): 313 | # TODO: store window between calls if possible 314 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 315 | class SSIM(torch.nn.Module): 316 | def __init__(self, window_size=11, size_average=True, val_range=None): 317 | super(SSIM, self).__init__() 318 | self.window_size = window_size 319 | self.size_average = size_average 320 | self.val_range = val_range 321 | 322 | # Assume 1 channel for SSIM 323 | self.channel = 1 324 | self.window = create_window(window_size) 325 | 326 | def forward(self, img1, img2): 327 | (_, channel, _, _) = img1.size() 328 | 329 | if channel == self.channel and self.window.dtype == img1.dtype: 330 | window = self.window 331 | else: 332 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 333 | self.window = window 334 | self.channel = channel 335 | 336 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 337 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=True): 338 | device = img1.device 339 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 340 | 341 | 342 | 343 | levels = weights.size()[0] 344 | 345 | ssims = [] 346 | mcs = [] 347 | 348 | for _ in range(levels): 349 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 350 | 351 | # Relu normalize (not compliant with original definition) 352 | if normalize == "relu": 353 | ssims.append(torch.relu(sim)) 354 | mcs.append(torch.relu(cs)) 355 | else: 356 | ssims.append(sim) 357 | mcs.append(cs) 358 | 359 | img1 = F.avg_pool2d(img1, (2, 2)) 360 | img2 = F.avg_pool2d(img2, (2, 2)) 361 | 362 | ssims = torch.stack(ssims) 363 | mcs = torch.stack(mcs) 364 | 365 | # Simple normalize (not compliant with original definition) 366 | # TODO: remove support for normalize == True (kept for backward support) 367 | if normalize == "simple" or normalize == True: 368 | ssims = (ssims + 1) / 2 369 | mcs = (mcs + 1) / 2 370 | 371 | pow1 = mcs ** weights 372 | pow2 = ssims ** weights 373 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 374 | output = torch.prod(pow1[:-1] * pow2[-1]) 375 | return output 376 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 377 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 378 | if val_range is None: 379 | if torch.max(img1) > 128: 380 | max_val = 255 381 | else: 382 | max_val = 1 383 | 384 | if torch.min(img1) < -0.5: 385 | min_val = -1 386 | else: 387 | min_val = 0 388 | L = max_val - min_val 389 | else: 390 | L = val_range 391 | 392 | padd = 0 393 | (_, channel, height, width) = img1.size() 394 | if window is None: 395 | real_size = min(window_size, height, width) 396 | window = create_window(real_size, channel=channel).to(img1.device) 397 | 398 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 399 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 400 | 401 | mu1_sq = mu1.pow(2) 402 | mu2_sq = mu2.pow(2) 403 | mu1_mu2 = mu1 * mu2 404 | 405 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 406 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 407 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 408 | 409 | C1 = (0.01 * L) ** 2 410 | C2 = (0.03 * L) ** 2 411 | 412 | v1 = 2.0 * sigma12 + C2 413 | v2 = sigma1_sq + sigma2_sq + C2 414 | cs = torch.mean(v1 / v2+epsilon_) # contrast sensitivity 415 | 416 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2+epsilon_) 417 | 418 | if size_average: 419 | ret = ssim_map.mean() 420 | else: 421 | ret = ssim_map.mean(1).mean(1).mean(1) 422 | 423 | if full: 424 | return ret, cs 425 | return ret 426 | def gaussian(window_size, sigma): 427 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 428 | return gauss/(gauss.sum()+epsilon_) 429 | 430 | 431 | def create_window(window_size, channel=1): 432 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 433 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 434 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 435 | return window 436 | 437 | class BCEDiceLoss(nn.Module): 438 | def __init__(self, weight=None, size_average=True): 439 | super().__init__() 440 | 441 | def forward(self, input, target): 442 | 443 | truth = target.view(-1) 444 | pred = input.view(-1) 445 | # pred = input 446 | # BCE loss 447 | bce_loss = nn.BCELoss()(pred, truth).double() 448 | 449 | # Dice Loss 450 | dice_coef = (2.0 * (pred * truth).double().sum() + 1) / ( 451 | pred.double().sum() + truth.double().sum() + 1 452 | ) 453 | 454 | return bce_loss + (1 - dice_coef) 455 | 456 | class BCEDiceLoss2(nn.Module): 457 | def __init__(self, weight=None, size_average=True): 458 | super().__init__() 459 | 460 | def forward(self, input, target): 461 | 462 | truth = target.view(-1) 463 | pred = input.view(-1) 464 | # pred = input 465 | # BCE loss 466 | bce_loss = nn.BCELoss()(pred, truth).double() 467 | eps = 1e-5 468 | # Dice Loss 469 | dice_coef = (2.0 * (pred * truth).double().sum() + eps) / ( 470 | pred.double().sum() + truth.double().sum() + eps 471 | ) 472 | 473 | return bce_loss + (1 - dice_coef) 474 | torch_losses = { 475 | 'l1loss': nn.L1Loss, 476 | 'l1': nn.L1Loss, 477 | 'mae': nn.L1Loss, 478 | 'mean_absolute_error': nn.L1Loss, 479 | 'smoothl1loss': nn.SmoothL1Loss, 480 | 'smoothl1': nn.SmoothL1Loss, 481 | 'mean_squared_error': nn.MSELoss, 482 | 'mse': nn.MSELoss, 483 | 'mseloss': nn.MSELoss, 484 | 'categorical_crossentropy': nn.CrossEntropyLoss, 485 | 'cce': nn.CrossEntropyLoss, 486 | 'crossentropyloss': nn.CrossEntropyLoss, 487 | 'negative_log_likelihood': nn.NLLLoss, 488 | 'nll': nn.NLLLoss, 489 | 'nllloss': nn.NLLLoss, 490 | 'poisson_negative_log_likelihood': nn.PoissonNLLLoss, 491 | 'poisson_nll': nn.PoissonNLLLoss, 492 | 'poissonnll': nn.PoissonNLLLoss, 493 | 'kullback_leibler_divergence': nn.KLDivLoss, 494 | 'kld': nn.KLDivLoss, 495 | 'kldivloss': nn.KLDivLoss, 496 | 'binary_crossentropy': nn.BCELoss, 497 | 'bce': nn.BCELoss, 498 | 'bceloss': nn.BCELoss, 499 | 'bcewithlogits': nn.BCEWithLogitsLoss, 500 | 'bcewithlogitsloss': nn.BCEWithLogitsLoss, 501 | 'hinge': nn.HingeEmbeddingLoss, 502 | 'hingeembeddingloss': nn.HingeEmbeddingLoss, 503 | 'multiclass_hinge': nn.MultiMarginLoss, 504 | 'multimarginloss': nn.MultiMarginLoss, 505 | 'softmarginloss': nn.SoftMarginLoss, 506 | 'softmargin': nn.SoftMarginLoss, 507 | 'multiclass_softmargin': nn.MultiLabelSoftMarginLoss, 508 | 'multilabelsoftmarginloss': nn.MultiLabelSoftMarginLoss, 509 | 'cosine': nn.CosineEmbeddingLoss, 510 | 'cosineloss': nn.CosineEmbeddingLoss, 511 | 'cosineembeddingloss': nn.CosineEmbeddingLoss, 512 | 'lovaszhinge': torch_lovasz_hinge, 513 | 'focalloss': TorchFocalLoss, 514 | 'focal': TorchFocalLoss, 515 | 'jaccard': TorchJaccardLoss, 516 | 'jaccardloss': TorchJaccardLoss, 517 | 'dice': TorchDiceLoss, 518 | 'diceloss': TorchDiceLoss 519 | , 'msssim': MSSSIM , 'bcedice': BCEDiceLoss, 'bcedice2': BCEDiceLoss2 520 | } 521 | -------------------------------------------------------------------------------- /nets/assembly_block.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ._torch_losses import torch_losses 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Variable 7 | import torch 8 | import skimage 9 | 10 | 11 | def assembly_block(mask_64, mask): 12 | 13 | # obtain boundray from 1-channel mask 14 | arr_mask = mask.cpu().detach().numpy() 15 | mask_boundary_arr = skimage.segmentation.find_boundaries(arr_mask, mode='inner', background=0).astype(np.float32) 16 | mask_boundary = torch.from_numpy(mask_boundary_arr).cuda().float() 17 | 18 | 19 | # recall 64-chanel mask before final conv 20 | # mask_boundary_arr = skimage.segmentation.find_boundaries(mask_64, mode='inner', background=0).astype(np.float32) 21 | conv1 = nn.Conv2d(1,64,3,padding=1).cuda() 22 | conv2 = nn.Conv2d(128,64,3,padding=1).cuda() 23 | conv3 = nn.Conv2d(64,1,3,padding=1).cuda() 24 | 25 | x = Variable(mask_boundary, requires_grad=True) 26 | x = conv1(mask_boundary) 27 | x = torch.cat([mask_64, x], dim=1) 28 | x = conv2(x) 29 | x = conv3(x) 30 | 31 | return mask_boundary, x 32 | -------------------------------------------------------------------------------- /nets/callbacks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .torch_callbacks import torch_callback_dict 3 | import torch 4 | 5 | 6 | def get_callbacks(framework, config): 7 | callbacks = [] 8 | if framework == 'torch': 9 | for callback, params in config['training']['callbacks'].items(): 10 | if callback == 'lr_schedule': 11 | callbacks.append(get_lr_schedule(framework, config)) 12 | else: 13 | callbacks.append(torch_callback_dict[callback](**params)) 14 | 15 | return callbacks 16 | 17 | 18 | def get_lr_schedule(framework, config): 19 | 20 | 21 | schedule_type = config['training'][ 22 | 'callbacks']['lr_schedule']['schedule_type'] 23 | initial_lr = config['training']['lr'] 24 | update_frequency = config['training']['callbacks']['lr_schedule'].get( 25 | 'update_frequency', 1) 26 | factor = config['training']['callbacks']['lr_schedule'].get( 27 | 'factor', 0) 28 | schedule_dict = config['training']['callbacks']['lr_schedule'].get( 29 | 'schedule_dict', None) 30 | if framework == 'torch': 31 | # just get the class itself to use; don't instantiate until the 32 | # optimizer has been created. 33 | if config['training'][ 34 | 'callbacks']['lr_schedule']['schedule_type'] == 'linear': 35 | lr_scheduler = torch.optim.lr_scheduler.StepLR 36 | elif config['training'][ 37 | 'callbacks']['lr_schedule']['schedule_type'] == 'exponential': 38 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR 39 | # elif config['training'][ 40 | # 'callbacks']['lr_schedule']['schedule_type'] == 'arbitrary': 41 | # lr_scheduler = torch.optim.lr_scheduler.MultiStepLR 42 | elif config['training'][ 43 | 'callbacks']['lr_schedule']['schedule_type'] == 'arbitrary': 44 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR 45 | 46 | elif config['training'][ 47 | 'callbacks']['lr_schedule']['schedule_type'] == 'cycle': 48 | print("check callback") 49 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR 50 | 51 | 52 | return lr_scheduler 53 | 54 | -------------------------------------------------------------------------------- /nets/datagen.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import rasterio 4 | from torch.utils.data import Dataset, DataLoader 5 | from .transform import _check_augs, process_aug_dict 6 | from utils.core import _check_df_load 7 | from utils.io import imread, _check_channel_order 8 | import skimage 9 | import skimage.segmentation 10 | def make_data_generator(framework, config, df, stage='train'): 11 | 12 | if framework.lower() not in ['pytorch', 'torch']: 13 | raise ValueError('{} is not an accepted value for `framework`'.format( 14 | framework)) 15 | 16 | # make sure the df is loaded 17 | df = _check_df_load(df) 18 | 19 | if stage == 'train': 20 | augs = config['training_augmentation'] 21 | shuffle = config['training_augmentation']['shuffle'] 22 | elif stage == 'validate': 23 | augs = config['validation_augmentation'] 24 | shuffle = False 25 | try: 26 | num_classes = config['data_specs']['num_classes'] 27 | except KeyError: 28 | num_classes = 1 29 | 30 | 31 | if framework in ['torch', 'pytorch']: 32 | dataset = TorchDataset( 33 | df, 34 | augs=augs, 35 | batch_size=config['batch_size'], 36 | label_type=config['data_specs']['label_type'], 37 | is_categorical=config['data_specs']['is_categorical'], 38 | num_classes=num_classes, 39 | dtype=config['data_specs']['dtype']) 40 | # set up workers for DataLoader for pytorch 41 | data_workers = config['data_specs'].get('data_workers') 42 | if data_workers == 1 or data_workers is None: 43 | data_workers = 0 # for DataLoader to run in main process 44 | data_gen = DataLoader( 45 | dataset, 46 | batch_size=config['batch_size'], 47 | shuffle=config['training_augmentation']['shuffle'], 48 | num_workers=data_workers, 49 | drop_last=True) 50 | 51 | return data_gen 52 | 53 | 54 | 55 | class TorchDataset(Dataset): 56 | 57 | def __init__(self, df, augs, batch_size, label_type='mask', 58 | is_categorical=False, num_classes=1, dtype=None): 59 | 60 | super().__init__() 61 | 62 | self.df = df 63 | self.batch_size = batch_size 64 | self.n_batches = int(np.floor(len(self.df)/self.batch_size)) 65 | self.aug = _check_augs(augs) 66 | self.is_categorical = is_categorical 67 | self.num_classes = num_classes 68 | 69 | if dtype is None: 70 | self.dtype = np.float32 # default 71 | # if it's a string, get the appropriate object 72 | elif isinstance(dtype, str): 73 | try: 74 | self.dtype = getattr(np, dtype) 75 | except AttributeError: 76 | raise ValueError( 77 | 'The data type {} is not supported'.format(dtype)) 78 | # lastly, check if it's already defined in the right format for use 79 | elif issubclass(dtype, np.number) or isinstance(dtype, np.dtype): 80 | self.dtype = dtype 81 | 82 | def __len__(self): 83 | return len(self.df) 84 | 85 | def __getitem__(self, idx): 86 | """Get one image, mask pair""" 87 | # Generate indexes of the batch 88 | image = imread(self.df['image'].iloc[idx]) 89 | mask = imread(self.df['label'].iloc[idx]) 90 | boundary = mask 91 | 92 | if not self.is_categorical: 93 | mask[mask != 0] = 1 94 | if len(mask.shape) == 2: 95 | mask = mask[:, :, np.newaxis] 96 | if len(image.shape) == 2: 97 | image = image[:, :, np.newaxis] 98 | 99 | if len(boundary.shape) == 2: 100 | boundary = boundary[:, :, np.newaxis] 101 | 102 | sample = {'image': image, 'mask': mask, 'boundary' : boundary} 103 | 104 | if self.aug: 105 | sample = self.aug(**sample) 106 | 107 | 108 | 109 | sample['image'] = _check_channel_order(sample['image'], 110 | 'torch').astype(self.dtype) 111 | sample['mask'] = _check_channel_order(sample['mask'], 112 | 'torch').astype(np.float32) 113 | 114 | sample['boundary'] = _check_channel_order(skimage.segmentation.find_boundaries(sample['mask'], mode='inner', background=0), 115 | 'torch').astype(np.float32) 116 | 117 | return sample 118 | 119 | 120 | class InferenceTiler(object): 121 | 122 | 123 | def __init__(self, framework, width, height, x_step=None, y_step=None, 124 | augmentations=None): 125 | 126 | self.framework = framework 127 | self.width = width 128 | self.height = height 129 | if x_step is None: 130 | self.x_step = self.width 131 | else: 132 | self.x_step = x_step 133 | if y_step is None: 134 | self.y_step = self.height 135 | else: 136 | self.y_step = y_step 137 | self.aug = _check_augs(augmentations) 138 | 139 | def __call__(self, im): 140 | 141 | # read in the image if it's a path 142 | if isinstance(im, str): 143 | im = imread(im) 144 | 145 | # determine how many samples will be generated with the sliding window 146 | src_im_height = im.shape[0] 147 | src_im_width = im.shape[1] 148 | 149 | 150 | 151 | y_steps = int(1+np.ceil((src_im_height-self.height)/self.y_step)) 152 | x_steps = int(1+np.ceil((src_im_width-self.width)/self.x_step)) 153 | if len(im.shape) == 2: # if there's no channel axis 154 | im = im[:, :, np.newaxis] # create one - will be needed for model 155 | top_left_corner_idxs = [] 156 | output_arr = [] 157 | for y in range(y_steps): 158 | if self.y_step*y + self.height > im.shape[0]: 159 | y_min = im.shape[0] - self.height 160 | else: 161 | y_min = self.y_step*y 162 | 163 | for x in range(x_steps): 164 | if self.x_step*x + self.width > im.shape[1]: 165 | x_min = im.shape[1] - self.width 166 | else: 167 | x_min = self.x_step*x 168 | 169 | subarr = im[y_min:y_min + self.height, 170 | x_min:x_min + self.width, 171 | :] 172 | if self.aug is not None: 173 | subarr = self.aug(image=subarr)['image'] 174 | output_arr.append(subarr) 175 | top_left_corner_idxs.append((y_min, x_min)) 176 | output_arr = np.stack(output_arr).astype(np.float32) 177 | if self.framework in ['torch', 'pytorch']: 178 | output_arr = np.moveaxis(output_arr, 3, 1) 179 | 180 | 181 | return output_arr, top_left_corner_idxs, (src_im_height, src_im_width) 182 | -------------------------------------------------------------------------------- /nets/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import skimage.io 5 | import numpy as np 6 | from warnings import warn 7 | from .model_io import get_model 8 | from .transform import process_aug_dict 9 | from .datagen import InferenceTiler as InferenceTiler 10 | from utils.core import get_data_paths 11 | import torch.nn.functional as F 12 | 13 | class Inferer(object): 14 | """Object for training `solaris` models using PyTorch or Keras.""" 15 | 16 | def __init__(self, config, custom_model_dict=None): 17 | self.config = config 18 | self.batch_size = self.config['batch_size'] 19 | self.framework = self.config['nn_framework'] 20 | self.model_name = self.config['model_name'] 21 | self.aoi = self.config["get_aoi"] 22 | self.date = self.config["training_date"] 23 | self.boundary = self.config["boundary"] 24 | self.weight_file = self.config['weight_file'] 25 | 26 | # check if the model was trained as part of the same pipeline; if so, 27 | # use the output from that. If not, use the pre-trained model directly. 28 | if self.config['train']: 29 | warn('Because the configuration specifies both training and ' 30 | 'inference, solaris is switching the model weights path ' 31 | 'to the training output path.') 32 | self.model_path = self.config['training']['model_dest_path'] 33 | if custom_model_dict is not None: 34 | custom_model_dict['weight_path'] = self.config[ 35 | 'training']['model_dest_path'] 36 | else: 37 | 38 | if len(self.model_name.split('_'))==2: 39 | self.model_path = self.config.get('model_path', None) + self.aoi + '_' +self.model_name.split('_')[0]+ '_' + self.model_name.split('_')[1]+ '_'+ self.date + '/' + self.weight_file 40 | else : 41 | self.model_path = self.config.get('model_path', None) + self.aoi + '_' +self.model_name.split('_')[0]+ '_' + self.date + '/' + self.weight_file 42 | self.infer_mode = self.config['infer'] 43 | if self.infer_mode : 44 | self.mode = 'Infer' 45 | 46 | self.model = get_model(self.model_name, self.framework, self.mode, 47 | self.model_path, pretrained=True, custom_model_dict=custom_model_dict) 48 | self.window_step_x = self.config['inference'].get('window_step_size_x', 49 | None) 50 | self.window_step_y = self.config['inference'].get('window_step_size_y', 51 | None) 52 | if self.window_step_x is None: 53 | self.window_step_x = self.config['data_specs']['width'] 54 | if self.window_step_y is None: 55 | self.window_step_y = self.config['data_specs']['height'] 56 | self.stitching_method = self.config['inference'].get( 57 | 'stitching_method', 'average') 58 | self.output_dir = self.config['inference']['output_dir'] + self.aoi + '_' + self.date + '/' 59 | 60 | if not os.path.isdir(self.output_dir): 61 | os.makedirs(self.output_dir) 62 | 63 | if self.framework in ['torch', 'pytorch']: 64 | self.gpu_available = torch.cuda.is_available() 65 | if self.gpu_available: 66 | self.gpu_count = torch.cuda.device_count() 67 | else: 68 | self.gpu_count = 0 69 | def __call__(self, infer_df=None): 70 | 71 | with torch.no_grad(): 72 | print(self.model_path) 73 | if infer_df is None: 74 | infer_df = get_infer_df(self.config) 75 | 76 | inf_tiler = InferenceTiler( 77 | self.framework, 78 | width=self.config['data_specs']['width'], 79 | height=self.config['data_specs']['height'], 80 | x_step=self.window_step_x, 81 | y_step=self.window_step_y, 82 | augmentations=process_aug_dict( 83 | self.config['inference_augmentation'])) 84 | for idx, im_path in enumerate(infer_df['image']): 85 | leng=len(infer_df['image']) 86 | print(idx,'/',leng, ' (%0.2f%%)' % float(100*idx/leng)) 87 | 88 | inf_input, idx_refs, ( 89 | src_im_height, src_im_width) = inf_tiler(im_path) 90 | 91 | if self.framework in ['torch', 'pytorch']: 92 | 93 | with torch.no_grad(): 94 | self.model.eval() 95 | 96 | if torch.cuda.is_available(): 97 | device = torch.device('cuda') 98 | self.model = self.model.cuda() 99 | else: 100 | device = torch.device('cpu') 101 | 102 | inf_input = torch.from_numpy(inf_input).float().to(device) 103 | 104 | # add additional input data, if applicable 105 | if self.config['data_specs'].get('additional_inputs', 106 | None) is not None: 107 | inf_input = [inf_input] 108 | for i in self.config['data_specs']['additional_inputs']: 109 | inf_input.append( 110 | infer_df[i].iloc[idx].to(device)) 111 | 112 | 113 | 114 | subarr_preds = self.model(inf_input) 115 | 116 | 117 | subarr_preds = subarr_preds.cpu().data.numpy() 118 | subarr_preds = subarr_preds[:, :, :src_im_height,:src_im_width] 119 | 120 | 121 | 122 | skimage.io.imsave(os.path.join(self.output_dir,os.path.split(im_path)[1]), subarr_preds) 123 | 124 | 125 | def get_infer_df(config): 126 | 127 | infer_df = get_data_paths(config['inference_data_csv']+config['get_aoi']+'_Test_df.csv' , infer=True) 128 | return infer_df 129 | -------------------------------------------------------------------------------- /nets/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ._torch_losses import torch_losses 3 | from torch import nn 4 | 5 | 6 | def get_loss(framework, loss, loss_weights=None, custom_losses=None): 7 | 8 | # lots of exception handling here. TODO: Refactor. 9 | 10 | if not isinstance(loss, dict): 11 | raise TypeError('The loss description is formatted improperly.' 12 | ' See the docs for details.') 13 | if len(loss) > 1: 14 | 15 | # get the weights for each loss within the composite 16 | if loss_weights is None: 17 | # weight all losses equally 18 | weights = {k: 1 for k in loss.keys()} 19 | else: 20 | weights = loss_weights 21 | 22 | # check if sublosses dict and weights dict have the same keys 23 | if list(loss.keys()).sort() != list(weights.keys()).sort(): 24 | raise ValueError( 25 | 'The losses and weights must have the same name keys.') 26 | 27 | if framework in ['pytorch', 'torch']: 28 | return TorchCompositeLoss(loss, weights, custom_losses) 29 | 30 | else: # parse individual loss functions 31 | loss_name, loss_dict = list(loss.items())[0] 32 | return get_single_loss(framework, loss_name, loss_dict, custom_losses) 33 | 34 | 35 | def get_single_loss(framework, loss_name, params_dict, custom_losses=None): 36 | 37 | if framework in ['torch', 'pytorch']: 38 | if params_dict is None: 39 | if custom_losses is not None and loss_name in custom_losses: 40 | return custom_losses.get(loss_name)() 41 | else: 42 | return torch_losses.get(loss_name.lower())() 43 | else: 44 | if custom_losses is not None and loss_name in custom_losses: 45 | return custom_losses.get(loss_name)(**params_dict) 46 | else: 47 | return torch_losses.get(loss_name.lower())(**params_dict) 48 | 49 | 50 | class TorchCompositeLoss(nn.Module): 51 | """Composite loss function.""" 52 | 53 | def __init__(self, loss_dict, weight_dict=None, custom_losses=None): 54 | """Create a composite loss function from a set of pytorch losses.""" 55 | super().__init__() 56 | self.weights = weight_dict 57 | self.losses = {loss_name: get_single_loss('pytorch', 58 | loss_name, 59 | loss_params, 60 | custom_losses) 61 | for loss_name, loss_params in loss_dict.items()} 62 | self.values = {} # values from the individual loss functions 63 | 64 | def forward(self, outputs, targets): 65 | loss = 0 66 | for func_name, weight in self.weights.items(): 67 | self.values[func_name] = self.losses[func_name](outputs, targets) 68 | loss += weight*self.values[func_name] 69 | 70 | return loss 71 | -------------------------------------------------------------------------------- /nets/model_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from warnings import warn 4 | import requests 5 | import numpy as np 6 | from tqdm.auto import tqdm 7 | from nets import weights_dir 8 | from .zoo import model_dict 9 | 10 | 11 | def get_model(model_name, framework, mode='Train', model_path=None, pretrained=False, 12 | custom_model_dict=None, num_classes=1): 13 | """Load a model from a file based on its name.""" 14 | if custom_model_dict is not None: 15 | md = custom_model_dict 16 | else: 17 | md = model_dict.get(model_name, None) 18 | if md is None: # if the model's not provided by solaris 19 | raise ValueError(f"{model_name} can't be found in solaris and no " 20 | "custom_model_dict was provided. Check your " 21 | "model_name in the config file and/or provide a " 22 | "custom_model_dict argument to Trainer(). ") 23 | if model_path is None or custom_model_dict is not None: 24 | 25 | model_path = md.get('weight_path') 26 | if num_classes == 1: 27 | model = md.get('arch')(pretrained=pretrained, mode=mode) 28 | else: 29 | model = md.get('arch')(num_classes=num_classes, pretrained=pretrained) 30 | 31 | if model is not None and pretrained: 32 | try: 33 | model = _load_model_weights(model, model_path, framework) 34 | except (OSError, FileNotFoundError): 35 | warn(f'The model weights file {model_path} was not found.' 36 | ' Attempting to download from the SpaceNet repository.') 37 | weight_path = _download_weights(md) 38 | model = _load_model_weights(model, weight_path, framework) 39 | 40 | return model 41 | 42 | 43 | def _load_model_weights(model, path, framework): 44 | """Backend for loading the model.""" 45 | 46 | if framework.lower() in ['torch', 'pytorch']: 47 | # pytorch already throws the right error on failed load, so no need 48 | # to fix exception 49 | if torch.cuda.is_available(): 50 | try: 51 | loaded = torch.load(path) 52 | except FileNotFoundError: 53 | # first, check to see if the weights are in the default sol dir 54 | default_path = os.path.join(weights_dir, 55 | os.path.split(path)[1]) 56 | loaded = torch.load(path) 57 | else: 58 | try: 59 | loaded = torch.load(path, map_location='cpu') 60 | except FileNotFoundError: 61 | default_path = os.path.join(weights_dir, 62 | os.path.split(path)[1]) 63 | loaded = torch.load(path, map_location='cpu') 64 | 65 | if isinstance(loaded, torch.nn.Module): # if it's a full model already 66 | model.load_state_dict(loaded.state_dict()) 67 | else: 68 | model.load_state_dict(loaded) 69 | 70 | return model 71 | 72 | 73 | def reset_weights(model, framework): 74 | 75 | if framework == 'torch': 76 | reinit_model = model.apply(_reset_torch_weights) 77 | 78 | return reinit_model 79 | 80 | 81 | def _reset_torch_weights(torch_layer): 82 | if isinstance(torch_layer, torch.nn.Conv2d) or \ 83 | isinstance(torch_layer, torch.nn.Linear): 84 | torch_layer.reset_parameters() 85 | 86 | 87 | def _download_weights(model_dict): 88 | """Download pretrained weights for a model.""" 89 | weight_url = model_dict.get('weight_url', None) 90 | weight_dest_path = model_dict.get('weight_path', os.path.join( 91 | weights_dir, weight_url.split('/')[-1])) 92 | if weight_url is None: 93 | raise KeyError("Can't find the weights file.") 94 | else: 95 | r = requests.get(weight_url, stream=True) 96 | if r.status_code != 200: 97 | raise ValueError('The file could not be downloaded. Check the URL' 98 | ' and network connections.') 99 | total_size = int(r.headers.get('content-length', 0)) 100 | block_size = 1024 101 | with open(weight_dest_path, 'wb') as f: 102 | for chunk in tqdm(r.iter_content(block_size), 103 | total=np.ceil(total_size//block_size), 104 | unit='KB', unit_scale=False): 105 | if chunk: 106 | f.write(chunk) 107 | 108 | return weight_dest_path 109 | -------------------------------------------------------------------------------- /nets/optimizers.py: -------------------------------------------------------------------------------- 1 | """Wrappers for training optimizers.""" 2 | import math 3 | import torch 4 | 5 | def get_optimizer(framework, config): 6 | 7 | if config['training']['optimizer'] is None: 8 | raise ValueError('An optimizer must be specified in the config ' 9 | 'file.') 10 | 11 | if framework in ['torch', 'pytorch']: 12 | return torch_optimizers.get(config['training']['optimizer'].lower()) 13 | 14 | class TorchAdamW(torch.optim.Optimizer): 15 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 16 | weight_decay=1e-2, amsgrad=False): 17 | if not 0.0 <= lr: 18 | raise ValueError("Invalid learning rate: {}".format(lr)) 19 | if not 0.0 <= eps: 20 | raise ValueError("Invalid epsilon value: {}".format(eps)) 21 | if not 0.0 <= betas[0] < 1.0: 22 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 23 | if not 0.0 <= betas[1] < 1.0: 24 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 25 | defaults = dict(lr=lr, betas=betas, eps=eps, 26 | weight_decay=weight_decay, amsgrad=amsgrad) 27 | super(TorchAdamW, self).__init__(params, defaults) 28 | 29 | def __setstate__(self, state): 30 | super(TorchAdamW, self).__setstate__(state) 31 | for group in self.param_groups: 32 | group.setdefault('amsgrad', False) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | Arguments: 37 | closure (callable, optional): A closure that reevaluates the model 38 | and returns the loss. 39 | """ 40 | loss = None 41 | if closure is not None: 42 | loss = closure() 43 | 44 | for group in self.param_groups: 45 | for p in group['params']: 46 | if p.grad is None: 47 | continue 48 | 49 | # Perform stepweight decay 50 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 51 | 52 | # Perform optimization step 53 | grad = p.grad.data 54 | if grad.is_sparse: 55 | raise RuntimeError('Adam does not support sparse' 56 | 'gradients, please consider SparseAdam' 57 | ' instead') 58 | amsgrad = group['amsgrad'] 59 | 60 | state = self.state[p] 61 | 62 | # State initialization 63 | if len(state) == 0: 64 | state['step'] = 0 65 | # Exponential moving average of gradient values 66 | state['exp_avg'] = torch.zeros_like(p.data) 67 | # Exponential moving average of squared gradient values 68 | state['exp_avg_sq'] = torch.zeros_like(p.data) 69 | if amsgrad: 70 | # Maintains max of all exp. moving avg. of sq. grad. values 71 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 72 | 73 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 74 | if amsgrad: 75 | max_exp_avg_sq = state['max_exp_avg_sq'] 76 | beta1, beta2 = group['betas'] 77 | 78 | state['step'] += 1 79 | 80 | # Decay the first and second moment running average coefficient 81 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 82 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 83 | if amsgrad: 84 | # Maintains the maximum of all 2nd moment running avg. till now 85 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 86 | # Use the max. for normalizing running avg. of gradient 87 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 88 | else: 89 | denom = exp_avg_sq.sqrt().add_(group['eps']) 90 | 91 | bias_correction1 = 1 - beta1 ** state['step'] 92 | bias_correction2 = 1 - beta2 ** state['step'] 93 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 94 | 95 | p.data.addcdiv_(-step_size, exp_avg, denom) 96 | 97 | return loss 98 | 99 | 100 | torch_optimizers = { 101 | 'adadelta': torch.optim.Adadelta, 102 | 'adam': torch.optim.Adam, 103 | 'adamw': TorchAdamW, 104 | 'sparseadam': torch.optim.SparseAdam, 105 | 'adamax': torch.optim.Adamax, 106 | 'asgd': torch.optim.ASGD, 107 | 'rmsprop': torch.optim.RMSprop, 108 | 'sgd': torch.optim.SGD, 109 | } 110 | 111 | -------------------------------------------------------------------------------- /nets/torch_callbacks.py: -------------------------------------------------------------------------------- 1 | """PyTorch Callbacks.""" 2 | 3 | import os 4 | import numpy as np 5 | 6 | import torch 7 | import time 8 | 9 | now = time.localtime() 10 | 11 | class TorchEarlyStopping(object): 12 | """Tracks if model training should stop based on rate of improvement. 13 | 14 | Arguments 15 | --------- 16 | patience : int, optional 17 | The number of epochs to wait before stopping the model if the metric 18 | didn't improve. Defaults to 5. 19 | threshold : float, optional 20 | The minimum metric improvement required to count as "improvement". 21 | Defaults to ``0.0`` (any improvement satisfies the requirement). 22 | verbose : bool, optional 23 | Verbose text output. Defaults to off (``False``). _NOTE_ : This 24 | currently does nothing. 25 | """ 26 | 27 | def __init__(self, patience=5, threshold=0.0, verbose=False): 28 | self.patience = patience 29 | self.threshold = threshold 30 | self.counter = 0 31 | self.best = None 32 | self.stop = False 33 | 34 | def __call__(self, metric_score): 35 | 36 | if self.best is None: 37 | self.best = metric_score 38 | self.counter = 0 39 | else: 40 | if self.best - self.threshold < metric_score: 41 | self.counter += 1 42 | else: 43 | self.best = metric_score 44 | self.counter = 0 45 | 46 | if self.counter >= self.patience: 47 | self.stop = True 48 | 49 | 50 | class TorchTerminateOnNaN(object): 51 | """Sets a stop condition if the model loss achieves an NaN or inf value. 52 | 53 | Arguments 54 | --------- 55 | patience : int, optional 56 | The number of epochs that must display an NaN loss value before 57 | stopping. Defaults to ``1``. 58 | verbose : bool, optional 59 | Verbose text output. Defaults to off (``False``). _NOTE_ : This 60 | currently does nothing. 61 | """ 62 | 63 | def __init__(self, patience=1, verbose=False): 64 | self.patience = patience 65 | self.counter = 0 66 | self.stop = False 67 | 68 | def __call__(self, loss): 69 | if np.isnan(loss) or np.isinf(loss): 70 | self.counter += 1 71 | if self.counter >= self.patience: 72 | self.stop = True 73 | else: 74 | self.counter = 0 75 | 76 | 77 | 78 | class TorchModelCheckpoint(object): 79 | """Save the model at specific points using Keras checkpointing args. 80 | 81 | Arguments 82 | --------- 83 | filepath : str, optional 84 | Path to save the model file to. The end of the path (before the 85 | file extension) will have ``'_[epoch]'`` added to it to ID specific 86 | checkpoints. 87 | monitor : str, optional 88 | The loss value to monitor. Options are 89 | ``['loss', 'val_loss', 'periodic']`` or a metric from the keys in 90 | :const:`solaris.nets.metrics.metric_dict` . Defaults to ``'loss'`` . If 91 | ``'periodic'``, it saves every n epochs (see `period` below). 92 | verbose : bool, optional 93 | Verbose text output. Defaults to ``False`` . 94 | save_best_only : bool, optional 95 | Save only the model with the best value? Defaults to no (``False`` ). 96 | mode : str, optional 97 | One of ``['auto', 'min', 'max']``. Is a better value higher or lower? 98 | Defaults to ``'auto'`` in which case it tries to infer it (if 99 | ``monitor='loss'`` or ``monitor='val_loss'`` , it assumes ``'min'`` , 100 | if it's a metric it assumes ``'max'`` .) If ``'min'``, it assumes lower 101 | values are better; if ``'max'`` , it assumes higher values are better. 102 | period : int, optional 103 | If using ``monitor='periodic'`` , this saves models every `period` 104 | epochs. Otherwise, it sets the minimum number of epochs between 105 | checkpoints. 106 | """ 107 | 108 | def __init__(self, filepath='', path_aoi='',monitor='loss', verbose=False, 109 | save_best_only=False, mode='auto', period=1, 110 | weights_only=True): 111 | 112 | self.filepath = filepath 113 | self.monitor = monitor 114 | self.aoi = path_aoi 115 | if self.monitor not in ['loss', 'val_loss', 'periodic']: 116 | self.monitor = metric_dict[self.monitor] 117 | self.verbose = verbose 118 | self.save_best_only = save_best_only 119 | self.period = period 120 | self.weights_only = weights_only 121 | self.mode = mode 122 | if self.mode == 'auto': 123 | if self.monitor in ['loss', 'val_loss']: 124 | self.mode = 'min' 125 | else: 126 | self.mode = 'max' 127 | 128 | self.epoch = 0 129 | self.last_epoch = 0 130 | self.last_saved_value = None 131 | 132 | def __call__(self, model, file_path, loss_value=None, y_true=None, y_pred=None): 133 | """Run a round of model checkpointing for an epoch. 134 | 135 | Arguments 136 | --------- 137 | model : model object 138 | The model to be saved during checkpoints. Must be a PyTorch model. 139 | loss_value : numeric, optional 140 | The numeric output of the loss function. Only required if using 141 | ``monitor='loss'`` or ``monitor='val_loss'`` . 142 | y_true : :class:`np.array` , optional 143 | The labels for the validation data. Only required if using 144 | a metric as the monitored value. 145 | y_pred : :class:`np.array` , optional 146 | The predicted values from the model. Only required if using 147 | a metric as the monitored value. 148 | """ 149 | 150 | self.epoch += 1 151 | if self.monitor == 'periodic': # update based on period 152 | if self.last_epoch + self.period <= self.epoch: 153 | # self.last_saved_value = loss_value if loss_value else 0 154 | self.save(model, file_path,self.weights_only) 155 | self.last_epoch = self.epoch 156 | 157 | 158 | elif self.monitor in ['loss', 'val_loss']: 159 | if self.last_saved_value is None: 160 | self.last_saved_value = loss_value 161 | if self.last_epoch + self.period <= self.epoch: 162 | self.save(model,file_path, self.weights_only) 163 | self.last_epoch = self.epoch 164 | if self.last_epoch + self.period <= self.epoch: 165 | if self.check_is_best_value(loss_value): 166 | self.last_saved_value = loss_value 167 | self.save(model, file_path,self.weights_only) 168 | self.last_epoch = self.epoch 169 | 170 | else: 171 | if self.last_saved_value is None: 172 | self.last_saved_value = self.monitor(y_true, y_pred) 173 | if self.last_epoch + self.period <= self.epoch: 174 | self.save(model,file_path, self.weights_only) 175 | self.last_epoch = self.epoch 176 | if self.last_epoch + self.period <= self.epoch: 177 | metric_value = self.monitor(y_true, y_pred) 178 | if self.check_is_best_value(metric_value): 179 | self.last_saved_value = metric_value 180 | self.save(model, file_path, self.weights_only) 181 | self.last_epoch = self.epoch 182 | 183 | def check_is_best_value(self, value): 184 | """Check if `value` is better than the best stored value.""" 185 | if self.mode == 'min' and self.last_saved_value > value: 186 | return True 187 | elif self.mode == 'max' and self.last_saved_value < value: 188 | return True 189 | else: 190 | return False 191 | 192 | def save(self, model, file_path, weights_only): 193 | """Save the model. 194 | 195 | Arguments 196 | --------- 197 | model : :class:`torch.nn.Module` 198 | A PyTorch model instance to save. 199 | weights_only : bool, optional 200 | Should the entire model be saved, or only its weights (also known 201 | as the state_dict)? Defaults to ``False`` (saves entire model). The 202 | entire model must be saved to resume training without re-defining 203 | the model architecture, optimizer, and loss function. 204 | """ 205 | # print("saved time : %04d/%02d/%02d %02d:%02d:%02d"% (now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec)) 206 | 207 | # print(self.aoi) 208 | # print("aoi") 209 | lossvalue=np.round(self.last_saved_value,3) 210 | # save_path = self.filepath + self.aoi + '_' + str(now.tm_mon) + str(now.tm_mday)+str(now.tm_hour)+str(now.tm_min)+ '/' 211 | save_path = file_path 212 | save_name = save_path + 'best'+ '_epoch{}_{}'.format(self.epoch, str(lossvalue)) 213 | #save_name = save_path + 'best'+ '_epoch_{}_{}'.format( str(self.epoch).zfill(3), str(lossvalue)) 214 | 215 | # save_name = self.filepath + self.aoi + '_' + str(now.tm_mon) + str(now.tm_mday)+ '/' + 'best'+ '_epoch{}_{}'.format( 216 | # self.epoch, np.round(self.last_saved_value, 1)) 217 | # save_name = os.path.splitext(self.filepath)[0] + self.aoi + + '_epoch{}_{}'.format( 218 | # self.epoch, np.round(self.last_saved_value, 3)) 219 | 220 | save_name = save_name + '.pth' 221 | print("saved path : ", save_path) 222 | print() 223 | print() 224 | if not os.path.exists(save_path) : 225 | os.makedirs(save_path) 226 | else : 227 | pass 228 | # os.makedirs(save_path) 229 | if isinstance(model, torch.nn.DataParallel): 230 | to_save = model.module 231 | else: 232 | to_save = model 233 | if weights_only: 234 | # os.makedirs(save_path) 235 | # torch.save(save_path, save_name) 236 | torch.save(to_save.state_dict(), save_name) 237 | 238 | else: 239 | torch.save(to_save, save_name) 240 | 241 | 242 | torch_callback_dict = { 243 | "early_stopping": TorchEarlyStopping, 244 | "model_checkpoint": TorchModelCheckpoint, 245 | "terminate_on_nan": TorchTerminateOnNaN, 246 | } 247 | -------------------------------------------------------------------------------- /nets/zoo/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .. import weights_dir 3 | 4 | from .unet import UNet 5 | from .unet_BE import UNet_BE 6 | from .resunet import ResUnetPlusPlus 7 | from .resunet_BE import ResUnetPlusPlus_BE 8 | from .ternaus import ternaus11 9 | from .ternaus_BE import ternaus_BE 10 | from .uspp import Uspp 11 | from .uspp_BE import Uspp_BE 12 | from .denet import DeNet 13 | from .brrnet import BRRNet 14 | from .brrnet_BE import BRRNet_BE 15 | from .enru import ENRUNet 16 | from .enru_BE import ENRUNet_BE 17 | 18 | model_dict = { 19 | 'unet' : { 20 | 'weight_path': None, 21 | 'weight_url': None, 22 | 'arch': UNet}, 23 | 'enru' : { 24 | 'weight_path': None, 25 | 'weight_url': None, 26 | 'arch': ENRUNet}, 27 | 'enru_BE' : { 28 | 'weight_path': None, 29 | 'weight_url': None, 30 | 'arch': ENRUNet_BE}, 31 | 'brrnet' : { 32 | 'weight_path': None, 33 | 'weight_url': None, 34 | 'arch': BRRNet}, 35 | 'brrnet_BE' : { 36 | 'weight_path': None, 37 | 'weight_url': None, 38 | 'arch': BRRNet_BE}, 39 | 'denet' : { 40 | 'weight_path': None, 41 | 'weight_url': None, 42 | 'arch': DeNet}, 43 | 'uspp' : { 44 | 'weight_path': None, 45 | 'weight_url': None, 46 | 'arch': Uspp}, 47 | 'uspp_BE' : { 48 | 'weight_path': None, 49 | 'weight_url': None, 50 | 'arch': Uspp_BE}, 51 | 'resunet_BE' : { 52 | 'weight_path':None, 53 | 'weight_url': None, 54 | 'arch': ResUnetPlusPlus_BE}, 55 | 'resunet' : { 56 | 'weight_path': None, 57 | 'weight_url': None, 58 | 'arch': ResUnetPlusPlus}, 59 | 'unet_BE' : { 60 | 'weight_path':None, 61 | 'weight_url': None, 62 | 'arch': UNet_BE}, 63 | 'ternaus' : { 64 | 'weight_path':None, 65 | 'weight_url': None, 66 | 'arch': ternaus11}, 67 | 'ternaus_BE' : { 68 | 'weight_path':None, 69 | 'weight_url': None, 70 | 'arch': ternaus_BE}, 71 | 'hrnetv2' : { 72 | 'weight_path':None, 73 | 'weight_url': None, 74 | 'arch': hrnetv2}, 75 | 'hrnetv2_BE' : { 76 | 'weight_path':None, 77 | 'weight_url': None, 78 | 'arch': hrnetv2_BE}, 79 | } 80 | -------------------------------------------------------------------------------- /nets/zoo/brrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def conv_block(in_channels, out_channels): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 8 | nn.BatchNorm2d(num_features=out_channels), 9 | nn.ReLU(inplace=True), 10 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 11 | nn.BatchNorm2d(num_features=out_channels), 12 | nn.ReLU(inplace=True) 13 | ) 14 | 15 | 16 | def up_transpose(in_channels, out_channels): 17 | return nn.Sequential( 18 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 19 | ) 20 | class center_block(nn.Module): 21 | def __init__(self, in_channels, out_channels): 22 | super(center_block, self).__init__() 23 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1,dilation=1) 24 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=2,dilation=2) 25 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=4,dilation=4) 26 | self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=8,dilation=8) 27 | self.conv5 = nn.Conv2d(out_channels, out_channels, 3, padding=16,dilation=16) 28 | self.conv6 = nn.Conv2d(out_channels, out_channels, 3, padding=32,dilation=32) 29 | 30 | self.bn_1 = nn.BatchNorm2d(num_features=out_channels) 31 | self.bn_2 = nn.BatchNorm2d(num_features=out_channels) 32 | self.bn_3 = nn.BatchNorm2d(num_features=out_channels) 33 | self.bn_4 = nn.BatchNorm2d(num_features=out_channels) 34 | self.bn_5 = nn.BatchNorm2d(num_features=out_channels) 35 | self.bn_6 = nn.BatchNorm2d(num_features=out_channels) 36 | self.relu = nn.ReLU() 37 | 38 | 39 | 40 | def forward(self,x):# 지금 rrm쪽이랑 센터랑 섞임.. 41 | 42 | 43 | x1 = self.relu(self.bn_1(self.conv1(x))) 44 | 45 | x2 = self.relu(self.bn_2(self.conv2(x1))) 46 | 47 | x3 = self.relu(self.bn_3(self.conv3(x2))) 48 | 49 | x4 = self.relu(self.bn_4(self.conv4(x3))) 50 | 51 | x5 = self.relu(self.bn_5(self.conv5(x4))) 52 | 53 | x6 = self.relu(self.bn_6(self.conv6(x5))) 54 | 55 | 56 | x = x1+x2+x3+x4+x5+x6 57 | 58 | return x 59 | 60 | class rrm_module(nn.Module): 61 | def __init__(self, in_channels, out_channels): 62 | super(rrm_module,self).__init__() 63 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1,dilation=1) 64 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=2,dilation=2) 65 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=4,dilation=4) 66 | self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=8,dilation=8) 67 | self.conv5 = nn.Conv2d(out_channels, out_channels, 3, padding=16,dilation=16) 68 | self.conv6 = nn.Conv2d(out_channels, out_channels, 3, padding=32,dilation=32) 69 | 70 | self.bn_1 = nn.BatchNorm2d(num_features=out_channels) 71 | self.bn_2 = nn.BatchNorm2d(num_features=out_channels) 72 | self.bn_3 = nn.BatchNorm2d(num_features=out_channels) 73 | self.bn_4 = nn.BatchNorm2d(num_features=out_channels) 74 | self.bn_5 = nn.BatchNorm2d(num_features=out_channels) 75 | self.bn_6 = nn.BatchNorm2d(num_features=out_channels) 76 | self.relu = nn.ReLU() 77 | 78 | self.out = nn.Conv2d(out_channels, 1, 3, padding=1,dilation=1) 79 | 80 | def forward(self,x): 81 | residual = x 82 | x1 = self.relu(self.bn_1(self.conv1(x))) 83 | 84 | x2 = self.relu(self.bn_2(self.conv2(x1))) 85 | x3 = self.relu(self.bn_3(self.conv3(x2))) 86 | x4 = self.relu(self.bn_4(self.conv4(x3))) 87 | x5 = self.relu(self.bn_5(self.conv5(x4))) 88 | x6 = self.relu(self.bn_6(self.conv6(x5))) 89 | x = x1+x2+x3+x4+x5+x6 90 | x = self.out(x) 91 | x = residual + x 92 | 93 | return x 94 | 95 | class decoder_block(nn.Module): 96 | def __init__(self, in_channels, out_channels): 97 | super(decoder_block,self).__init__() 98 | self.bn_i = nn.BatchNorm2d(num_features=in_channels) 99 | self.relu = nn.ReLU() 100 | self.conv = conv_block(in_channels, out_channels) 101 | def forward(self, x): 102 | 103 | out = self.bn_i(x) 104 | out = self.relu(out) 105 | out = self.conv(out) 106 | return out 107 | 108 | class BRRNet(nn.Module): 109 | 110 | def __init__(self, n_class=1, pretrained=False,mode='Train'): 111 | super().__init__() 112 | self.mode=mode 113 | self.dconv_down1 = conv_block(3, 64) 114 | self.dconv_down2 = conv_block(64, 128) 115 | self.dconv_down3 = conv_block(128, 256) 116 | 117 | self.maxpool = nn.MaxPool2d(2,2) 118 | self.center = center_block(256,512) 119 | self.deconv3 = up_transpose(512,256) 120 | self.deconv2 = up_transpose(256,128) 121 | self.deconv1 = up_transpose(128,64) 122 | 123 | self.decoder_3 = decoder_block(512, 256) 124 | self.decoder_2 = decoder_block(256, 128) 125 | self.decoder_1 = decoder_block(128, 64) 126 | self.output_1 = nn.Conv2d(64,n_class, 1) 127 | self.rrm = rrm_module(1,64) 128 | def forward(self, x): 129 | 130 | conv1 = self.dconv_down1(x) 131 | # print(conv1.shape) 132 | x = self.maxpool(conv1) 133 | # print(x.shape) 134 | conv2 = self.dconv_down2(x) 135 | x = self.maxpool(conv2) 136 | 137 | conv3 = self.dconv_down3(x) 138 | x = self.maxpool(conv3) 139 | 140 | x = self.center(x) 141 | 142 | x = self.deconv3(x) # 512 256 143 | x = torch.cat([conv3,x],1) # 256 + 256 144 | 145 | x = self.decoder_3(x) # 512 256 146 | 147 | x = self.deconv2(x) 148 | x = torch.cat([conv2,x],1) 149 | x = self.decoder_2(x) 150 | 151 | x = self.deconv1(x) 152 | x = torch.cat([conv1,x],1) 153 | x = self.decoder_1(x) 154 | 155 | x = self.output_1(x) 156 | out = self.rrm(x) 157 | if self.mode == 'Train': 158 | return F.sigmoid(out) 159 | elif self.mode == 'Infer': 160 | return out -------------------------------------------------------------------------------- /nets/zoo/brrnet_BE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import skimage 4 | import numpy as np 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | def conv_block(in_channels, out_channels): 9 | return nn.Sequential( 10 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 11 | nn.BatchNorm2d(num_features=out_channels), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 14 | nn.BatchNorm2d(num_features=out_channels), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | 19 | def up_transpose(in_channels, out_channels): 20 | return nn.Sequential( 21 | nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 22 | ) 23 | class center_block(nn.Module): 24 | def __init__(self, in_channels, out_channels): 25 | super(center_block, self).__init__() 26 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1,dilation=1) 27 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=2,dilation=2) 28 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=4,dilation=4) 29 | self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=8,dilation=8) 30 | self.conv5 = nn.Conv2d(out_channels, out_channels, 3, padding=16,dilation=16) 31 | self.conv6 = nn.Conv2d(out_channels, out_channels, 3, padding=32,dilation=32) 32 | 33 | self.bn_1 = nn.BatchNorm2d(num_features=out_channels) 34 | self.bn_2 = nn.BatchNorm2d(num_features=out_channels) 35 | self.bn_3 = nn.BatchNorm2d(num_features=out_channels) 36 | self.bn_4 = nn.BatchNorm2d(num_features=out_channels) 37 | self.bn_5 = nn.BatchNorm2d(num_features=out_channels) 38 | self.bn_6 = nn.BatchNorm2d(num_features=out_channels) 39 | self.relu = nn.ReLU() 40 | 41 | 42 | 43 | def forward(self,x):# 지금 rrm쪽이랑 센터랑 섞임.. 44 | 45 | 46 | x1 = self.relu(self.bn_1(self.conv1(x))) 47 | 48 | x2 = self.relu(self.bn_2(self.conv2(x1))) 49 | 50 | x3 = self.relu(self.bn_3(self.conv3(x2))) 51 | 52 | x4 = self.relu(self.bn_4(self.conv4(x3))) 53 | 54 | x5 = self.relu(self.bn_5(self.conv5(x4))) 55 | 56 | x6 = self.relu(self.bn_6(self.conv6(x5))) 57 | 58 | 59 | x = x1+x2+x3+x4+x5+x6 60 | 61 | return x 62 | 63 | class rrm_module(nn.Module): 64 | def __init__(self, in_channels, out_channels): 65 | super(rrm_module,self).__init__() 66 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1,dilation=1) 67 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=2,dilation=2) 68 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=4,dilation=4) 69 | self.conv4 = nn.Conv2d(out_channels, out_channels, 3, padding=8,dilation=8) 70 | self.conv5 = nn.Conv2d(out_channels, out_channels, 3, padding=16,dilation=16) 71 | self.conv6 = nn.Conv2d(out_channels, out_channels, 3, padding=32,dilation=32) 72 | 73 | self.bn_1 = nn.BatchNorm2d(num_features=out_channels) 74 | self.bn_2 = nn.BatchNorm2d(num_features=out_channels) 75 | self.bn_3 = nn.BatchNorm2d(num_features=out_channels) 76 | self.bn_4 = nn.BatchNorm2d(num_features=out_channels) 77 | self.bn_5 = nn.BatchNorm2d(num_features=out_channels) 78 | self.bn_6 = nn.BatchNorm2d(num_features=out_channels) 79 | self.relu = nn.ReLU() 80 | 81 | # self.out = nn.Conv2d(out_channels, 1, 3, padding=1,dilation=1) 82 | # BE mode 83 | self.out = nn.Conv2d(out_channels, 64, 3, padding=1,dilation=1) 84 | 85 | def forward(self,x): 86 | residual = x 87 | x1 = self.relu(self.bn_1(self.conv1(x))) 88 | 89 | x2 = self.relu(self.bn_2(self.conv2(x1))) 90 | x3 = self.relu(self.bn_3(self.conv3(x2))) 91 | x4 = self.relu(self.bn_4(self.conv4(x3))) 92 | x5 = self.relu(self.bn_5(self.conv5(x4))) 93 | x6 = self.relu(self.bn_6(self.conv6(x5))) 94 | x = x1+x2+x3+x4+x5+x6 95 | x = self.out(x) 96 | x = residual + x 97 | output = x 98 | # output = F.sigmoid(x) 99 | return output 100 | 101 | class decoder_block(nn.Module): 102 | def __init__(self, in_channels, out_channels): 103 | super(decoder_block,self).__init__() 104 | self.bn_i = nn.BatchNorm2d(num_features=in_channels) 105 | self.relu = nn.ReLU() 106 | self.conv = conv_block(in_channels, out_channels) 107 | def forward(self, x): 108 | 109 | out = self.bn_i(x) 110 | out = self.relu(out) 111 | out = self.conv(out) 112 | return out 113 | 114 | class BRRNet_BE(nn.Module): 115 | 116 | def __init__(self, n_class=1, pretrained=False, mode= 'Train'): 117 | super().__init__() 118 | self.mode = mode 119 | self.dconv_down1 = conv_block(3, 64) 120 | self.dconv_down2 = conv_block(64, 128) 121 | self.dconv_down3 = conv_block(128, 256) 122 | 123 | self.maxpool = nn.MaxPool2d(2,2) 124 | self.center = center_block(256,512) 125 | self.deconv3 = up_transpose(512,256) 126 | self.deconv2 = up_transpose(256,128) 127 | self.deconv1 = up_transpose(128,64) 128 | 129 | self.decoder_3 = decoder_block(512, 256) 130 | self.decoder_2 = decoder_block(256, 128) 131 | self.decoder_1 = decoder_block(128, 64) 132 | # self.output_1 = nn.Conv2d(64,n_class, 1) 133 | # self.rrm = rrm_module(1,64) 134 | # BE mode 135 | self.output_1 = nn.Conv2d(64,64, 1) 136 | self.rrm = rrm_module(64,64) 137 | 138 | # HED Block 139 | self.dsn1 = nn.Conv2d(64, 1, 1) 140 | self.dsn2 = nn.Conv2d(128, 1, 1) 141 | self.dsn3 = nn.Conv2d(256, 1, 1) 142 | self.dsn4 = nn.Conv2d(512, 1, 1) 143 | 144 | 145 | #boundary enhancement part 146 | self.fuse = nn.Sequential(nn.Conv2d(4, 64, 1),nn.ReLU(inplace=True)) 147 | 148 | self.SE_mimic = nn.Sequential( 149 | nn.Linear(64, 64, bias=False), 150 | nn.ReLU(inplace=True), 151 | nn.Linear(64, 4, bias=False), 152 | nn.Sigmoid() 153 | ) 154 | self.final_boundary = nn.Conv2d(4,2,1) 155 | 156 | self.final_conv = nn.Sequential( 157 | nn.Conv2d(128,64,3, padding=1), 158 | nn.ReLU(inplace=True) 159 | ) 160 | self.final_mask = nn.Conv2d(64,2,1) 161 | 162 | 163 | 164 | self.relu = nn.ReLU() 165 | self.out = nn.Conv2d(64,1,1) 166 | 167 | 168 | 169 | 170 | 171 | def forward(self, x): 172 | h = x.size(2) 173 | w = x.size(3) 174 | conv1 = self.dconv_down1(x) 175 | # print(conv1.shape) 176 | x = self.maxpool(conv1) 177 | # print(x.shape) 178 | conv2 = self.dconv_down2(x) 179 | x = self.maxpool(conv2) 180 | 181 | conv3 = self.dconv_down3(x) 182 | x = self.maxpool(conv3) 183 | 184 | conv4 = self.center(x) 185 | 186 | x = self.deconv3(conv4) # 512 256 187 | x = torch.cat([conv3,x],1) # 256 + 256 188 | 189 | x = self.decoder_3(x) # 512 256 190 | 191 | x = self.deconv2(x) 192 | x = torch.cat([conv2,x],1) 193 | x = self.decoder_2(x) 194 | 195 | x = self.deconv1(x) 196 | x = torch.cat([conv1,x],1) 197 | x = self.decoder_1(x) 198 | 199 | x = self.output_1(x) 200 | out = self.rrm(x) 201 | 202 | 203 | d1 = self.dsn1(conv1) 204 | d2 = F.upsample_bilinear(self.dsn2(conv2), size=(h,w)) 205 | d3 = F.upsample_bilinear(self.dsn3(conv3), size=(h,w)) 206 | d4 = F.upsample_bilinear(self.dsn4(conv4), size=(h,w)) 207 | 208 | d1_out = F.sigmoid(d1) 209 | d2_out = F.sigmoid(d2) 210 | d3_out = F.sigmoid(d3) 211 | d4_out = F.sigmoid(d4) 212 | 213 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out), 1) 214 | 215 | fuse_box = self.fuse(concat) 216 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1)) 217 | GAP = GAP.view(-1, 64) 218 | se_like = self.SE_mimic(GAP) 219 | se_like = torch.unsqueeze(se_like, 2) 220 | se_like = torch.unsqueeze(se_like, 3) 221 | 222 | feat_se = concat * se_like.expand_as(concat) 223 | boundary = self.final_boundary(feat_se) 224 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1) 225 | bd_sftmax = F.softmax(boundary, dim=1) 226 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1) 227 | 228 | feat_concat = torch.cat( [out, fuse_box], 1) 229 | feat_concat_conv = self.final_conv(feat_concat) 230 | mask = self.final_mask(feat_concat_conv) 231 | mask_sftmax = F.softmax(mask,dim=1) 232 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1) 233 | 234 | if self.mode == 'Train': 235 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1) 236 | elif self.mode == 'Infer': 237 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1) 238 | 239 | 240 | mask_out = torch.unsqueeze(mask[:,1,:,:],1) 241 | relu = self.relu(mask_out) 242 | scalar = relu.cpu().detach().numpy() 243 | if np.sum(scalar) == 0: 244 | average = 0 245 | else : 246 | average = scalar[np.nonzero(scalar)].mean() 247 | mask_out = mask_out-relu + (average*scalefactor) 248 | 249 | if self.mode == 'Train': 250 | mask_out = F.sigmoid(mask_out) 251 | boundary_out = F.sigmoid(boundary_out) 252 | 253 | return d1_out, d2_out, d3_out, d4_out, boundary_out, mask_out 254 | elif self.mode =='Infer': 255 | return mask_out 256 | 257 | -------------------------------------------------------------------------------- /nets/zoo/denet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import skimage 4 | import numpy as np 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | class _downsampling(nn.Module): 9 | def __init__(self, channel_in): 10 | super(_downsampling, self).__init__() 11 | #channel_in, channel_out = channel_var 12 | 13 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_in, kernel_size=3, stride=2, padding=1) 14 | self.maxpool = nn.MaxPool2d(2) 15 | 16 | self.bn = nn.BatchNorm2d(2*channel_in) 17 | self.relu = nn.ReLU() 18 | 19 | def forward(self, x): 20 | out1= self.conv(x) 21 | out2= self.maxpool(x) 22 | 23 | out = torch.cat([out1, out2], 1) 24 | out = self.relu(self.bn(out)) 25 | return out 26 | 27 | 28 | class _linear_residual(nn.Module): 29 | def __init__(self, channel_in): 30 | super(_linear_residual, self).__init__() 31 | #channel_in, channel_out = channel_var 32 | 33 | self.conv1 = nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=1, stride=1, padding=0) 34 | self.bn1 = nn.BatchNorm2d(int(channel_in/4.)) 35 | self.relu1= nn.ELU(alpha=1.673) 36 | 37 | self.conv2 = nn.Conv2d(in_channels=int(channel_in/4.), out_channels=int(channel_in/4.), kernel_size=3, stride=1, padding=1) 38 | self.bn2 = nn.BatchNorm2d(int(channel_in/4.)) 39 | self.relu2= nn.ELU(alpha=1.673) 40 | 41 | self.conv3 = nn.Conv2d(in_channels=int(channel_in/4.), out_channels=channel_in, kernel_size=1, stride=1, padding=0) 42 | 43 | def forward(self, x): 44 | residual = x 45 | _lambda = 1.051 46 | 47 | out = self.bn1(self.conv1(x)) 48 | out = self.relu1(out) * _lambda 49 | 50 | out = self.bn2(self.conv2(out)) 51 | out = self.relu2(out) * _lambda 52 | 53 | out = self.conv3(out) 54 | 55 | out = torch.add(out, residual) 56 | return out 57 | 58 | class _encoding_block(nn.Module): 59 | def __init__(self, channel_in): 60 | super(_encoding_block, self).__init__() 61 | 62 | self.block_1 = nn.Sequential( 63 | _linear_residual(channel_in=channel_in), 64 | _linear_residual(channel_in=channel_in), 65 | _linear_residual(channel_in=channel_in), 66 | _linear_residual(channel_in=channel_in), 67 | _linear_residual(channel_in=channel_in), 68 | _linear_residual(channel_in=channel_in), 69 | ) 70 | 71 | def forward(self, x): 72 | return self.block_1(x) 73 | 74 | 75 | class _compressing_module(nn.Module): 76 | def __init__(self, channel_in): 77 | super(_compressing_module, self).__init__() 78 | 79 | self.conv1 = nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=1, stride=1, padding=0) 80 | self.bn1 = nn.BatchNorm2d(int(channel_in/4.)) 81 | self.relu1= nn.ReLU() 82 | 83 | self.conv2 = nn.Conv2d(in_channels=int(channel_in/4.), out_channels=int(channel_in/4.), kernel_size=3, stride=1, padding=1) 84 | self.bn2 = nn.BatchNorm2d(int(channel_in/4.)) 85 | self.relu2= nn.ReLU() 86 | 87 | self.conv3 = nn.Conv2d(in_channels=int(channel_in/4.), out_channels=channel_in, kernel_size=1, stride=1, padding=0) 88 | 89 | def forward(self, x): 90 | residual = x 91 | 92 | out = self.bn1(self.conv1(x)) 93 | out = self.relu1(out) 94 | 95 | out = self.bn2(self.conv2(out)) 96 | out = self.relu2(out) 97 | 98 | out = self.conv3(out) 99 | return out 100 | 101 | 102 | class _duc(nn.Module): 103 | def __init__(self): 104 | super(_duc, self).__init__() 105 | 106 | self.subpixel = nn.PixelShuffle(8) 107 | 108 | def forward(self, x): 109 | #out = self.relu(self.conv(x)) 110 | #out = self.subpixel(out) 111 | out = self.subpixel(x) 112 | return out 113 | 114 | 115 | class DeNet(nn.Module): 116 | def __init__(self, pretrained=False,mode='Train'): 117 | super(DeNet, self).__init__() 118 | self.mode=mode 119 | self.conv_i = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1, stride=1, padding=0) 120 | #self.relu1 = nn.PReLU() 121 | self.relu1 = nn.ReLU() 122 | #self.relu1 = nn.LeakyReLU(0.1) 123 | self.DS_block_1 = self.make_layer(_downsampling, 64) 124 | self.EC_block_1 = self.make_layer(_encoding_block, 128) 125 | 126 | self.DS_block_2 = self.make_layer(_downsampling, 128) 127 | self.EC_block_2 = self.make_layer(_encoding_block, 256) 128 | 129 | self.DS_block_3 = self.make_layer(_downsampling, 256) 130 | self.EC_block_3 = self.make_layer(_encoding_block, 512) 131 | 132 | self.CP_block_41= self.make_layer(_compressing_module, 512) 133 | self.EC_block_42= self.make_layer(_encoding_block, 512) 134 | self.CP_block_43= self.make_layer(_compressing_module, 512) 135 | self.EC_block_44= self.make_layer(_encoding_block, 512) 136 | self.CP_block_45= self.make_layer(_compressing_module, 512) 137 | self.EC_block_46= self.make_layer(_encoding_block, 512) 138 | self.CP_block_47= self.make_layer(_compressing_module, 512) 139 | 140 | self.conv_f = nn.Conv2d(in_channels=512, out_channels=64, kernel_size=1, stride=1, padding=0) 141 | #self.relu2 = nn.PReLU() 142 | self.relu2 = nn.ReLU() 143 | #self.relu2 = nn.LeakyReLU(0.1) 144 | 145 | self.dcu = _duc() 146 | 147 | def make_layer(self, block, channel_in): 148 | layers = [] 149 | layers.append(block(channel_in)) 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | residual = x 154 | 155 | out = self.relu1(self.conv_i(x)) 156 | out = self.DS_block_1(out) 157 | out = self.EC_block_1(out) 158 | 159 | out = self.DS_block_2(out) 160 | out = self.EC_block_2(out) 161 | 162 | out = self.DS_block_3(out) 163 | out = self.EC_block_3(out) 164 | 165 | out = self.CP_block_41(out) 166 | out = self.EC_block_42(out) 167 | out = self.CP_block_43(out) 168 | out = self.EC_block_44(out) 169 | out = self.CP_block_45(out) 170 | out = self.EC_block_46(out) 171 | out = self.CP_block_47(out) 172 | 173 | out = self.relu2(self.conv_f(out)) 174 | out = self.dcu(out) 175 | if self.mode == 'Train': 176 | return F.sigmoid(out) 177 | elif self.mode == 'Infer': 178 | return out -------------------------------------------------------------------------------- /nets/zoo/enru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import math 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | "3x3 convolution with padding" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=1, bias=False) 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = conv3x3(inplanes, planes, stride) 18 | self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes) 19 | self.relu = nn.ReLU(inplace=True) 20 | self.conv2 = conv3x3(planes, planes) 21 | self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes) 22 | self.downsample = downsample 23 | self.stride = stride 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | 35 | if self.downsample is not None: 36 | residual = self.downsample(x) 37 | 38 | out += residual 39 | out = self.relu(out) 40 | 41 | return out 42 | 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None): 48 | super(Bottleneck, self).__init__() 49 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 50 | self.bn1 = bn(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 52 | padding=1, bias=False) 53 | self.bn2 = bn(planes) 54 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 55 | self.bn3 = bn(planes * 4) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | residual = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | 86 | def __init__(self, block, layers, num_classes=1000, deep_base=False, norm_type=None): 87 | super(ResNet, self).__init__() 88 | self.inplanes = 128 if deep_base else 16 89 | if deep_base: 90 | self.prefix = nn.Sequential(OrderedDict([ 91 | ('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)), 92 | ('bn1', bn(64)), 93 | ('relu1', nn.ReLU(inplace=False)), 94 | ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)), 95 | ('bn2', bn(64)), 96 | ('relu2', nn.ReLU(inplace=False)), 97 | ('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)), 98 | ('bn3', bn(self.inplanes)), 99 | ('relu3', nn.ReLU(inplace=False))] 100 | )) 101 | else: 102 | self.prefix = nn.Sequential(OrderedDict([ 103 | ('conv1', nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)), 104 | ('bn1', bn(self.inplanes)), 105 | ('relu', nn.ReLU(inplace=False))] 106 | )) 107 | 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change. 109 | 110 | self.layer1 = self._make_layer(block, 16, layers[0], norm_type=norm_type) 111 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2, norm_type=norm_type) 112 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2, norm_type=norm_type) 113 | self.layer4 = self._make_layer(block, 128, layers[3], stride=2, norm_type=norm_type) 114 | self.avgpool = nn.AvgPool2d(7, stride=1) 115 | self.fc = nn.Linear(128 * block.expansion, num_classes) 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | # elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)): 122 | # m.weight.data.fill_(1) 123 | # m.bias.data.zero_() 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1, norm_type=None): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | bn(planes * block.expansion), 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample, norm_type=norm_type)) 136 | self.inplanes = planes * block.expansion 137 | for i in range(1, blocks): 138 | layers.append(block(self.inplanes, planes, norm_type=norm_type)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | x = self.conv1(x) 144 | x = self.bn1(x) 145 | x = self.relu(x) 146 | x = self.maxpool(x) 147 | 148 | x = self.layer1(x) 149 | 150 | x = self.layer2(x) 151 | 152 | x = self.layer3(x) 153 | x = self.layer4(x) 154 | 155 | x = self.avgpool(x) 156 | x = x.view(x.size(0), -1) 157 | x = self.fc(x) 158 | 159 | return x 160 | 161 | 162 | class NormalResnetBackbone(nn.Module): 163 | def __init__(self, orig_resnet): 164 | super(NormalResnetBackbone, self).__init__() 165 | 166 | self.num_features = 512 167 | # take pretrained resnet, except AvgPool and FC 168 | self.prefix = orig_resnet.prefix 169 | self.maxpool = orig_resnet.maxpool 170 | self.layer1 = orig_resnet.layer1 171 | self.layer2 = orig_resnet.layer2 172 | self.layer3 = orig_resnet.layer3 173 | self.layer4 = orig_resnet.layer4 174 | 175 | def get_num_features(self): 176 | return self.num_features 177 | 178 | def forward(self, x): 179 | tuple_features = list() 180 | x = self.prefix(x) 181 | x = self.maxpool(x) 182 | x0 = x 183 | x1 = self.layer1(x) 184 | tuple_features.append(x1) 185 | 186 | x2 = self.layer2(x1) 187 | tuple_features.append(x2) 188 | x3 = self.layer3(x2) 189 | tuple_features.append(x3) 190 | x4 = self.layer4(x3) 191 | tuple_features.append(x4) 192 | 193 | return x0, x1, x2, x3, x4 194 | 195 | def resnet50(**kwargs): 196 | """Constructs a ResNet-50 model. 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on Places 199 | """ 200 | model = ResNet(Bottleneck, [3, 4, 6, 3], deep_base=False, **kwargs) 201 | 202 | return model 203 | 204 | 205 | def bn(num_features): 206 | return nn.Sequential( 207 | nn.BatchNorm2d(num_features), 208 | nn.ReLU() 209 | ) 210 | 211 | class PSPModule(nn.Module): 212 | # (1, 2, 3, 6) 213 | def __init__(self, sizes=(1, 3, 6, 8), dimension=2): 214 | super(PSPModule, self).__init__() 215 | self.stages = nn.ModuleList([self._make_stage(size, dimension) for size in sizes]) 216 | 217 | def _make_stage(self, size, dimension=2): 218 | if dimension == 1: 219 | prior = nn.AdaptiveAvgPool1d(output_size=size) 220 | elif dimension == 2: 221 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 222 | elif dimension == 3: 223 | prior = nn.AdaptiveAvgPool3d(output_size=(size, size, size)) 224 | return prior 225 | 226 | def forward(self, feats): 227 | n, c, _, _ = feats.size() 228 | priors = [stage(feats).view(n, c, -1) for stage in self.stages] 229 | center = torch.cat(priors, -1) 230 | return center 231 | 232 | 233 | class _SelfAttentionBlock(nn.Module): 234 | ''' 235 | The basic implementation for self-attention block/non-local block 236 | Input: 237 | N X C X H X W 238 | Parameters: 239 | in_channels : the dimension of the input feature map 240 | key_channels : the dimension after the key/query transform 241 | value_channels : the dimension after the value transform 242 | scale : choose the scale to downsample the input feature maps (save memory cost) 243 | Return: 244 | N X C X H X W 245 | position-aware context features.(w/o concate or add with the input) 246 | ''' 247 | 248 | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1,psp_size=(1,3,6,8)): 249 | super(_SelfAttentionBlock, self).__init__() 250 | self.scale = scale 251 | self.in_channels = in_channels 252 | self.out_channels = out_channels 253 | self.key_channels = key_channels 254 | self.value_channels = value_channels 255 | if out_channels == None: 256 | self.out_channels = in_channels 257 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) 258 | self.f_key = nn.Sequential( 259 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 260 | kernel_size=1, stride=1, padding=0), 261 | bn(self.key_channels), 262 | # ModuleHelper.BNReLU(self.key_channels, norm_type=norm_type), 263 | ) 264 | self.f_query = self.f_key 265 | self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, 266 | kernel_size=1, stride=1, padding=0) 267 | self.W = nn.Conv2d(in_channels=self.value_channels, out_channels=self.out_channels, 268 | kernel_size=1, stride=1, padding=0) 269 | 270 | self.psp = PSPModule(psp_size) 271 | nn.init.constant_(self.W.weight, 0) 272 | nn.init.constant_(self.W.bias, 0) 273 | 274 | def forward(self, x): 275 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 276 | if self.scale > 1: 277 | x = self.pool(x) 278 | 279 | value = self.psp(self.f_value(x)) 280 | 281 | query = self.f_query(x).view(batch_size, self.key_channels, -1) 282 | query = query.permute(0, 2, 1) 283 | key = self.f_key(x) 284 | # value=self.psp(value)#.view(batch_size, self.value_channels, -1) 285 | value = value.permute(0, 2, 1) 286 | key = self.psp(key) # .view(batch_size, self.key_channels, -1) 287 | sim_map = torch.matmul(query, key) 288 | sim_map = (self.key_channels ** -.5) * sim_map 289 | sim_map = F.softmax(sim_map, dim=-1) 290 | 291 | context = torch.matmul(sim_map, value) 292 | context = context.permute(0, 2, 1).contiguous() 293 | context = context.view(batch_size, self.value_channels, *x.size()[2:]) 294 | context = self.W(context) 295 | return context 296 | 297 | 298 | class SelfAttentionBlock2D(_SelfAttentionBlock): 299 | def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1,psp_size=(1,3,6,8)): 300 | super(SelfAttentionBlock2D, self).__init__(in_channels, 301 | key_channels, 302 | value_channels, 303 | out_channels, 304 | scale, 305 | 306 | psp_size=psp_size) 307 | 308 | 309 | class APNB(nn.Module): 310 | """ 311 | Parameters: 312 | in_features / out_features: the channels of the input / output feature maps. 313 | dropout: we choose 0.05 as the default value. 314 | size: you can apply multiple sizes. Here we only use one size. 315 | Return: 316 | features fused with Object context information. 317 | """ 318 | 319 | def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1]), psp_size=(1,3,6,8)): 320 | super(APNB, self).__init__() 321 | self.stages = [] 322 | 323 | self.psp_size=psp_size 324 | self.stages = nn.ModuleList( 325 | [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes]) 326 | self.conv_bn_dropout = nn.Sequential( 327 | nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0), 328 | # ModuleHelper.BNReLU(out_channels, norm_type=norm_type), 329 | bn(out_channels), 330 | nn.Dropout2d(dropout) 331 | ) 332 | 333 | def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): 334 | return SelfAttentionBlock2D(in_channels, 335 | key_channels, 336 | value_channels, 337 | output_channels, 338 | size, 339 | 340 | self.psp_size) 341 | 342 | def forward(self, feats): 343 | priors = [stage(feats) for stage in self.stages] 344 | context = priors[0] 345 | for i in range(1, len(priors)): 346 | context += priors[i] 347 | output = self.conv_bn_dropout(torch.cat([context, feats], 1)) 348 | return output 349 | 350 | 351 | def double_conv(in_channels, out_channels): 352 | return nn.Sequential( 353 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 354 | nn.ReLU(inplace=True), 355 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 356 | nn.ReLU(inplace=True) 357 | ) 358 | 359 | 360 | class ENRUNet(nn.Sequential): 361 | def __init__(self,pretrained=False,mode='Train'): 362 | super(ENRUNet, self).__init__() 363 | self.mode=mode 364 | self.backbone = NormalResnetBackbone(resnet50()) 365 | # low_in_channels, high_in_channels, out_channels, key_channels, value_channels, dropout 366 | self.dconv_up4 = double_conv(512+256, 256) 367 | self.dconv_up3 = double_conv(256+128, 128) 368 | self.dconv_up2 = double_conv(128+64, 64) 369 | self.dconv_up1 = double_conv(64 + 16, 64) 370 | self.APNB = nn.Sequential( 371 | APNB(in_channels=64, out_channels=64, key_channels=32, value_channels=32, 372 | dropout=0.05, sizes=([1])) 373 | ) 374 | 375 | self.conv_last = nn.Conv2d(64, 1, 1) 376 | 377 | def forward(self, x_): 378 | x0, x1, x2, x3, x4 = self.backbone(x_) 379 | up4 = F.interpolate(x4, size=(x3.size(2), x3.size(3)), mode="bilinear", align_corners=True) 380 | x = torch.cat([up4, x3], dim=1) 381 | x = self.dconv_up4(x) 382 | up3 = F.interpolate(x, size=(x2.size(2), x2.size(3)), mode="bilinear", align_corners=True) 383 | x = torch.cat([up3, x2], dim=1) 384 | x = self.dconv_up3(x) 385 | up2 = F.interpolate(x, size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True) 386 | x = torch.cat([up2, x1], dim=1) 387 | x = self.dconv_up2(x) 388 | up1 = F.interpolate(x, size=(x0.size(2), x0.size(3)), mode="bilinear", align_corners=True) 389 | x = torch.cat([up1, x0], dim=1) 390 | x = self.dconv_up1(x) 391 | x = F.interpolate(x, size=(x_.size(2), x_.size(3)), mode="bilinear", align_corners=True) 392 | x = self.APNB(x) 393 | out = self.conv_last(x) 394 | if self.mode == 'Train': 395 | return F.sigmoid(out) 396 | elif self.mode == 'Infer': 397 | return out -------------------------------------------------------------------------------- /nets/zoo/hrnet.yml: -------------------------------------------------------------------------------- 1 | # HRNET_32 : 2 | FINAL_CONV_KERNEL : 1 3 | STAGE1 : 4 | NUM_MODULES : 1 5 | NUM_BRANCHES : 1 6 | NUM_BLOCKS : [4] 7 | NUM_CHANNELS : [64] 8 | BLOCK : 'BOTTLENECK' 9 | FUSE_METHOD : 'SUM' 10 | STAGE2 : 11 | NUM_MODULES : 1 12 | NUM_BRANCHES : 2 13 | NUM_BLOCKS : [4,4] 14 | NUM_CHANNELS : [32,64] 15 | BLOCK : 'BASIC' 16 | FUSE_METHOD : 'SUM' 17 | STAGE3 : 18 | NUM_MODULES : 4 19 | NUM_BRANCHES : 3 20 | NUM_BLOCKS : [4,4,4] 21 | NUM_CHANNELS : [32,64,128] 22 | BLOCK : 'BASIC' 23 | FUSE_METHOD : 'SUM' 24 | STAGE4 : 25 | NUM_MODULES : 3 26 | NUM_BRANCHES : 4 27 | NUM_BLOCKS : [4,4,4,4] 28 | NUM_CHANNELS : [32,64,128,256] 29 | BLOCK : 'BASIC' 30 | FUSE_METHOD : 'SUM' 31 | 32 | 33 | # HRNET_32 : 34 | # FINAL_CONV_KERNEL : 1 35 | # STAGE1 : 36 | # NUM_MODULES : 1 37 | # NUM_BRANCHES : 1 38 | # NUM_BLOCKS : [4] 39 | # NUM_CHANNELS : [64] 40 | # BLOCK : 'BOTTLENECK' 41 | # FUSE_METHOD : 'SUM' 42 | # STAGE2 : 43 | # NUM_MODULES : 1 44 | # NUM_BRANCHES : 2 45 | # NUM_BLOCKS : [4,4] 46 | # NUM_CHANNELS : [32,64] 47 | # BLOCK : 'BASIC' 48 | # FUSE_METHOD : 'SUM' 49 | # STAGE3 : 50 | # NUM_MODULES : 4 51 | # NUM_BRANCHES : 3 52 | # NUM_BLOCKS : [4,4,4] 53 | # NUM_CHANNELS : [32,64,128] 54 | # BLOCK : 'BASIC' 55 | # FUSE_METHOD : 'SUM' 56 | # STAGE4 : 57 | # NUM_MODULES : 3 58 | # NUM_BRANCHES : 4 59 | # NUM_BLOCKS : [4,4,4,4] 60 | # NUM_CHANNELS : [32,64,128,256] 61 | # BLOCK : 'BASIC' 62 | # FUSE_METHOD : 'SUM' 63 | -------------------------------------------------------------------------------- /nets/zoo/hrnet_config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import sys 3 | import os 4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname("__file__")))) 5 | 6 | def parse(path): 7 | 8 | with open(path, 'r') as f: 9 | config = yaml.safe_load(f) 10 | f.close() 11 | return config 12 | -------------------------------------------------------------------------------- /nets/zoo/resunet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResUnetPlusPlus(nn.Module): 7 | def __init__(self, filters=[32, 64, 128, 256, 512], pretrained=False,mode='Train'): 8 | super(ResUnetPlusPlus, self).__init__() 9 | self.mode=mode 10 | self.input_layer = nn.Sequential( 11 | nn.Conv2d(3, filters[0], kernel_size=3, padding=1), 12 | nn.BatchNorm2d(filters[0]), 13 | nn.ReLU(), 14 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 15 | ) 16 | self.input_skip = nn.Sequential( 17 | nn.Conv2d(3, filters[0], kernel_size=3, padding=1) 18 | ) 19 | 20 | self.squeeze_excite1 = Squeeze_Excite_Block(filters[0]) 21 | 22 | self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1) 23 | 24 | self.squeeze_excite2 = Squeeze_Excite_Block(filters[1]) 25 | 26 | self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1) 27 | 28 | self.squeeze_excite3 = Squeeze_Excite_Block(filters[2]) 29 | 30 | self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1) 31 | 32 | self.aspp_bridge = ASPP(filters[3], filters[4]) 33 | 34 | self.attn1 = AttentionBlock(filters[2], filters[4], filters[4]) 35 | self.upsample1 = Upsample_(2) 36 | self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1) 37 | 38 | self.attn2 = AttentionBlock(filters[1], filters[3], filters[3]) 39 | self.upsample2 = Upsample_(2) 40 | self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1) 41 | 42 | self.attn3 = AttentionBlock(filters[0], filters[2], filters[2]) 43 | self.upsample3 = Upsample_(2) 44 | self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1) 45 | 46 | self.aspp_out = ASPP(filters[1], filters[0]) 47 | 48 | self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1)) 49 | 50 | def forward(self, x): 51 | x1 = self.input_layer(x) + self.input_skip(x) 52 | 53 | x2 = self.squeeze_excite1(x1) 54 | x2 = self.residual_conv1(x2) 55 | 56 | x3 = self.squeeze_excite2(x2) 57 | x3 = self.residual_conv2(x3) 58 | 59 | x4 = self.squeeze_excite3(x3) 60 | x4 = self.residual_conv3(x4) 61 | 62 | x5 = self.aspp_bridge(x4) 63 | 64 | x6 = self.attn1(x3, x5) 65 | x6 = self.upsample1(x6) 66 | x6 = torch.cat([x6, x3], dim=1) 67 | x6 = self.up_residual_conv1(x6) 68 | 69 | x7 = self.attn2(x2, x6) 70 | x7 = self.upsample2(x7) 71 | x7 = torch.cat([x7, x2], dim=1) 72 | x7 = self.up_residual_conv2(x7) 73 | 74 | x8 = self.attn3(x1, x7) 75 | x8 = self.upsample3(x8) 76 | x8 = torch.cat([x8, x1], dim=1) 77 | x8 = self.up_residual_conv3(x8) 78 | 79 | x9 = self.aspp_out(x8) 80 | out = self.output_layer(x9) 81 | if self.mode == 'Train': 82 | return F.sigmoid(out) 83 | elif self.mode == 'Infer': 84 | return out 85 | class ResidualConv(nn.Module): 86 | def __init__(self, input_dim, output_dim, stride, padding): 87 | super(ResidualConv, self).__init__() 88 | 89 | self.conv_block = nn.Sequential( 90 | nn.BatchNorm2d(input_dim), 91 | nn.ReLU(), 92 | nn.Conv2d( 93 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 94 | ), 95 | nn.BatchNorm2d(output_dim), 96 | nn.ReLU(), 97 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 98 | ) 99 | self.conv_skip = nn.Sequential( 100 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 101 | nn.BatchNorm2d(output_dim), 102 | ) 103 | 104 | def forward(self, x): 105 | 106 | return self.conv_block(x) + self.conv_skip(x) 107 | 108 | 109 | class Upsample(nn.Module): 110 | def __init__(self, input_dim, output_dim, kernel, stride): 111 | super(Upsample, self).__init__() 112 | 113 | self.upsample = nn.ConvTranspose2d( 114 | input_dim, output_dim, kernel_size=kernel, stride=stride 115 | ) 116 | 117 | def forward(self, x): 118 | return self.upsample(x) 119 | 120 | 121 | class Squeeze_Excite_Block(nn.Module): 122 | def __init__(self, channel, reduction=16): 123 | super(Squeeze_Excite_Block, self).__init__() 124 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 125 | self.fc = nn.Sequential( 126 | nn.Linear(channel, channel // reduction, bias=False), 127 | nn.ReLU(inplace=True), 128 | nn.Linear(channel // reduction, channel, bias=False), 129 | nn.Sigmoid(), 130 | ) 131 | 132 | def forward(self, x): 133 | b, c, _, _ = x.size() 134 | y = self.avg_pool(x).view(b, c) 135 | y = self.fc(y).view(b, c, 1, 1) 136 | return x * y.expand_as(x) 137 | 138 | 139 | class ASPP(nn.Module): 140 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): 141 | super(ASPP, self).__init__() 142 | 143 | self.aspp_block1 = nn.Sequential( 144 | nn.Conv2d( 145 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] 146 | ), 147 | nn.ReLU(inplace=True), 148 | nn.BatchNorm2d(out_dims), 149 | ) 150 | self.aspp_block2 = nn.Sequential( 151 | nn.Conv2d( 152 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] 153 | ), 154 | nn.ReLU(inplace=True), 155 | nn.BatchNorm2d(out_dims), 156 | ) 157 | self.aspp_block3 = nn.Sequential( 158 | nn.Conv2d( 159 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] 160 | ), 161 | nn.ReLU(inplace=True), 162 | nn.BatchNorm2d(out_dims), 163 | ) 164 | 165 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) 166 | self._init_weights() 167 | 168 | def forward(self, x): 169 | x1 = self.aspp_block1(x) 170 | x2 = self.aspp_block2(x) 171 | x3 = self.aspp_block3(x) 172 | out = torch.cat([x1, x2, x3], dim=1) 173 | return self.output(out) 174 | 175 | def _init_weights(self): 176 | for m in self.modules(): 177 | if isinstance(m, nn.Conv2d): 178 | nn.init.kaiming_normal_(m.weight) 179 | elif isinstance(m, nn.BatchNorm2d): 180 | m.weight.data.fill_(1) 181 | m.bias.data.zero_() 182 | 183 | 184 | class Upsample_(nn.Module): 185 | def __init__(self, scale=2): 186 | super(Upsample_, self).__init__() 187 | 188 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) 189 | 190 | def forward(self, x): 191 | return self.upsample(x) 192 | 193 | 194 | class AttentionBlock(nn.Module): 195 | def __init__(self, input_encoder, input_decoder, output_dim): 196 | super(AttentionBlock, self).__init__() 197 | 198 | self.conv_encoder = nn.Sequential( 199 | nn.BatchNorm2d(input_encoder), 200 | nn.ReLU(), 201 | nn.Conv2d(input_encoder, output_dim, 3, padding=1), 202 | nn.MaxPool2d(2, 2), 203 | ) 204 | 205 | self.conv_decoder = nn.Sequential( 206 | nn.BatchNorm2d(input_decoder), 207 | nn.ReLU(), 208 | nn.Conv2d(input_decoder, output_dim, 3, padding=1), 209 | ) 210 | 211 | self.conv_attn = nn.Sequential( 212 | nn.BatchNorm2d(output_dim), 213 | nn.ReLU(), 214 | nn.Conv2d(output_dim, 1, 1), 215 | ) 216 | 217 | def forward(self, x1, x2): 218 | out = self.conv_encoder(x1) + self.conv_decoder(x2) 219 | out = self.conv_attn(out) 220 | return out * x2 -------------------------------------------------------------------------------- /nets/zoo/resunet_BE.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class ResUnetPlusPlus_BE(nn.Module): 7 | def __init__(self, filters=[32, 64, 128, 256, 512], pretrained=False, mode = 'Train'): 8 | super(ResUnetPlusPlus_BE, self).__init__() 9 | self.mode = mode 10 | self.input_layer = nn.Sequential( 11 | nn.Conv2d(3, filters[0], kernel_size=3, padding=1), 12 | nn.BatchNorm2d(filters[0]), 13 | nn.ReLU(), 14 | nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1), 15 | ) 16 | self.input_skip = nn.Sequential( 17 | nn.Conv2d(3, filters[0], kernel_size=3, padding=1) 18 | ) 19 | 20 | self.squeeze_excite1 = Squeeze_Excite_Block(filters[0]) 21 | 22 | self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1) 23 | 24 | self.squeeze_excite2 = Squeeze_Excite_Block(filters[1]) 25 | 26 | self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1) 27 | 28 | self.squeeze_excite3 = Squeeze_Excite_Block(filters[2]) 29 | 30 | self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1) 31 | 32 | self.aspp_bridge = ASPP(filters[3], filters[4]) 33 | 34 | self.attn1 = AttentionBlock(filters[2], filters[4], filters[4]) 35 | self.upsample1 = Upsample_(2) 36 | self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1) 37 | 38 | self.attn2 = AttentionBlock(filters[1], filters[3], filters[3]) 39 | self.upsample2 = Upsample_(2) 40 | self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1) 41 | 42 | self.attn3 = AttentionBlock(filters[0], filters[2], filters[2]) 43 | self.upsample3 = Upsample_(2) 44 | self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1) 45 | 46 | self.aspp_out = ASPP(filters[1], filters[0]) 47 | 48 | self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1)) 49 | 50 | # HED Block 51 | self.dsn1 = nn.Conv2d(filters[0], 1, 1) 52 | self.dsn2 = nn.Conv2d(filters[1], 1, 1) 53 | self.dsn3 = nn.Conv2d(filters[2], 1, 1) 54 | self.dsn4 = nn.Conv2d(filters[3], 1, 1) 55 | self.dsn5 = nn.Conv2d(filters[4], 1, 1) 56 | self.fuse = nn.Sequential(nn.Conv2d(5, 32, 1),nn.ReLU(inplace=True)) 57 | # self.fuse = nn.Conv2d(5, 64, 1) 58 | 59 | self.SE_mimic = nn.Sequential( 60 | nn.Linear(32, 32, bias=False), 61 | nn.ReLU(inplace=True), 62 | nn.Linear(32, 5, bias=False), 63 | nn.Sigmoid() 64 | ) 65 | self.final_boundary = nn.Conv2d(5,2,1) 66 | 67 | self.final_conv = nn.Sequential( 68 | nn.Conv2d(64,64,3, padding=1), 69 | nn.ReLU(inplace=True) 70 | ) 71 | self.final_mask = nn.Conv2d(64,2,1) 72 | 73 | 74 | 75 | self.relu = nn.ReLU() 76 | self.out = nn.Conv2d(64,1,1) 77 | 78 | 79 | 80 | def forward(self, x): 81 | h = x.size(2) 82 | w = x.size(3) 83 | x1 = self.input_layer(x) + self.input_skip(x) 84 | 85 | x2 = self.squeeze_excite1(x1) 86 | x2 = self.residual_conv1(x2) 87 | 88 | x3 = self.squeeze_excite2(x2) 89 | x3 = self.residual_conv2(x3) 90 | 91 | x4 = self.squeeze_excite3(x3) 92 | x4 = self.residual_conv3(x4) 93 | 94 | x5 = self.aspp_bridge(x4) 95 | 96 | x6 = self.attn1(x3, x5) 97 | x6 = self.upsample1(x6) 98 | x6 = torch.cat([x6, x3], dim=1) 99 | x6 = self.up_residual_conv1(x6) 100 | 101 | x7 = self.attn2(x2, x6) 102 | x7 = self.upsample2(x7) 103 | x7 = torch.cat([x7, x2], dim=1) 104 | x7 = self.up_residual_conv2(x7) 105 | 106 | x8 = self.attn3(x1, x7) 107 | x8 = self.upsample3(x8) 108 | x8 = torch.cat([x8, x1], dim=1) 109 | x8 = self.up_residual_conv3(x8) 110 | 111 | xx = self.aspp_out(x8) 112 | # out = self.output_layer(x9) 113 | # out = F.sigmoid(out) 114 | 115 | ## side output 116 | d1 = self.dsn1(x1) 117 | d2 = F.upsample_bilinear(self.dsn2(x2), size=(h,w)) 118 | d3 = F.upsample_bilinear(self.dsn3(x3), size=(h,w)) 119 | d4 = F.upsample_bilinear(self.dsn4(x4), size=(h,w)) 120 | d5 = F.upsample_bilinear(self.dsn5(x5), size=(h,w)) 121 | # 122 | ###########sigmoid ver 123 | d1_out = F.sigmoid(d1) 124 | d2_out = F.sigmoid(d2) 125 | d3_out = F.sigmoid(d3) 126 | d4_out = F.sigmoid(d4) 127 | d5_out = F.sigmoid(d5) 128 | 129 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out), 1) 130 | 131 | fuse_box = self.fuse(concat) 132 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1)) 133 | GAP = GAP.view(-1, 32) 134 | se_like = self.SE_mimic(GAP) 135 | se_like = torch.unsqueeze(se_like, 2) 136 | se_like = torch.unsqueeze(se_like, 3) 137 | 138 | feat_se = concat * se_like.expand_as(concat) 139 | boundary = self.final_boundary(feat_se) 140 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1) 141 | bd_sftmax = F.softmax(boundary, dim=1) 142 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1) 143 | 144 | feat_concat = torch.cat( [xx, fuse_box], 1) 145 | feat_concat_conv = self.final_conv(feat_concat) 146 | mask = self.final_mask(feat_concat_conv) 147 | mask_sftmax = F.softmax(mask,dim=1) 148 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1) 149 | 150 | if self.mode == 'Train': 151 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1) 152 | elif self.mode == 'Infer': 153 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1) 154 | 155 | 156 | mask_out = torch.unsqueeze(mask[:,1,:,:],1) 157 | relu = self.relu(mask_out) 158 | scalar = relu.cpu().detach().numpy() 159 | if np.sum(scalar) == 0: 160 | average = 0 161 | else : 162 | average = scalar[np.nonzero(scalar)].mean() 163 | mask_out = mask_out-relu + (average*scalefactor) 164 | 165 | if self.mode == 'Train': 166 | mask_out = F.sigmoid(mask_out) 167 | boundary_out = F.sigmoid(boundary_out) 168 | 169 | return d1_out, d2_out, d3_out, d4_out, d5_out, boundary_out, mask_out 170 | elif self.mode =='Infer': 171 | return mask_out 172 | 173 | 174 | 175 | class ResidualConv(nn.Module): 176 | def __init__(self, input_dim, output_dim, stride, padding): 177 | super(ResidualConv, self).__init__() 178 | 179 | self.conv_block = nn.Sequential( 180 | nn.BatchNorm2d(input_dim), 181 | nn.ReLU(), 182 | nn.Conv2d( 183 | input_dim, output_dim, kernel_size=3, stride=stride, padding=padding 184 | ), 185 | nn.BatchNorm2d(output_dim), 186 | nn.ReLU(), 187 | nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1), 188 | ) 189 | self.conv_skip = nn.Sequential( 190 | nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1), 191 | nn.BatchNorm2d(output_dim), 192 | ) 193 | 194 | def forward(self, x): 195 | 196 | return self.conv_block(x) + self.conv_skip(x) 197 | 198 | 199 | class Upsample(nn.Module): 200 | def __init__(self, input_dim, output_dim, kernel, stride): 201 | super(Upsample, self).__init__() 202 | 203 | self.upsample = nn.ConvTranspose2d( 204 | input_dim, output_dim, kernel_size=kernel, stride=stride 205 | ) 206 | 207 | def forward(self, x): 208 | return self.upsample(x) 209 | 210 | 211 | class Squeeze_Excite_Block(nn.Module): 212 | def __init__(self, channel, reduction=16): 213 | super(Squeeze_Excite_Block, self).__init__() 214 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 215 | self.fc = nn.Sequential( 216 | nn.Linear(channel, channel // reduction, bias=False), 217 | nn.ReLU(inplace=True), 218 | nn.Linear(channel // reduction, channel, bias=False), 219 | nn.Sigmoid(), 220 | ) 221 | 222 | def forward(self, x): 223 | b, c, _, _ = x.size() 224 | y = self.avg_pool(x).view(b, c) 225 | y = self.fc(y).view(b, c, 1, 1) 226 | return x * y.expand_as(x) 227 | 228 | 229 | class ASPP(nn.Module): 230 | def __init__(self, in_dims, out_dims, rate=[6, 12, 18]): 231 | super(ASPP, self).__init__() 232 | 233 | self.aspp_block1 = nn.Sequential( 234 | nn.Conv2d( 235 | in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0] 236 | ), 237 | nn.ReLU(inplace=True), 238 | nn.BatchNorm2d(out_dims), 239 | ) 240 | self.aspp_block2 = nn.Sequential( 241 | nn.Conv2d( 242 | in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1] 243 | ), 244 | nn.ReLU(inplace=True), 245 | nn.BatchNorm2d(out_dims), 246 | ) 247 | self.aspp_block3 = nn.Sequential( 248 | nn.Conv2d( 249 | in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2] 250 | ), 251 | nn.ReLU(inplace=True), 252 | nn.BatchNorm2d(out_dims), 253 | ) 254 | 255 | self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1) 256 | self._init_weights() 257 | 258 | def forward(self, x): 259 | x1 = self.aspp_block1(x) 260 | x2 = self.aspp_block2(x) 261 | x3 = self.aspp_block3(x) 262 | out = torch.cat([x1, x2, x3], dim=1) 263 | return self.output(out) 264 | 265 | def _init_weights(self): 266 | for m in self.modules(): 267 | if isinstance(m, nn.Conv2d): 268 | nn.init.kaiming_normal_(m.weight) 269 | elif isinstance(m, nn.BatchNorm2d): 270 | m.weight.data.fill_(1) 271 | m.bias.data.zero_() 272 | 273 | 274 | class Upsample_(nn.Module): 275 | def __init__(self, scale=2): 276 | super(Upsample_, self).__init__() 277 | 278 | self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale) 279 | 280 | def forward(self, x): 281 | return self.upsample(x) 282 | 283 | 284 | class AttentionBlock(nn.Module): 285 | def __init__(self, input_encoder, input_decoder, output_dim): 286 | super(AttentionBlock, self).__init__() 287 | 288 | self.conv_encoder = nn.Sequential( 289 | nn.BatchNorm2d(input_encoder), 290 | nn.ReLU(), 291 | nn.Conv2d(input_encoder, output_dim, 3, padding=1), 292 | nn.MaxPool2d(2, 2), 293 | ) 294 | 295 | self.conv_decoder = nn.Sequential( 296 | nn.BatchNorm2d(input_decoder), 297 | nn.ReLU(), 298 | nn.Conv2d(input_decoder, output_dim, 3, padding=1), 299 | ) 300 | 301 | self.conv_attn = nn.Sequential( 302 | nn.BatchNorm2d(output_dim), 303 | nn.ReLU(), 304 | nn.Conv2d(output_dim, 1, 1), 305 | ) 306 | 307 | def forward(self, x1, x2): 308 | out = self.conv_encoder(x1) + self.conv_decoder(x2) 309 | out = self.conv_attn(out) 310 | return out * x2 -------------------------------------------------------------------------------- /nets/zoo/ternaus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torchvision import models 6 | 7 | 8 | def conv3x3(in_: int, out: int) -> nn.Module: 9 | return nn.Conv2d(in_, out, 3, padding=1) 10 | 11 | 12 | class ConvRelu(nn.Module): 13 | def __init__(self, in_: int, out: int) -> None: 14 | super().__init__() 15 | self.conv = conv3x3(in_, out) 16 | self.activation = nn.ReLU(inplace=True) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | x = self.conv(x) 20 | x = self.activation(x) 21 | return x 22 | 23 | 24 | class DecoderBlock(nn.Module): 25 | def __init__( 26 | self, in_channels: int, middle_channels: int, out_channels: int 27 | ) -> None: 28 | super().__init__() 29 | 30 | self.block = nn.Sequential( 31 | ConvRelu(in_channels, middle_channels), 32 | nn.ConvTranspose2d( 33 | middle_channels, 34 | out_channels, 35 | kernel_size=3, 36 | stride=2, 37 | padding=1, 38 | output_padding=1, 39 | ), 40 | nn.ReLU(inplace=True), 41 | ) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | return self.block(x) 45 | 46 | 47 | class ternaus11(nn.Module): 48 | def __init__(self, num_filters: int = 32, pretrained: bool = False,mode='Train') -> None: 49 | """ 50 | Args: 51 | num_filters: 52 | pretrained: 53 | False - no pre-trained network is used 54 | True - encoder is pre-trained with VGG11 55 | """ 56 | super().__init__() 57 | self.pool = nn.MaxPool2d(2, 2) 58 | self.mode=mode 59 | self.encoder = models.vgg11(pretrained=pretrained).features 60 | 61 | self.relu = self.encoder[1] 62 | self.conv1 = self.encoder[0] 63 | self.conv2 = self.encoder[3] 64 | self.conv3s = self.encoder[6] 65 | self.conv3 = self.encoder[8] 66 | self.conv4s = self.encoder[11] 67 | self.conv4 = self.encoder[13] 68 | self.conv5s = self.encoder[16] 69 | self.conv5 = self.encoder[18] 70 | 71 | self.center = DecoderBlock( 72 | num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8 73 | ) 74 | self.dec5 = DecoderBlock( 75 | num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8 76 | ) 77 | self.dec4 = DecoderBlock( 78 | num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4 79 | ) 80 | self.dec3 = DecoderBlock( 81 | num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2 82 | ) 83 | self.dec2 = DecoderBlock( 84 | num_filters * (4 + 2), num_filters * 2 * 2, num_filters 85 | ) 86 | self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters) 87 | 88 | self.final = nn.Conv2d(num_filters, 1, kernel_size=1) 89 | 90 | def forward(self, x: torch.Tensor) -> torch.Tensor: 91 | conv1 = self.relu(self.conv1(x)) 92 | conv2 = self.relu(self.conv2(self.pool(conv1))) 93 | conv3s = self.relu(self.conv3s(self.pool(conv2))) 94 | conv3 = self.relu(self.conv3(conv3s)) 95 | conv4s = self.relu(self.conv4s(self.pool(conv3))) 96 | conv4 = self.relu(self.conv4(conv4s)) 97 | conv5s = self.relu(self.conv5s(self.pool(conv4))) 98 | conv5 = self.relu(self.conv5(conv5s)) 99 | 100 | center = self.center(self.pool(conv5)) 101 | 102 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 103 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 104 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 105 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 106 | dec1 = self.dec1(torch.cat([dec2, conv1], 1)) 107 | out = self.final(dec1) 108 | if self.mode == 'Train': 109 | return F.sigmoid(out) 110 | elif self.mode == 'Infer': 111 | return out 112 | 113 | class Interpolate(nn.Module): 114 | def __init__( 115 | self, 116 | size: int = None, 117 | scale_factor: int = None, 118 | mode: str = "nearest", 119 | align_corners: bool = False, 120 | ): 121 | super().__init__() 122 | self.interp = nn.functional.interpolate 123 | self.size = size 124 | self.mode = mode 125 | self.scale_factor = scale_factor 126 | self.align_corners = align_corners 127 | 128 | def forward(self, x: torch.Tensor) -> torch.Tensor: 129 | x = self.interp( 130 | x, 131 | size=self.size, 132 | scale_factor=self.scale_factor, 133 | mode=self.mode, 134 | align_corners=self.align_corners, 135 | ) 136 | return x 137 | 138 | 139 | class DecoderBlockV2(nn.Module): 140 | def __init__( 141 | self, 142 | in_channels: int, 143 | middle_channels: int, 144 | out_channels: int, 145 | is_deconv: bool = True, 146 | ): 147 | super().__init__() 148 | self.in_channels = in_channels 149 | 150 | if is_deconv: 151 | """ 152 | Paramaters for Deconvolution were chosen to avoid artifacts, following 153 | link https://distill.pub/2016/deconv-checkerboard/ 154 | """ 155 | 156 | self.block = nn.Sequential( 157 | ConvRelu(in_channels, middle_channels), 158 | nn.ConvTranspose2d( 159 | middle_channels, out_channels, kernel_size=4, stride=2, padding=1 160 | ), 161 | nn.ReLU(inplace=True), 162 | ) 163 | else: 164 | self.block = nn.Sequential( 165 | Interpolate(scale_factor=2, mode="bilinear"), 166 | ConvRelu(in_channels, middle_channels), 167 | ConvRelu(middle_channels, out_channels), 168 | ) 169 | 170 | def forward(self, x: torch.Tensor) -> torch.Tensor: 171 | return self.block(x) -------------------------------------------------------------------------------- /nets/zoo/ternaus_BE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torchvision import models 6 | import numpy as np 7 | 8 | def conv3x3(in_: int, out: int) -> nn.Module: 9 | return nn.Conv2d(in_, out, 3, padding=1) 10 | 11 | 12 | class ConvRelu(nn.Module): 13 | def __init__(self, in_: int, out: int) -> None: 14 | super().__init__() 15 | self.conv = conv3x3(in_, out) 16 | self.activation = nn.ReLU(inplace=True) 17 | 18 | def forward(self, x: torch.Tensor) -> torch.Tensor: 19 | x = self.conv(x) 20 | x = self.activation(x) 21 | return x 22 | 23 | 24 | class DecoderBlock(nn.Module): 25 | def __init__( 26 | self, in_channels: int, middle_channels: int, out_channels: int 27 | ) -> None: 28 | super().__init__() 29 | 30 | self.block = nn.Sequential( 31 | ConvRelu(in_channels, middle_channels), 32 | nn.ConvTranspose2d( 33 | middle_channels, 34 | out_channels, 35 | kernel_size=3, 36 | stride=2, 37 | padding=1, 38 | output_padding=1, 39 | ), 40 | nn.ReLU(inplace=True), 41 | ) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | return self.block(x) 45 | 46 | 47 | class ternaus_BE(nn.Module): 48 | def __init__(self, num_filters: int = 32, pretrained: bool = False, mode= 'Train') -> None: 49 | """ 50 | Args: 51 | num_filters: 52 | pretrained: 53 | False - no pre-trained network is used 54 | True - encoder is pre-trained with VGG11 55 | """ 56 | super().__init__() 57 | self.pool = nn.MaxPool2d(2, 2) 58 | 59 | self.encoder = models.vgg11(pretrained=pretrained).features 60 | 61 | self.relu = self.encoder[1] 62 | self.conv1 = self.encoder[0] 63 | self.conv2 = self.encoder[3] 64 | self.conv3s = self.encoder[6] 65 | self.conv3 = self.encoder[8] 66 | self.conv4s = self.encoder[11] 67 | self.conv4 = self.encoder[13] 68 | self.conv5s = self.encoder[16] 69 | self.conv5 = self.encoder[18] 70 | self.conv6 = ConvRelu(num_filters * 8 * 2, num_filters * 8 * 2) 71 | self.decoder6 = nn.ConvTranspose2d(num_filters * 8 * 2, 72 | num_filters * 8, 73 | kernel_size=3, 74 | stride=2, 75 | padding=1, 76 | output_padding=1,) 77 | 78 | self.center = DecoderBlock( 79 | num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8 80 | ) 81 | self.dec5 = DecoderBlock( 82 | num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8 83 | ) 84 | self.dec4 = DecoderBlock( 85 | num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4 86 | ) 87 | self.dec3 = DecoderBlock( 88 | num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2 89 | ) 90 | self.dec2 = DecoderBlock( 91 | num_filters * (4 + 2), num_filters * 2 * 2, num_filters 92 | ) 93 | self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters) 94 | 95 | self.final = nn.Conv2d(num_filters, 1, kernel_size=1) 96 | 97 | # HED Block 98 | self.dsn1 = nn.Conv2d(num_filters*2, 1, 1) 99 | self.dsn2 = nn.Conv2d(num_filters*4, 1, 1) 100 | self.dsn3 = nn.Conv2d(num_filters*8, 1, 1) 101 | self.dsn4 = nn.Conv2d(num_filters*16, 1, 1) 102 | self.dsn5 = nn.Conv2d(num_filters*16, 1, 1) 103 | self.dsn6 = nn.Conv2d(num_filters*8, 1, 1) 104 | self.fuse = nn.Sequential(nn.Conv2d(6, 32, 1),nn.ReLU(inplace=True)) 105 | # self.fuse = nn.Conv2d(5, 64, 1) 106 | 107 | self.SE_mimic = nn.Sequential( 108 | nn.Linear(32, 32, bias=False), 109 | nn.ReLU(inplace=True), 110 | nn.Linear(32, 6, bias=False), 111 | nn.Sigmoid() 112 | ) 113 | self.final_boundary = nn.Conv2d(6,2,1) 114 | 115 | self.final_conv = nn.Sequential( 116 | nn.Conv2d(64,64,3, padding=1), 117 | nn.ReLU(inplace=True) 118 | ) 119 | self.final_mask = nn.Conv2d(64,2,1) 120 | 121 | 122 | 123 | self.relu = nn.ReLU() 124 | self.out = nn.Conv2d(64,1,1) 125 | 126 | 127 | def forward(self, x: torch.Tensor) -> torch.Tensor: 128 | h = x.size(2) 129 | w = x.size(3) 130 | conv1 = self.relu(self.conv1(x)) 131 | conv1p = self.pool(conv1) 132 | conv2 = self.relu(self.conv2(conv1p)) 133 | conv2p = self.pool(conv2) 134 | conv3s = self.relu(self.conv3s(conv2p)) 135 | conv3 = self.relu(self.conv3(conv3s)) 136 | conv3p = self.pool(conv3) 137 | conv4s = self.relu(self.conv4s(conv3p)) 138 | conv4 = self.relu(self.conv4(conv4s)) 139 | conv4p = self.pool(conv4) 140 | conv5s = self.relu(self.conv5s(conv4p)) 141 | conv5 = self.relu(self.conv5(conv5s)) 142 | conv5p = self.pool(conv5) 143 | 144 | # center = self.center(conv5p) 145 | conv6s = self.conv6(conv5p) 146 | conv6 = self.relu(self.decoder6(conv6s)) 147 | center = conv6 148 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 149 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 150 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 151 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 152 | dec1 = self.dec1(torch.cat([dec2, conv1], 1)) 153 | xx = dec1 154 | # xx = self.final(dec1) 155 | # out = F.sigmoid(out) 156 | 157 | ## side output 158 | d1 = self.dsn1(conv1) 159 | d2 = F.upsample_bilinear(self.dsn2(conv2), size=(h,w)) 160 | d3 = F.upsample_bilinear(self.dsn3(conv3), size=(h,w)) 161 | d4 = F.upsample_bilinear(self.dsn4(conv4), size=(h,w)) 162 | d5 = F.upsample_bilinear(self.dsn5(conv5), size=(h,w)) 163 | 164 | d6 = F.upsample_bilinear(self.dsn6(conv6), size=(h,w)) 165 | # 166 | ###########sigmoid ver 167 | d1_out = F.sigmoid(d1) 168 | d2_out = F.sigmoid(d2) 169 | d3_out = F.sigmoid(d3) 170 | d4_out = F.sigmoid(d4) 171 | d5_out = F.sigmoid(d5) 172 | d6_out = F.sigmoid(d6) 173 | 174 | # concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out), 1) 175 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out,d6_out ), 1) 176 | 177 | fuse_box = self.fuse(concat) 178 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1)) 179 | GAP = GAP.view(-1, 32) 180 | se_like = self.SE_mimic(GAP) 181 | se_like = torch.unsqueeze(se_like, 2) 182 | se_like = torch.unsqueeze(se_like, 3) 183 | 184 | feat_se = concat * se_like.expand_as(concat) 185 | boundary = self.final_boundary(feat_se) 186 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1) 187 | bd_sftmax = F.softmax(boundary, dim=1) 188 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1) 189 | 190 | feat_concat = torch.cat( [xx, fuse_box], 1) 191 | feat_concat_conv = self.final_conv(feat_concat) 192 | mask = self.final_mask(feat_concat_conv) 193 | mask_sftmax = F.softmax(mask,dim=1) 194 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1) 195 | 196 | if self.mode == 'Train': 197 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1) 198 | elif self.mode == 'Infer': 199 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1) 200 | 201 | 202 | mask_out = torch.unsqueeze(mask[:,1,:,:],1) 203 | relu = self.relu(mask_out) 204 | scalar = relu.cpu().detach().numpy() 205 | if np.sum(scalar) == 0: 206 | average = 0 207 | else : 208 | average = scalar[np.nonzero(scalar)].mean() 209 | mask_out = mask_out-relu + (average*scalefactor) 210 | 211 | if self.mode == 'Train': 212 | mask_out = F.sigmoid(mask_out) 213 | boundary_out = F.sigmoid(boundary_out) 214 | 215 | return d1_out, d2_out, d3_out, d4_out, d5_out, d6_out, boundary_out, mask_out 216 | elif self.mode =='Infer': 217 | return mask_out 218 | 219 | return out 220 | 221 | 222 | class Interpolate(nn.Module): 223 | def __init__( 224 | self, 225 | size: int = None, 226 | scale_factor: int = None, 227 | mode: str = "nearest", 228 | align_corners: bool = False, 229 | ): 230 | super().__init__() 231 | self.interp = nn.functional.interpolate 232 | self.size = size 233 | self.mode = mode 234 | self.scale_factor = scale_factor 235 | self.align_corners = align_corners 236 | 237 | def forward(self, x: torch.Tensor) -> torch.Tensor: 238 | x = self.interp( 239 | x, 240 | size=self.size, 241 | scale_factor=self.scale_factor, 242 | mode=self.mode, 243 | align_corners=self.align_corners, 244 | ) 245 | return x 246 | 247 | 248 | class DecoderBlockV2(nn.Module): 249 | def __init__( 250 | self, 251 | in_channels: int, 252 | middle_channels: int, 253 | out_channels: int, 254 | is_deconv: bool = True, 255 | ): 256 | super().__init__() 257 | self.in_channels = in_channels 258 | 259 | if is_deconv: 260 | """ 261 | Paramaters for Deconvolution were chosen to avoid artifacts, following 262 | link https://distill.pub/2016/deconv-checkerboard/ 263 | """ 264 | 265 | self.block = nn.Sequential( 266 | ConvRelu(in_channels, middle_channels), 267 | nn.ConvTranspose2d( 268 | middle_channels, out_channels, kernel_size=4, stride=2, padding=1 269 | ), 270 | nn.ReLU(inplace=True), 271 | ) 272 | else: 273 | self.block = nn.Sequential( 274 | Interpolate(scale_factor=2, mode="bilinear"), 275 | ConvRelu(in_channels, middle_channels), 276 | ConvRelu(middle_channels, out_channels), 277 | ) 278 | 279 | def forward(self, x: torch.Tensor) -> torch.Tensor: 280 | return self.block(x) -------------------------------------------------------------------------------- /nets/zoo/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def double_conv(in_channels, out_channels): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 8 | nn.ReLU(inplace=True), 9 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 10 | nn.ReLU(inplace=True) 11 | ) 12 | 13 | class _up_deconv(nn.Module): 14 | def __init__(self, in_channels, out_channels): 15 | super(_up_deconv, self).__init__() 16 | 17 | self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 18 | self.bn_i = nn.BatchNorm2d(num_features=out_channels) 19 | self.relu = nn.ReLU() 20 | 21 | def forward(self, x): 22 | 23 | out = self.bn_i(self.deconv(x)) 24 | out = self.relu(out) 25 | 26 | return out 27 | 28 | class UNet(nn.Module): 29 | 30 | def __init__(self, n_class=1, pretrained=False, mode='Train'): 31 | super().__init__() 32 | self.mode=mode 33 | self.dconv_down1 = double_conv(3, 64) 34 | self.dconv_down2 = double_conv(64, 128) 35 | self.dconv_down3 = double_conv(128, 256) 36 | self.dconv_down4 = double_conv(256, 512) 37 | self.dconv_down5 = double_conv(512, 1024) 38 | 39 | self.maxpool = nn.MaxPool2d(2) 40 | # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 41 | self.upsample5 = _up_deconv(1024,512) 42 | self.upsample4 = _up_deconv(512,256) 43 | self.upsample3 = _up_deconv(256,128) 44 | self.upsample2 = _up_deconv(128,64) 45 | self.dconv_up4 = double_conv(512 + 512, 512) 46 | self.dconv_up3 = double_conv(256 + 256, 256) 47 | self.dconv_up2 = double_conv(128 + 128, 128) 48 | self.dconv_up1 = double_conv(64 + 64, 64) 49 | 50 | self.conv_last = nn.Conv2d(64, n_class, 1) 51 | 52 | 53 | def forward(self, x): 54 | 55 | conv1 = self.dconv_down1(x) 56 | x = self.maxpool(conv1) 57 | 58 | conv2 = self.dconv_down2(x) 59 | x = self.maxpool(conv2) 60 | 61 | conv3 = self.dconv_down3(x) 62 | x = self.maxpool(conv3) 63 | 64 | conv4 = self.dconv_down4(x) 65 | x = self.maxpool(conv4) 66 | 67 | x = self.dconv_down5(x) 68 | x = self.upsample5(x) 69 | x = torch.cat([x, conv4],1) 70 | 71 | x = self.dconv_up4(x) 72 | x = self.upsample4(x) 73 | x = torch.cat([x, conv3], dim=1) 74 | 75 | x = self.dconv_up3(x) 76 | x = self.upsample3(x) 77 | x = torch.cat([x, conv2], dim=1) 78 | 79 | x = self.dconv_up2(x) 80 | x = self.upsample2(x) 81 | x = torch.cat([x, conv1], dim=1) 82 | 83 | x = self.dconv_up1(x) 84 | 85 | out = self.conv_last(x) 86 | 87 | if self.mode == 'Train': 88 | return F.sigmoid(out) 89 | elif self.mode == 'Infer': 90 | return out -------------------------------------------------------------------------------- /nets/zoo/unet_BE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import skimage 4 | import numpy as np 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | 9 | def double_conv(in_channels, out_channels): 10 | return nn.Sequential( 11 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 14 | nn.ReLU(inplace=True) 15 | ) 16 | 17 | class _up_deconv(nn.Module): 18 | def __init__(self, in_channels, out_channels): 19 | super(_up_deconv, self).__init__() 20 | 21 | self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 22 | 23 | self.relu = nn.ReLU() 24 | 25 | def forward(self, x): 26 | 27 | out = self.deconv(x) 28 | out = self.relu(out) 29 | 30 | return out 31 | 32 | class UNet_BE(nn.Module): 33 | 34 | def __init__(self, n_class=1, pretrained=False, mode='Train'): 35 | super().__init__() 36 | self.mode = mode 37 | n_channels = 32 38 | 39 | 40 | self.dconv_down1 = double_conv(3, 64) 41 | self.dconv_down2 = double_conv(64, 128) 42 | self.dconv_down3 = double_conv(128, 256) 43 | self.dconv_down4 = double_conv(256, 512) 44 | self.dconv_down5 = double_conv(512, 1024) 45 | self.maxpool = nn.MaxPool2d(2) 46 | self.upsample5 = _up_deconv(1024,512) 47 | self.upsample4 = _up_deconv(512,256) 48 | self.upsample3 = _up_deconv(256,128) 49 | self.upsample2 = _up_deconv(128,64) 50 | self.dconv_up4 = double_conv(512 + 512, 512) 51 | self.dconv_up3 = double_conv(256 + 256, 256) 52 | self.dconv_up2 = double_conv(128 + 128, 128) 53 | self.dconv_up1 = double_conv(64 + 64, 64) 54 | # HED Block 55 | self.dsn1 = nn.Conv2d(64, 1, 1) 56 | self.dsn2 = nn.Conv2d(128, 1, 1) 57 | self.dsn3 = nn.Conv2d(256, 1, 1) 58 | self.dsn4 = nn.Conv2d(512, 1, 1) 59 | self.dsn5 = nn.Conv2d(1024, 1, 1) 60 | 61 | #boundary enhancement part 62 | self.fuse = nn.Sequential(nn.Conv2d(5, 64, 1),nn.ReLU(inplace=True)) 63 | self.SE_mimic = nn.Sequential( 64 | nn.Linear(64, 64, bias=False), 65 | nn.ReLU(inplace=True), 66 | nn.Linear(64, 5, bias=False), 67 | nn.Sigmoid() 68 | ) 69 | self.final_boundary = nn.Conv2d(5,2,1) 70 | self.final_conv = nn.Sequential( 71 | nn.Conv2d(128,64,3, padding=1), 72 | nn.ReLU(inplace=True) 73 | ) 74 | self.final_mask = nn.Conv2d(64,2,1) 75 | self.relu = nn.ReLU() 76 | self.out = nn.Conv2d(64,1,1) 77 | 78 | 79 | def forward(self, x): 80 | h = x.size(2) 81 | w = x.size(3) 82 | 83 | 84 | conv1 = self.dconv_down1(x) 85 | x = self.maxpool(conv1) 86 | conv2 = self.dconv_down2(x) 87 | x = self.maxpool(conv2) 88 | conv3 = self.dconv_down3(x) 89 | x = self.maxpool(conv3) 90 | conv4 = self.dconv_down4(x) 91 | x = self.maxpool(conv4) 92 | conv5 = self.dconv_down5(x) 93 | x = self.upsample5(conv5) 94 | x = torch.cat([x, conv4],1) 95 | x = self.dconv_up4(x) 96 | x = self.upsample4(x) 97 | x = torch.cat([x, conv3], dim=1) 98 | x = self.dconv_up3(x) 99 | x = self.upsample3(x) 100 | x = torch.cat([x, conv2], dim=1) 101 | x = self.dconv_up2(x) 102 | x = self.upsample2(x) 103 | x = torch.cat([x, conv1], dim=1) 104 | x = self.dconv_up1(x) 105 | # out = F.sigmoid(self.out(x)) 106 | 107 | 108 | ## side output 109 | d1 = self.dsn1(conv1) 110 | d2 = F.upsample_bilinear(self.dsn2(conv2), size=(h,w)) 111 | d3 = F.upsample_bilinear(self.dsn3(conv3), size=(h,w)) 112 | d4 = F.upsample_bilinear(self.dsn4(conv4), size=(h,w)) 113 | d5 = F.upsample_bilinear(self.dsn5(conv5), size=(h,w)) 114 | 115 | d1_out = F.sigmoid(d1) 116 | d2_out = F.sigmoid(d2) 117 | d3_out = F.sigmoid(d3) 118 | d4_out = F.sigmoid(d4) 119 | d5_out = F.sigmoid(d5) 120 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out), 1) 121 | 122 | fuse_box = self.fuse(concat) 123 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1)) 124 | GAP = GAP.view(-1, 64) 125 | se_like = self.SE_mimic(GAP) 126 | se_like = torch.unsqueeze(se_like, 2) 127 | se_like = torch.unsqueeze(se_like, 3) 128 | 129 | feat_se = concat * se_like.expand_as(concat) 130 | boundary = self.final_boundary(feat_se) 131 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1) 132 | bd_sftmax = F.softmax(boundary, dim=1) 133 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1) 134 | 135 | feat_concat = torch.cat( [x, fuse_box], 1) 136 | feat_concat_conv = self.final_conv(feat_concat) 137 | mask = self.final_mask(feat_concat_conv) 138 | mask_sftmax = F.softmax(mask,dim=1) 139 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1) 140 | 141 | if self.mode == 'Train': 142 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1) 143 | elif self.mode == 'Infer': 144 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1) 145 | 146 | 147 | mask_out = torch.unsqueeze(mask[:,1,:,:],1) 148 | relu = self.relu(mask_out) 149 | scalar = relu.cpu().detach().numpy() 150 | if np.sum(scalar) == 0: 151 | average = 0 152 | else : 153 | average = scalar[np.nonzero(scalar)].mean() 154 | mask_out = mask_out-relu + (average*scalefactor) 155 | 156 | if self.mode == 'Train': 157 | mask_out = F.sigmoid(mask_out) 158 | boundary_out = F.sigmoid(boundary_out) 159 | 160 | return d1_out, d2_out, d3_out, d4_out, d5_out, boundary_out, mask_out 161 | elif self.mode =='Infer': 162 | return mask_out 163 | 164 | 165 | 166 | # 167 | -------------------------------------------------------------------------------- /nets/zoo/uspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class _stage_block(nn.Module): 7 | def __init__(self, channel_var): 8 | super(_stage_block, self).__init__() 9 | 10 | channel_in, channel_out = channel_var 11 | 12 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=3, stride=1, padding=1) 13 | self.bn = nn.BatchNorm2d(channel_out) 14 | self.relu = nn.ReLU() 15 | 16 | def forward(self, x): 17 | out = self.bn( self.conv(x) ) 18 | out = self.relu(out) 19 | return out 20 | 21 | 22 | class _upss_block(nn.Module): 23 | def __init__(self, channel_in): 24 | super(_upss_block, self).__init__() 25 | self.conv1 = nn.Sequential( 26 | nn.MaxPool2d(1), 27 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=1, stride=1, padding=0), 28 | ) 29 | self.conv2 = nn.Sequential( 30 | nn.MaxPool2d(2), 31 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=2, stride=1, padding=1), 32 | ) 33 | self.conv3 = nn.Sequential( 34 | nn.MaxPool2d(3), 35 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=3, stride=1, padding=1), 36 | ) 37 | self.conv4 = nn.Sequential( 38 | nn.MaxPool2d(6), 39 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=4, stride=1, padding=2), 40 | ) 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | h, w = x.size(2), x.size(3) 46 | 47 | out1 = self.conv1(x) 48 | out1 = F.upsample(input=out1, size=(h, w), mode='bilinear') 49 | out2 = self.conv2(x) 50 | out2 = F.upsample(input=out2, size=(h, w), mode='bilinear') 51 | out3 = self.conv3(x) 52 | out3 = F.upsample(input=out3, size=(h, w), mode='bilinear') 53 | out4 = self.conv4(x) 54 | out4 = F.upsample(input=out4, size=(h, w), mode='bilinear') 55 | 56 | out = torch.cat([out1, out2, out3, out4, residual], 1) 57 | return out 58 | 59 | 60 | class _down(nn.Module): 61 | def __init__(self, channel_in): 62 | super(_down, self).__init__() 63 | self.maxpool = nn.MaxPool2d(2) 64 | 65 | def forward(self, x): 66 | out = self.maxpool(x) 67 | return out 68 | 69 | 70 | class _up(nn.Module): 71 | def __init__(self, channel_in): 72 | super(_up, self).__init__() 73 | 74 | #self.relu = nn.PReLU() 75 | #self.subpixel = nn.PixelShuffle(2) 76 | self.subpixel = nn.ConvTranspose2d(in_channels=channel_in, out_channels=int(channel_in/2.), kernel_size=2, stride=2) 77 | #self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_in, kernel_size=1, stride=1, padding=0) 78 | 79 | def forward(self, x): 80 | #out = self.relu(self.conv(x)) 81 | #out = self.subpixel(out) 82 | out = self.subpixel(x) 83 | return out 84 | 85 | 86 | class Uspp(nn.Module): 87 | def __init__(self, pretrained=False,mode='Train'): 88 | super(Uspp, self).__init__() 89 | self.mode=mode 90 | self.DCR_block11 = self.make_layer(_stage_block, [ 3, 64]) 91 | self.DCR_block12 = self.make_layer(_stage_block, [ 64, 64]) 92 | self.down1 = self.make_layer(_down, 64) 93 | self.DCR_block21 = self.make_layer(_stage_block, [ 64,128]) 94 | self.DCR_block22 = self.make_layer(_stage_block, [128,128]) 95 | self.down2 = self.make_layer(_down, 128) 96 | self.DCR_block31 = self.make_layer(_stage_block, [128,256]) 97 | self.DCR_block32 = self.make_layer(_stage_block, [256,256]) 98 | self.down3 = self.make_layer(_down, 256) 99 | self.DCR_block41 = self.make_layer(_stage_block, [256,512]) 100 | self.DCR_block42 = self.make_layer(_stage_block, [512,512]) 101 | self.down4 = self.make_layer(_down, 512) 102 | 103 | self.uspp = self.make_layer(_upss_block, 512) 104 | 105 | self.up4 = self.make_layer(_up, 1024) 106 | self.DCR_block43 = self.make_layer(_stage_block,[1024,512]) 107 | self.DCR_block44 = self.make_layer(_stage_block, [512,512]) 108 | self.up3 = self.make_layer(_up, 512) 109 | self.DCR_block33 = self.make_layer(_stage_block, [512,256]) 110 | self.DCR_block34 = self.make_layer(_stage_block, [256,256]) 111 | self.up2 = self.make_layer(_up, 256) 112 | self.DCR_block23 = self.make_layer(_stage_block, [256,128]) 113 | self.DCR_block24 = self.make_layer(_stage_block, [128,128]) 114 | self.up1 = self.make_layer(_up, 128) 115 | self.DCR_block13 = self.make_layer(_stage_block, [128, 64]) 116 | self.DCR_block14 = self.make_layer(_stage_block, [ 64, 1]) 117 | 118 | def make_layer(self, block, channel_in): 119 | layers = [] 120 | layers.append(block(channel_in)) 121 | return nn.Sequential(*layers) 122 | 123 | def forward(self, x): 124 | residual = x 125 | 126 | out = self.DCR_block11(x) 127 | conc1= self.DCR_block12(out) 128 | out = self.down1(conc1) 129 | 130 | out = self.DCR_block21(out) 131 | conc2= self.DCR_block22(out) 132 | out = self.down2(conc2) 133 | 134 | out = self.DCR_block31(out) 135 | conc3= self.DCR_block32(out) 136 | out = self.down3(conc3) 137 | 138 | out = self.DCR_block41(out) 139 | conc4= self.DCR_block42(out) 140 | out = self.down4(conc4) 141 | 142 | # bridge part 143 | out = self.uspp(out) 144 | 145 | out = self.up4(out) 146 | out = torch.cat([conc4, out], 1) 147 | out = self.DCR_block43(out) 148 | out = self.DCR_block44(out) 149 | 150 | out = self.up3(out) 151 | out = torch.cat([conc3, out], 1) 152 | out = self.DCR_block33(out) 153 | out = self.DCR_block34(out) 154 | 155 | out = self.up2(out) 156 | out = torch.cat([conc2, out], 1) 157 | out = self.DCR_block23(out) 158 | out = self.DCR_block24(out) 159 | 160 | out = self.up1(out) 161 | out = torch.cat([conc1, out], 1) 162 | out = self.DCR_block13(out) 163 | out = self.DCR_block14(out) 164 | if self.mode == 'Train': 165 | return F.sigmoid(out) 166 | elif self.mode == 'Infer': 167 | return out -------------------------------------------------------------------------------- /nets/zoo/uspp_BE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import skimage 4 | import numpy as np 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | class _stage_block(nn.Module): 11 | def __init__(self, channel_var): 12 | super(_stage_block, self).__init__() 13 | 14 | channel_in, channel_out = channel_var 15 | 16 | self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_out, kernel_size=3, stride=1, padding=1) 17 | self.bn = nn.BatchNorm2d(channel_out) 18 | self.relu = nn.ReLU() 19 | 20 | def forward(self, x): 21 | out = self.bn( self.conv(x) ) 22 | out = self.relu(out) 23 | return out 24 | 25 | 26 | class _upss_block(nn.Module): 27 | def __init__(self, channel_in): 28 | super(_upss_block, self).__init__() 29 | self.conv1 = nn.Sequential( 30 | nn.MaxPool2d(1), 31 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=1, stride=1, padding=0), 32 | ) 33 | self.conv2 = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=2, stride=1, padding=1), 36 | ) 37 | self.conv3 = nn.Sequential( 38 | nn.MaxPool2d(3), 39 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=3, stride=1, padding=1), 40 | ) 41 | self.conv4 = nn.Sequential( 42 | nn.MaxPool2d(6), 43 | nn.Conv2d(in_channels=channel_in, out_channels=int(channel_in/4.), kernel_size=4, stride=1, padding=2), 44 | ) 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | h, w = x.size(2), x.size(3) 50 | 51 | out1 = self.conv1(x) 52 | out1 = F.upsample(input=out1, size=(h, w), mode='bilinear') 53 | out2 = self.conv2(x) 54 | out2 = F.upsample(input=out2, size=(h, w), mode='bilinear') 55 | out3 = self.conv3(x) 56 | out3 = F.upsample(input=out3, size=(h, w), mode='bilinear') 57 | out4 = self.conv4(x) 58 | out4 = F.upsample(input=out4, size=(h, w), mode='bilinear') 59 | 60 | out = torch.cat([out1, out2, out3, out4, residual], 1) 61 | return out 62 | 63 | 64 | class _down(nn.Module): 65 | def __init__(self, channel_in): 66 | super(_down, self).__init__() 67 | self.maxpool = nn.MaxPool2d(2) 68 | 69 | def forward(self, x): 70 | out = self.maxpool(x) 71 | return out 72 | 73 | 74 | class _up(nn.Module): 75 | def __init__(self, channel_in): 76 | super(_up, self).__init__() 77 | 78 | #self.relu = nn.PReLU() 79 | #self.subpixel = nn.PixelShuffle(2) 80 | self.subpixel = nn.ConvTranspose2d(in_channels=channel_in, out_channels=int(channel_in/2.), kernel_size=2, stride=2) 81 | #self.conv = nn.Conv2d(in_channels=channel_in, out_channels=channel_in, kernel_size=1, stride=1, padding=0) 82 | 83 | def forward(self, x): 84 | #out = self.relu(self.conv(x)) 85 | #out = self.subpixel(out) 86 | out = self.subpixel(x) 87 | return out 88 | 89 | 90 | class Uspp_BE(nn.Module): 91 | def __init__(self, pretrained=False, mode= 'Train'): 92 | super(Uspp_BE, self).__init__() 93 | self.mode = mode 94 | self.DCR_block11 = self.make_layer(_stage_block, [ 3, 64]) 95 | self.DCR_block12 = self.make_layer(_stage_block, [ 64, 64]) 96 | self.down1 = self.make_layer(_down, 64) 97 | self.DCR_block21 = self.make_layer(_stage_block, [ 64,128]) 98 | self.DCR_block22 = self.make_layer(_stage_block, [128,128]) 99 | self.down2 = self.make_layer(_down, 128) 100 | self.DCR_block31 = self.make_layer(_stage_block, [128,256]) 101 | self.DCR_block32 = self.make_layer(_stage_block, [256,256]) 102 | self.down3 = self.make_layer(_down, 256) 103 | self.DCR_block41 = self.make_layer(_stage_block, [256,512]) 104 | self.DCR_block42 = self.make_layer(_stage_block, [512,512]) 105 | self.down4 = self.make_layer(_down, 512) 106 | 107 | self.uspp = self.make_layer(_upss_block, 512) 108 | 109 | self.up4 = self.make_layer(_up, 1024) 110 | self.DCR_block43 = self.make_layer(_stage_block,[1024,512]) 111 | self.DCR_block44 = self.make_layer(_stage_block, [512,512]) 112 | self.up3 = self.make_layer(_up, 512) 113 | self.DCR_block33 = self.make_layer(_stage_block, [512,256]) 114 | self.DCR_block34 = self.make_layer(_stage_block, [256,256]) 115 | self.up2 = self.make_layer(_up, 256) 116 | self.DCR_block23 = self.make_layer(_stage_block, [256,128]) 117 | self.DCR_block24 = self.make_layer(_stage_block, [128,128]) 118 | self.up1 = self.make_layer(_up, 128) 119 | self.DCR_block13 = self.make_layer(_stage_block, [128, 64]) 120 | # self.DCR_block14 = self.make_layer(_stage_block, [ 64, 1]) 121 | self.DCR_block14 = self.make_layer(_stage_block, [ 64, 64]) 122 | # HED Block 123 | self.dsn1 = nn.Conv2d(64, 1, 1) 124 | self.dsn2 = nn.Conv2d(128, 1, 1) 125 | self.dsn3 = nn.Conv2d(256, 1, 1) 126 | self.dsn4 = nn.Conv2d(512, 1, 1) 127 | self.dsn5 = nn.Conv2d(1024, 1, 1) 128 | 129 | #boundary enhancement part 130 | self.fuse = nn.Sequential(nn.Conv2d(5, 64, 1),nn.ReLU(inplace=True)) 131 | 132 | self.SE_mimic = nn.Sequential( 133 | nn.Linear(64, 64, bias=False), 134 | nn.ReLU(inplace=True), 135 | nn.Linear(64, 5, bias=False), 136 | nn.Sigmoid() 137 | ) 138 | self.final_boundary = nn.Conv2d(5,2,1) 139 | 140 | self.final_conv = nn.Sequential( 141 | nn.Conv2d(128,64,3, padding=1), 142 | nn.ReLU(inplace=True) 143 | ) 144 | self.final_mask = nn.Conv2d(64,2,1) 145 | 146 | 147 | 148 | self.relu = nn.ReLU() 149 | self.out = nn.Conv2d(64,1,1) 150 | 151 | 152 | def make_layer(self, block, channel_in): 153 | layers = [] 154 | layers.append(block(channel_in)) 155 | return nn.Sequential(*layers) 156 | 157 | def forward(self, x): 158 | residual = x 159 | h = x.size(2) 160 | w = x.size(3) 161 | 162 | out = self.DCR_block11(x) 163 | conc1= self.DCR_block12(out) 164 | out = self.down1(conc1) 165 | 166 | out = self.DCR_block21(out) 167 | conc2= self.DCR_block22(out) 168 | out = self.down2(conc2) 169 | 170 | out = self.DCR_block31(out) 171 | conc3= self.DCR_block32(out) 172 | out = self.down3(conc3) 173 | 174 | out = self.DCR_block41(out) 175 | conc4= self.DCR_block42(out) 176 | out = self.down4(conc4) 177 | 178 | # bridge part 179 | conc5 = self.uspp(out) 180 | 181 | out = self.up4(conc5) 182 | out = torch.cat([conc4, out], 1) 183 | out = self.DCR_block43(out) 184 | out = self.DCR_block44(out) 185 | 186 | out = self.up3(out) 187 | out = torch.cat([conc3, out], 1) 188 | out = self.DCR_block33(out) 189 | out = self.DCR_block34(out) 190 | 191 | out = self.up2(out) 192 | out = torch.cat([conc2, out], 1) 193 | out = self.DCR_block23(out) 194 | out = self.DCR_block24(out) 195 | 196 | out = self.up1(out) 197 | out = torch.cat([conc1, out], 1) 198 | out = self.DCR_block13(out) 199 | out = self.DCR_block14(out) 200 | 201 | d1 = self.dsn1(conc1) 202 | d2 = F.upsample_bilinear(self.dsn2(conc2), size=(h,w)) 203 | d3 = F.upsample_bilinear(self.dsn3(conc3), size=(h,w)) 204 | d4 = F.upsample_bilinear(self.dsn4(conc4), size=(h,w)) 205 | d5 = F.upsample_bilinear(self.dsn5(conc5), size=(h,w)) 206 | d1_out = F.sigmoid(d1) 207 | d2_out = F.sigmoid(d2) 208 | d3_out = F.sigmoid(d3) 209 | d4_out = F.sigmoid(d4) 210 | d5_out = F.sigmoid(d5) 211 | concat = torch.cat((d1_out, d2_out, d3_out, d4_out, d5_out), 1) 212 | 213 | fuse_box = self.fuse(concat) 214 | GAP = F.adaptive_avg_pool2d(fuse_box,(1,1)) 215 | GAP = GAP.view(-1, 64) 216 | se_like = self.SE_mimic(GAP) 217 | se_like = torch.unsqueeze(se_like, 2) 218 | se_like = torch.unsqueeze(se_like, 3) 219 | 220 | feat_se = concat * se_like.expand_as(concat) 221 | boundary = self.final_boundary(feat_se) 222 | boundary_out = torch.unsqueeze(boundary[:,1,:,:],1) 223 | bd_sftmax = F.softmax(boundary, dim=1) 224 | boundary_scale = torch.unsqueeze(bd_sftmax[:,1,:,:],1) 225 | 226 | feat_concat = torch.cat( [out, fuse_box], 1) 227 | feat_concat_conv = self.final_conv(feat_concat) 228 | mask = self.final_mask(feat_concat_conv) 229 | mask_sftmax = F.softmax(mask,dim=1) 230 | mask_scale = torch.unsqueeze(mask_sftmax[:,1,:,:],1) 231 | 232 | if self.mode == 'Train': 233 | scalefactor = torch.clamp(mask_scale+boundary_scale,0,1) 234 | elif self.mode == 'Infer': 235 | scalefactor = torch.clamp(mask_scale+5*boundary_scale,0,1) 236 | 237 | 238 | mask_out = torch.unsqueeze(mask[:,1,:,:],1) 239 | relu = self.relu(mask_out) 240 | scalar = relu.cpu().detach().numpy() 241 | if np.sum(scalar) == 0: 242 | average = 0 243 | else : 244 | average = scalar[np.nonzero(scalar)].mean() 245 | mask_out = mask_out-relu + (average*scalefactor) 246 | 247 | if self.mode == 'Train': 248 | mask_out = F.sigmoid(mask_out) 249 | boundary_out = F.sigmoid(boundary_out) 250 | 251 | return d1_out, d2_out, d3_out, d4_out, d5_out, boundary_out, mask_out 252 | elif self.mode =='Infer': 253 | return mask_out 254 | 255 | -------------------------------------------------------------------------------- /notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/d4ed2c0f95bbee435abc97389df1357393b7e570/notebooks/__init__.py -------------------------------------------------------------------------------- /notebooks/data_prep.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 7, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Dataset location (edit as needed)\n", 10 | "import os\n", 11 | "import pandas as pd\n", 12 | "root_dir = '../'" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 8, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "3.7.10 (default, Feb 26 2021, 18:47:35) \n", 25 | "[GCC 7.3.0]\n", 26 | "Python 3.7.10\n" 27 | ] 28 | } 29 | ], 30 | "source": [ 31 | "import sys\n", 32 | "print(sys.version)\n", 33 | "!python --version" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 12, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/html": [ 44 | "
\n", 45 | "\n", 58 | "\n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | "
imagelabel
0../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0...../data/Train/Urban3D_Train/masks/_10_JAX_Tile...
1../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0...../data/Train/Urban3D_Train/masks/_10_JAX_Tile...
2../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0...../data/Train/Urban3D_Train/masks/_10_JAX_Tile...
3../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0...../data/Train/Urban3D_Train/masks/_10_JAX_Tile...
4../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0...../data/Train/Urban3D_Train/masks/_10_JAX_Tile...
\n", 94 | "
" 95 | ], 96 | "text/plain": [ 97 | " image \\\n", 98 | "0 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n", 99 | "1 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n", 100 | "2 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n", 101 | "3 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n", 102 | "4 ../data/Train/Urban3D_Train/RGB/_10_JAX_Tile_0... \n", 103 | "\n", 104 | " label \n", 105 | "0 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... \n", 106 | "1 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... \n", 107 | "2 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... \n", 108 | "3 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... \n", 109 | "4 ../data/Train/Urban3D_Train/masks/_10_JAX_Tile... " 110 | ] 111 | }, 112 | "metadata": {}, 113 | "output_type": "display_data" 114 | }, 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "output csv: ../csvs/Urban3D_Train_df.csv\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "# Make dataframe csvs for train/test\n", 125 | "\n", 126 | "out_dir = os.path.join(root_dir, 'csvs/')\n", 127 | "os.makedirs(out_dir, exist_ok=True)\n", 128 | "# data_dir = 'Test/Urban3D_Test'\n", 129 | "data_dir = 'Train/Urban3D_Train'\n", 130 | "\n", 131 | "\n", 132 | "d = os.path.join(root_dir, 'data', data_dir)\n", 133 | "subdirs = sorted([f for f in os.listdir(d)]) \n", 134 | "outpath = os.path.join(out_dir, data_dir.split('/')[1] + '_df.csv')\n", 135 | "im_list, mask_list = [], []\n", 136 | "\n", 137 | "\n", 138 | "im_files = [os.path.join( d,'RGB', f.split('.')[0] + '.tif')\n", 139 | "for f in sorted(os.listdir(os.path.join(d,'RGB' )))]\n", 140 | "\n", 141 | "mask_files = [os.path.join(d, 'masks', f.split('.')[0] + '.tif')\n", 142 | "for f in sorted(os.listdir(os.path.join(d, 'masks')))]\n", 143 | "\n", 144 | "\n", 145 | "im_list.extend(im_files)\n", 146 | "mask_list.extend(mask_files)\n", 147 | "\n", 148 | "\n", 149 | "df = pd.DataFrame({'image': im_list, 'label': mask_list})\n", 150 | "display(df.head())\n", 151 | "\n", 152 | "df.to_csv(outpath, index=False)\n", 153 | "\n", 154 | "print(\"output csv:\", outpath)" 155 | ] 156 | } 157 | ], 158 | "metadata": { 159 | "kernelspec": { 160 | "display_name": "Python 3 (ipykernel)", 161 | "language": "python", 162 | "name": "python3" 163 | }, 164 | "language_info": { 165 | "codemirror_mode": { 166 | "name": "ipython", 167 | "version": 3 168 | }, 169 | "file_extension": ".py", 170 | "mimetype": "text/x-python", 171 | "name": "python", 172 | "nbconvert_exporter": "python", 173 | "pygments_lexer": "ipython3", 174 | "version": "3.7.10" 175 | } 176 | }, 177 | "nbformat": 4, 178 | "nbformat_minor": 4 179 | } 180 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HoinJung/BEmodule-Satellite-Building-Segmentation/d4ed2c0f95bbee435abc97389df1357393b7e570/src/__init__.py -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['CUDA_VISIBLE_DEVICES']='1' 4 | import sys 5 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname("__file__")))) 6 | 7 | import nets 8 | import utils 9 | 10 | 11 | config_path = '../yml/infer.yml' 12 | config = utils.config.parse(config_path) 13 | # print('Config:') 14 | # print(config) 15 | 16 | # make infernce output dir 17 | # os.makedirs(os.path.dirname(config['inference']['output_dir']), exist_ok=True) 18 | 19 | inferer = nets.infer.Inferer(config) 20 | inferer() 21 | # inferer.Inferer() 22 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname("__file__")))) 5 | 6 | import nets 7 | import utils 8 | import time 9 | os.environ['CUDA_VISIBLE_DEVICES']='0,1' 10 | 11 | config_path = '../yml/train.yml' 12 | config = utils.config.parse(config_path) 13 | 14 | # make model output dir 15 | 16 | os.makedirs(os.path.dirname(config['training']['callbacks']['model_checkpoint']['filepath']), exist_ok=True) 17 | start_time = str(int(time.time())) 18 | config['start_time'] = start_time 19 | trainer = nets.train.Trainer(config=config) 20 | trainer.train() 21 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config, core, io, data 2 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from nets import zoo 3 | 4 | 5 | def parse(path): 6 | 7 | with open(path, 'r') as f: 8 | config = yaml.safe_load(f) 9 | f.close() 10 | 11 | if not config['train'] and not config['infer']: 12 | raise ValueError('"train", "infer", or both must be true.') 13 | if config['train'] and config['train_csv_dir'] is None: 14 | raise ValueError('"train_csv_dir" must be provided if training.') 15 | if config['infer'] and config['inference_data_csv'] is None: 16 | raise ValueError('"inference_csv_dir" must be provided if "infer".') 17 | 18 | train_aoi = config['aoi'] 19 | 20 | """ Custom AOI """ 21 | if train_aoi == 2: 22 | aoi = 'AOI_2_Vegas' 23 | elif train_aoi == 3: 24 | aoi = 'AOI_3_Paris' 25 | elif train_aoi == 4: 26 | aoi = 'AOI_4_Shanghai' 27 | elif train_aoi == 5: 28 | aoi = 'AOI_5_Khartoum' 29 | elif train_aoi == 6: 30 | aoi = 'Urban3D' 31 | elif train_aoi == 7: 32 | aoi = 'WHU' 33 | elif train_aoi == 8: 34 | aoi = 'mass' 35 | elif train_aoi == 9: 36 | aoi = 'WHU_asia' 37 | config['get_aoi'] = aoi 38 | 39 | if config['training']['lr'] is not None: 40 | config['training']['lr'] = float(config['training']['lr']) 41 | 42 | 43 | if config['validation_augmentation'] is not None \ 44 | and config['inference_augmentation'] is None: 45 | config['inference_augmentation'] = config['validation_augmentation'] 46 | 47 | return config 48 | -------------------------------------------------------------------------------- /utils/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import skimage 5 | 6 | 7 | def _check_skimage_im_load(im): 8 | """Check if `im` is already loaded in; if not, load it in.""" 9 | if isinstance(im, str): 10 | return skimage.io.imread(im) 11 | elif isinstance(im, np.ndarray): 12 | return im 13 | else: 14 | raise ValueError( 15 | "{} is not an accepted image format for scikit-image.".format(im)) 16 | 17 | 18 | def _check_df_load(df): 19 | """Check if `df` is already loaded in, if not, load from file.""" 20 | if isinstance(df, str): 21 | if df.lower().endswith('json'): 22 | return _check_gdf_load(df) 23 | else: 24 | return pd.read_csv(df) 25 | elif isinstance(df, pd.DataFrame): 26 | return df 27 | else: 28 | raise ValueError(f"{df} is not an accepted DataFrame format.") 29 | 30 | 31 | 32 | def get_data_paths(path, infer=False): 33 | """Get a pandas dataframe of images and labels from a csv. 34 | 35 | This file is designed to parse image:label reference CSVs (or just image) 36 | for inferencde) as defined in the documentation. Briefly, these should be 37 | CSVs containing two columns: 38 | 39 | ``'image'``: the path to images. 40 | ``'label'``: the path to the label file that corresponds to the image. 41 | 42 | Arguments 43 | --------- 44 | path : str 45 | Path to a .CSV-formatted reference file defining the location of 46 | training, validation, or inference data. See docs for details. 47 | infer : bool, optional 48 | If ``infer=True`` , the ``'label'`` column will not be returned (as it 49 | is unnecessary for inference), even if it is present. 50 | 51 | Returns 52 | ------- 53 | df : :class:`pandas.DataFrame` 54 | A :class:`pandas.DataFrame` containing the relevant `image` and `label` 55 | information from the CSV at `path` (unless ``infer=True`` , in which 56 | case only the `image` column is returned.) 57 | 58 | """ 59 | df = pd.read_csv(path) 60 | if infer: 61 | return df[['image']] # no labels in those files 62 | else: 63 | return df[['image', 'label']] # remove anything extraneous 64 | 65 | 66 | def get_files_recursively(path, traverse_subdirs=False, extension='.tif'): 67 | """Get files from subdirs of `path`, joining them to the dir.""" 68 | if traverse_subdirs: 69 | walker = os.walk(path) 70 | path_list = [] 71 | for step in walker: 72 | if not step[2]: # if there are no files in the current dir 73 | continue 74 | path_list += [os.path.join(step[0], fname) 75 | for fname in step[2] if 76 | fname.lower().endswith(extension)] 77 | return path_list 78 | else: 79 | return [os.path.join(path, f) for f in os.listdir(path) 80 | if f.endswith(extension)] 81 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from .log import _get_logging_level 4 | from .core import get_files_recursively 5 | import logging 6 | 7 | 8 | def make_dataset_csv(im_dir, im_ext='tif', label_dir=None, label_ext='json', 9 | output_path='dataset.csv', stage='train', match_re=None, 10 | recursive=False, ignore_mismatch=None, verbose=0): 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(_get_logging_level(int(verbose))) 14 | logger.debug('Checking arguments.') 15 | 16 | if stage != 'infer' and label_dir is None: 17 | raise ValueError("label_dir must be provided if stage is not infer.") 18 | logger.info('Matching images to labels.') 19 | logger.debug('Getting image file paths.') 20 | im_fnames = get_files_recursively(im_dir, traverse_subdirs=recursive, 21 | extension=im_ext) 22 | logger.debug(f"Got {len(im_fnames)} image file paths.") 23 | temp_im_df = pd.DataFrame({'image_path': im_fnames}) 24 | 25 | if stage != 'infer': 26 | logger.debug('Preparing training or validation set.') 27 | logger.debug('Getting label file paths.') 28 | label_fnames = get_files_recursively(label_dir, 29 | traverse_subdirs=recursive, 30 | extension=label_ext) 31 | logger.debug(f"Got {len(label_fnames)} label file paths.") 32 | if len(im_fnames) != len(label_fnames): 33 | logger.warn('The number of images and label files is not equal.') 34 | 35 | logger.debug("Matching image files to label files.") 36 | logger.debug("Extracting image filename substrings for matching.") 37 | temp_label_df = pd.DataFrame({'label_path': label_fnames}) 38 | temp_im_df['image_fname'] = temp_im_df['image_path'].apply( 39 | lambda x: os.path.split(x)[1]) 40 | temp_label_df['label_fname'] = temp_label_df['label_path'].apply( 41 | lambda x: os.path.split(x)[1]) 42 | if match_re: 43 | logger.debug('match_re is True, extracting regex matches') 44 | im_match_strs = temp_im_df['image_fname'].str.extract(match_re) 45 | label_match_strs = temp_label_df['label_fname'].str.extract( 46 | match_re) 47 | if len(im_match_strs.columns) > 1 or \ 48 | len(label_match_strs.columns) > 1: 49 | raise ValueError('Multiple regex matches occurred within ' 50 | 'individual filenames.') 51 | else: 52 | temp_im_df['match_str'] = im_match_strs 53 | temp_label_df['match_str'] = label_match_strs 54 | else: 55 | logger.debug('match_re is False, will match by fname without ext') 56 | temp_im_df['match_str'] = temp_im_df['image_fname'].apply( 57 | lambda x: os.path.splitext(x)[0]) 58 | temp_label_df['match_str'] = temp_label_df['label_fname'].apply( 59 | lambda x: os.path.splitext(x)[0]) 60 | 61 | logger.debug('Aligning label and image dataframes by' 62 | ' match_str.') 63 | temp_join_df = pd.merge(temp_im_df, temp_label_df, on='match_str', 64 | how='inner') 65 | logger.debug(f'Length of joined dataframe: {len(temp_join_df)}') 66 | if len(temp_join_df) < len(temp_im_df) and \ 67 | ignore_mismatch is None: 68 | raise ValueError('There is not a perfect 1:1 match of images ' 69 | 'to label files. To allow this behavior, see ' 70 | 'the make_dataset_csv() ignore_mismatch ' 71 | 'argument.') 72 | elif len(temp_join_df) > len(temp_im_df) and ignore_mismatch is None: 73 | raise ValueError('There are multiple label files matching at ' 74 | 'least one image file.') 75 | elif len(temp_join_df) > len(temp_im_df) and ignore_mismatch == 'skip': 76 | logger.info('ignore_mismatch="skip", so dropping any images with ' 77 | f'duplicates. Original images: {len(temp_im_df)}') 78 | dup_rows = temp_join_df.duplicated(subset='match_str', keep=False) 79 | temp_join_df = temp_join_df.loc[~dup_rows, :] 80 | logger.info('Remaining images after dropping duplicates: ' 81 | f'{len(temp_join_df)}') 82 | logger.debug('Dropping extra columns from output dataframe.') 83 | output_df = temp_join_df[['image_path', 'label_path']].rename( 84 | columns={'image_path': 'image', 'label_path': 'label'}) 85 | 86 | elif stage == 'infer': 87 | logger.debug('Preparing inference dataset dataframe.') 88 | output_df = temp_im_df.rename(columns={'image_path': 'image'}) 89 | 90 | logger.debug(f'Saving output dataframe to {output_path} .') 91 | output_df.to_csv(output_path, index=False) 92 | 93 | return output_df 94 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | """Utility functions for data io.""" 2 | import numpy as np 3 | import skimage.io 4 | 5 | 6 | def imread(path, make_8bit=False, rescale=False, 7 | rescale_min='auto', rescale_max='auto'): 8 | """Read in an image file and rescale pixel values (if applicable). 9 | 10 | Note 11 | ---- 12 | Because overhead imagery is often either 16-bit or multispectral (i.e. >3 13 | channels or bands that don't directly translate into the RGB scheme of 14 | photographs), this package using scikit-image_ ``io`` algorithms. Though 15 | slightly slower, these algorithms are compatible with any bit depth or 16 | channel count. 17 | 18 | .. _scikit-image: https://scikit-image.org 19 | 20 | Arguments 21 | --------- 22 | path : str 23 | Path to the image file to load. 24 | make_8bit : bool, optional 25 | Should the image be converted to an 8-bit format? Defaults to False. 26 | rescale : bool, optional 27 | Should pixel intensities be rescaled? Defaults to no (False). 28 | rescale_min : ``'auto'`` or :class:`int` or :class:`float` or :class:`list` 29 | The minimum pixel value(s) for rescaling. If ``rescale=True`` but no 30 | value is provided for `rescale_min`, the minimum pixel intensity in 31 | each channel of the image will be subtracted such that the minimum 32 | value becomes zero. If a single number is provided, that number will be 33 | subtracted from each channel. If a list of values is provided that is 34 | the same length as the number of channels, then those values will be 35 | subtracted from the corresponding channels. 36 | rescale_max : ``'auto'`` or :class:`int` or :class:`float` or :class:`list` 37 | The max pixel value(s) for rescaling. If ``rescale=True`` but no 38 | value is provided for `rescale_max`, each channel will be rescaled such 39 | that the maximum value in the channel is set to the bit range's max. 40 | If a single number is provided, that number will be set as the upper 41 | limit for all channels. If a list of values is provided that is the 42 | same length as the number of channels, then those values will be 43 | set to the maximum value in the corresponding channels. 44 | 45 | Returns 46 | ------- 47 | im : :func:`numpy.array` 48 | A NumPy array of shape ``[Y, X, C]`` containing the imagery, with dtype 49 | ``uint8``. 50 | 51 | """ 52 | im_arr = skimage.io.imread(path) 53 | # check dtype for preprocessing 54 | if im_arr.dtype == np.uint8: 55 | dtype = 'uint8' 56 | elif im_arr.dtype == np.uint16: 57 | dtype = 'uint16' 58 | elif im_arr.dtype in [np.float16, np.float32, np.float64]: 59 | if np.amax(im_arr) <= 1 and np.amin(im_arr) >= 0: 60 | dtype = 'zero-one normalized' # range = 0-1 61 | elif np.amax(im_arr) > 0 and np.amin(im_arr) < 0: 62 | dtype = 'z-scored' 63 | elif np.amax(im_arr) <= 255: 64 | dtype = '255 float' 65 | elif np.amax(im_arr) <= 65535: 66 | dtype = '65535 float' 67 | else: 68 | raise TypeError('The loaded image array is an unexpected dtype.') 69 | else: 70 | raise TypeError('The loaded image array is an unexpected dtype.') 71 | if make_8bit: 72 | im_arr = preprocess_im_arr(im_arr, dtype, rescale=rescale, 73 | rescale_min=rescale_min, 74 | rescale_max=rescale_max) 75 | return im_arr 76 | 77 | 78 | def preprocess_im_arr(im_arr, im_format, rescale=False, 79 | rescale_min='auto', rescale_max='auto'): 80 | """Convert image to standard shape and dtype for use in the pipeline. 81 | 82 | Notes 83 | ----- 84 | This repo will require the following of images: 85 | 86 | - Their shape is of form [X, Y, C] 87 | - Input images are dtype ``uint8`` 88 | 89 | This function will take an image array `im_arr` and reshape it accordingly. 90 | 91 | Arguments 92 | --------- 93 | im_arr : :func:`numpy.array` 94 | A numpy array representation of an image. `im_arr` should have either 95 | two or three dimensions. 96 | im_format : str 97 | One of ``'uint8'``, ``'uint16'``, ``'z-scored'``, 98 | ``'zero-one normalized'``, ``'255 float'``, or ``'65535 float'``. 99 | String indicating the dtype of the input, which will dictate the 100 | preprocessing applied. 101 | rescale : bool, optional 102 | Should pixel intensities be rescaled? Defaults to no (False). 103 | rescale_min : ``'auto'`` or :class:`int` or :class:`float` or :class:`list` 104 | The minimum pixel value(s) for rescaling. If ``rescale=True`` but no 105 | value is provided for `rescale_min`, the minimum pixel intensity in 106 | each channel of the image will be subtracted such that the minimum 107 | value becomes zero. If a single number is provided, that number will be 108 | subtracted from each channel. If a list of values is provided that is 109 | the same length as the number of channels, then those values will be 110 | subtracted from the corresponding channels. 111 | rescale_max : ``'auto'`` or :class:`int` or :class:`float` or :class:`list` 112 | The max pixel value(s) for rescaling. If ``rescale=True`` but no 113 | value is provided for `rescale_max`, each channel will be rescaled such 114 | that the maximum value in the channel is set to the bit range's max. 115 | If a single number is provided, that number will be set as the upper 116 | limit for all channels. If a list of values is provided that is the 117 | same length as the number of channels, then those values will be 118 | set to the maximum value in the corresponding channels. 119 | 120 | Returns 121 | ------- 122 | A :func:`numpy.array` with shape ``[X, Y, C]`` and dtype ``uint8``. 123 | 124 | """ 125 | # get [Y, X, C] axis order set up 126 | if im_arr.ndim not in [2, 3]: 127 | raise ValueError('This package can only read two-dimensional' 128 | 'image data with an optional channel dimension.') 129 | if im_arr.ndim == 2: 130 | im_arr = im_arr[:, :, np.newaxis] 131 | if im_arr.shape[0] < im_arr.shape[2]: # if the channel axis comes first 132 | im_arr = np.moveaxis(im_arr, 0, -1) # move 0th axis tolast position 133 | 134 | # rescale images (if applicable) 135 | if rescale: 136 | im_arr = rescale_arr(im_arr, im_format, rescale_min, rescale_max) 137 | 138 | if im_format == 'uint8': 139 | return im_arr.astype('uint8') # just to be sure 140 | elif im_format == 'uint16': 141 | im_arr = (im_arr.astype('float64')*255./65535.).astype('uint8') 142 | elif im_format == 'z-scored': 143 | im_arr = ((im_arr+1)*177.5).astype('uint8') 144 | elif im_format == 'zero-one normalized': 145 | im_arr = (im_arr*255).astype('uint8') 146 | elif im_format == '255 float': 147 | im_arr = im_arr.astype('uint8') 148 | elif im_format == '65535 float': 149 | # why are you using this format? 150 | im_arr = (im_arr*255/65535).astype('uint8') 151 | return im_arr 152 | 153 | 154 | def scale_for_model(image, output_type=None): 155 | """Scale an image to a model's required parameters. 156 | 157 | Arguments 158 | --------- 159 | image : :class:`np.array` 160 | The image array to be transformed to a desired output format. 161 | output_type : str, optional 162 | The data format of the output to pass into the model. There are five 163 | possible values: 164 | 165 | * ``'normalized'`` : values rescaled to 0-1. 166 | * ``'zscored'`` : image converted to zero mean and unit stdev. 167 | * ``'8bit'`` : image converted to 8-bit format. 168 | * ``'16bit'`` : image converted to 16-bit format. 169 | 170 | If no value is provided, no re-scaling is performed (input array is 171 | returned directly). 172 | """ 173 | 174 | if output_type is None: 175 | return image 176 | elif output_type == 'normalized': 177 | out_im = image/image.max() 178 | return out_im 179 | elif output_type == 'zscored': 180 | return (image - np.mean(image))/np.std(image) 181 | elif output_type == '8bit': 182 | if image.max() > 255: 183 | # assume it's 16-bit, rescale to 8-bit scale to min/max 184 | out_im = 255.*image/65535 185 | return out_im.astype('uint8') 186 | elif image.max() <= 1: 187 | out_im = 255.*image 188 | return out_im.astype('uint8') 189 | else: 190 | return image.astype('uint8') 191 | elif output_type == '16bit': 192 | if (image.max() < 255) and (image.max() > 1): 193 | # scale to min/max 194 | out_im = 65535.*image/255 195 | return out_im.astype('uint16') 196 | elif image.max() <= 1: 197 | out_im = 65535.*image 198 | return out_im.astype('uint16') 199 | else: 200 | return image.astype('uint16') 201 | else: 202 | raise ValueError('output_type must be one of' 203 | ' "normalized", "zscored", "8bit", "16bit"') 204 | 205 | 206 | def rescale_arr(im_arr, im_format, rescale_min='auto', rescale_max='auto'): 207 | """Rescale array values in a 3D image array with channel order [Y, X, C]. 208 | 209 | Arguments 210 | --------- 211 | im_arr : :class:`numpy.array` 212 | A numpy array representation of an image. `im_arr` should have either 213 | two or three dimensions. 214 | im_format : str 215 | One of ``'uint8'``, ``'uint16'``, ``'z-scored'``, 216 | ``'zero-one normalized'``, ``'255 float'``, or ``'65535 float'``. 217 | String indicating the dtype of the input, which will dictate the 218 | preprocessing applied. 219 | rescale_min : ``'auto'`` or :class:`int` or :class:`float` or :class:`list` 220 | The minimum pixel value(s) for rescaling. If ``rescale=True`` but no 221 | value is provided for `rescale_min`, the minimum pixel intensity in 222 | each channel of the image will be subtracted such that the minimum 223 | value becomes zero. If a single number is provided, that number will be 224 | subtracted from each channel. If a list of values is provided that is 225 | the same length as the number of channels, then those values will be 226 | subtracted from the corresponding channels. 227 | rescale_max : ``'auto'`` or :class:`int` or :class:`float` or :class:`list` 228 | The max pixel value(s) for rescaling. If ``rescale=True`` but no 229 | value is provided for `rescale_max`, each channel will be rescaled such 230 | that the maximum value in the channel is set to the bit range's max. 231 | If a single number is provided, that number will be set as the upper 232 | limit for all channels. If a list of values is provided that is the 233 | same length as the number of channels, then those values will be 234 | set to the maximum value in the corresponding channels. 235 | 236 | Returns 237 | ------- 238 | normalized_arr : :class:`numpy.array` 239 | """ 240 | 241 | if isinstance(rescale_min, list): 242 | if len(rescale_min) != im_arr.shape[2]: # if list len != channels 243 | raise ValueError('The channel rescaling parameters must be ' 244 | 'either a single value or a list of length ' 245 | 'n_channels.') 246 | else: 247 | rescale_min = np.array(rescale_min) 248 | elif isinstance(rescale_min, int) or isinstance(rescale_min, float): 249 | rescale_min = np.array([rescale_min]*im_arr.shape[2]) 250 | elif rescale_min == 'auto': 251 | rescale_min = np.amin(im_arr, axis=(0, 1)) 252 | 253 | if isinstance(rescale_max, list): 254 | if len(rescale_max) != im_arr.shape[2]: # if list len != channels 255 | raise ValueError('The channel rescaling parameters must be ' 256 | 'either a single value or a list of length ' 257 | 'n_channels.') 258 | else: 259 | rescale_max = np.array(rescale_max) 260 | elif isinstance(rescale_max, int) or isinstance(rescale_max, float): 261 | rescale_max = np.array([rescale_max]*im_arr.shape[2]) 262 | elif rescale_max == 'auto': 263 | rescale_max = np.amax(im_arr, axis=(0, 1)) 264 | 265 | scale_factor = None 266 | if im_format in ['uint8', '255 float']: 267 | scale_factor = 255 268 | elif im_format in ['uint16', '65535 float']: 269 | scale_factor = 65535 270 | elif im_format == 'zero-one normalized': 271 | scale_factor = 1 272 | 273 | # set all values above the scale max to the scale max, and all values 274 | # below the scale min to the scale min 275 | for channel in range(im_arr.shape[2]): 276 | subarr = im_arr[:, :, channel] 277 | subarr[subarr < rescale_min[channel]] = rescale_min[channel] 278 | subarr[subarr > rescale_max[channel]] = rescale_max[channel] 279 | im_arr[:, :, channel] = subarr 280 | 281 | if scale_factor is not None: 282 | im_arr = (im_arr-rescale_min)*( 283 | scale_factor/(rescale_max-rescale_min)) 284 | 285 | return im_arr 286 | 287 | 288 | def _check_channel_order(im_arr, framework): 289 | im_shape = im_arr.shape 290 | if len(im_shape) == 3: # doesn't matter for 1-channel images 291 | if im_shape[0] > im_shape[2] and framework in ['torch', 'pytorch']: 292 | # in [Y, X, C], needs to be in [C, Y, X] 293 | im_arr = np.moveaxis(im_arr, 2, 0) 294 | elif im_shape[2] > im_shape[0] and framework == 'keras': 295 | # in [C, Y, X], needs to be in [Y, X, C] 296 | im_arr = np.moveaxis(im_arr, 0, 2) 297 | elif len(im_shape) == 4: # for a whole minibatch 298 | if im_shape[1] > im_shape[3] and framework in ['torch', 'pytorch']: 299 | # in [Y, X, C], needs to be in [C, Y, X] 300 | im_arr = np.moveaxis(im_arr, 3, 1) 301 | elif im_shape[3] > im_shape[1] and framework == 'keras': 302 | # in [C, Y, X], needs to be in [Y, X, C] 303 | im_arr = np.moveaxis(im_arr, 1, 3) 304 | 305 | return im_arr 306 | -------------------------------------------------------------------------------- /utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def _get_logging_level(level_int): 5 | """Convert a logging level integer into a log level.""" 6 | if isinstance(level_int, bool): 7 | level_int = int(level_int) 8 | if level_int < 0: 9 | return logging.CRITICAL + 1 # silence all possible outputs 10 | elif level_int == 0: 11 | return logging.WARNING 12 | elif level_int == 1: 13 | return logging.INFO 14 | elif level_int == 2: 15 | return logging.DEBUG 16 | elif level_int in [10, 20, 30, 40, 50]: # if user provides the logger int 17 | return level_int 18 | elif isinstance(level_int, int): # if it's an int but not one of the above 19 | return level_int 20 | else: 21 | raise ValueError(f"logging level set to {level_int}, " 22 | "but it must be an integer <= 2.") 23 | -------------------------------------------------------------------------------- /yml/infer.yml: -------------------------------------------------------------------------------- 1 | model_name : unet_BE 2 | model_path: '../result/models_weight/' 3 | training_date : '1629037296' 4 | aoi : 6 5 | boundary : True 6 | inference: 7 | window_step_size_x: 8 | window_step_size_y: 9 | output_dir: '../result/infer/' 10 | weight_file : 'final.pth' 11 | train: false 12 | infer: true 13 | pretrained: false 14 | nn_framework: torch 15 | batch_size: 4 16 | data_specs: 17 | width: 512 18 | height: 512 19 | dtype: 20 | image_type: zscore 21 | rescale: false 22 | rescale_minima: auto 23 | rescale_maxima: auto 24 | channels: 3 25 | label_type: mask 26 | is_categorical: false 27 | mask_channels: 1 28 | val_holdout_frac: 0.1 29 | data_workers: 30 | 31 | training_data_csv: 32 | validation_data_csv: 33 | inference_data_csv: '../csvs/' 34 | training_augmentation: 35 | augmentations: 36 | p: 1.0 37 | shuffle: true 38 | validation_augmentation: 39 | augmentations: 40 | p: 1.0 41 | inference_augmentation: 42 | augmentations: 43 | p: 1.0 44 | training: 45 | epochs: 300 46 | steps_per_epoch: 47 | optimizer: Adam 48 | lr: 1e-4 49 | opt_args: 50 | loss: 51 | bcewithlogits: 52 | jaccard: 53 | loss_weights: 54 | bcewithlogits: 10 55 | jaccard: 2.5 56 | metrics: 57 | training: f1_score 58 | validation: f1_score 59 | checkpoint_frequency: 10 60 | callbacks: 61 | early_stopping: 62 | patience: 24 63 | model_checkpoint: 64 | filepath: 65 | monitor: val_loss 66 | lr_schedule: 67 | schedule_type: arbitrary 68 | schedule_dict: 69 | milestones: 70 | - 200 71 | gamma: 0.1 72 | model_dest_path: 73 | verbose: true 74 | -------------------------------------------------------------------------------- /yml/train.yml: -------------------------------------------------------------------------------- 1 | 2 | # Choose model refered to 'net/zoo/__init__.py' 3 | 4 | # model_name : ternaus 5 | model_name: unet_BE 6 | 7 | # aoi 2 3 4 5 6 7 8 9 8 | #Area Vegas Paris Shanghai Khartoum Urban3D WHU-HR Massachusetts WHU-LR 9 | 10 | aoi : 6 11 | 12 | # If you adopt BE module in your model, change 'boundary' as True. 13 | boundary : True 14 | # boundary : False 15 | 16 | # Number of stage in encoder. U-Net, ResUNet have 5 stage in their architecutres, while TernausNet has 6. 17 | num_stage : 5 18 | 19 | # Pretrained model path 20 | model_path : '' 21 | # model_path: '../result/models_weight/{WEIGHT_DIR}/{PRETRAIN_FILE}.pth' 22 | 23 | 24 | train: true 25 | infer: false 26 | pretrained: False 27 | nn_framework: torch 28 | batch_size: 4 29 | data_specs: 30 | width: 512 31 | height: 512 32 | dtype: 33 | image_type: zscore 34 | rescale: false 35 | rescale_minima: auto 36 | rescale_maxima: auto 37 | channels: 3 38 | label_type: mask 39 | is_categorical: false 40 | mask_channels: 2 41 | val_holdout_frac: 0.175 42 | data_workers: 43 | num_classes : 1 44 | 45 | train_csv_dir : '../csvs/' 46 | validation_data_csv: 47 | inference_data_csv: 48 | 49 | # No augmentation! 50 | # If you want to add any of them, follow the discription in 'nets/transform.py/' 51 | training_augmentation: 52 | augmentations: 53 | CenterCrop : 54 | height : 512 55 | width : 512 56 | p : 1.0 57 | p: 1.0 58 | shuffle: true 59 | validation_augmentation: 60 | augmentations: 61 | CenterCrop : 62 | height : 512 63 | width : 512 64 | p : 1.0 65 | p: 1.0 66 | inference_augmentation: 67 | augmentations: 68 | p: 1.0 69 | 70 | # Enough epoch was set, because we use EarlyStopping. 71 | training: 72 | epochs: 10000 73 | steps_per_epoch: 74 | optimizer: adam 75 | lr: 1e-4 76 | opt_args: 77 | 78 | 79 | # BE module use focal+msssim+bce loss. 80 | # If you don't need BE module(boundary=False), 'loss_mask' and 'loss_boundary' do not work. 81 | loss : 82 | focal: 83 | loss_weights : 84 | focal : 1 85 | loss_mask: 86 | msssim : 87 | loss_mask_weights: 88 | msssim : 1 89 | loss_boundary: 90 | bce : 91 | loss_boundary_weights: 92 | bce : 1 93 | metrics: 94 | training: p 95 | validation: f 96 | checkpoint_frequency: 10 97 | callbacks: 98 | early_stopping: 99 | patience: 15 100 | model_checkpoint: 101 | filepath: '../result/models_weight/' 102 | path_aoi : 103 | monitor: val_loss 104 | lr_schedule: 105 | schedule_type: arbitrary 106 | schedule_dict: 107 | verbose: true 108 | 109 | inference: 110 | window_step_size_x: 111 | window_step_size_y: 112 | output_dir: 113 | --------------------------------------------------------------------------------