├── .gitignore ├── LICENSE ├── README.md ├── augmentations.py ├── configs ├── america.yaml ├── asia.yaml ├── base.yaml ├── baseline.yaml ├── baseline_10mbands.yaml ├── baseline_bceloss.yaml ├── baseline_diceloss.yaml ├── baseline_europe.yaml ├── baseline_frankensteinloss.yaml ├── baseline_highlr.yaml ├── baseline_jaccardlikeloss.yaml ├── baseline_jaccardlikeloss_long.yaml ├── baseline_largeb.yaml ├── baseline_largecrops_smallb.yaml ├── baseline_long.yaml ├── baseline_lowlr.yaml ├── baseline_oversampling.yaml ├── baseline_smallcrops.yaml ├── baseline_verylargeb.yaml ├── baseline_verylowlr.yaml ├── debug.yaml ├── deterministic_debug.yaml ├── dualstream.yaml ├── dualstreamunet_debug.yaml ├── europe.yaml ├── fusion.yaml ├── fusion_1.yaml ├── fusion_10.yaml ├── fusion_2.yaml ├── fusion_3.yaml ├── fusion_4.yaml ├── fusion_5.yaml ├── fusion_6.yaml ├── fusion_7.yaml ├── fusion_8.yaml ├── fusion_9.yaml ├── fusion_dualstream.yaml ├── fusion_dualstream_1.yaml ├── fusion_dualstream_10.yaml ├── fusion_dualstream_2.yaml ├── fusion_dualstream_3.yaml ├── fusion_dualstream_4.yaml ├── fusion_dualstream_5.yaml ├── fusion_dualstream_6.yaml ├── fusion_dualstream_7.yaml ├── fusion_dualstream_8.yaml ├── fusion_dualstream_9.yaml ├── fusion_evenlargerb.yaml ├── fusion_fewbands.yaml ├── fusion_highlr.yaml ├── fusion_lowlr.yaml ├── fusion_sgd.yaml ├── initial_parameter_search.yaml ├── model_architecture_explorer.yaml ├── new_base.yaml ├── new_baseline.yaml ├── optical.yaml ├── optical_1.yaml ├── optical_10.yaml ├── optical_2.yaml ├── optical_3.yaml ├── optical_4.yaml ├── optical_5.yaml ├── optical_6.yaml ├── optical_7.yaml ├── optical_8.yaml ├── optical_9.yaml ├── optical_visualization.yaml ├── radar_debug.yaml ├── sar.yaml ├── sar_1.yaml ├── sar_10.yaml ├── sar_2.yaml ├── sar_3.yaml ├── sar_4.yaml ├── sar_5.yaml ├── sar_6.yaml ├── sar_7.yaml ├── sar_8.yaml ├── sar_9.yaml ├── threshold_debug.yaml ├── train_multiplier_debug.yaml └── urban_extraction_loader.yaml ├── custom.py ├── datasets.py ├── evaluation.py ├── evaluation_metrics.py ├── experiment_manager ├── __init__.py ├── args.py └── config │ ├── __init__.py │ ├── config.py │ └── defaults.py ├── inference.py ├── loss_functions.py ├── make_xys.py ├── network_summary.py ├── networks ├── __init__.py ├── daudt2018.py ├── network_loader.py ├── ours.py └── papadomanolaki2019.py ├── preprocess.py ├── preprocessing.py ├── tools.py ├── train.py ├── train_network.py ├── urban_extraction.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNetLSTM 2 | Code of the following manuscript: 3 | 4 | 'Detecting Urban Changes With Recurrent Neural Networks From Multitemporal Sentinel-2 Data' 5 | 6 | https://arxiv.org/abs/1910.07778 7 | 8 | # Steps 9 | # 1. Preprocessing with preprocess.py 10 | Create a folder (e.g 'images') of the raw data with the following structure: 11 | 12 | / images / city / imgs_i / (13 tif 2D images of sentinel channels) 13 | 14 | where i=[1,2,3,4,5] 15 | 16 | and city = ['abudhabi', 'aguasclaras', 'beihai', 'beirut', 'bercy', 'bordeaux', 'brasilia', 'chongqing', 17 | 'cupertino', 'dubai', 'hongkong', 'lasvegas', 'milano', 'montpellier', 'mumbai', 'nantes', 18 | 'norcia', 'paris', 'pisa', 'rennes', 'rio', 'saclay_e', 'saclay_w', 'valencia'] 19 | 20 | Use preprocess.py to preprocess the images of the OSCD dataset. 21 | 22 | # 2. Create csv file with (x,y) locations for patch extraction during the training process using make_xys.py 23 | Here you need to specify the folder with the OSCD dataset's Labels. 24 | 25 | Note that 'train_areas' list should be defined in the same way both in make_xys.py and main.py 26 | 27 | # 3. Start the training process with main.py 28 | 29 | # 4. Make predictions on the OSCD dataset's testing images with inference.py 30 | 31 | Comments are included in the scripts for further instructions. 32 | 33 | If you find this work useful, please consider citing: M.Papadomanolaki, Sagar Verma, M. Vakalopoulou, S. Gupta, K., 'Detecting Urban Changes With Recurrent Neural Networks From Multitemporal Sentinel-2 Data', IGARSS 2019, Yokohama, Japan 34 | -------------------------------------------------------------------------------- /augmentations.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as TF 2 | from torchvision import transforms 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def compose_transformations(cfg): 8 | transformations = [] 9 | 10 | if cfg.AUGMENTATION.CROP_TYPE == 'uniform': 11 | transformations.append(UniformCrop(crop_size=cfg.AUGMENTATION.CROP_SIZE)) 12 | elif cfg.AUGMENTATION.CROP_TYPE == 'importance': 13 | transformations.append(ImportanceRandomCrop(crop_size=cfg.AUGMENTATION.CROP_SIZE)) 14 | 15 | if cfg.AUGMENTATION.RANDOM_FLIP: 16 | transformations.append(RandomFlip()) 17 | 18 | if cfg.AUGMENTATION.RANDOM_ROTATE: 19 | transformations.append(RandomRotate()) 20 | 21 | transformations.append(Numpy2Torch()) 22 | 23 | return transforms.Compose(transformations) 24 | 25 | 26 | class Numpy2Torch(object): 27 | def __call__(self, sample: tuple): 28 | img1, img2, label = sample 29 | img1_tensor = TF.to_tensor(img1) 30 | img2_tensor = TF.to_tensor(img2) 31 | label_tensor = TF.to_tensor(label) 32 | return img1_tensor, img2_tensor, label_tensor 33 | 34 | 35 | class RandomFlip(object): 36 | def __call__(self, sample): 37 | img1, img2, label = sample 38 | h_flip = np.random.choice([True, False]) 39 | v_flip = np.random.choice([True, False]) 40 | 41 | if h_flip: 42 | img1 = np.flip(img1, axis=1).copy() 43 | img2 = np.flip(img2, axis=1).copy() 44 | label = np.flip(label, axis=1).copy() 45 | 46 | if v_flip: 47 | img1 = np.flip(img1, axis=0).copy() 48 | img2 = np.flip(img2, axis=0).copy() 49 | label = np.flip(label, axis=0).copy() 50 | 51 | return img1, img2, label 52 | 53 | 54 | class RandomRotate(object): 55 | def __call__(self, args): 56 | img1, img2, label = args 57 | k = np.random.randint(1, 4) # number of 90 degree rotations 58 | img1 = np.rot90(img1, k, axes=(0, 1)).copy() 59 | img2 = np.rot90(img2, k, axes=(0, 1)).copy() 60 | label = np.rot90(label, k, axes=(0, 1)).copy() 61 | return img1, img2, label 62 | 63 | 64 | # Performs uniform cropping on images 65 | class UniformCrop(object): 66 | def __init__(self, crop_size): 67 | self.crop_size = crop_size 68 | 69 | def random_crop(self, img1: np.ndarray, img2: np.ndarray, label: np.ndarray): 70 | height, width, _ = label.shape 71 | crop_limit_x = width - self.crop_size 72 | crop_limit_y = height - self.crop_size 73 | x = np.random.randint(0, crop_limit_x) 74 | y = np.random.randint(0, crop_limit_y) 75 | 76 | img1_crop = img1[y:y+self.crop_size, x:x+self.crop_size, ] 77 | img2_crop = img2[y:y + self.crop_size, x:x + self.crop_size, ] 78 | label_crop = label[y:y+self.crop_size, x:x+self.crop_size, ] 79 | return img1_crop, img2_crop, label_crop 80 | 81 | def __call__(self, sample: tuple): 82 | img1, img2, label = sample 83 | img1, img2, label = self.random_crop(img1, img2, label) 84 | return img1, img2, label 85 | 86 | 87 | class ImportanceRandomCrop(UniformCrop): 88 | def __call__(self, sample): 89 | img1, img2, label = sample 90 | 91 | sample_size = 20 92 | balancing_factor = 5 93 | 94 | random_crops = [self.random_crop(img1, img2, label) for _ in range(sample_size)] 95 | crop_weights = np.array([crop_label.sum() for _, _, crop_label in random_crops]) + balancing_factor 96 | crop_weights = crop_weights / crop_weights.sum() 97 | 98 | sample_idx = np.random.choice(sample_size, p=crop_weights) 99 | img1, img2, label = random_crops[sample_idx] 100 | 101 | return img1, img2, label 102 | 103 | -------------------------------------------------------------------------------- /configs/america.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | DATASET: 4 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 5 | TRAIN: ['aguasclaras', 'cupertino'] 6 | TEST: ['brasilia', 'rio', 'lasvegas'] -------------------------------------------------------------------------------- /configs/asia.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | DATASET: 4 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 5 | TRAIN: ['abudhabi', 'beihai', 'hongkong', 'beirut', 'mumbai'] 6 | TEST: ['dubai', 'chongqing'] 7 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | SEED: 7 2 | THRESH: 0.0 3 | DEBUG: False 4 | SAVE_MODEL: True 5 | LOGGING: 1 6 | MODEL: 7 | TYPE: 'unet' # should support unet, unet_lstm, siammese_conc, siamese_diff, 8 | TOPOLOGY: [64, 128, 256, 512,] 9 | OUT_CHANNELS: 1 10 | IN_CHANNELS: 26 11 | LOSS_TYPE: 'JaccardLikeLoss' 12 | POSITIVE_WEIGHT: 0.9 13 | PRINT_SUMMARY: False 14 | DATALOADER: 15 | NUM_WORKER: 8 16 | SHUFFLE: True 17 | DATASET: 18 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 19 | MODE: 'optical' # optical, sar or fusion 20 | SENTINEL1_BANDS: ['VV'] 21 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 22 | TRAIN: ['aguasclaras', 'bercy', 'bordeaux', 'nantes', 'paris', 'rennes', 'saclay_e', 'abudhabi', 'cupertino', 23 | 'pisa', 'beihai', 'hongkong', 'beirut', 'mumbai'] 24 | TEST: ['brasilia', 'montpellier', 'norcia', 'rio', 'saclay_w', 'valencia', 'dubai', 'lasvegas', 'milano', 25 | 'chongqing'] 26 | TRAIN_MULTIPLIER: 5 27 | OUTPUT_BASE_DIR: '/storage/shafner/urban_change_detection/run_logs/' 28 | TRAINER: 29 | LR: 5e-5 30 | EPOCHS: 400 31 | BATCH_SIZE: 64 32 | OPTIMIZER: 'adam' # adam or sdg 33 | AUGMENTATION: 34 | OVERSAMPLING: 'none' # none, pixel or change 35 | CROP_TYPE: 'importance' # uniform or importance 36 | CROP_SIZE: 32 37 | RANDOM_FLIP: True 38 | RANDOM_ROTATE: True -------------------------------------------------------------------------------- /configs/baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.76 4 | 5 | -------------------------------------------------------------------------------- /configs/baseline_10mbands.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.0 4 | LOGGING: 1 5 | SAVE_MODEL: True 6 | 7 | MODEL: 8 | LOSS_TYPE: 'JaccardLikeLoss' 9 | IN_CHANNELS: 8 10 | 11 | TRAINER: 12 | LR: 5e-5 13 | EPOCHS: 1500 14 | 15 | DATASET: 16 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 17 | MODE: 'optical' # optical, radar or fusion 18 | SENTINEL2: 19 | BANDS: ['B02', 'B03', 'B04', 'B08'] 20 | TEMPORAL_MODE: 'bi-temporal' -------------------------------------------------------------------------------- /configs/baseline_bceloss.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.69 4 | LOGGING: 1 5 | SAVE_MODEL: True 6 | 7 | MODEL: 8 | LOSS_TYPE: 'WeightedBCEWithLogitsLoss' 9 | POSITIVE_WEIGHT: 0.8 10 | 11 | TRAINER: 12 | LR: 5e-5 13 | EPOCHS: 1500 -------------------------------------------------------------------------------- /configs/baseline_diceloss.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.8 4 | LOGGING: 10 5 | 6 | MODEL: 7 | LOSS_TYPE: 'SoftDiceLoss' 8 | 9 | TRAINER: 10 | EPOCHS: 1000 -------------------------------------------------------------------------------- /configs/baseline_europe.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | LOGGING: 1 5 | SAVE_MODEL: True 6 | 7 | MODEL: 8 | LOSS_TYPE: 'JaccardLikeLoss' 9 | 10 | TRAINER: 11 | LR: 5e-5 12 | EPOCHS: 2000 13 | BATCH_SIZE: 6 14 | 15 | AUGMENTATION: 16 | CROP_TYPE: 'importance' # uniform or importance 17 | CROP_SIZE: 32 18 | RANDOM_FLIP: True 19 | RANDOM_ROTATE: True 20 | 21 | DATASET: 22 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 23 | TRAIN: ['bercy', 'bordeaux', 'nantes', 'paris', 'rennes', 'saclay_e', 'pisa'] 24 | TEST: ['montpellier', 'norcia', 'saclay_w', 'valencia', 'milano'] -------------------------------------------------------------------------------- /configs/baseline_frankensteinloss.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.8 4 | 5 | MODEL: 6 | LOSS_TYPE: 'FrankensteinLoss' -------------------------------------------------------------------------------- /configs/baseline_highlr.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.76 4 | LOGGING: 1 5 | SAVE_MODEL: False 6 | TRAINER: 7 | LR: 5e-3 8 | EPOCHS: 1000 -------------------------------------------------------------------------------- /configs/baseline_jaccardlikeloss.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.8 4 | 5 | MODEL: 6 | LOSS_TYPE: 'JaccardLikeLoss' -------------------------------------------------------------------------------- /configs/baseline_jaccardlikeloss_long.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.69 4 | LOGGING: 1 5 | SAVE_MODEL: True 6 | 7 | MODEL: 8 | LOSS_TYPE: 'JaccardLikeLoss' 9 | 10 | TRAINER: 11 | LR: 5e-5 12 | EPOCHS: 1500 -------------------------------------------------------------------------------- /configs/baseline_largeb.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.69 4 | LOGGING: 1 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | LOSS_TYPE: 'JaccardLikeLoss' 9 | 10 | TRAINER: 11 | LR: 5e-5 12 | EPOCHS: 1500 13 | BATCH_SIZE: 6 -------------------------------------------------------------------------------- /configs/baseline_largecrops_smallb.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.0 4 | LOGGING: 1 5 | SAVE_MODEL: True 6 | 7 | MODEL: 8 | LOSS_TYPE: 'JaccardLikeLoss' 9 | 10 | TRAINER: 11 | LR: 5e-5 12 | EPOCHS: 1500 13 | BATCH_SIZE: 2 14 | 15 | AUGMENTATION: 16 | CROP_SIZE: 64 -------------------------------------------------------------------------------- /configs/baseline_long.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.76 4 | 5 | LOGGING: 1 6 | 7 | TRAINER: 8 | EPOCHS: 1000 -------------------------------------------------------------------------------- /configs/baseline_lowlr.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 4 | LOGGING: 1 5 | SAVE_MODEL: True 6 | TRAINER: 7 | LR: 5e-5 8 | EPOCHS: 1000 -------------------------------------------------------------------------------- /configs/baseline_oversampling.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | DEBUG: False 4 | SAVE_MODEL: False 5 | 6 | AUGMENTATION: 7 | OVERSAMPLING: 'pixel' -------------------------------------------------------------------------------- /configs/baseline_smallcrops.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | AUGMENTATION: 4 | CROP_SIZE: 16 -------------------------------------------------------------------------------- /configs/baseline_verylargeb.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.69 4 | 5 | MODEL: 6 | LOSS_TYPE: 'JaccardLikeLoss' 7 | 8 | TRAINER: 9 | LR: 5e-5 10 | EPOCHS: 1500 11 | BATCH_SIZE: 12 -------------------------------------------------------------------------------- /configs/baseline_verylowlr.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.76 4 | LOGGING: 1 5 | SAVE_MODEL: False 6 | TRAINER: 7 | LR: 1e-5 8 | EPOCHS: 2000 -------------------------------------------------------------------------------- /configs/debug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | DEBUG: True 4 | -------------------------------------------------------------------------------- /configs/deterministic_debug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SAVE_MODEL: False 6 | DEBUG: False 7 | 8 | SEED: 7 9 | 10 | MODEL: 11 | IN_CHANNELS: 2 12 | 13 | TRAINER: 14 | LR: 1e-2 15 | 16 | DATASET: 17 | MODE: 'sar' 18 | SENTINEL1_BANDS: ['VV'] 19 | 20 | AUGMENTATION: 21 | RANDOM_FLIP: True 22 | RANDOM_ROTATE: True 23 | CROP_TYPE: 'importance' -------------------------------------------------------------------------------- /configs/dualstream.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.0 4 | SAVE_MODEL: True 5 | DEBUG: False 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/dualstreamunet_debug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.0 4 | SAVE_MODEL: False 5 | DEBUG: True 6 | 7 | 8 | MODEL: 9 | TYPE: 'dualstreamunet' 10 | IN_CHANNELS: 28 11 | 12 | DATASET: 13 | MODE: 'fusion' 14 | SENTINEL1_BANDS: ['VV'] 15 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 16 | -------------------------------------------------------------------------------- /configs/europe.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | DATASET: 4 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 5 | TRAIN: ['bercy', 'bordeaux', 'nantes', 'paris', 'rennes', 'saclay_e', 'pisa'] 6 | TEST: ['montpellier', 'norcia', 'saclay_w', 'valencia', 'milano'] -------------------------------------------------------------------------------- /configs/fusion.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | THRESH: 0.0 4 | 5 | MODEL: 6 | IN_CHANNELS: 28 7 | 8 | DATASET: 9 | MODE: 'fusion' # optical, radar or fusion 10 | SENTINEL1: 11 | BANDS: ['VV'] 12 | SENTINEL2: 13 | BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_1.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 1 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_10.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 10 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_2.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 2 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_3.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 3 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_4.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 4 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_5.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 5 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_6.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 6 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_7.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 7 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_8.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 8 4 | THRESH: 0.00 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_9.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 9 4 | THRESH: 0.75 5 | SAVE_MODEL: True 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/fusion_dualstream.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.66 4 | DEBUG: False 5 | 6 | MODEL: 7 | TYPE: 'dualstreamunet' 8 | IN_CHANNELS: 28 9 | 10 | DATASET: 11 | MODE: 'fusion' 12 | SENTINEL1_BANDS: ['VV'] 13 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_dualstream_1.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 1 4 | SAVE_MODEL: False 5 | THRESH: 0.66 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_dualstream_10.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 10 4 | THRESH: 0.66 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKER: 16 -------------------------------------------------------------------------------- /configs/fusion_dualstream_2.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 2 4 | SAVE_MODEL: False 5 | THRESH: 0.66 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_dualstream_3.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 3 4 | SAVE_MODEL: False 5 | DEBUG: True 6 | THRESH: 0.66 7 | 8 | MODEL: 9 | TYPE: 'dualstreamunet' 10 | IN_CHANNELS: 28 11 | 12 | DATASET: 13 | MODE: 'fusion' 14 | SENTINEL1_BANDS: ['VV'] 15 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_dualstream_4.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 4 4 | SAVE_MODEL: False 5 | THRESH: 0.66 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_dualstream_5.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 5 4 | THRESH: 0.66 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_dualstream_6.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 6 4 | THRESH: 0.66 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKER: 16 -------------------------------------------------------------------------------- /configs/fusion_dualstream_7.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 7 4 | THRESH: 0.79 5 | SAVE_MODEL: True 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKER: 16 -------------------------------------------------------------------------------- /configs/fusion_dualstream_8.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 8 4 | THRESH: 0.66 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKER: 16 -------------------------------------------------------------------------------- /configs/fusion_dualstream_9.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | SEED: 9 4 | THRESH: 0.66 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'dualstreamunet' 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' 13 | SENTINEL1_BANDS: ['VV'] 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKER: 16 -------------------------------------------------------------------------------- /configs/fusion_evenlargerb.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | THRESH: 0.0 4 | 5 | TRAINER: 6 | BATCH_SIZE: 32 7 | 8 | MODEL: 9 | IN_CHANNELS: 28 10 | 11 | DATASET: 12 | MODE: 'fusion' # optical, radar or fusion 13 | SENTINEL1: 14 | BANDS: ['VV'] 15 | SENTINEL2: 16 | BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 17 | TRAIN_MULTIPLIER: 5 -------------------------------------------------------------------------------- /configs/fusion_fewbands.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | THRESH: 0.0 4 | 5 | MODEL: 6 | IN_CHANNELS: 10 7 | 8 | DATASET: 9 | MODE: 'fusion' # optical, radar or fusion 10 | SENTINEL1: 11 | BANDS: ['VV'] 12 | SENTINEL2: 13 | BANDS: ['B02', 'B03', 'B04','B08'] -------------------------------------------------------------------------------- /configs/fusion_highlr.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | THRESH: 0.0 4 | 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | IN_CHANNELS: 28 9 | 10 | TRAINER: 11 | LR: 1e-3 12 | 13 | DATASET: 14 | MODE: 'fusion' # optical, radar or fusion 15 | SENTINEL1: 16 | BANDS: ['VV'] 17 | SENTINEL2: 18 | BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_lowlr.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | THRESH: 0.0 4 | 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | IN_CHANNELS: 28 9 | 10 | TRAINER: 11 | LR: 1e-5 12 | 13 | DATASET: 14 | MODE: 'fusion' # optical, radar or fusion 15 | SENTINEL1: 16 | BANDS: ['VV'] 17 | SENTINEL2: 18 | BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/fusion_sgd.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | THRESH: 0.0 4 | 5 | MODEL: 6 | IN_CHANNELS: 28 7 | 8 | TRAINER: 9 | OPTIMIZER: 'sgd' 10 | 11 | DATASET: 12 | MODE: 'fusion' # optical, radar or fusion 13 | SENTINEL1: 14 | BANDS: ['VV'] 15 | SENTINEL2: 16 | BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/initial_parameter_search.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.0 4 | 5 | TRAINER: 6 | BATCH_SIZE: 14 7 | -------------------------------------------------------------------------------- /configs/model_architecture_explorer.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | SAVE_MODEL: False 5 | DEBUG: True 6 | SEED: 1 7 | 8 | MODEL: 9 | IN_CHANNELS: 2 10 | 11 | DATASET: 12 | MODE: 'sar' 13 | SENTINEL1_BANDS: ['VV'] -------------------------------------------------------------------------------- /configs/new_base.yaml: -------------------------------------------------------------------------------- 1 | SEED: 7 2 | THRESH: 0.0 3 | DEBUG: False 4 | SAVE_MODEL: True 5 | LOGGING: 1 6 | MODEL: 7 | TYPE: 'dualstreamunet' # should support unet, unet_lstm, siammese_conc, siamese_diff, 8 | TOPOLOGY: [64, 128, 256, 512,] 9 | OUT_CHANNELS: 1 10 | IN_CHANNELS: 26 11 | LOSS_TYPE: 'JaccardLikeLoss' 12 | POSITIVE_WEIGHT: 0.9 13 | DATALOADER: 14 | NUM_WORKER: 8 15 | SHUFFLE: True 16 | DATASET: 17 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 18 | MODE: 'optical' # optical, radar or fusion 19 | SENTINEL1_BANDS: ['VV'] 20 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 21 | TRAIN: ['aguasclaras', 'bercy', 'bordeaux', 'nantes', 'paris', 'rennes', 'saclay_e', 'abudhabi', 'cupertino', 22 | 'pisa', 'beihai', 'hongkong', 'beirut', 'mumbai'] 23 | TEST: ['brasilia', 'montpellier', 'norcia', 'rio', 'saclay_w', 'valencia', 'dubai', 'lasvegas', 'milano', 24 | 'chongqing'] 25 | TRAIN_MULTIPLIER: 1 26 | OUTPUT_BASE_DIR: '/storage/shafner/urban_change_detection/run_logs/' 27 | TRAINER: 28 | LR: 5e-5 29 | EPOCHS: 1500 30 | BATCH_SIZE: 12 31 | OPTIMIZER: 'adam' # adam or sdg 32 | AUGMENTATION: 33 | OVERSAMPLING: 'none' # none, pixel or change 34 | CROP_TYPE: 'importance' # uniform or importance 35 | CROP_SIZE: 32 36 | RANDOM_FLIP: True 37 | RANDOM_ROTATE: True -------------------------------------------------------------------------------- /configs/new_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.0 4 | LOGGING: 1 5 | SAVE_MODEL: True 6 | TRAINER: 7 | LR: 5e-5 8 | EPOCHS: 2000 -------------------------------------------------------------------------------- /configs/optical.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 # 56 80 4 | 5 | MODEL: 6 | TYPE: 'unet' 7 | IN_CHANNELS: 26 8 | 9 | DATASET: 10 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 11 | MODE: 'optical' 12 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/optical_1.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 # 56 80 4 | SEED: 1 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_10.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 # 56 80 4 | SEED: 10 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_2.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 # 56 80 4 | SEED: 2 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_3.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 4 | SEED: 3 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_4.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 4 | SEED: 4 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_5.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 4 | SEED: 5 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_6.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 4 | SEED: 6 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_7.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.82 4 | SEED: 7 5 | SAVE_MODEL: True 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_8.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 4 | SEED: 8 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_9.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 # 56 80 4 | SEED: 9 5 | SAVE_MODEL: False 6 | 7 | MODEL: 8 | TYPE: 'unet' 9 | IN_CHANNELS: 26 10 | 11 | DATASET: 12 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 13 | MODE: 'optical' 14 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/optical_visualization.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | DATASET: 4 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 5 | MODE: 'optical' 6 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/radar_debug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | DEBUG: True 4 | 5 | TRAINER: 6 | EPOCHS: 10 7 | 8 | MODEL: 9 | OUT_CHANNELS: 1 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'radar' # optical, radar or fusion 14 | SENTINEL1: 15 | BANDS: ['VV'] 16 | TEMPORAL_MODE: 'bi-temporal' 17 | TRAIN: ['aguasclaras', 'bercy', 'bordeaux', 'nantes', 'paris', 'rennes', 'saclay_e', 'abudhabi', 'cupertino', 18 | 'pisa', 'beihai', 'hongkong', 'beirut', 'mumbai'] 19 | TEST: ['brasilia', 'montpellier', 'norcia', 'rio', 'saclay_w', 'valencia', 'dubai', 'lasvegas', 'milano', 20 | 'chongqing'] -------------------------------------------------------------------------------- /configs/sar.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | MODEL: 6 | IN_CHANNELS: 2 7 | 8 | DATASET: 9 | MODE: 'sar' 10 | SENTINEL1_BANDS: ['VV'] -------------------------------------------------------------------------------- /configs/sar_1.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | SAVE_MODEL: False 5 | SEED: 1 6 | 7 | MODEL: 8 | IN_CHANNELS: 2 9 | 10 | DATASET: 11 | MODE: 'sar' 12 | SENTINEL1_BANDS: ['VV'] -------------------------------------------------------------------------------- /configs/sar_10.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SEED: 10 6 | 7 | SAVE_MODEL: False 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/sar_2.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SEED: 2 6 | 7 | SAVE_MODEL: False 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/sar_3.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SEED: 3 6 | 7 | SAVE_MODEL: False 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 18 | -------------------------------------------------------------------------------- /configs/sar_4.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.99 4 | 5 | SEED: 4 6 | 7 | SAVE_MODEL: True 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/sar_5.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SEED: 5 6 | 7 | SAVE_MODEL: False 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/sar_6.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SEED: 6 6 | 7 | SAVE_MODEL: False 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/sar_7.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SEED: 7 6 | 7 | SAVE_MODEL: False 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/sar_8.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SEED: 8 6 | 7 | SAVE_MODEL: False 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/sar_9.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.98 4 | 5 | SEED: 9 6 | 7 | SAVE_MODEL: False 8 | 9 | MODEL: 10 | IN_CHANNELS: 2 11 | 12 | DATASET: 13 | MODE: 'sar' 14 | SENTINEL1_BANDS: ['VV'] 15 | 16 | DATALOADER: 17 | NUM_WORKERS: 16 -------------------------------------------------------------------------------- /configs/threshold_debug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | 3 | THRESH: 0.64 4 | DEBUG: False 5 | 6 | MODEL: 7 | TYPE: 'unet' 8 | IN_CHANNELS: 26 9 | 10 | DATASET: 11 | PATH: '/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed/' 12 | MODE: 'optical' 13 | SENTINEL2_BANDS: ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] -------------------------------------------------------------------------------- /configs/train_multiplier_debug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | THRESH: 0.0 4 | DEBUG: True 5 | 6 | DATASET: 7 | TRAIN_MULTIPLIER: 5 8 | 9 | TRAINER: 10 | BATCH_SIZE: 64 -------------------------------------------------------------------------------- /configs/urban_extraction_loader.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "new_base.yaml" 2 | 3 | DATASET: 4 | MODE: 'optical' 5 | SENTINEL2: 6 | BANDS: ['B02', 'B03', 'B04', 'B08', 'B11', 'B12'] -------------------------------------------------------------------------------- /custom.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import Dataset 2 | from torchvision import transforms 3 | from skimage import io 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | 8 | class MyDataset(Dataset): 9 | def __init__(self, csv_path, image_ids, image_folder, label_folder, nb_dates, patch_size): 10 | # Read the csv file 11 | self.data_info = pd.read_csv(str(csv_path)) 12 | 13 | self.patch_size = patch_size 14 | self.nb_dates = nb_dates 15 | 16 | self.all_imgs = [] 17 | for nd in self.nb_dates: 18 | imgs_i = [] 19 | for city in image_ids: 20 | image_file = image_folder / city / f'{city}_{nd}.npy' 21 | imgs_i.append(np.load(image_file)) 22 | self.all_imgs.append(imgs_i) 23 | 24 | self.all_labels = [] 25 | for city in image_ids: 26 | label_file = label_folder / city / 'cm' / f'{city}-cm.tif' 27 | label = io.imread(label_file) 28 | label[label == 1] = 0 29 | label[label == 2] = 1 30 | self.all_labels.append(label) 31 | 32 | # Calculate len 33 | self.data_len = self.data_info.shape[0] - 1 34 | 35 | def __getitem__(self, index): 36 | x = int(self.data_info.iloc[:, 0][index]) 37 | y = int(self.data_info.iloc[:, 1][index]) 38 | image_id = int(self.data_info.iloc[:, 2][index]) 39 | transformation_id = int(self.data_info.iloc[:, 3][index]) 40 | 41 | def transform_date(patch, tr_id): 42 | if tr_id == 0: 43 | patch = patch 44 | elif tr_id == 1: 45 | patch = np.rot90(patch, k=1) 46 | elif tr_id == 2: 47 | patch = np.rot90(patch, k=2) 48 | elif tr_id == 3: 49 | patch = np.rot90(patch, k=3) 50 | 51 | return patch 52 | 53 | image_patch = [] 54 | for nd in self.nb_dates: 55 | find_patch = self.all_imgs[self.nb_dates.index(nd)][image_id] [x:x + self.patch_size, y:y + self.patch_size, :] 56 | find_patch = np.concatenate( (find_patch[:,:,1:4], np.reshape(find_patch[:,:,7], (find_patch.shape[0],find_patch.shape[1],1))), 2) #take the 4 highest resolution channels 57 | image_patch.append(np.transpose(transform_date(find_patch, transformation_id), (2,0,1))) 58 | find_labels = self.all_labels[image_id] [x:x + self.patch_size, y:y + self.patch_size] 59 | label_patch = transform_date(find_labels, transformation_id) 60 | 61 | 62 | return np.ascontiguousarray(image_patch), np.ascontiguousarray(label_patch) 63 | 64 | def __len__(self): 65 | return self.data_len 66 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data as torch_data 3 | from torchvision import transforms 4 | from pathlib import Path 5 | import numpy as np 6 | import augmentations as aug 7 | import random 8 | 9 | ORBITS = { 10 | 'aguasclaras': [24], 11 | 'bercy': [59, 8, 110], 12 | 'bordeaux': [30, 8, 81], 13 | 'nantes': [30, 81], 14 | 'paris': [59, 8, 110], 15 | 'rennes': [30, 81], 16 | 'saclay_e': [59, 8], 17 | 'abudhabi': [130], 18 | 'cupertino': [35, 115, 42], 19 | 'pisa': [15, 168], 20 | 'beihai': [157], 21 | 'hongkong': [11, 113], 22 | 'beirut': [14, 87], 23 | 'mumbai': [34], 24 | 'brasilia': [24], 25 | 'montpellier': [59, 37], 26 | 'norcia': [117, 44, 22, 95], 27 | 'rio': [155], 28 | 'saclay_w': [59, 8, 110], 29 | 'valencia': [30, 103, 8, 110], 30 | 'dubai': [130, 166], 31 | 'lasvegas': [166, 173], 32 | 'milano': [66, 168], 33 | 'chongqing': [55, 164] 34 | } 35 | 36 | 37 | class OSCDDataset(torch.utils.data.Dataset): 38 | def __init__(self, cfg, dataset: str, no_augmentation: bool = False): 39 | super().__init__() 40 | 41 | self.cfg = cfg 42 | self.root_dir = Path(cfg.DATASET.PATH) 43 | 44 | if dataset == 'train': 45 | multiplier = cfg.DATASET.TRAIN_MULTIPLIER 46 | self.cities = multiplier * cfg.DATASET.TRAIN 47 | else: 48 | self.cities = cfg.DATASET.TEST 49 | 50 | self.length = len(self.cities) 51 | 52 | if no_augmentation: 53 | self.transform = transforms.Compose([aug.Numpy2Torch()]) 54 | else: 55 | self.transform = aug.compose_transformations(cfg) 56 | 57 | self.mode = cfg.DATASET.MODE 58 | 59 | # creating boolean feature vector to subset sentinel 1 and sentinel 2 bands 60 | available_features_sentinel1 = ['VV'] 61 | selected_features_sentinel1 = cfg.DATASET.SENTINEL1_BANDS 62 | self.s1_feature_selection = self._get_feature_selection(available_features_sentinel1, 63 | selected_features_sentinel1) 64 | available_features_sentinel2 = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 65 | 'B11', 'B12'] 66 | selected_features_sentinel2 = cfg.DATASET.SENTINEL2_BANDS 67 | self.s2_feature_selection = self._get_feature_selection(available_features_sentinel2, 68 | selected_features_sentinel2) 69 | 70 | def __getitem__(self, index): 71 | 72 | city = self.cities[index] 73 | 74 | # np.random.seed(self.cfg.SEED) 75 | # random.seed(self.cfg.SEED) 76 | 77 | # randomly choosing an orbit for sentinel1 78 | orbit = np.random.choice(ORBITS[city]) 79 | # orbit = ORBITS[city][0] 80 | 81 | if self.cfg.DATASET.MODE == 'optical': 82 | t1_img = self._get_sentinel2_data(city, 't1') 83 | t2_img = self._get_sentinel2_data(city, 't2') 84 | elif self.cfg.DATASET.MODE == 'sar': 85 | t1_img = self._get_sentinel1_data(city, orbit, 't1') 86 | t2_img = self._get_sentinel1_data(city, orbit, 't2') 87 | else: 88 | s1_t1_img = self._get_sentinel1_data(city, orbit, 't1') 89 | s2_t1_img = self._get_sentinel2_data(city, 't1') 90 | t1_img = np.concatenate((s1_t1_img, s2_t1_img), axis=2) 91 | 92 | s1_t2_img = self._get_sentinel1_data(city, orbit, 't2') 93 | s2_t2_img = self._get_sentinel2_data(city, 't2') 94 | t2_img = np.concatenate((s1_t2_img, s2_t2_img), axis=2) 95 | 96 | label = self._get_label_data(city) 97 | t1_img, t2_img, label = self.transform((t1_img, t2_img, label)) 98 | 99 | sample = { 100 | 't1_img': t1_img, 101 | 't2_img': t2_img, 102 | 'label': label, 103 | 'city': city 104 | } 105 | 106 | return sample 107 | 108 | def _get_sentinel1_data(self, city, orbit, t): 109 | file = self.root_dir / city / 'sentinel1' / f'sentinel1_{city}_{orbit}_{t}.npy' 110 | img = np.load(file)[:, :, self.s1_feature_selection] 111 | return img.astype(np.float32) 112 | 113 | def _get_sentinel2_data(self, city, t): 114 | file = self.root_dir / city / 'sentinel2' / f'sentinel2_{city}_{t}.npy' 115 | img = np.load(file)[:, :, self.s2_feature_selection] 116 | return img.astype(np.float32) 117 | 118 | def _get_label_data(self, city): 119 | label_file = self.root_dir / city / 'label' / f'urbanchange_{city}.npy' 120 | label = np.load(label_file).astype(np.float32) 121 | label = label[:, :, np.newaxis] 122 | return label 123 | 124 | def _get_feature_selection(self, features, selection): 125 | feature_selection = [False for _ in range(len(features))] 126 | for feature in selection: 127 | i = features.index(feature) 128 | feature_selection[i] = True 129 | return feature_selection 130 | 131 | def __len__(self): 132 | return self.length 133 | 134 | def sampler(self): 135 | if self.cfg.AUGMENTATION.OVERSAMPLING == 'pixel': 136 | sampling_weights = np.array([float(self._get_label_data(city).size) for city in self.cities]) 137 | if self.cfg.AUGMENTATION.OVERSAMPLING == 'change': 138 | sampling_weights = np.array([float(np.sum(self._get_label_data(city))) for city in self.cities]) 139 | sampler = torch_data.WeightedRandomSampler(weights=sampling_weights, num_samples=self.length, 140 | replacement=True) 141 | return sampler 142 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import torchvision.transforms.functional as TF 4 | from torch.utils import data as torch_data 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from networks.network_loader import load_network 9 | import datasets 10 | from experiment_manager.config import new_config 11 | from pathlib import Path 12 | import evaluation_metrics as eval 13 | 14 | 15 | # loading cfg for inference 16 | def load_cfg(cfg_file: Path): 17 | cfg = new_config() 18 | cfg.merge_from_file(str(cfg_file)) 19 | return cfg 20 | 21 | 22 | # loading network for inference 23 | def load_net(cfg, net_file): 24 | 25 | net = load_network(cfg) 26 | 27 | state_dict = torch.load(str(net_file), map_location=lambda storage, loc: storage) 28 | net.load_state_dict(state_dict) 29 | 30 | mode = 'cuda' if torch.cuda.is_available() else 'cpu' 31 | device = torch.device(mode) 32 | 33 | net.to(device) 34 | net.eval() 35 | 36 | return net 37 | 38 | 39 | def visual_evaluation(root_dir: Path, cfg_file: Path, net_file: Path, dataset: str = 'test', n: int = 10, 40 | save_dir: Path = None, label_pred_only: bool = False): 41 | 42 | mode = 'cuda' if torch.cuda.is_available() else 'cpu' 43 | device = torch.device(mode) 44 | 45 | # loading cfg and network 46 | cfg = load_cfg(cfg_file) 47 | net = load_net(cfg, net_file) 48 | 49 | # bands for visualizaiton 50 | s1_bands, s2_bands = cfg.DATASET.SENTINEL1_BANDS, cfg.DATASET.SENTINEL2_BANDS 51 | all_bands = s1_bands + s2_bands 52 | 53 | dataset = datasets.OSCDDataset(cfg, dataset, no_augmentation=True) 54 | dataloader_kwargs = { 55 | 'batch_size': 1, 56 | 'num_workers': 0, 57 | 'shuffle': False, 58 | 'pin_memory': True, 59 | } 60 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs) 61 | 62 | with torch.no_grad(): 63 | net.eval() 64 | for step, batch in enumerate(dataloader): 65 | city = batch['city'][0] 66 | print(city) 67 | t1_img = batch['t1_img'].to(device) 68 | t2_img = batch['t2_img'].to(device) 69 | y_true = batch['label'].to(device) 70 | y_pred = net(t1_img, t2_img) 71 | y_pred = torch.sigmoid(y_pred) 72 | y_pred = y_pred.cpu().detach().numpy()[0, ] 73 | y_pred = y_pred > cfg.THRESH 74 | y_pred = y_pred.transpose((1, 2, 0)).astype('uint8') 75 | 76 | # label 77 | y_true = y_true.cpu().detach().numpy()[0, ] 78 | y_true = y_true.transpose((1, 2, 0)).astype('uint8') 79 | 80 | if label_pred_only: 81 | fig, axs = plt.subplots(1, 2, figsize=(10, 10)) 82 | axs[0].imshow(y_true[:, :, 0]) 83 | axs[1].imshow(y_pred[:, :, 0]) 84 | else: 85 | fig, axs = plt.subplots(1, 4, figsize=(20, 10)) 86 | rgb_indices = [all_bands.index(band) for band in ('B04', 'B03', 'B02')] 87 | for i, img in enumerate([t1_img, t2_img]): 88 | img = img.cpu().detach().numpy()[0, ] 89 | img = img.transpose((1, 2, 0)) 90 | rgb = img[:, :, rgb_indices] / 0.3 91 | rgb = np.minimum(rgb, 1) 92 | axs[i+2].imshow(rgb) 93 | axs[0].imshow(y_true[:, :, 0]) 94 | axs[1].imshow(y_pred[:, :, 0]) 95 | 96 | for ax in axs: 97 | ax.set_axis_off() 98 | 99 | if save_dir is None: 100 | save_dir = root_dir / 'evaluation' / cfg_file.stem 101 | if not save_dir.exists(): 102 | save_dir.mkdir() 103 | file = save_dir / f'eval_{cfg_file.stem}_{city}.png' 104 | 105 | plt.savefig(file, dpi=300, bbox_inches='tight') 106 | plt.close() 107 | 108 | 109 | def visualize_images(root_dir: Path, save_dir: Path = None): 110 | 111 | mode = 'cuda' if torch.cuda.is_available() else 'cpu' 112 | device = torch.device(mode) 113 | 114 | cfg_file = Path.cwd() / 'configs' / 'optical_visualization.yaml' 115 | cfg = load_cfg(cfg_file) 116 | 117 | dataset = datasets.OSCDDataset(cfg, 'test', no_augmentation=True) 118 | dataloader_kwargs = { 119 | 'batch_size': 1, 120 | 'num_workers': 0, 121 | 'shuffle': False, 122 | 'pin_memory': True, 123 | } 124 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs) 125 | 126 | with torch.no_grad(): 127 | for step, batch in enumerate(dataloader): 128 | city = batch['city'][0] 129 | print(city) 130 | t1_img = batch['t1_img'].to(device) 131 | t2_img = batch['t2_img'].to(device) 132 | 133 | rgb_indices = [3, 2, 1] 134 | for i, img in enumerate([t1_img, t2_img]): 135 | fig, ax = plt.subplots() 136 | img = img.cpu().detach().numpy()[0,] 137 | img = img.transpose((1, 2, 0)) 138 | rgb = img[:, :, rgb_indices] / 0.3 139 | rgb = np.minimum(rgb, 1) 140 | ax.imshow(rgb) 141 | ax.set_axis_off() 142 | 143 | if save_dir is None: 144 | save_dir = root_dir / 'evaluation' / 'images' 145 | if not save_dir.exists(): 146 | save_dir.mkdir() 147 | file = save_dir / f'{city}_img{i + 1}.png' 148 | 149 | plt.savefig(file, dpi=300, bbox_inches='tight') 150 | plt.close() 151 | 152 | 153 | def visualize_missclassifications(root_dir: Path, cfg_file: Path, net_file: Path, save_dir: Path = None): 154 | 155 | mode = 'cuda' if torch.cuda.is_available() else 'cpu' 156 | device = torch.device(mode) 157 | 158 | # loading cfg and network 159 | cfg = load_cfg(cfg_file) 160 | net = load_net(cfg, net_file) 161 | 162 | dataset = datasets.OSCDDataset(cfg, 'test', no_augmentation=True) 163 | dataloader_kwargs = { 164 | 'batch_size': 1, 165 | 'num_workers': 0, 166 | 'shuffle': False, 167 | 'pin_memory': True, 168 | } 169 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs) 170 | 171 | with torch.no_grad(): 172 | net.eval() 173 | for step, batch in enumerate(dataloader): 174 | city = batch['city'][0] 175 | print(city) 176 | t1_img = batch['t1_img'].to(device) 177 | t2_img = batch['t2_img'].to(device) 178 | y_true = batch['label'].to(device) 179 | y_pred = net(t1_img, t2_img) 180 | y_pred = torch.sigmoid(y_pred) 181 | y_pred = y_pred.cpu().detach().numpy()[0, ] 182 | y_pred = y_pred > cfg.THRESH 183 | y_pred = y_pred.transpose((1, 2, 0)).astype('uint8')[:, :, 0] 184 | 185 | # label 186 | y_true = y_true.cpu().detach().numpy()[0, ] 187 | y_true = y_true.transpose((1, 2, 0)).astype('uint8')[:, :, 0] 188 | 189 | img = np.zeros((*y_true.shape, 3)) 190 | true_positives = np.logical_and(y_pred, y_true) 191 | false_positives = np.logical_and(y_pred, np.logical_not(y_true)) 192 | false_negatives = np.logical_and(np.logical_not(y_pred), y_true) 193 | img[true_positives, :] = [1, 1, 1] 194 | img[false_positives] = [0, 1, 0] 195 | img[false_negatives] = [1, 0, 1] 196 | 197 | fig, ax = plt.subplots() 198 | ax.imshow(img) 199 | ax.set_axis_off() 200 | 201 | if save_dir is None: 202 | save_dir = root_dir / 'evaluation' / cfg_file.stem 203 | if not save_dir.exists(): 204 | save_dir.mkdir() 205 | file = save_dir / f'missclassfications_{cfg_file.stem}_{city}.png' 206 | 207 | plt.savefig(file, dpi=300, bbox_inches='tight') 208 | plt.close() 209 | 210 | 211 | def numeric_evaluation(cfg_file: Path, net_file: Path): 212 | 213 | tta_thresholds = np.linspace(0, 1, 11) 214 | 215 | mode = 'cuda' if torch.cuda.is_available() else 'cpu' 216 | device = torch.device(mode) 217 | 218 | # loading cfg and network 219 | cfg = load_cfg(cfg_file) 220 | net = load_net(cfg, net_file) 221 | dataset = datasets.OSCDDataset(cfg, 'test', no_augmentation=True) 222 | 223 | dataloader_kwargs = { 224 | 'batch_size': 1, 225 | 'num_workers': 0, 226 | 'shuffle':cfg.DATALOADER.SHUFFLE, 227 | 'pin_memory': True, 228 | } 229 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs) 230 | 231 | def predict(t1, t2): 232 | pred = net(t1, t2) 233 | pred = torch.sigmoid(pred) > cfg.THRESH 234 | pred = pred.detach().float() 235 | return pred 236 | 237 | def evaluate(true, pred): 238 | f1_score = eval.f1_score(true.flatten(), pred.flatten(), dim=0).item() 239 | true_pos = eval.true_pos(true.flatten(), pred.flatten(), dim=0).item() 240 | false_pos = eval.false_pos(true.flatten(), pred.flatten(), dim=0).item() 241 | false_neg = eval.false_neg(true.flatten(), pred.flatten(), dim=0).item() 242 | return f1_score, true_pos, false_pos, false_neg 243 | 244 | cities, f1_scores, true_positives, false_positives, false_negatives = [], [], [], [], [] 245 | tta = [] 246 | with torch.no_grad(): 247 | net.eval() 248 | for step, batch in enumerate(dataloader): 249 | 250 | city = batch['city'][0] 251 | print(city) 252 | cities.append(city) 253 | 254 | t1_img = batch['t1_img'].to(device) 255 | t2_img = batch['t2_img'].to(device) 256 | 257 | y_true = batch['label'].to(device) 258 | 259 | y_pred = predict(t1_img, t2_img) 260 | f1_score, tp, fp, fn = evaluate(y_true, y_pred) 261 | f1_scores.append(f1_score) 262 | true_positives.append(tp) 263 | false_positives.append(fp) 264 | false_negatives.append(fn) 265 | 266 | sum_preds = torch.zeros(y_true.shape).float().to(device) 267 | n_augs = 0 268 | 269 | # rotations 270 | for k in range(4): 271 | t1_img_rot = torch.rot90(t1_img, k, (2, 3)) 272 | t2_img_rot = torch.rot90(t2_img, k, (2, 3)) 273 | y_pred = predict(t1_img_rot, t2_img_rot) 274 | y_pred = torch.rot90(y_pred, 4 - k, (2, 3)) 275 | 276 | sum_preds += y_pred 277 | n_augs += 1 278 | 279 | # flips 280 | for flip in [(2, 3), (3, 2)]: 281 | t1_img_flip = torch.flip(t1_img, flip) 282 | t2_img_flip = torch.flip(t1_img, flip) 283 | y_pred = predict(t1_img_flip, t2_img_flip) 284 | y_pred = torch.flip(y_pred, flip) 285 | 286 | sum_preds += y_pred 287 | n_augs += 1 288 | 289 | pred_tta = sum_preds.float() / n_augs 290 | tta_city = [] 291 | for ts in tta_thresholds: 292 | y_pred = pred_tta > ts 293 | y_pred = y_pred.float() 294 | eval_ts = evaluate(y_true, y_pred) 295 | tta_city.append(eval_ts) 296 | tta.append(tta_city) 297 | 298 | precision = np.sum(true_positives) / (np.sum(true_positives) + np.sum(false_positives)) 299 | recall = np.sum(true_positives) / (np.sum(true_positives) + np.sum(false_negatives)) 300 | f1_score = 2 * (precision * recall / (precision + recall)) 301 | print(f'precision: {precision:.3f}, recall: {recall:.3f}, f1: {f1_score:.3f}') 302 | 303 | tta_f1_scores = [] 304 | for i, ts in enumerate(tta_thresholds): 305 | tta_ts = [city[i] for city in tta] 306 | tp = np.sum([eval_ts[1] for eval_ts in tta_ts]) 307 | fp = np.sum([eval_ts[2] for eval_ts in tta_ts]) 308 | fn = np.sum([eval_ts[3] for eval_ts in tta_ts]) 309 | pre_tta = tp / (tp + fp + 1e-5) 310 | re_tta = tp / (tp + fn + 1e-5) 311 | f1_score_tta = 2 * (pre_tta * re_tta / (pre_tta + re_tta + 1e-5)) 312 | tta_f1_scores.append(f1_score_tta) 313 | print(f'{ts:.2f}: {f1_score_tta:.3f}') 314 | 315 | fig, ax = plt.subplots() 316 | ax.plot(tta_thresholds, tta_f1_scores) 317 | ax.plot(tta_thresholds, [f1_score] * 11, label=f'without tta ({f1_score:.3f})') 318 | ax.legend() 319 | ax.set_xlabel('tta threshold (gt)') 320 | ax.set_ylabel('f1 score') 321 | ax.set_title(cfg_file.stem) 322 | # plt.show() 323 | 324 | 325 | 326 | 327 | 328 | 329 | def subset_pred_results(pred_results, cities): 330 | indices = [i for i, city in enumerate(pred_results['city']) if city in cities] 331 | for key in pred_results.keys(): 332 | sublist_key = [pred_results[key][i] for i in indices] 333 | pred_results[key] = sublist_key 334 | return pred_results 335 | 336 | 337 | def orbit_comparison(cfg_file, net_file): 338 | pass 339 | 340 | 341 | if __name__ == '__main__': 342 | 343 | CFG_DIR = Path.cwd() / 'configs' 344 | NET_DIR = Path('/storage/shafner/urban_change_detection/run_logs/') 345 | STORAGE_DIR = Path('/storage/shafner/urban_change_detection') 346 | 347 | dataset = 'OSCD_dataset' 348 | cfg = 'fusion_9' 349 | 350 | cfg_file = CFG_DIR / f'{cfg}.yaml' 351 | net_file = NET_DIR / cfg / 'final_net.pkl' 352 | 353 | # visual_evaluation(STORAGE_DIR, cfg_file, net_file, 'test', 100, label_pred_only=False) 354 | # visualize_missclassifications(STORAGE_DIR, cfg_file, net_file) 355 | visualize_images(STORAGE_DIR) 356 | # numeric_evaluation(cfg_file, net_file) 357 | 358 | -------------------------------------------------------------------------------- /evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MultiThresholdMetric(object): 5 | def __init__(self, threshold): 6 | 7 | # FIXME Does not operate properly 8 | 9 | ''' 10 | Takes in rasterized and batched images 11 | :param y_true: [B, H, W] 12 | :param y_pred: [B, C, H, W] 13 | :param threshold: [Thresh] 14 | ''' 15 | 16 | self._thresholds = threshold[ :, None, None, None, None] # [Tresh, B, C, H, W] 17 | self._data_dims = (-1, -2, -3, -4) # For a B/W image, it should be [Thresh, B, C, H, W], 18 | 19 | self.TP = 0 20 | self.TN = 0 21 | self.FP = 0 22 | self.FN = 0 23 | 24 | def _normalize_dimensions(self): 25 | ''' Converts y_truth, y_label and threshold to [B, Thres, C, H, W]''' 26 | # Naively assume that all of existing shapes of tensors, we transform [B, H, W] -> [B, Thresh, C, H, W] 27 | self._thresholds = self._thresholds[ :, None, None, None, None] # [Tresh, B, C, H, W] 28 | # self._y_pred = self._y_pred[None, ...] # [B, Thresh, C, ...] 29 | # self._y_true = self._y_true[None,:, None, ...] # [Thresh, B, C, ...] 30 | 31 | def add_sample(self, y_true:torch.Tensor, y_pred): 32 | y_true = y_true.bool()[None,...] # [Thresh, B, C, ...] 33 | y_pred = y_pred[None, ...] # [Thresh, B, C, ...] 34 | y_pred_offset = (y_pred - self._thresholds + 0.5).round().bool() 35 | 36 | self.TP += (y_true & y_pred_offset).sum(dim=self._data_dims).float() 37 | self.TN += (~y_true & ~y_pred_offset).sum(dim=self._data_dims).float() 38 | self.FP += (y_true & ~y_pred_offset).sum(dim=self._data_dims).float() 39 | self.FN += (~y_true & y_pred_offset).sum(dim=self._data_dims).float() 40 | 41 | @property 42 | def precision(self): 43 | if hasattr(self, '_precision'): 44 | '''precision previously computed''' 45 | return self._precision 46 | 47 | denom = (self.TP + self.FP).clamp(10e-05) 48 | self._precision = self.TP / denom 49 | return self._precision 50 | 51 | @property 52 | def recall(self): 53 | if hasattr(self, '_recall'): 54 | '''recall previously computed''' 55 | return self._recall 56 | 57 | denom = (self.TP + self.FN).clamp(10e-05) 58 | self._recall = self.TP / denom 59 | return self._recall 60 | 61 | def compute_basic_metrics(self): 62 | ''' 63 | Computes False Negative Rate and False Positive rate 64 | :return: 65 | ''' 66 | 67 | false_pos_rate = self.FP/(self.FP + self.TN) 68 | false_neg_rate = self.FN / (self.FN + self.TP) 69 | 70 | return false_pos_rate, false_neg_rate 71 | 72 | def compute_f1(self): 73 | denom = (self.precision + self.recall).clamp(10e-05) 74 | return 2 * self.precision * self.recall / denom 75 | 76 | 77 | def true_pos(y_true: torch.Tensor, y_pred: torch.Tensor, dim=0): 78 | return torch.sum(y_true * torch.round(y_pred), dim=dim) 79 | 80 | 81 | def false_pos(y_true: torch.Tensor, y_pred: torch.Tensor, dim=0): 82 | return torch.sum(y_true * (1. - torch.round(y_pred)), dim=dim) 83 | 84 | 85 | def false_neg(y_true: torch.Tensor, y_pred: torch.Tensor, dim=0): 86 | return torch.sum((1. - y_true) * torch.round(y_pred), dim=dim) 87 | 88 | 89 | def precision(y_true: torch.Tensor, y_pred: torch.Tensor, dim): 90 | denominator = (true_pos(y_true, y_pred, dim) + false_pos(y_true, y_pred, dim)) 91 | denominator = torch.clamp(denominator, 10e-05) 92 | return true_pos(y_true, y_pred, dim) / denominator 93 | 94 | 95 | def recall(y_true: torch.Tensor, y_pred: torch.Tensor, dim): 96 | denominator = (true_pos(y_true, y_pred, dim) + false_neg(y_true, y_pred, dim)) 97 | denominator = torch.clamp(denominator, 10e-05) 98 | return true_pos(y_true, y_pred, dim) / denominator 99 | 100 | 101 | def f1_score(gts: torch.Tensor, preds: torch.Tensor, dim=(-1, -2)): 102 | gts = gts.float() 103 | preds = preds.float() 104 | 105 | with torch.no_grad(): 106 | recall_val = recall(gts, preds, dim) 107 | precision_val = precision(gts, preds, dim) 108 | denom = torch.clamp( (recall_val + precision_val), 10e-5) 109 | 110 | f1 = 2. * recall_val * precision_val / denom 111 | 112 | return f1 113 | -------------------------------------------------------------------------------- /experiment_manager/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SebastianHafner/urban_change_detection/4632539dbd98cc1be6817a9a7b236cfa4fbc652d/experiment_manager/__init__.py -------------------------------------------------------------------------------- /experiment_manager/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def default_argument_parser(): 4 | """ 5 | Create a parser with some common arguments used by detectron2 users. 6 | 7 | Returns: 8 | argparse.ArgumentParser: 9 | """ 10 | parser = argparse.ArgumentParser(description="Experiment Args") 11 | parser.add_argument('-c',"--config-file", dest='config_file', default="", required=True, metavar="FILE", help="path to config file") 12 | parser.add_argument('-d', '--data-dir', dest='data_dir', type=str, 13 | default='', help='dataset directory') 14 | parser.add_argument('-o', '--output-dir', dest='log_dir', type=str, 15 | default='', help='output directory') 16 | parser.add_argument( 17 | "--resume", 18 | dest='resume', 19 | action="store_true", 20 | help="whether to attempt to resume from the checkpoint directory", 21 | ) 22 | parser.add_argument('--resume-from', dest='resume_from', type=str, 23 | default='', help='path of which the model will be loaded from') 24 | parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") 25 | parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") 26 | 27 | # Hacky hack 28 | # parser.add_argument("--eval-training", action="store_true", help="perform evaluation on training set only") 29 | 30 | parser.add_argument( 31 | "opts", 32 | help="Modify config options using the command-line", 33 | default=None, 34 | nargs=argparse.REMAINDER, 35 | ) 36 | return parser 37 | -------------------------------------------------------------------------------- /experiment_manager/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import CfgNode, new_config, global_config 2 | 3 | __all__ = [ 4 | "CfgNode", 5 | "new_config", 6 | "global_config" 7 | ] 8 | -------------------------------------------------------------------------------- /experiment_manager/config/config.py: -------------------------------------------------------------------------------- 1 | # Largely taken from FVCore and Detectron2 2 | 3 | import logging 4 | from argparse import ArgumentParser 5 | from tabulate import tabulate 6 | from collections import OrderedDict 7 | import yaml 8 | from fvcore.common.config import CfgNode as _CfgNode 9 | # TODO Initialize Cfg from Base Config 10 | class CfgNode(_CfgNode): 11 | """ 12 | The same as `fvcore.common.config.CfgNode`, but different in: 13 | 14 | 1. Use unsafe yaml loading by default. 15 | Note that this may lead to arbitrary code execution: you must not 16 | load a config file from untrusted sources before manually inspecting 17 | the content of the file. 18 | 2. Support config versioning. 19 | When attempting to merge an old config, it will convert the old config automatically. 20 | 21 | """ 22 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 23 | 24 | # Always allow merging new configs 25 | self.__dict__[CfgNode.NEW_ALLOWED] = True 26 | super(CfgNode, self).__init__(init_dict, key_list, True) 27 | 28 | 29 | # Note that the default value of allow_unsafe is changed to True 30 | def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None: 31 | loaded_cfg = _CfgNode.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) 32 | loaded_cfg = type(self)(loaded_cfg) 33 | 34 | # defaults.py needs to import CfgNode 35 | self.merge_from_other_cfg(loaded_cfg) 36 | 37 | def new_config(): 38 | ''' 39 | Creates a new config based on the default config file 40 | :return: 41 | ''' 42 | from .defaults import C 43 | return C.clone() 44 | 45 | global_config = CfgNode() 46 | 47 | class HPConfig(): 48 | ''' 49 | A hyperparameter config object 50 | ''' 51 | def __init__(self): 52 | self.data = {} 53 | self.argparser = ArgumentParser() 54 | 55 | def create_hp(self, name, value, argparse=False, argparse_args={}): 56 | ''' 57 | Creates a new hyperparameter, optionally sourced from argparse external arguments 58 | :param name: 59 | :param value: 60 | :param argparse: 61 | :param argparse_args: 62 | :return: 63 | ''' 64 | self.data[name] = value 65 | if argparse: 66 | datatype = type(value) 67 | # Handle boolean type 68 | if datatype == bool: 69 | self.argparser.add_argument(f'--{name}', action='store_true', *argparse_args) 70 | else: 71 | self.argparser.add_argument(f'--{name}', type=datatype, *argparse_args) 72 | 73 | def parse_args(self): 74 | ''' 75 | Performs a parse operation from the program arguments 76 | :return: 77 | ''' 78 | args = self.argparser.parse_known_args()[0] 79 | for key, value in args.__dict__.items(): 80 | # Arg not present, using default 81 | if value is None: continue 82 | self.data[key] = value 83 | 84 | def __str__(self): 85 | ''' 86 | Converts the HP into a human readable string format 87 | :return: 88 | ''' 89 | table = {'hyperparameter': self.data.keys(), 90 | 'values': list(self.data.values()), 91 | } 92 | return tabulate(table, headers='keys', tablefmt="fancy_grid", ) 93 | 94 | 95 | def save_yml(self, file_path): 96 | ''' 97 | Save HP config to a yaml file 98 | :param file_path: 99 | :return: 100 | ''' 101 | with open(file_path, 'w') as file: 102 | yaml.dump(self.data, file, default_flow_style=False) 103 | 104 | def load_yml(self, file_path): 105 | ''' 106 | Load HP Config from a yaml file 107 | :param file_path: 108 | :return: 109 | ''' 110 | with open(file_path, 'r') as file: 111 | yml_hp = yaml.safe_load(file) 112 | 113 | for hp_name, hp_value in yml_hp.items(): 114 | self.data[hp_name] = hp_value 115 | 116 | def __getattr__(self, name): 117 | return self.data[name] 118 | 119 | def config(name='default') -> HPConfig: 120 | ''' 121 | Retrives a configuration (optionally, creating it) of the run. If no `name` provided, then 'default' is used 122 | :param name: Optional name of the 123 | :return: HPConfig object 124 | ''' 125 | # Configuration doesn't exist yet 126 | # if name not in _config_data.keys(): 127 | # _config_data[name] = HPConfig() 128 | # return _config_data[name] 129 | pass 130 | 131 | def load_from_yml(): 132 | ''' 133 | Load a HPConfig from a YML file 134 | :return: 135 | ''' 136 | pass -------------------------------------------------------------------------------- /experiment_manager/config/defaults.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is a global default file, each individual project will have their own respective default file 3 | ''' 4 | from .config import CfgNode as CN 5 | 6 | C = CN() 7 | 8 | C.CONFIG_DIR = 'config/' 9 | C.OUTPUT_BASE_DIR = 'output/' 10 | 11 | C.SEED = 7 12 | 13 | C.MODEL = CN() 14 | C.MODEL.TYPE = 'unet' 15 | C.MODEL.OUT_CHANNELS = 1 16 | C.MODEL.IN_CHANNELS = 2 17 | C.MODEL.LOSS_TYPE = 'FrankensteinLoss' 18 | 19 | C.DATALOADER = CN() 20 | C.DATALOADER.NUM_WORKER = 8 21 | C.DATALOADER.SHUFFLE = True 22 | 23 | C.DATASET = CN() 24 | C.DATASET.PATH = '' 25 | C.DATASET.MODE = '' 26 | C.DATASET.SENTINEL1 = CN() 27 | C.DATASET.SENTINEL1.BANDS = ['VV', 'VH'] 28 | C.DATASET.SENTINEL1.TEMPORAL_MODE = 'bi-temporal' 29 | C.DATASET.SENTINEL2 = CN() 30 | C.DATASET.SENTINEL2.BANDS = ['B2', 'B3', 'B4', 'B8', 'B11', 'B12'] 31 | C.DATASET.SENTINEL2.TEMPORAL_MODE = 'bi-temporal' 32 | C.DATASET.ALL_CITIES = [] 33 | C.DATASET.TEST_CITIES = [] 34 | 35 | C.OUTPUT_BASE_DIR = '' 36 | 37 | C.TRAINER = CN() 38 | C.TRAINER.LR = 1e-4 39 | C.TRAINER.BATCH_SIZE = 16 40 | C.TRAINER.EPOCHS = 50 41 | 42 | C.AUGMENTATION = CN() 43 | C.AUGMENTATION.CROP_TYPE = 'none' 44 | C.AUGMENTATION.CROP_SIZE = 32 45 | C.RANDOM_FLIP = True 46 | C.RANDOM_ROTATE = True 47 | 48 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import cv2 3 | from skimage import io 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | import torch 8 | import network 9 | import tools 10 | 11 | def sliding_window(IMAGE, patch_size, step): 12 | prediction = np.zeros((IMAGE.shape[3], IMAGE.shape[4], 2)) 13 | x=0 14 | while (x!=IMAGE.shape[0]): 15 | y=0 16 | while(y!=IMAGE.shape[1]): 17 | 18 | if (not y+patch_size > IMAGE.shape[4]) and (not x+patch_size > IMAGE.shape[3]): 19 | patch = IMAGE[:, :, :, x:x + patch_size, y:y + patch_size] 20 | patch = tools.to_cuda(torch.from_numpy(patch).float()) 21 | output = model(patch) 22 | output = output.cpu().data.numpy().squeeze() 23 | output = np.transpose(output, (1,2,0)) 24 | for i in range(0, patch_size): 25 | for j in range(0, patch_size): 26 | prediction[x+i, y+j] += (output[i,j,:]) 27 | 28 | stride=step 29 | 30 | if y + patch_size == IMAGE.shape[4]: 31 | break 32 | 33 | if y + patch_size > IMAGE.shape[4]: 34 | y = IMAGE.shape[4] - patch_size 35 | else: 36 | y = y+stride 37 | 38 | if x + patch_size == IMAGE.shape[3]: 39 | break 40 | 41 | if x + patch_size > IMAGE.shape[3]: 42 | x = IMAGE.shape[3] - patch_size 43 | else: 44 | x = x+stride 45 | 46 | final_pred = np.zeros((IMAGE.shape[3], IMAGE.shape[4])) 47 | print('ok') 48 | for i in range(0, final_pred.shape[0]): 49 | for j in range(0, final_pred.shape[1]): 50 | final_pred[i,j] = np.argmax(prediction[i,j]) 51 | 52 | final_pred[final_pred==1]=2 53 | final_pred[final_pred==0]=1 54 | 55 | return final_pred 56 | 57 | test_areas = ['brasilia', 'milano', 'norcia', 'chongqing', 'dubai', 'lasvegas', 'montpellier', 'rio', 'saclay_w', 'valencia'] 58 | test_areas=['brasilia'] 59 | nb_dates = [1,2,3,4,5] 60 | patch_size = 32 61 | step = 16 62 | model=network.U_Net(4,2,nb_dates) 63 | BATCH_SIZE=1 64 | save_dir = 'PREDICTIONS' 65 | os.mkdir(save_dir) 66 | model.load_state_dict(torch.load('./saved_models/model_22.pt')) #ena apo to 5D 67 | model=tools.to_cuda(model) 68 | model = model.eval() 69 | 70 | FOLDER = './IMGS_PREPROCESSED/' 71 | 72 | for id in test_areas: 73 | print('test_area', id) 74 | 75 | imgs = [] 76 | for nd in nb_dates: 77 | img = np.load(FOLDER + id + '/' + id + '_{}.npy'.format(str(nd))) 78 | img = np.concatenate((img[:,:,1:4], np.reshape(img[:,:,7], (img.shape[0],img.shape[1],1))), 2) 79 | img = np.transpose(img, (2,0,1)) 80 | imgs.append(img) 81 | imgs = np.asarray(imgs) 82 | imgs = np.reshape(imgs, (imgs.shape[0], 1, imgs.shape[1], imgs.shape[2], imgs.shape[3])) 83 | 84 | pred = sliding_window(imgs, patch_size, step) 85 | cv2.imwrite('./' + save_dir + '/' + id + '.tif', pred) 86 | -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def soft_dice_loss(input:torch.Tensor, target:torch.Tensor): 5 | input_sigmoid = torch.sigmoid(input) 6 | eps = 1e-6 7 | 8 | iflat = input_sigmoid.flatten() 9 | tflat = target.flatten() 10 | intersection = (iflat * tflat).sum() 11 | 12 | return 1 - ((2. * intersection) / 13 | (iflat.sum() + tflat.sum() + eps)) 14 | 15 | 16 | def soft_dice_loss_multi_class(input:torch.Tensor, y:torch.Tensor): 17 | p = torch.softmax(input, dim=1) 18 | eps = 1e-6 19 | 20 | sum_dims= (0, 2, 3) # Batch, height, width 21 | 22 | intersection = (y * p).sum(dim=sum_dims) 23 | denom = (y.sum(dim=sum_dims) + p.sum(dim=sum_dims)).clamp(eps) 24 | 25 | loss = 1 - (2. * intersection / denom).mean() 26 | return loss 27 | 28 | 29 | def soft_dice_loss_multi_class_debug(input:torch.Tensor, y:torch.Tensor): 30 | p = torch.softmax(input, dim=1) 31 | eps = 1e-6 32 | 33 | sum_dims= (0, 2, 3) # Batch, height, width 34 | 35 | intersection = (y * p).sum(dim=sum_dims) 36 | denom = (y.sum(dim=sum_dims) + p.sum(dim=sum_dims)).clamp(eps) 37 | 38 | loss = 1 - (2. * intersection / denom).mean() 39 | loss_components = 1 - 2 * intersection/denom 40 | return loss, loss_components 41 | 42 | 43 | def generalized_soft_dice_loss_multi_class(input:torch.Tensor, y:torch.Tensor): 44 | p = torch.softmax(input, dim=1) 45 | eps = 1e-12 46 | 47 | # TODO [B, C, H, W] -> [C, B, H, W] because softdice includes all pixels 48 | 49 | sum_dims= (0, 2, 3) # Batch, height, width 50 | ysum = y.sum(dim=sum_dims) 51 | wc = 1 / (ysum ** 2 + eps) 52 | intersection = ((y * p).sum(dim=sum_dims) * wc).sum() 53 | denom = ((ysum + p.sum(dim=sum_dims)) * wc).sum() 54 | 55 | loss = 1 - (2. * intersection / denom) 56 | return loss 57 | 58 | 59 | def jaccard_like_loss_multi_class(input:torch.Tensor, y:torch.Tensor): 60 | p = torch.softmax(input, dim=1) 61 | eps = 1e-6 62 | 63 | # TODO [B, C, H, W] -> [C, B, H, W] because softdice includes all pixels 64 | 65 | sum_dims= (0, 2, 3) # Batch, height, width 66 | 67 | intersection = (y * p).sum(dim=sum_dims) 68 | denom = (y ** 2 + p ** 2).sum(dim=sum_dims) + (y*p).sum(dim=sum_dims) + eps 69 | 70 | loss = 1 - (2. * intersection / denom).mean() 71 | return loss 72 | 73 | 74 | def jaccard_like_loss(input:torch.Tensor, target:torch.Tensor): 75 | input_sigmoid = torch.sigmoid(input) 76 | eps = 1e-6 77 | 78 | iflat = input_sigmoid.flatten() 79 | tflat = target.flatten() 80 | intersection = (iflat * tflat).sum() 81 | denom = (iflat**2 + tflat**2).sum() - (iflat * tflat).sum() + eps 82 | 83 | return 1 - ((2. * intersection) / denom) 84 | 85 | 86 | def jaccard_like_balanced_loss(input:torch.Tensor, target:torch.Tensor): 87 | input_sigmoid = torch.sigmoid(input) 88 | eps = 1e-6 89 | 90 | iflat = input_sigmoid.flatten() 91 | tflat = target.flatten() 92 | intersection = (iflat * tflat).sum() 93 | denom = (iflat**2 + tflat**2).sum() - (iflat * tflat).sum() + eps 94 | piccard = (2. * intersection)/denom 95 | 96 | n_iflat = 1-iflat 97 | n_tflat = 1-tflat 98 | neg_intersection = (n_iflat * n_tflat).sum() 99 | neg_denom = (n_iflat**2 + n_tflat**2).sum() - (n_iflat * n_tflat).sum() 100 | n_piccard = (2. * neg_intersection)/neg_denom 101 | 102 | return 1 - piccard - n_piccard 103 | 104 | 105 | def soft_dice_loss_balanced(input:torch.Tensor, target:torch.Tensor): 106 | input_sigmoid = torch.sigmoid(input) 107 | eps = 1e-6 108 | 109 | iflat = input_sigmoid.flatten() 110 | tflat = target.flatten() 111 | intersection = (iflat * tflat).sum() 112 | 113 | dice_pos = ((2. * intersection) / 114 | (iflat.sum() + tflat.sum() + eps)) 115 | 116 | negatiev_intersection = ((1-iflat) * (1 - tflat)).sum() 117 | dice_neg = (2 * negatiev_intersection) / ((1-iflat).sum() + (1-tflat).sum() + eps) 118 | 119 | return 1 - dice_pos - dice_neg -------------------------------------------------------------------------------- /make_xys.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import io 3 | from skimage.transform import rotate, resize 4 | import os 5 | import cv2 6 | import pandas as pd 7 | from pathlib import Path 8 | 9 | 10 | def shuffle(vector): 11 | vector = np.asarray(vector) 12 | p = np.random.permutation(len(vector)) 13 | vector = vector[p] 14 | return vector 15 | 16 | 17 | def sliding_window_train(i_city, labeled_areas, label, window_size, step): 18 | city = [] 19 | fpatches_labels = [] 20 | 21 | x = 0 22 | while x != label.shape[0]: 23 | y = 0 24 | while y != label.shape[1]: 25 | 26 | if (not y + window_size > label.shape[1]) and (not x + window_size > label.shape[0]): 27 | line = np.array([x, y, labeled_areas.index(i_city), 0]) 28 | # (x,y) are the saved coordinates, 29 | # labeled_areas.index(i_city)... are the image ids, e.g according to train_areas, 30 | # the indice for abudhabi in the list is 0, for beihai it is 1, for beirut is 3, etc.. 31 | # the fourth element which has been set as 0, represents the transformadion index, 32 | # which in this case indicates that no data augmentation will be performed for the 33 | # specific patch 34 | 35 | city.append(line) 36 | 37 | # CONDITIONS 38 | new_patch_label = label[x:x + window_size, y:y + window_size] 39 | ff = np.where(new_patch_label == 2) 40 | perc = ff[0].shape[0] / float(window_size * window_size) 41 | if ff[0].shape[0] == 0: 42 | stride = window_size 43 | else: 44 | stride = step 45 | if perc >= 0.05: 46 | # if percentage of change exceeds a threshold, perform data augmentation on this patch 47 | # Below, 1, 2, 3 are transformation indexes that will be used by the custom dataloader 48 | # to perform various rotations 49 | line = np.array([x, y, labeled_areas.index(i_city), 1]) 50 | city.append(line) 51 | line = np.array([x, y, labeled_areas.index(i_city), 2]) 52 | city.append(line) 53 | line=np.array([x, y, labeled_areas.index(i_city), 3]) 54 | city.append(line) 55 | # CONDITIONS 56 | 57 | if y + window_size == label.shape[1]: 58 | break 59 | 60 | if y + window_size > label.shape[1]: 61 | y = label.shape[1] - window_size 62 | else: 63 | y = y + stride 64 | 65 | if x + window_size == label.shape[0]: 66 | break 67 | 68 | if x + window_size > label.shape[0]: 69 | x = label.shape[0] - window_size 70 | else: 71 | x = x + stride 72 | 73 | return np.asarray(city) 74 | 75 | 76 | if __name__ == '__main__': 77 | 78 | train_areas = ['abudhabi', 'beihai', 'aguasclaras', 'beirut', 'bercy', 'bordeaux', 'cupertino', 79 | 'hongkong', 'mumbai', 'nantes', 'rennes', 'saclay_e', 'pisa', 'rennes'] 80 | 81 | FOLDER = Path('C:/Users/hafne/urban_change_detection/data/Onera/') 82 | 83 | step = 6 84 | patch_s = 32 85 | 86 | cities = [] 87 | for i_city in train_areas: 88 | file = FOLDER / 'labels' / i_city / 'cm' / f'{i_city}-cm.tif' 89 | print('icity', i_city) 90 | train_gt = io.imread(file) 91 | xy_city = sliding_window_train(i_city, train_areas, train_gt, patch_s, step) 92 | cities.append(xy_city) 93 | 94 | # from all training (x,y) locations, divide 4/5 for training and 1/5 for validation 95 | final_cities = np.concatenate(cities, axis=0) 96 | size_len = len(final_cities) 97 | portion = int(size_len / 5) 98 | final_cities = shuffle(final_cities) 99 | final_cities_train = final_cities[:-portion] 100 | final_cities_val = final_cities[-portion:] 101 | 102 | # save train to csv file 103 | df = pd.DataFrame({'X': list(final_cities_train[:, 0]), 104 | 'Y': list(final_cities_train[:, 1]), 105 | 'image_ID': list(final_cities_train[:, 2]), 106 | 'transform_ID': list(final_cities_train[:, 3]), 107 | }) 108 | train_file = FOLDER / 'myxys_train.csv' 109 | df.to_csv(str(train_file), index=False, columns=["X", "Y", "image_ID", "transform_ID"]) 110 | 111 | # save val to csv file 112 | df = pd.DataFrame({'X': list(final_cities_val[:, 0]), 113 | 'Y': list(final_cities_val[:, 1]), 114 | 'image_ID': list(final_cities_val[:, 2]), 115 | 'transform_ID': list(final_cities_val[:, 3]), 116 | }) 117 | val_file = FOLDER / 'myxys_val.csv' 118 | df.to_csv(str(val_file), index=False, columns=["X", "Y", "image_ID", "transform_ID"]) 119 | -------------------------------------------------------------------------------- /network_summary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsummary import summary 3 | from networks.network_loader import load_network 4 | from experiment_manager.config import new_config 5 | from pathlib import Path 6 | CFG_DIR = Path.cwd() / 'configs' 7 | 8 | 9 | # loading cfg for inference 10 | def load_cfg(cfg_file: Path): 11 | cfg = new_config() 12 | cfg.merge_from_file(str(cfg_file)) 13 | return cfg 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | cfg = 'fusion_dualstream_1' 19 | 20 | cfg_file = CFG_DIR / f'{cfg}.yaml' 21 | 22 | # loading cfg and network 23 | cfg = load_cfg(cfg_file) 24 | 25 | net = load_network(cfg) 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | net.to(device) 28 | 29 | # TODO: replace by number of bands function 30 | h = w = cfg.AUGMENTATION.CROP_SIZE 31 | img_channels = cfg.MODEL.IN_CHANNELS // 2 32 | img_size = (img_channels, h, w) 33 | 34 | summary(net, input_size=[img_size, img_size]) 35 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SebastianHafner/urban_change_detection/4632539dbd98cc1be6817a9a7b236cfa4fbc652d/networks/__init__.py -------------------------------------------------------------------------------- /networks/daudt2018.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | 11 | class SiameseUNetConc(nn.Module): 12 | """SiamUnet_conc segmentation network.""" 13 | 14 | def __init__(self, input_nbr, label_nbr): 15 | super(SiameseUNetConc, self).__init__() 16 | 17 | self.input_nbr = input_nbr 18 | 19 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 20 | self.bn11 = nn.BatchNorm2d(16) 21 | self.do11 = nn.Dropout2d(p=0.2) 22 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 23 | self.bn12 = nn.BatchNorm2d(16) 24 | self.do12 = nn.Dropout2d(p=0.2) 25 | 26 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 27 | self.bn21 = nn.BatchNorm2d(32) 28 | self.do21 = nn.Dropout2d(p=0.2) 29 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 30 | self.bn22 = nn.BatchNorm2d(32) 31 | self.do22 = nn.Dropout2d(p=0.2) 32 | 33 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 34 | self.bn31 = nn.BatchNorm2d(64) 35 | self.do31 = nn.Dropout2d(p=0.2) 36 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 37 | self.bn32 = nn.BatchNorm2d(64) 38 | self.do32 = nn.Dropout2d(p=0.2) 39 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 40 | self.bn33 = nn.BatchNorm2d(64) 41 | self.do33 = nn.Dropout2d(p=0.2) 42 | 43 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 44 | self.bn41 = nn.BatchNorm2d(128) 45 | self.do41 = nn.Dropout2d(p=0.2) 46 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 47 | self.bn42 = nn.BatchNorm2d(128) 48 | self.do42 = nn.Dropout2d(p=0.2) 49 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 50 | self.bn43 = nn.BatchNorm2d(128) 51 | self.do43 = nn.Dropout2d(p=0.2) 52 | 53 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 54 | 55 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1) 56 | self.bn43d = nn.BatchNorm2d(128) 57 | self.do43d = nn.Dropout2d(p=0.2) 58 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 59 | self.bn42d = nn.BatchNorm2d(128) 60 | self.do42d = nn.Dropout2d(p=0.2) 61 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 62 | self.bn41d = nn.BatchNorm2d(64) 63 | self.do41d = nn.Dropout2d(p=0.2) 64 | 65 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 66 | 67 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1) 68 | self.bn33d = nn.BatchNorm2d(64) 69 | self.do33d = nn.Dropout2d(p=0.2) 70 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 71 | self.bn32d = nn.BatchNorm2d(64) 72 | self.do32d = nn.Dropout2d(p=0.2) 73 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 74 | self.bn31d = nn.BatchNorm2d(32) 75 | self.do31d = nn.Dropout2d(p=0.2) 76 | 77 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 78 | 79 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1) 80 | self.bn22d = nn.BatchNorm2d(32) 81 | self.do22d = nn.Dropout2d(p=0.2) 82 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 83 | self.bn21d = nn.BatchNorm2d(16) 84 | self.do21d = nn.Dropout2d(p=0.2) 85 | 86 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 87 | 88 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1) 89 | self.bn12d = nn.BatchNorm2d(16) 90 | self.do12d = nn.Dropout2d(p=0.2) 91 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 92 | 93 | self.sm = nn.LogSoftmax(dim=1) 94 | 95 | def forward(self, x1, x2): 96 | 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 102 | 103 | 104 | # Stage 2 105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 108 | 109 | # Stage 3 110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 114 | 115 | # Stage 4 116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 120 | 121 | 122 | #################################################### 123 | # Stage 1 124 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 125 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 126 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 127 | 128 | # Stage 2 129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 132 | 133 | # Stage 3 134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 138 | 139 | # Stage 4 140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 144 | 145 | 146 | #################################################### 147 | # Stage 4d 148 | x4d = self.upconv4(x4p) 149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 150 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) 151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 154 | 155 | # Stage 3d 156 | x3d = self.upconv3(x41d) 157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 158 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) 159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 162 | 163 | # Stage 2d 164 | x2d = self.upconv2(x31d) 165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 166 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) 167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 169 | 170 | # Stage 1d 171 | x1d = self.upconv1(x21d) 172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 173 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) 174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 175 | x11d = self.conv11d(x12d) 176 | 177 | # return self.sm(x11d) 178 | return x11d 179 | 180 | 181 | class SiameseUNetDiff(nn.Module): 182 | """SiamUnet_diff segmentation network.""" 183 | 184 | def __init__(self, input_nbr, label_nbr): 185 | super(SiameseUNetDiff, self).__init__() 186 | 187 | self.input_nbr = input_nbr 188 | 189 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 190 | self.bn11 = nn.BatchNorm2d(16) 191 | self.do11 = nn.Dropout2d(p=0.2) 192 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 193 | self.bn12 = nn.BatchNorm2d(16) 194 | self.do12 = nn.Dropout2d(p=0.2) 195 | 196 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 197 | self.bn21 = nn.BatchNorm2d(32) 198 | self.do21 = nn.Dropout2d(p=0.2) 199 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 200 | self.bn22 = nn.BatchNorm2d(32) 201 | self.do22 = nn.Dropout2d(p=0.2) 202 | 203 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 204 | self.bn31 = nn.BatchNorm2d(64) 205 | self.do31 = nn.Dropout2d(p=0.2) 206 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 207 | self.bn32 = nn.BatchNorm2d(64) 208 | self.do32 = nn.Dropout2d(p=0.2) 209 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 210 | self.bn33 = nn.BatchNorm2d(64) 211 | self.do33 = nn.Dropout2d(p=0.2) 212 | 213 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 214 | self.bn41 = nn.BatchNorm2d(128) 215 | self.do41 = nn.Dropout2d(p=0.2) 216 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 217 | self.bn42 = nn.BatchNorm2d(128) 218 | self.do42 = nn.Dropout2d(p=0.2) 219 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 220 | self.bn43 = nn.BatchNorm2d(128) 221 | self.do43 = nn.Dropout2d(p=0.2) 222 | 223 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 224 | 225 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 226 | self.bn43d = nn.BatchNorm2d(128) 227 | self.do43d = nn.Dropout2d(p=0.2) 228 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 229 | self.bn42d = nn.BatchNorm2d(128) 230 | self.do42d = nn.Dropout2d(p=0.2) 231 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 232 | self.bn41d = nn.BatchNorm2d(64) 233 | self.do41d = nn.Dropout2d(p=0.2) 234 | 235 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 236 | 237 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 238 | self.bn33d = nn.BatchNorm2d(64) 239 | self.do33d = nn.Dropout2d(p=0.2) 240 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 241 | self.bn32d = nn.BatchNorm2d(64) 242 | self.do32d = nn.Dropout2d(p=0.2) 243 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 244 | self.bn31d = nn.BatchNorm2d(32) 245 | self.do31d = nn.Dropout2d(p=0.2) 246 | 247 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 248 | 249 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 250 | self.bn22d = nn.BatchNorm2d(32) 251 | self.do22d = nn.Dropout2d(p=0.2) 252 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 253 | self.bn21d = nn.BatchNorm2d(16) 254 | self.do21d = nn.Dropout2d(p=0.2) 255 | 256 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 257 | 258 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 259 | self.bn12d = nn.BatchNorm2d(16) 260 | self.do12d = nn.Dropout2d(p=0.2) 261 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 262 | 263 | self.sm = nn.LogSoftmax(dim=1) 264 | 265 | def forward(self, x1, x2): 266 | 267 | 268 | """Forward method.""" 269 | # Stage 1 270 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 271 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 272 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 273 | 274 | 275 | # Stage 2 276 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 277 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 278 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 279 | 280 | # Stage 3 281 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 282 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 283 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 284 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 285 | 286 | # Stage 4 287 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 288 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 289 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 290 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 291 | 292 | #################################################### 293 | # Stage 1 294 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 295 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 296 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 297 | 298 | 299 | # Stage 2 300 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 301 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 302 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 303 | 304 | # Stage 3 305 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 306 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 307 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 308 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 309 | 310 | # Stage 4 311 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 312 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 313 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 314 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 315 | 316 | 317 | 318 | # Stage 4d 319 | x4d = self.upconv4(x4p) 320 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 321 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) 322 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 323 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 324 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 325 | 326 | # Stage 3d 327 | x3d = self.upconv3(x41d) 328 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 329 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) 330 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 331 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 332 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 333 | 334 | # Stage 2d 335 | x2d = self.upconv2(x31d) 336 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 337 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) 338 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 339 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 340 | 341 | # Stage 1d 342 | x1d = self.upconv1(x21d) 343 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 344 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) 345 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 346 | x11d = self.conv11d(x12d) 347 | 348 | # return self.sm(x11d) 349 | return x11d 350 | 351 | 352 | class UNet(nn.Module): 353 | """EF segmentation network.""" 354 | 355 | def __init__(self, input_nbr, label_nbr): 356 | super(UNet, self).__init__() 357 | 358 | self.input_nbr = input_nbr 359 | 360 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 361 | self.bn11 = nn.BatchNorm2d(16) 362 | self.do11 = nn.Dropout2d(p=0.2) 363 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 364 | self.bn12 = nn.BatchNorm2d(16) 365 | self.do12 = nn.Dropout2d(p=0.2) 366 | 367 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 368 | self.bn21 = nn.BatchNorm2d(32) 369 | self.do21 = nn.Dropout2d(p=0.2) 370 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 371 | self.bn22 = nn.BatchNorm2d(32) 372 | self.do22 = nn.Dropout2d(p=0.2) 373 | 374 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 375 | self.bn31 = nn.BatchNorm2d(64) 376 | self.do31 = nn.Dropout2d(p=0.2) 377 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 378 | self.bn32 = nn.BatchNorm2d(64) 379 | self.do32 = nn.Dropout2d(p=0.2) 380 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 381 | self.bn33 = nn.BatchNorm2d(64) 382 | self.do33 = nn.Dropout2d(p=0.2) 383 | 384 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 385 | self.bn41 = nn.BatchNorm2d(128) 386 | self.do41 = nn.Dropout2d(p=0.2) 387 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 388 | self.bn42 = nn.BatchNorm2d(128) 389 | self.do42 = nn.Dropout2d(p=0.2) 390 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 391 | self.bn43 = nn.BatchNorm2d(128) 392 | self.do43 = nn.Dropout2d(p=0.2) 393 | 394 | 395 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 396 | 397 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 398 | self.bn43d = nn.BatchNorm2d(128) 399 | self.do43d = nn.Dropout2d(p=0.2) 400 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 401 | self.bn42d = nn.BatchNorm2d(128) 402 | self.do42d = nn.Dropout2d(p=0.2) 403 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 404 | self.bn41d = nn.BatchNorm2d(64) 405 | self.do41d = nn.Dropout2d(p=0.2) 406 | 407 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 408 | 409 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 410 | self.bn33d = nn.BatchNorm2d(64) 411 | self.do33d = nn.Dropout2d(p=0.2) 412 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 413 | self.bn32d = nn.BatchNorm2d(64) 414 | self.do32d = nn.Dropout2d(p=0.2) 415 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 416 | self.bn31d = nn.BatchNorm2d(32) 417 | self.do31d = nn.Dropout2d(p=0.2) 418 | 419 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 420 | 421 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 422 | self.bn22d = nn.BatchNorm2d(32) 423 | self.do22d = nn.Dropout2d(p=0.2) 424 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 425 | self.bn21d = nn.BatchNorm2d(16) 426 | self.do21d = nn.Dropout2d(p=0.2) 427 | 428 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 429 | 430 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 431 | self.bn12d = nn.BatchNorm2d(16) 432 | self.do12d = nn.Dropout2d(p=0.2) 433 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 434 | 435 | self.sm = nn.LogSoftmax(dim=1) 436 | 437 | def forward(self, x1, x2): 438 | 439 | x = torch.cat((x1, x2), 1) 440 | 441 | """Forward method.""" 442 | # Stage 1 443 | x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) 444 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 445 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2) 446 | 447 | # Stage 2 448 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 449 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 450 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2) 451 | 452 | # Stage 3 453 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 454 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 455 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 456 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2) 457 | 458 | # Stage 4 459 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 460 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 461 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 462 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2) 463 | 464 | 465 | # Stage 4d 466 | x4d = self.upconv4(x4p) 467 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) 468 | x4d = torch.cat((pad4(x4d), x43), 1) 469 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 470 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 471 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 472 | 473 | # Stage 3d 474 | x3d = self.upconv3(x41d) 475 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) 476 | x3d = torch.cat((pad3(x3d), x33), 1) 477 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 478 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 479 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 480 | 481 | # Stage 2d 482 | x2d = self.upconv2(x31d) 483 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) 484 | x2d = torch.cat((pad2(x2d), x22), 1) 485 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 486 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 487 | 488 | # Stage 1d 489 | x1d = self.upconv1(x21d) 490 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) 491 | x1d = torch.cat((pad1(x1d), x12), 1) 492 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 493 | x11d = self.conv11d(x12d) 494 | 495 | return x11d 496 | # return self.sm(x11d) -------------------------------------------------------------------------------- /networks/network_loader.py: -------------------------------------------------------------------------------- 1 | from networks.ours import UNet, DualStreamUNet 2 | 3 | 4 | def load_network(cfg): 5 | architecture = cfg.MODEL.TYPE 6 | if architecture == 'unet': 7 | return UNet(cfg) 8 | elif architecture == 'dualstreamunet': 9 | return DualStreamUNet(cfg) 10 | else: 11 | return UNet(cfg) 12 | -------------------------------------------------------------------------------- /networks/ours.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch.nn.functional as F 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.padding import ReplicationPad2d 6 | 7 | 8 | class DualStreamUNet(nn.Module): 9 | 10 | def __init__(self, cfg): 11 | super(DualStreamUNet, self).__init__() 12 | assert (cfg.DATASET.MODE == 'fusion') 13 | self._cfg = cfg 14 | out = cfg.MODEL.OUT_CHANNELS 15 | topology = cfg.MODEL.TOPOLOGY 16 | 17 | # sentinel-1 unet stream 18 | n_s1_bands = len(cfg.DATASET.SENTINEL1_BANDS) 19 | s1_in = n_s1_bands * 2 # t1 and t2 20 | self.s1_stream = UNet(cfg, n_channels=s1_in, n_classes=out, topology=topology, enable_outc=False) 21 | self.n_s1_bands = n_s1_bands 22 | 23 | # sentinel-2 unet stream 24 | n_s2_bands = len(cfg.DATASET.SENTINEL2_BANDS) 25 | s2_in = n_s2_bands * 2 26 | self.s2_stream = UNet(cfg, n_channels=s2_in, n_classes=out, topology=topology, enable_outc=False) 27 | self.n_s2_bands = n_s2_bands 28 | 29 | # out block combining unet outputs 30 | out_dim = 2 * topology[0] 31 | self.out_conv = OutConv(out_dim, out) 32 | 33 | def forward(self, t1_img, t2_img): 34 | 35 | s1_t1, s2_t1 = torch.split(t1_img, [self.n_s1_bands, self.n_s2_bands], dim=1) 36 | s1_t2, s2_t2 = torch.split(t2_img, [self.n_s1_bands, self.n_s2_bands], dim=1) 37 | 38 | s1_feature = self.s1_stream(s1_t1, s1_t2) 39 | s2_feature = self.s2_stream(s2_t1, s2_t2) 40 | 41 | fusion = torch.cat((s1_feature, s2_feature), dim=1) 42 | 43 | out = self.out_conv(fusion) 44 | 45 | return out 46 | 47 | 48 | class UNet(nn.Module): 49 | def __init__(self, cfg, n_channels=None, n_classes=None, topology=None, enable_outc=True): 50 | 51 | self._cfg = cfg 52 | 53 | n_channels = cfg.MODEL.IN_CHANNELS if n_channels is None else n_channels 54 | n_classes = cfg.MODEL.OUT_CHANNELS if n_classes is None else n_classes 55 | topology = cfg.MODEL.TOPOLOGY if topology is None else topology 56 | 57 | super(UNet, self).__init__() 58 | 59 | first_chan = topology[0] 60 | self.inc = InConv(n_channels, first_chan, DoubleConv) 61 | self.enable_outc = enable_outc 62 | self.outc = OutConv(first_chan, n_classes) 63 | 64 | # Variable scale 65 | down_topo = topology 66 | down_dict = OrderedDict() 67 | n_layers = len(down_topo) 68 | up_topo = [first_chan] # topography upwards 69 | up_dict = OrderedDict() 70 | 71 | # Downward layers 72 | for idx in range(n_layers): 73 | is_not_last_layer = idx != n_layers-1 74 | in_dim = down_topo[idx] 75 | out_dim = down_topo[idx+1] if is_not_last_layer else down_topo[idx] # last layer 76 | 77 | layer = Down(in_dim, out_dim, DoubleConv) 78 | 79 | print(f'down{idx+1}: in {in_dim}, out {out_dim}') 80 | down_dict[f'down{idx+1}'] = layer 81 | up_topo.append(out_dim) 82 | self.down_seq = nn.ModuleDict(down_dict) 83 | 84 | # Upward layers 85 | for idx in reversed(range(n_layers)): 86 | is_not_last_layer = idx != 0 87 | x1_idx = idx 88 | x2_idx = idx - 1 if is_not_last_layer else idx 89 | in_dim = up_topo[x1_idx] * 2 90 | out_dim = up_topo[x2_idx] 91 | 92 | layer = Up(in_dim, out_dim, DoubleConv) 93 | 94 | print(f'up{idx+1}: in {in_dim}, out {out_dim}') 95 | up_dict[f'up{idx+1}'] = layer 96 | 97 | self.up_seq = nn.ModuleDict(up_dict) 98 | 99 | def forward(self, x1, x2=None): 100 | x = x1 if x2 is None else torch.cat((x1, x2), 1) 101 | 102 | x1 = self.inc(x) 103 | 104 | inputs = [x1] 105 | # Downward U: 106 | for layer in self.down_seq.values(): 107 | out = layer(inputs[-1]) 108 | inputs.append(out) 109 | 110 | # Upward U: 111 | inputs.reverse() 112 | x1 = inputs.pop(0) 113 | for idx, layer in enumerate(self.up_seq.values()): 114 | x2 = inputs[idx] 115 | x1 = layer(x1, x2) # x1 for next up layer 116 | 117 | out = self.outc(x1) if self.enable_outc else x1 118 | 119 | return out 120 | 121 | 122 | # sub-parts of the U-Net model 123 | class DoubleConv(nn.Module): 124 | '''(conv => BN => ReLU) * 2''' 125 | 126 | def __init__(self, in_ch, out_ch): 127 | super(DoubleConv, self).__init__() 128 | self.conv = nn.Sequential( 129 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 130 | nn.BatchNorm2d(out_ch), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 133 | nn.BatchNorm2d(out_ch), 134 | nn.ReLU(inplace=True) 135 | ) 136 | 137 | def forward(self, x): 138 | x = self.conv(x) 139 | return x 140 | 141 | 142 | class InConv(nn.Module): 143 | def __init__(self, in_ch, out_ch, conv_block): 144 | super(InConv, self).__init__() 145 | self.conv = conv_block(in_ch, out_ch) 146 | 147 | def forward(self, x): 148 | x = self.conv(x) 149 | return x 150 | 151 | 152 | class Down(nn.Module): 153 | def __init__(self, in_ch, out_ch, conv_block): 154 | super(Down, self).__init__() 155 | 156 | self.mpconv = nn.Sequential( 157 | nn.MaxPool2d(2), 158 | conv_block(in_ch, out_ch) 159 | ) 160 | 161 | def forward(self, x): 162 | x = self.mpconv(x) 163 | return x 164 | 165 | 166 | class Up(nn.Module): 167 | def __init__(self, in_ch, out_ch, conv_block): 168 | super(Up, self).__init__() 169 | 170 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 171 | self.conv = conv_block(in_ch, out_ch) 172 | 173 | def forward(self, x1, x2): 174 | x1 = self.up(x1) 175 | 176 | # input is CHW 177 | diffY = x2.detach().size()[2] - x1.detach().size()[2] 178 | diffX = x2.detach().size()[3] - x1.detach().size()[3] 179 | 180 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) 181 | 182 | # for padding issues, see 183 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 184 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 185 | 186 | x = torch.cat([x2, x1], dim=1) 187 | x = self.conv(x) 188 | return x 189 | 190 | 191 | class OutConv(nn.Module): 192 | def __init__(self, in_ch, out_ch): 193 | super(OutConv, self).__init__() 194 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 195 | 196 | def forward(self, x): 197 | x = self.conv(x) 198 | return x 199 | -------------------------------------------------------------------------------- /networks/papadomanolaki2019.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from torch.autograd import Variable 6 | import tools 7 | 8 | class conv_block(nn.Module): 9 | def __init__(self,ch_in,ch_out): 10 | super(conv_block,self).__init__() 11 | self.conv = nn.Sequential( 12 | nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 13 | nn.BatchNorm2d(ch_out), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 16 | nn.BatchNorm2d(ch_out), 17 | nn.ReLU(inplace=True) 18 | ) 19 | 20 | 21 | def forward(self,x): 22 | x = self.conv(x) 23 | return x 24 | 25 | 26 | class up_conv(nn.Module): 27 | def __init__(self,ch_in,ch_out): 28 | super(up_conv,self).__init__() 29 | self.up = nn.Sequential( 30 | nn.Upsample(scale_factor=2), 31 | nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 32 | nn.BatchNorm2d(ch_out), 33 | nn.ReLU(inplace=True) 34 | ) 35 | 36 | def forward(self,x): 37 | x = self.up(x) 38 | return x 39 | 40 | 41 | class RNNCell(nn.Module): 42 | def __init__(self, input_size, hidden_size): 43 | super(RNNCell, self).__init__() 44 | self.input_size = input_size 45 | self.hidden_size = hidden_size 46 | self.in_gate = nn.Conv2d(input_size + hidden_size, hidden_size, 3, 1, 1) 47 | self.forget_gate = nn.Conv2d(input_size + hidden_size, hidden_size, 3, 1, 1) 48 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, 3, 1, 1) 49 | self.cell_gate = nn.Conv2d(input_size + hidden_size, hidden_size, 3, 1, 1) 50 | 51 | def forward(self, input, h_state, c_state): 52 | 53 | conc_inputs = torch.cat( (input, h_state), 1) 54 | 55 | in_gate = self.in_gate(conc_inputs) 56 | forget_gate = self.forget_gate(conc_inputs) 57 | out_gate = self.out_gate(conc_inputs) 58 | cell_gate = self.cell_gate(conc_inputs) 59 | 60 | in_gate = torch.sigmoid(in_gate) 61 | forget_gate = torch.sigmoid(forget_gate) 62 | out_gate = torch.sigmoid(out_gate) 63 | cell_gate = torch.tanh(cell_gate) 64 | 65 | c_state = (forget_gate * c_state) + (in_gate * cell_gate) 66 | h_state = out_gate * torch.tanh(c_state) 67 | 68 | return h_state, c_state 69 | 70 | 71 | class set_values(nn.Module): 72 | def __init__(self, hidden_size, height, width): 73 | super(set_values, self).__init__() 74 | self.hidden_size=hidden_size 75 | self.height=height 76 | self.width=width 77 | self.dropout = nn.Dropout(0.7) 78 | self.RCell = RNNCell(self.hidden_size, self.hidden_size) 79 | 80 | 81 | def forward(self, seq, xinp): 82 | xout = tools.to_cuda(Variable(torch.zeros(xinp.size()[0], xinp.size()[1], self.hidden_size, self.height, self.width))) 83 | 84 | h_state, c_state = ( tools.to_cuda(Variable(torch.zeros(xinp[0].shape[0], self.hidden_size, self.height, self.width))), 85 | tools.to_cuda(Variable(torch.zeros(xinp[0].shape[0], self.hidden_size, self.height, self.width))) ) 86 | 87 | for t in range(xinp.size()[0]): 88 | input_t = seq(xinp[t]) 89 | xout[t] = input_t 90 | h_state, c_state = self.RCell(input_t, h_state, c_state) 91 | 92 | return self.dropout(h_state), xout 93 | 94 | 95 | class U_Net(nn.Module): 96 | def __init__(self,img_ch, output_ch, patch_size): 97 | super(U_Net,self).__init__() 98 | 99 | self.patch_size = patch_size 100 | self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2) 101 | 102 | self.Conv1 = conv_block(ch_in=img_ch,ch_out=16) 103 | self.set1 = set_values(16, self.patch_size, self.patch_size) 104 | 105 | self.Conv2 = conv_block(ch_in=16,ch_out=32) 106 | self.set2 = set_values(32, self.patch_size/2, self.patch_size/2) 107 | 108 | self.Conv3 = conv_block(ch_in=32,ch_out=64) 109 | self.set3 = set_values(64, self.patch_size/4, self.patch_size/4) 110 | 111 | self.Conv4 = conv_block(ch_in=64,ch_out=128) 112 | self.set4 = set_values(128, self.patch_size/8, self.patch_size/8) 113 | 114 | self.Conv5 = conv_block(ch_in=128,ch_out=256) 115 | self.set5 = set_values(256, self.patch_size/16, self.patch_size/16) 116 | 117 | self.Up5 = up_conv(ch_in=256,ch_out=128) 118 | self.Up_conv5 = conv_block(ch_in=256, ch_out=128) 119 | 120 | self.Up4 = up_conv(ch_in=128,ch_out=64) 121 | self.Up_conv4 = conv_block(ch_in=128, ch_out=64) 122 | 123 | self.Up3 = up_conv(ch_in=64,ch_out=32) 124 | self.Up_conv3 = conv_block(ch_in=64, ch_out=32) 125 | 126 | self.Up2 = up_conv(ch_in=32,ch_out=16) 127 | self.Up_conv2 = conv_block(ch_in=32, ch_out=16) 128 | 129 | self.Conv_1x1 = nn.Conv2d(16,output_ch,kernel_size=1,stride=1,padding=0) 130 | 131 | 132 | def encoder(self, x): 133 | x1, xout = self.set1(self.Conv1, x) 134 | 135 | x2, xout = self.set2( nn.Sequential(self.Maxpool, self.Conv2), xout) 136 | 137 | x3, xout = self.set3( nn.Sequential(self.Maxpool, self.Conv3), xout) 138 | 139 | x4, xout = self.set4( nn.Sequential(self.Maxpool, self.Conv4), xout) 140 | 141 | x5, xout = self.set5( nn.Sequential(self.Maxpool, self.Conv5), xout) 142 | 143 | return x1,x2,x3,x4,x5 144 | 145 | def forward(self,input): 146 | #encoding path 147 | x1,x2,x3,x4,x5 = self.encoder(input) 148 | 149 | # decoding + concat path 150 | d5 = self.Up5(x5) 151 | d5 = torch.cat((d5,x4),dim=1) 152 | d5 = self.Up_conv5(d5) 153 | 154 | d4 = self.Up4(d5) 155 | d4 = torch.cat((d4,x3),dim=1) 156 | d4 = self.Up_conv4(d4) 157 | 158 | d3 = self.Up3(d4) 159 | d3 = torch.cat((d3,x2),dim=1) 160 | d3 = self.Up_conv3(d3) 161 | 162 | d2 = self.Up2(d3) 163 | d2 = torch.cat((d2,x1),dim=1) 164 | d2 = self.Up_conv2(d2) 165 | 166 | d1 = self.Conv_1x1(d2) 167 | 168 | return d1 169 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from skimage import io 3 | import numpy as np 4 | from pathlib import Path 5 | 6 | 7 | def stretch_8bit(band, lower_percent=2, higher_percent=98): 8 | a = 0 9 | b = 255 10 | real_values = band.flatten() 11 | real_values = real_values[real_values > 0] 12 | c = np.percentile(real_values, lower_percent) 13 | d = np.percentile(real_values, higher_percent) 14 | t = a + (band - c) * (b - a) / float(d - c) 15 | t[t < a] = a 16 | t[t > b] = b 17 | return t.astype(np.uint8) / 255. 18 | 19 | 20 | def histogram_match(source, reference, match_proportion=1.0): 21 | orig_shape = source.shape 22 | source = source.ravel() 23 | 24 | if np.ma.is_masked(reference): 25 | reference = reference.compressed() 26 | else: 27 | reference = reference.ravel() 28 | 29 | s_values, s_idx, s_counts = np.unique( 30 | source, return_inverse=True, return_counts=True) 31 | r_values, r_counts = np.unique(reference, return_counts=True) 32 | s_size = source.size 33 | 34 | if np.ma.is_masked(source): 35 | mask_index = np.ma.where(s_values.mask) 36 | s_size = np.ma.where(s_idx != mask_index[0])[0].size 37 | s_values = s_values.compressed() 38 | s_counts = np.delete(s_counts, mask_index) 39 | 40 | s_quantiles = np.cumsum(s_counts).astype(np.float64) / s_size 41 | r_quantiles = np.cumsum(r_counts).astype(np.float64) / reference.size 42 | 43 | interp_r_values = np.interp(s_quantiles, r_quantiles, r_values) 44 | 45 | if np.ma.is_masked(source): 46 | interp_r_values = np.insert(interp_r_values, mask_index[0], source.fill_value) 47 | 48 | target = interp_r_values[s_idx] 49 | 50 | if match_proportion is not None and match_proportion != 1: 51 | diff = source - target 52 | target = source - (diff * match_proportion) 53 | 54 | if np.ma.is_masked(source): 55 | target = np.ma.masked_where(s_idx == mask_index[0], target) 56 | target.fill_value = source.fill_value 57 | 58 | return target.reshape(orig_shape) 59 | 60 | 61 | if __name__ == '__main__': 62 | 63 | FOLDER = Path('C:/Users/hafne/urban_change_detection/data/Onera/') 64 | 65 | IMG_FOLDER = FOLDER / 'images' 66 | # folder of the form ./IMGS_PREPROCESSED/abudhabi/imgs_1/..(13 tif 2D images of sentinel channels).. 67 | # ./IMGS_PREPROCESSED/abudhabi/imgs_2/..(13 tif 2D images of sentinel channels).. 68 | # .... 69 | # ./IMGS_PREPROCESSED/abudhabi/imgs_n/..(13 tif 2D images of sentinel channels).. 70 | # where n = number of dates 71 | 72 | # here you specify which dates you want to use 73 | nb_dates = [1, 2] 74 | channels = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 75 | 76 | # all areas of the OSCD dataset 77 | all_areas = ['abudhabi', 'aguasclaras', 'beihai', 'beirut', 'bercy', 'bordeaux', 'brasilia', 'chongqing', 78 | 'cupertino', 'dubai', 'hongkong', 'lasvegas', 'milano', 'montpellier', 'mumbai', 'nantes', 79 | 'norcia', 'paris', 'pisa', 'rennes', 'rio', 'saclay_e', 'saclay_w', 'valencia'] 80 | 81 | DESTINATION_FOLDER = FOLDER / 'images_preprocessed' 82 | if not DESTINATION_FOLDER.exists(): 83 | DESTINATION_FOLDER.mkdir() 84 | 85 | for i_path in all_areas: 86 | print(i_path) 87 | 88 | date_folders = [] 89 | for nd in nb_dates: 90 | temp_path = IMG_FOLDER / i_path / f'imgs_{nd}' 91 | files = [file for file in temp_path.glob('**/*')] 92 | date_folders.append(files) 93 | 94 | # B02 channel has the same dimensions with the groundtruth for the labeled images. 95 | # So we keep it to reshape the rest of the channels for both labeled images and nonlabeled images 96 | gts = [s for s in date_folders[0] if 'B02' in str(s)] 97 | gts = io.imread(gts[0]) 98 | 99 | temp_path = DESTINATION_FOLDER / i_path 100 | if not temp_path.exists(): 101 | temp_path.mkdir() 102 | 103 | for nd in nb_dates: 104 | print('date', nd) 105 | imgs = [] 106 | if nd == 1: 107 | for ch in channels: 108 | im = [s for s in date_folders[nd-1] if ch in str(s)] 109 | im = io.imread(im[0]) 110 | im[im > 5500] = 5500 111 | im = stretch_8bit(im) 112 | im = cv2.resize(im, (gts.shape[1], gts.shape[0])) 113 | im = np.reshape(im, (gts.shape[0], gts.shape[1], 1)) 114 | imgs.append(im) 115 | imgs0 = imgs 116 | else: 117 | 118 | for ch in channels: 119 | im = [s for s in date_folders[nd-1] if ch in str(s)] 120 | im = io.imread(im[0]) 121 | im[im > 5500] = 5500 122 | im = stretch_8bit(im) 123 | im = histogram_match(im, imgs0[channels.index(ch)]) 124 | im = cv2.resize(im, (gts.shape[1], gts.shape[0])) 125 | im = np.reshape(im, (gts.shape[0], gts.shape[1], 1)) 126 | imgs.append(im) 127 | 128 | im_merge = np.stack(imgs, axis=2) 129 | im_merge = np.asarray(im_merge) 130 | im_merge = np.reshape(im_merge, (im_merge.shape[0], im_merge.shape[1], im_merge.shape[2])) 131 | 132 | im_file = DESTINATION_FOLDER / i_path / f'{i_path}_{nd}.npy' 133 | np.save(str(im_file), im_merge) 134 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pathlib import Path 3 | import utils 4 | import cv2 5 | import tifffile 6 | 7 | 8 | def get_band(file: Path) -> str: 9 | return file.stem.split('_')[-1] 10 | 11 | 12 | def combine_bands(folder: Path) -> tuple: 13 | 14 | bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'] 15 | n_bands = len(bands) 16 | 17 | # using blue band as reference (10 m) to create img 18 | blue_file = folder / 'B02.tif' 19 | blue = tifffile.imread(str(blue_file)) 20 | img = np.ndarray((*blue.shape, n_bands), dtype=np.float32) 21 | 22 | for i, band in enumerate(bands): 23 | band_file = folder / f'{band}.tif' 24 | arr = tifffile.imread(str(band_file)) 25 | band_h, band_w = arr.shape 26 | 27 | # up-sample 20 m bands 28 | # arr = cv2.resize(arr, (w, h), interpolation=cv2.INTER_CUBIC) 29 | 30 | # rescaling image to [0, 1] 31 | arr = np.clip(arr / 10000, a_min=0, a_max=1) 32 | img[:, :, i] = arr 33 | 34 | return img 35 | 36 | 37 | def process_city(img_folder: Path, label_folder: Path, city: str, new_root: Path) -> None: 38 | 39 | print(city) 40 | 41 | new_parent = new_root / city 42 | new_parent.mkdir(exist_ok=True) 43 | 44 | # image data 45 | for t in [1, 2]: 46 | 47 | # get data 48 | from_folder = img_folder / city / f'imgs_{t}_rect' 49 | img = combine_bands(from_folder) 50 | 51 | # save data 52 | to_folder = new_parent / 'sentinel2' 53 | to_folder.mkdir(exist_ok=True) 54 | 55 | save_file = to_folder / f'sentinel2_{city}_t{t}.npy' 56 | np.save(save_file, img) 57 | 58 | from_label_file = label_folder / city / 'cm' / f'{city}-cm.tif' 59 | label = tifffile.imread(str(from_label_file)) 60 | label = label - 1 61 | 62 | to_label_file = new_parent / 'label' / f'urbanchange_{city}.npy' 63 | to_label_file.parent.mkdir(exist_ok=True) 64 | np.save(to_label_file, label) 65 | 66 | 67 | def add_sentinel1(s1_folder: Path, label_folder: Path, city: str, orbit: int, new_root: Path): 68 | 69 | label_file = label_folder / city / 'cm' / f'{city}-cm.tif' 70 | label = tifffile.imread(str(label_file)) 71 | h, w = label.shape 72 | 73 | for t in [1, 2]: 74 | s1_file = s1_folder / f'sentinel1_{city}_{orbit}_t{t}.tif' 75 | img = tifffile.imread(str(s1_file)) 76 | 77 | img = cv2.resize(img, (w, h), interpolation=cv2.INTER_CUBIC) 78 | img = img[:, :, None] 79 | 80 | # save data 81 | to_folder = new_root / city / 'sentinel1' 82 | to_folder.mkdir(exist_ok=True) 83 | 84 | save_file = to_folder / f'sentinel1_{city}_{orbit}_t{t}.npy' 85 | np.save(save_file, img) 86 | 87 | 88 | 89 | if __name__ == '__main__': 90 | # assume unchanged OSCD dataset 91 | IMG_FOLDER = Path('/storage/shafner/urban_change_detection/OSCD_dataset/images/') 92 | LABEL_FOLDER = Path('/storage/shafner/urban_change_detection/OSCD_dataset/labels/') 93 | NEW_ROOT = Path('/storage/shafner/urban_change_detection/OSCD_dataset/preprocessed') 94 | S1_FOLDER = Path('/storage/shafner/urban_change_detection/OSCD_dataset/sentinel1') 95 | 96 | CITIES = ['aguasclaras', 'bercy', 'bordeaux', 'nantes', 'paris', 'rennes', 'saclay_e', 'abudhabi', 'cupertino', 97 | 'pisa', 'beihai', 'hongkong', 'beirut', 'mumbai', 'brasilia', 'montpellier', 'norcia', 'rio', 'saclay_w', 98 | 'valencia', 'dubai', 'lasvegas', 'milano', 'chongqing'] 99 | 100 | ORBITS = { 101 | 'aguasclaras': [24], 102 | 'bercy': [59, 8, 110], 103 | 'bordeaux': [30, 8, 81], 104 | 'nantes': [30, 81], 105 | 'paris': [59, 8, 110], 106 | 'rennes': [30, 81], 107 | 'saclay_e': [59, 8], 108 | 'abudhabi': [130], 109 | 'cupertino': [35, 115, 42], 110 | 'pisa': [15, 168], 111 | 'beihai': [157], 112 | 'hongkong': [11, 113], 113 | 'beirut': [14, 87], 114 | 'mumbai': [34], 115 | 'brasilia': [24], 116 | 'montpellier': [59, 37], 117 | 'norcia': [117, 44, 22, 95], 118 | 'rio': [155], 119 | 'saclay_w': [59, 8, 110], 120 | 'valencia': [30, 103, 8, 110], 121 | 'dubai': [130, 166], 122 | 'lasvegas': [166, 173], 123 | 'milano': [66, 168], 124 | 'chongqing': [55, 164] 125 | } 126 | 127 | for city in CITIES: 128 | # process_city(IMG_FOLDER, LABEL_FOLDER, city, NEW_ROOT) 129 | orbits = ORBITS[city] 130 | for orbit in orbits: 131 | add_sentinel1(S1_FOLDER, LABEL_FOLDER, city, orbit, NEW_ROOT) 132 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import numpy as np 10 | import torchnet as tnt 11 | import cv2 12 | 13 | USE_CUDA = torch.cuda.is_available() 14 | DEVICE = 0 15 | def to_cuda(v): 16 | if USE_CUDA: 17 | return v.cuda(DEVICE) 18 | return v 19 | 20 | def accuracy(input, target): 21 | return 100 * float(np.count_nonzero(input == target)) / target.size 22 | 23 | def conf_m(output, target_th): 24 | 25 | output_conf=((output.data.squeeze()).transpose(1,3)).transpose(1,2) 26 | output_conf=(output_conf.contiguous()).view(output_conf.size(0)*output_conf.size(1)*output_conf.size(2), output_conf.size(3)) 27 | target_conf=target_th.data.squeeze() 28 | target_conf=(target_conf.contiguous()).view(target_conf.size(0)*target_conf.size(1)*target_conf.size(2)) 29 | return output_conf, target_conf 30 | 31 | def write_results(ff, save_folder, epoch, train_acc, test_acc, change_acc, non_ch, train_losses, val_losses): 32 | ff=open('./' + save_folder + '/progress_run.txt','a') 33 | ff.write('train: ') 34 | ff.write(str('%.3f' % train_acc)) 35 | ff.write(' ') 36 | ff.write(' val: ') 37 | ff.write(str('%.3f' % test_acc)) 38 | ff.write(' ') 39 | ff.write(' CHANGE: ') 40 | ff.write(str('%.3f' % change_acc)) 41 | ff.write(' ') 42 | ff.write(' NON_CHANGE: ') 43 | ff.write(str('%.3f' % non_ch)) 44 | ff.write(' ') 45 | ff.write(' E: ') 46 | ff.write(str(epoch)) 47 | ff.write(' ') 48 | ff.write(' TRAIN_LOSS: ') 49 | ff.write(str('%.3f' % train_losses)) 50 | ff.write(' VAL_LOSS: ') 51 | ff.write(str('%.3f' % val_losses)) 52 | ff.write('\n') 53 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from tqdm import tqdm 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import numpy as np 9 | import torchnet as tnt 10 | import tools 11 | from networks import network, networkL 12 | import custom 13 | from torch.utils.data import DataLoader 14 | from pathlib import Path 15 | 16 | 17 | if __name__ == '__main__': 18 | 19 | train_areas = ['abudhabi', 'beihai', 'aguasclaras', 'beirut', 'bercy', 'bordeaux', 'cupertino', 20 | 'hongkong', 'mumbai', 'nantes', 'rennes', 'saclay_e', 'pisa', 'rennes'] 21 | 22 | # FOLDER = Path('C:/Users/hafne/urban_change_detection/data/Onera/') 23 | FOLDER = Path('/storage/shafner/urban_change_detection/Onera/') 24 | 25 | 26 | csv_file_train = FOLDER / 'myxys_train.csv' 27 | csv_file_val = FOLDER / 'myxys_val.csv' 28 | img_folder = FOLDER / 'images_preprocessed' # folder with preprocessed images according to preprocess.py 29 | lbl_folder = FOLDER / 'labels' # folder with OSCD dataset's labels 30 | save_folder = FOLDER / 'save_models' 31 | save_folder.mkdir(exist_ok=True) 32 | 33 | patch_size = 32 34 | 35 | # specify the number of dates you want to use, e.g put [1,2,3,4,5] if you want to use all five dates 36 | # or [1,2,5] to use just three of them 37 | nb_dates = [1, 2] 38 | 39 | # setting device on GPU if available, else CPU 40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | print('Using device:', device) 42 | 43 | 44 | model_type = 'simple' #choose network type ('simple' or 'lstm') 45 | #'simple' refers to a simple U-Net while 'lstm' refers to a U-Net involving LSTM blocks 46 | 47 | model_type = 'simple' #choose network type ('simple' or 'lstm') 48 | #'simple' refers to a simple U-Net while 'lstm' refers to a U-Net involving LSTM blocks 49 | if model_type == 'simple': 50 | net = network.U_Net(4, 2, nb_dates) 51 | elif model_type == 'lstm': 52 | net = networkL.U_Net(4, 2, patch_size) 53 | else: 54 | net = None 55 | print('invalid on_network_argument') 56 | 57 | model = tools.to_cuda(net) 58 | 59 | change_dataset_train = custom.MyDataset(csv_file_train, train_areas, img_folder, lbl_folder, nb_dates, patch_size) 60 | change_dataset_val = custom.MyDataset(csv_file_val, train_areas, img_folder, lbl_folder, nb_dates, patch_size) 61 | mydataset_val = DataLoader(change_dataset_val, batch_size=32) 62 | 63 | # images_train, labels_train, images_val, labels_val = tools.make_data(size_len, portion, change_dataset) 64 | base_lr = 0.0001 65 | optimizer = optim.Adam(model.parameters(), lr=base_lr) 66 | weight_tensor = torch.FloatTensor(2) 67 | weight_tensor[0] = 0.20 68 | weight_tensor[1] = 0.80 69 | criterion = tools.to_cuda(nn.CrossEntropyLoss(tools.to_cuda(weight_tensor))) 70 | confusion_matrix = tnt.meter.ConfusionMeter(2, normalized=True) 71 | epochs = 60 72 | 73 | save_file = save_folder / 'progress_L2D.txt' 74 | save_file.touch(exist_ok=True) 75 | # ff = open(save_file, 'w') 76 | iter_ = 0 77 | for epoch in range(1, epochs + 1): 78 | mydataset = DataLoader(change_dataset_train, batch_size=32, shuffle=True) 79 | model.train() 80 | train_losses = [] 81 | confusion_matrix.reset() 82 | 83 | for i, batch, in enumerate(mydataset): 84 | img_batch, lbl_batch = batch 85 | img_batch, lbl_batch = tools.to_cuda(img_batch.permute(1, 0, 2, 3, 4)), tools.to_cuda(lbl_batch) 86 | 87 | optimizer.zero_grad() 88 | output = model(img_batch.float()) 89 | output_conf, target_conf = tools.conf_m(output, lbl_batch) 90 | confusion_matrix.add(output_conf, target_conf) 91 | 92 | loss = criterion(output, lbl_batch.long()) 93 | train_losses.append(loss.item()) 94 | loss.backward() 95 | optimizer.step() 96 | 97 | del(img_batch, lbl_batch, loss) 98 | 99 | train_acc = (np.trace(confusion_matrix.conf) / float(np.ndarray.sum(confusion_matrix.conf))) * 100 100 | print(f'train loss: {np.mean(train_losses):.3f}, train acc: {train_acc:.3f}') 101 | confusion_matrix.reset() 102 | # end of epoch 103 | 104 | # VALIDATION 105 | with torch.no_grad(): 106 | model.eval() 107 | 108 | val_losses = [] 109 | print(len(mydataset_val)) 110 | 111 | for i, batch, in enumerate(mydataset_val): 112 | # TODO: maybe fix this (last batch does not work) 113 | if i < (len(mydataset_val) - 1): 114 | img_batch, lbl_batch = batch 115 | img_batch, lbl_batch = tools.to_cuda(img_batch.permute(1, 0, 2, 3, 4)), tools.to_cuda(lbl_batch) 116 | 117 | output = model(img_batch.float()) 118 | loss = criterion(output, lbl_batch.long()) 119 | val_losses.append(loss.item()) 120 | output_conf, target_conf = tools.conf_m(output, lbl_batch) 121 | confusion_matrix.add(output_conf, target_conf) 122 | 123 | print(confusion_matrix.conf) 124 | test_acc = (np.trace(confusion_matrix.conf) / float(np.ndarray.sum(confusion_matrix.conf))) * 100 125 | change_acc = confusion_matrix.conf[1, 1] / float(confusion_matrix.conf[1, 0] + confusion_matrix.conf[1, 1]) * 100 126 | non_ch = confusion_matrix.conf[0, 0] / float(confusion_matrix.conf[0, 0]+confusion_matrix.conf[0, 1]) * 100 127 | print(f'val loss: {np.mean(val_losses):.3f}, val acc: {test_acc:.3f}') 128 | print(f'Non_ch_Acc: {non_ch:.3f}, Change_Accuracy: {change_acc:.3f}') 129 | confusion_matrix.reset() 130 | 131 | # tools.write_results(ff, save_folder, epoch, train_acc, test_acc, change_acc, non_ch, np.mean(train_losses), np.mean(val_losses)) 132 | if epoch % 5 == 0: # save model every 5 epochs 133 | model_file = save_folder / f'model_{epoch}.pt' 134 | # torch.save(model.state_dict(), model_file) 135 | -------------------------------------------------------------------------------- /train_network.py: -------------------------------------------------------------------------------- 1 | # general modules 2 | import json 3 | import sys 4 | import os 5 | import numpy as np 6 | from pathlib import Path 7 | 8 | # learning framework 9 | import torch 10 | from torch.utils import data as torch_data 11 | from torch.nn import functional as F 12 | from torchvision import transforms 13 | 14 | # config for experiments 15 | from experiment_manager import args 16 | from experiment_manager.config import config 17 | 18 | # custom stuff 19 | import augmentations as aug 20 | import evaluation_metrics as eval 21 | import loss_functions as lf 22 | import datasets 23 | 24 | # networks from papers and ours 25 | from networks.network_loader import load_network 26 | 27 | # logging 28 | import wandb 29 | 30 | 31 | def setup(args): 32 | cfg = config.new_config() 33 | cfg.merge_from_file(f'configs/{args.config_file}.yaml') 34 | cfg.merge_from_list(args.opts) 35 | cfg.NAME = args.config_file 36 | return cfg 37 | 38 | 39 | def train(net, cfg): 40 | 41 | # setting device on GPU if available, else CPU 42 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 43 | print('Using device:', device) 44 | 45 | net.to(device) 46 | 47 | if cfg.TRAINER.OPTIMIZER == 'adam': 48 | optimizer = torch.optim.Adam(net.parameters(), lr=cfg.TRAINER.LR, weight_decay=0.0005) 49 | else: 50 | optimizer = torch.optim.SGD(net.parameters(), lr=cfg.TRAINER.LR, momentum=0.9) 51 | 52 | # loss functions 53 | if cfg.MODEL.LOSS_TYPE == 'BCEWithLogitsLoss': 54 | criterion = torch.nn.BCEWithLogitsLoss() 55 | elif cfg.MODEL.LOSS_TYPE == 'WeightedBCEWithLogitsLoss': 56 | positive_weight = torch.tensor([cfg.MODEL.POSITIVE_WEIGHT]).float().to(device) 57 | criterion = torch.nn.BCEWithLogitsLoss(pos_weight=positive_weight) 58 | elif cfg.MODEL.LOSS_TYPE == 'SoftDiceLoss': 59 | criterion = lf.soft_dice_loss 60 | elif cfg.MODEL.LOSS_TYPE == 'SoftDiceBalancedLoss': 61 | criterion = lf.soft_dice_loss_balanced 62 | elif cfg.MODEL.LOSS_TYPE == 'JaccardLikeLoss': 63 | criterion = lf.jaccard_like_loss 64 | elif cfg.MODEL.LOSS_TYPE == 'ComboLoss': 65 | criterion = lambda pred, gts: F.binary_cross_entropy_with_logits(pred, gts) + lf.soft_dice_loss(pred, gts) 66 | elif cfg.MODEL.LOSS_TYPE == 'WeightedComboLoss': 67 | criterion = lambda pred, gts: 2 * F.binary_cross_entropy_with_logits(pred, gts) + lf.soft_dice_loss(pred, gts) 68 | elif cfg.MODEL.LOSS_TYPE == 'FrankensteinLoss': 69 | criterion = lambda pred, gts: F.binary_cross_entropy_with_logits(pred, gts) + lf.jaccard_like_balanced_loss(pred, gts) 70 | elif cfg.MODEL.LOSS_TYPE == 'WeightedFrankensteinLoss': 71 | positive_weight = torch.tensor([cfg.MODEL.POSITIVE_WEIGHT]).float().to(device) 72 | criterion = lambda pred, gts: F.binary_cross_entropy_with_logits(pred, gts, pos_weight=positive_weight) + 5 * lf.jaccard_like_balanced_loss(pred, gts) 73 | else: 74 | criterion = lf.soft_dice_loss 75 | 76 | # reset the generators 77 | dataset = datasets.OSCDDataset(cfg, 'train') 78 | drop_last = True 79 | batch_size = cfg.TRAINER.BATCH_SIZE 80 | dataloader_kwargs = { 81 | 'batch_size': batch_size, 82 | 'num_workers': 0 if cfg.DEBUG else cfg.DATALOADER.NUM_WORKER, 83 | 'shuffle': cfg.DATALOADER.SHUFFLE, 84 | 'drop_last': drop_last, 85 | 'pin_memory': True, 86 | } 87 | if cfg.AUGMENTATION.OVERSAMPLING != 'none': 88 | dataloader_kwargs['sampler'] = dataset.sampler() 89 | dataloader_kwargs['shuffle'] = False 90 | 91 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs) 92 | 93 | save_path = Path(cfg.OUTPUT_BASE_DIR) / cfg.NAME 94 | save_path.mkdir(exist_ok=True) 95 | 96 | best_test_f1 = 0 97 | positive_pixels = 0 98 | pixels = 0 99 | global_step = 0 100 | epochs = cfg.TRAINER.EPOCHS 101 | batches = len(dataloader) // batch_size if drop_last else len(dataloader) // batch_size + 1 102 | for epoch in range(epochs): 103 | 104 | loss_tracker = 0 105 | net.train() 106 | 107 | for i, batch in enumerate(dataloader): 108 | 109 | t1_img = batch['t1_img'].to(device) 110 | t2_img = batch['t2_img'].to(device) 111 | 112 | label = batch['label'].to(device) 113 | 114 | optimizer.zero_grad() 115 | 116 | output = net(t1_img, t2_img) 117 | 118 | loss = criterion(output, label) 119 | loss_tracker += loss.item() 120 | loss.backward() 121 | optimizer.step() 122 | 123 | positive_pixels += torch.sum(label).item() 124 | pixels += torch.numel(label) 125 | 126 | global_step += 1 127 | 128 | if epoch % cfg.LOGGING == 0: 129 | print(f'epoch {epoch} / {cfg.TRAINER.EPOCHS}') 130 | 131 | # printing and logging loss 132 | avg_loss = loss_tracker / batches 133 | print(f'avg training loss {avg_loss:.5f}') 134 | 135 | # positive pixel ratio used to check oversampling 136 | if cfg.DEBUG: 137 | print(f'positive pixel ratio: {positive_pixels / pixels:.3f}') 138 | else: 139 | wandb.log({f'positive pixel ratio': positive_pixels / pixels}) 140 | positive_pixels = 0 141 | pixels = 0 142 | 143 | # model evaluation 144 | # train (different thresholds are tested) 145 | train_thresholds = torch.linspace(0, 1, 101).to(device) 146 | train_maxF1, train_maxTresh = model_evaluation(net, cfg, device, train_thresholds, run_type='train', 147 | epoch=epoch, step=global_step) 148 | # test (using the best training threshold) 149 | test_threshold = torch.tensor([train_maxTresh]) 150 | test_f1, _ = model_evaluation(net, cfg, device, test_threshold, run_type='test', epoch=epoch, 151 | step=global_step) 152 | 153 | if test_f1 > best_test_f1: 154 | print(f'BEST PERFORMANCE SO FAR! <--------------------', flush=True) 155 | best_test_f1 = test_f1 156 | 157 | if cfg.SAVE_MODEL and not cfg.DEBUG: 158 | print(f'saving network', flush=True) 159 | # model_file = save_path / 'best_net.pkl' 160 | # torch.save(net.state_dict(), model_file) 161 | 162 | if (epoch + 1) == 390: 163 | if cfg.SAVE_MODEL and not cfg.DEBUG: 164 | print(f'saving network', flush=True) 165 | model_file = save_path / f'final_net.pkl' 166 | torch.save(net.state_dict(), model_file) 167 | 168 | 169 | def model_evaluation(net, cfg, device, thresholds, run_type, epoch, step): 170 | 171 | thresholds = thresholds.to(device) 172 | y_true_set = [] 173 | y_pred_set = [] 174 | 175 | measurer = eval.MultiThresholdMetric(thresholds) 176 | 177 | dataset = datasets.OSCDDataset(cfg, run_type, no_augmentation=True) 178 | dataloader_kwargs = { 179 | 'batch_size': 1, 180 | 'num_workers': 0 if cfg.DEBUG else cfg.DATALOADER.NUM_WORKER, 181 | 'shuffle': cfg.DATALOADER.SHUFFLE, 182 | 'pin_memory': True, 183 | } 184 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs) 185 | 186 | with torch.no_grad(): 187 | net.eval() 188 | for step, batch in enumerate(dataloader): 189 | t1_img = batch['t1_img'].to(device) 190 | t2_img = batch['t2_img'].to(device) 191 | y_true = batch['label'].to(device) 192 | 193 | y_pred = net(t1_img, t2_img) 194 | 195 | y_pred = torch.sigmoid(y_pred) 196 | 197 | y_true = y_true.detach() 198 | y_pred = y_pred.detach() 199 | y_true_set.append(y_true.cpu()) 200 | y_pred_set.append(y_pred.cpu()) 201 | 202 | measurer.add_sample(y_true, y_pred) 203 | 204 | print(f'Computing {run_type} F1 score ', end=' ', flush=True) 205 | 206 | f1 = measurer.compute_f1() 207 | fpr, fnr = measurer.compute_basic_metrics() 208 | maxF1 = f1.max() 209 | argmaxF1 = f1.argmax() 210 | best_fpr = fpr[argmaxF1] 211 | best_fnr = fnr[argmaxF1] 212 | best_thresh = thresholds[argmaxF1] 213 | 214 | if not cfg.DEBUG: 215 | wandb.log({ 216 | f'{run_type} max F1': maxF1, 217 | f'{run_type} argmax F1': argmaxF1, 218 | f'{run_type} false positive rate': best_fpr, 219 | f'{run_type} false negative rate': best_fnr, 220 | 'step': step, 221 | 'epoch': epoch, 222 | }) 223 | 224 | print(f'{maxF1.item():.3f}', flush=True) 225 | 226 | return maxF1.item(), best_thresh.item() 227 | 228 | 229 | if __name__ == '__main__': 230 | 231 | # setting up config based on parsed argument 232 | parser = args.default_argument_parser() 233 | args = parser.parse_known_args()[0] 234 | cfg = setup(args) 235 | 236 | torch.manual_seed(cfg.SEED) 237 | np.random.seed(cfg.SEED) 238 | torch.backends.cudnn.deterministic = True 239 | torch.backends.cudnn.benchmark = False 240 | 241 | # loading network 242 | net = load_network(cfg) 243 | 244 | # tracking land with w&b 245 | if not cfg.DEBUG: 246 | wandb.init( 247 | name=cfg.NAME, 248 | project='urban_change_detection', 249 | tags=['run', 'change', 'detection', ], 250 | ) 251 | 252 | try: 253 | train(net, cfg) 254 | except KeyboardInterrupt: 255 | print('Training terminated') 256 | try: 257 | sys.exit(0) 258 | except SystemExit: 259 | os._exit(0) 260 | 261 | 262 | -------------------------------------------------------------------------------- /urban_extraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data as torch_data 3 | import numpy as np 4 | 5 | from networks.ours import UNet 6 | from experiment_manager.config import new_config 7 | import datasets 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | from pathlib import Path 12 | 13 | 14 | # loading cfg for inference 15 | def load_cfg(cfg_file: Path): 16 | cfg = new_config() 17 | cfg.merge_from_file(str(cfg_file)) 18 | return cfg 19 | 20 | 21 | # loading network for inference 22 | def load_net(cfg, net_file: Path, device): 23 | net = UNet(cfg) 24 | state_dict = torch.load(str(net_file), map_location=lambda storage, loc: storage) 25 | net.load_state_dict(state_dict) 26 | return net.to(device) 27 | 28 | 29 | def classify(img, net, threshold, return_numpy=True): 30 | y_logits = net(img) 31 | y_prob = torch.sigmoid(y_logits) 32 | y_pred = y_prob > threshold 33 | if return_numpy: 34 | return torch2numpy(y_pred, 'uint8'), torch2numpy(y_prob) 35 | return y_pred, y_prob 36 | 37 | 38 | def torch2numpy(tensor: torch.tensor, nptype: str = 'float32'): 39 | cpu_tensor = tensor.cpu().detach() 40 | arr = cpu_tensor.numpy().astype(nptype) 41 | if len(arr.shape) == 4: 42 | transpose = (0, 2, 3, 1) 43 | elif len(arr.shape) == 3: 44 | transpose = (1, 2, 0) 45 | else: 46 | transpose = (0, 1) 47 | arr = arr.transpose(transpose) 48 | return arr 49 | 50 | 51 | 52 | def visual_evaluation(net_cfg_file: Path, net_file: Path, ds_cfg_file: Path, dataset: str = 'test', 53 | save_path: Path = None): 54 | 55 | mode = 'cuda' if torch.cuda.is_available() else 'cpu' 56 | device = torch.device(mode) 57 | 58 | # loading network 59 | net_cfg = load_cfg(net_cfg_file) 60 | net = load_net(net_cfg, net_file, device) 61 | 62 | # loading dataset 63 | ds_cfg = load_cfg(ds_cfg_file) 64 | dataset = datasets.OSCDDataset(ds_cfg, dataset, no_augmentation=True) 65 | dataloader_kwargs = { 66 | 'batch_size': 1, 67 | 'num_workers': 0, 68 | 'shuffle': False, 69 | 'pin_memory': True, 70 | } 71 | dataloader = torch_data.DataLoader(dataset, **dataloader_kwargs) 72 | threshold = net_cfg.THRESH 73 | 74 | with torch.no_grad(): 75 | net.eval() 76 | for step, batch in enumerate(dataloader): 77 | 78 | fig, axs = plt.subplots(1, 4, figsize=(20, 10)) 79 | 80 | city = batch['city'][0] 81 | print(city) 82 | 83 | t1_img = batch['t1_img'].to(device) 84 | t2_img = batch['t2_img'].to(device) 85 | y_true = batch['label'].to(device) 86 | 87 | data = {'pred': [], 'prob': [], 'rgb': []} 88 | for i, img in enumerate([t1_img, t2_img]): 89 | img_arr = torch2numpy(img) 90 | y_pred, y_prob = classify(img, net, threshold, return_numpy=True) 91 | data['pred'].append(y_pred[0, :, :, 0]) 92 | data['prob'].append(y_prob[0, :, :, 0]) 93 | 94 | img_arr = img_arr[0, ...] 95 | rgb = img_arr[:, :, [2, 1, 0]] 96 | rgb = np.minimum(rgb / 0.3, 1) 97 | data['rgb'].append(rgb) 98 | 99 | axs[i].imshow(y_prob[0, :, :, 0], vmin=0, vmax=1) 100 | # axs[i*2+1].imshow(y_prob[0, :, :, 0]) 101 | 102 | label_arr = torch2numpy(y_true, 'uint8') 103 | axs[2].imshow(label_arr[0, :, :, 0]) 104 | di = data['prob'][1]-data['prob'][0] 105 | axs[3].imshow(di, vmin=0, vmax=1) 106 | 107 | for ax in axs: 108 | ax.set_axis_off() 109 | 110 | assert(save_path.exists()) 111 | file = save_path / f'urban_extraction_{city}.png' 112 | plt.savefig(file, dpi=300, bbox_inches='tight') 113 | plt.close() 114 | 115 | 116 | if __name__ == '__main__': 117 | 118 | save_path = Path('/storage/shafner/urban_change_detection/urban_extraction') 119 | 120 | # network 121 | ue_cfg = 'baseline_sentinel2' 122 | ue_cfg_path = Path('/home/shafner/urban_dl/configs/urban_extraction') 123 | ue_cfg_file = ue_cfg_path / f'{ue_cfg}.yaml' 124 | ue_net_path = Path('/storage/shafner/run_logs/unet') 125 | ue_net_file = ue_net_path / ue_cfg / 'best_net.pkl' 126 | 127 | # cfg 128 | ds_cfg = 'urban_extraction_loader' 129 | ds_cfg_file = Path.cwd() / 'configs' / f'{ds_cfg}.yaml' 130 | 131 | visual_evaluation(ue_cfg_file, ue_net_file, ds_cfg_file, dataset='test', save_path=save_path) 132 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import rasterio 3 | from pathlib import Path 4 | 5 | 6 | # reading in geotiff file as numpy array 7 | def read_tif(file: Path): 8 | 9 | if not file.exists(): 10 | raise FileNotFoundError(f'File {file} not found') 11 | 12 | with rasterio.open(file) as dataset: 13 | arr = dataset.read() # (bands X height X width) 14 | transform = dataset.transform 15 | crs = dataset.crs 16 | 17 | return arr.transpose((1, 2, 0)), transform, crs 18 | 19 | 20 | # writing an array to a geo tiff file 21 | def write_tif(file: Path, arr, transform, crs): 22 | 23 | if not file.parent.exists(): 24 | file.parent.mkdir() 25 | 26 | height, width, bands = arr.shape 27 | with rasterio.open( 28 | file, 29 | 'w', 30 | driver='GTiff', 31 | height=height, 32 | width=width, 33 | count=bands, 34 | dtype=arr.dtype, 35 | crs=crs, 36 | transform=transform, 37 | ) as dst: 38 | for i in range(bands): 39 | dst.write(arr[:, :, i], i + 1) 40 | 41 | 42 | def to_numpy(tensor:torch.Tensor): 43 | return tensor.cpu().detach().numpy() 44 | --------------------------------------------------------------------------------